nautilus_serialization/arrow/
custom.rs1use std::sync::Arc;
26
27use arrow::record_batch::RecordBatch;
28use nautilus_model::data::{
29 ArrowDecoder, ArrowEncoder, CustomData, CustomDataTrait, Data, DataType,
30 decode_custom_from_arrow, ensure_arrow_registered, ensure_custom_data_json_registered,
31 get_arrow_schema,
32};
33
34use super::{ArrowSchemaProvider, DecodeDataFromRecordBatch, EncodeToRecordBatch};
35
36pub trait CustomDataSerialize: CustomDataTrait {
43 fn schema(&self) -> anyhow::Result<arrow::datatypes::Schema>;
48
49 fn encode_record_batch(
54 &self,
55 items: &[Arc<dyn CustomDataTrait>],
56 ) -> anyhow::Result<RecordBatch>;
57}
58
59pub fn ensure_custom_data_registered<T>()
67where
68 T: CustomDataTrait
69 + ArrowSchemaProvider
70 + EncodeToRecordBatch
71 + DecodeDataFromRecordBatch
72 + Clone
73 + Send
74 + Sync
75 + 'static,
76{
77 let type_name = T::type_name_static();
78
79 if get_arrow_schema(type_name).is_some() {
81 return;
82 }
83
84 let _ = ensure_custom_data_json_registered::<T>();
85
86 let schema = Arc::new(T::get_schema(None));
87
88 let encoder: ArrowEncoder = Box::new(|items: &[Arc<dyn CustomDataTrait>]| {
89 let typed: Result<Vec<T>, _> = items
90 .iter()
91 .map(|b| {
92 b.as_any()
93 .downcast_ref::<T>()
94 .cloned()
95 .ok_or_else(|| anyhow::anyhow!("Expected {}", T::type_name_static()))
96 })
97 .collect();
98 let typed = typed?;
99 let metadata = typed
100 .first()
101 .map(EncodeToRecordBatch::metadata)
102 .unwrap_or_default();
103 EncodeToRecordBatch::encode_batch(&metadata, &typed).map_err(|e| anyhow::anyhow!("{e}"))
104 });
105
106 let decoder: ArrowDecoder = Box::new(|metadata, batch| {
107 T::decode_data_batch(metadata, batch).map_err(|e| anyhow::anyhow!("{e}"))
108 });
109
110 let _ = ensure_arrow_registered(type_name, schema, encoder, decoder);
111}
112
113#[derive(Debug)]
121pub struct CustomDataDecoder;
122
123impl ArrowSchemaProvider for CustomDataDecoder {
124 fn get_schema(
125 metadata: Option<std::collections::HashMap<String, String>>,
126 ) -> arrow::datatypes::Schema {
127 if let Some(metadata) = metadata
128 && let Some(type_name) = metadata.get("type_name")
129 && let Some(schema) = get_arrow_schema(type_name)
130 {
131 return (*schema).clone();
132 }
133
134 arrow::datatypes::Schema::new(vec![arrow::datatypes::Field::new(
136 "dummy",
137 arrow::datatypes::DataType::Int64,
138 true,
139 )])
140 }
141}
142
143fn strip_data_type_column(
146 batch: &RecordBatch,
147) -> Result<(RecordBatch, Option<DataType>), super::EncodingError> {
148 use super::extract_column_string;
149
150 let Some(data_type_col_idx) = batch
151 .schema()
152 .fields()
153 .iter()
154 .position(|f| f.name() == "data_type")
155 else {
156 return Ok((batch.clone(), None));
157 };
158
159 if batch.num_rows() == 0 {
160 return Ok((batch.clone(), None));
161 }
162
163 let cols = batch.columns();
164 let string_col = extract_column_string(cols, "data_type", data_type_col_idx).map_err(|e| {
165 super::EncodingError::ParseError("custom_data", format!("data_type column: {e}"))
166 })?;
167 let first_value = string_col.value(0);
168 let data_type = DataType::from_persistence_json(first_value)
169 .map_err(|e| super::EncodingError::ParseError("custom_data", e.to_string()))?;
170
171 let new_fields: Vec<_> = batch
172 .schema()
173 .fields()
174 .iter()
175 .enumerate()
176 .filter(|(i, _)| *i != data_type_col_idx)
177 .map(|(_, f)| f.clone())
178 .collect();
179 let new_columns: Vec<Arc<dyn arrow::array::Array>> = batch
180 .columns()
181 .iter()
182 .enumerate()
183 .filter(|(i, _)| *i != data_type_col_idx)
184 .map(|(_, c)| Arc::clone(c))
185 .collect();
186 let new_schema =
187 arrow::datatypes::Schema::new_with_metadata(new_fields, batch.schema().metadata().clone());
188 let stripped_batch = RecordBatch::try_new(Arc::new(new_schema), new_columns)
189 .map_err(|e| super::EncodingError::ParseError("custom_data", e.to_string()))?;
190
191 Ok((stripped_batch, Some(data_type)))
192}
193
194impl DecodeDataFromRecordBatch for CustomDataDecoder {
195 fn decode_data_batch(
196 metadata: &std::collections::HashMap<String, String>,
197 record_batch: RecordBatch,
198 ) -> Result<Vec<Data>, super::EncodingError> {
199 let type_name = metadata
200 .get("type_name")
201 .cloned()
202 .unwrap_or_else(|| "Unknown".to_string());
203
204 let (batch_to_decode, restored_data_type) = strip_data_type_column(&record_batch)?;
205
206 if batch_to_decode.num_rows() == 0 {
207 return Ok(Vec::new());
208 }
209
210 let data = match decode_custom_from_arrow(&type_name, metadata, batch_to_decode) {
211 Ok(Some(d)) => d,
212 Ok(None) => {
213 return Err(super::EncodingError::ParseError(
214 "custom_data",
215 format!(
216 "unknown custom data type '{type_name}'; only Rust-registered types are supported"
217 ),
218 ));
219 }
220 Err(e) => {
221 return Err(super::EncodingError::ParseError(
222 "custom_data",
223 format!("decode_custom_from_arrow: {e}"),
224 ));
225 }
226 };
227
228 if let Some(dt) = restored_data_type {
229 Ok(data
230 .into_iter()
231 .map(|d| {
232 if let Data::Custom(c) = d {
233 Data::Custom(CustomData::new(Arc::clone(&c.data), dt.clone()))
234 } else {
235 d
236 }
237 })
238 .collect())
239 } else {
240 Ok(data)
241 }
242 }
243}