Skip to main content

nautilus_serialization/arrow/
custom.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this code except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Custom data: registration and dynamic decoding.
17//!
18//! - **Registration:** Call [`ensure_custom_data_registered::<T>()`] once (e.g. before using the
19//!   catalog) for each custom data type `T` produced by the `#[custom_data]` macro. For Python
20//!   bindings, also call [`nautilus_model::data::register_rust_extractor::<T>()`].
21//! - **Decoder:** [`CustomDataDecoder`] provides [`ArrowSchemaProvider`] and
22//!   [`DecodeDataFromRecordBatch`] for Parquet-backed custom data decoded at runtime by type name.
23//!   Types must be registered via [`ensure_custom_data_registered::<T>()`] before use.
24
25use std::sync::Arc;
26
27use arrow::record_batch::RecordBatch;
28use nautilus_model::data::{
29    ArrowDecoder, ArrowEncoder, CustomData, CustomDataTrait, Data, DataType,
30    decode_custom_from_arrow, ensure_arrow_registered, ensure_custom_data_json_registered,
31    get_arrow_schema,
32};
33
34use super::{ArrowSchemaProvider, DecodeDataFromRecordBatch, EncodeToRecordBatch};
35
36/// Trait for custom data types that support Arrow schema and record batch encoding.
37/// Used as a type bound by the `#[custom_data]` macro; catalog encoding goes through
38/// the registry, not this trait directly.
39///
40/// Implemented by the `#[custom_data]` macro for Rust custom data types. Python custom
41/// types use the registry encoder registered by `register_custom_data_class` instead.
42pub trait CustomDataSerialize: CustomDataTrait {
43    /// Returns the Arrow schema for this custom data type.
44    ///
45    /// # Errors
46    /// Returns an error if schema construction fails.
47    fn schema(&self) -> anyhow::Result<arrow::datatypes::Schema>;
48
49    /// Encodes a batch of custom data items to an Arrow RecordBatch.
50    ///
51    /// # Errors
52    /// Returns an error if encoding fails (e.g. type mismatch or Arrow error).
53    fn encode_record_batch(
54        &self,
55        items: &[Arc<dyn CustomDataTrait>],
56    ) -> anyhow::Result<RecordBatch>;
57}
58
59/// Registers a custom data type in the JSON and Arrow registries. Call once per type
60/// (e.g. at catalog decode or before querying custom data).
61///
62/// Each distinct type `T` is registered at most once (per process). Safe to call
63/// multiple times for the same `T`.
64///
65/// For types exposed to Python, also call [`nautilus_model::data::register_rust_extractor::<T>()`].
66pub fn ensure_custom_data_registered<T>()
67where
68    T: CustomDataTrait
69        + ArrowSchemaProvider
70        + EncodeToRecordBatch
71        + DecodeDataFromRecordBatch
72        + Clone
73        + Send
74        + Sync
75        + 'static,
76{
77    let type_name = T::type_name_static();
78
79    // Skip if already registered
80    if get_arrow_schema(type_name).is_some() {
81        return;
82    }
83
84    let _ = ensure_custom_data_json_registered::<T>();
85
86    let schema = Arc::new(T::get_schema(None));
87
88    let encoder: ArrowEncoder = Box::new(|items: &[Arc<dyn CustomDataTrait>]| {
89        let typed: Result<Vec<T>, _> = items
90            .iter()
91            .map(|b| {
92                b.as_any()
93                    .downcast_ref::<T>()
94                    .cloned()
95                    .ok_or_else(|| anyhow::anyhow!("Expected {}", T::type_name_static()))
96            })
97            .collect();
98        let typed = typed?;
99        let metadata = typed
100            .first()
101            .map(EncodeToRecordBatch::metadata)
102            .unwrap_or_default();
103        EncodeToRecordBatch::encode_batch(&metadata, &typed).map_err(|e| anyhow::anyhow!("{e}"))
104    });
105
106    let decoder: ArrowDecoder = Box::new(|metadata, batch| {
107        T::decode_data_batch(metadata, batch).map_err(|e| anyhow::anyhow!("{e}"))
108    });
109
110    let _ = ensure_arrow_registered(type_name, schema, encoder, decoder);
111}
112
113/// Decoder for custom data types that are identified at runtime by metadata (e.g. `type_name`).
114///
115/// Only Rust-registered custom types (e.g. `RustTestCustomData`, `MacroYieldCurveData`) can be
116/// decoded. Unknown types return an error.
117///
118/// **Important:** The caller must ensure that any Rust custom data types are registered
119/// via [`ensure_custom_data_registered::<T>()`] before use.
120#[derive(Debug)]
121pub struct CustomDataDecoder;
122
123impl ArrowSchemaProvider for CustomDataDecoder {
124    fn get_schema(
125        metadata: Option<std::collections::HashMap<String, String>>,
126    ) -> arrow::datatypes::Schema {
127        if let Some(metadata) = metadata
128            && let Some(type_name) = metadata.get("type_name")
129            && let Some(schema) = get_arrow_schema(type_name)
130        {
131            return (*schema).clone();
132        }
133
134        // Unknown type - return minimal schema (caller should not use this for decode)
135        arrow::datatypes::Schema::new(vec![arrow::datatypes::Field::new(
136            "dummy",
137            arrow::datatypes::DataType::Int64,
138            true,
139        )])
140    }
141}
142
143/// Strips the data_type column from a record batch and returns the parsed DataType.
144/// Returns (batch, None) if there is no data_type column.
145fn strip_data_type_column(
146    batch: &RecordBatch,
147) -> Result<(RecordBatch, Option<DataType>), super::EncodingError> {
148    use super::extract_column_string;
149
150    let Some(data_type_col_idx) = batch
151        .schema()
152        .fields()
153        .iter()
154        .position(|f| f.name() == "data_type")
155    else {
156        return Ok((batch.clone(), None));
157    };
158
159    if batch.num_rows() == 0 {
160        return Ok((batch.clone(), None));
161    }
162
163    let cols = batch.columns();
164    let string_col = extract_column_string(cols, "data_type", data_type_col_idx).map_err(|e| {
165        super::EncodingError::ParseError("custom_data", format!("data_type column: {e}"))
166    })?;
167    let first_value = string_col.value(0);
168    let data_type = DataType::from_persistence_json(first_value)
169        .map_err(|e| super::EncodingError::ParseError("custom_data", e.to_string()))?;
170
171    let new_fields: Vec<_> = batch
172        .schema()
173        .fields()
174        .iter()
175        .enumerate()
176        .filter(|(i, _)| *i != data_type_col_idx)
177        .map(|(_, f)| f.clone())
178        .collect();
179    let new_columns: Vec<Arc<dyn arrow::array::Array>> = batch
180        .columns()
181        .iter()
182        .enumerate()
183        .filter(|(i, _)| *i != data_type_col_idx)
184        .map(|(_, c)| Arc::clone(c))
185        .collect();
186    let new_schema =
187        arrow::datatypes::Schema::new_with_metadata(new_fields, batch.schema().metadata().clone());
188    let stripped_batch = RecordBatch::try_new(Arc::new(new_schema), new_columns)
189        .map_err(|e| super::EncodingError::ParseError("custom_data", e.to_string()))?;
190
191    Ok((stripped_batch, Some(data_type)))
192}
193
194impl DecodeDataFromRecordBatch for CustomDataDecoder {
195    fn decode_data_batch(
196        metadata: &std::collections::HashMap<String, String>,
197        record_batch: RecordBatch,
198    ) -> Result<Vec<Data>, super::EncodingError> {
199        let type_name = metadata
200            .get("type_name")
201            .cloned()
202            .unwrap_or_else(|| "Unknown".to_string());
203
204        let (batch_to_decode, restored_data_type) = strip_data_type_column(&record_batch)?;
205
206        if batch_to_decode.num_rows() == 0 {
207            return Ok(Vec::new());
208        }
209
210        let data = match decode_custom_from_arrow(&type_name, metadata, batch_to_decode) {
211            Ok(Some(d)) => d,
212            Ok(None) => {
213                return Err(super::EncodingError::ParseError(
214                    "custom_data",
215                    format!(
216                        "unknown custom data type '{type_name}'; only Rust-registered types are supported"
217                    ),
218                ));
219            }
220            Err(e) => {
221                return Err(super::EncodingError::ParseError(
222                    "custom_data",
223                    format!("decode_custom_from_arrow: {e}"),
224                ));
225            }
226        };
227
228        if let Some(dt) = restored_data_type {
229            Ok(data
230                .into_iter()
231                .map(|d| {
232                    if let Data::Custom(c) = d {
233                        Data::Custom(CustomData::new(Arc::clone(&c.data), dt.clone()))
234                    } else {
235                        d
236                    }
237                })
238                .collect())
239        } else {
240            Ok(data)
241        }
242    }
243}