1use std::{collections::HashMap, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, Int8Array, UInt8Array, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::{Data, custom::CustomData},
26 enums::{FromU8, OrderSide},
27 types::fixed::PRECISION_BYTES,
28};
29use nautilus_serialization::arrow::{
30 ArrowSchemaProvider, DecodeDataFromRecordBatch, EncodeToRecordBatch, EncodingError,
31 decode_price, decode_quantity, extract_column, validate_precision_bytes,
32};
33
34use super::parse_metadata;
35use crate::types::DatabentoImbalance;
36
37impl ArrowSchemaProvider for DatabentoImbalance {
38 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
39 let fields = vec![
40 Field::new(
41 "ref_price",
42 DataType::FixedSizeBinary(PRECISION_BYTES),
43 false,
44 ),
45 Field::new(
46 "cont_book_clr_price",
47 DataType::FixedSizeBinary(PRECISION_BYTES),
48 false,
49 ),
50 Field::new(
51 "auct_interest_clr_price",
52 DataType::FixedSizeBinary(PRECISION_BYTES),
53 false,
54 ),
55 Field::new(
56 "paired_qty",
57 DataType::FixedSizeBinary(PRECISION_BYTES),
58 false,
59 ),
60 Field::new(
61 "total_imbalance_qty",
62 DataType::FixedSizeBinary(PRECISION_BYTES),
63 false,
64 ),
65 Field::new("side", DataType::UInt8, false),
66 Field::new("significant_imbalance", DataType::Int8, false),
67 Field::new("ts_event", DataType::UInt64, false),
68 Field::new("ts_recv", DataType::UInt64, false),
69 Field::new("ts_init", DataType::UInt64, false),
70 ];
71
72 match metadata {
73 Some(metadata) => Schema::new_with_metadata(fields, metadata),
74 None => Schema::new(fields),
75 }
76 }
77}
78
79impl EncodeToRecordBatch for DatabentoImbalance {
80 #[expect(clippy::unnecessary_cast)] fn encode_batch(
82 metadata: &HashMap<String, String>,
83 data: &[Self],
84 ) -> Result<RecordBatch, ArrowError> {
85 let mut ref_price_builder =
86 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
87 let mut cont_book_clr_price_builder =
88 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
89 let mut auct_interest_clr_price_builder =
90 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
91 let mut paired_qty_builder =
92 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
93 let mut total_imbalance_qty_builder =
94 FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
95 let mut side_builder = UInt8Array::builder(data.len());
96 let mut significant_imbalance_builder = Int8Array::builder(data.len());
97 let mut ts_event_builder = UInt64Array::builder(data.len());
98 let mut ts_recv_builder = UInt64Array::builder(data.len());
99 let mut ts_init_builder = UInt64Array::builder(data.len());
100
101 for item in data {
102 ref_price_builder
103 .append_value(item.ref_price.raw.to_le_bytes())
104 .unwrap();
105 cont_book_clr_price_builder
106 .append_value(item.cont_book_clr_price.raw.to_le_bytes())
107 .unwrap();
108 auct_interest_clr_price_builder
109 .append_value(item.auct_interest_clr_price.raw.to_le_bytes())
110 .unwrap();
111 paired_qty_builder
112 .append_value(item.paired_qty.raw.to_le_bytes())
113 .unwrap();
114 total_imbalance_qty_builder
115 .append_value(item.total_imbalance_qty.raw.to_le_bytes())
116 .unwrap();
117 side_builder.append_value(item.side as u8);
118 significant_imbalance_builder.append_value(item.significant_imbalance as i8);
119 ts_event_builder.append_value(item.ts_event.as_u64());
120 ts_recv_builder.append_value(item.ts_recv.as_u64());
121 ts_init_builder.append_value(item.ts_init.as_u64());
122 }
123
124 RecordBatch::try_new(
125 Self::get_schema(Some(metadata.clone())).into(),
126 vec![
127 Arc::new(ref_price_builder.finish()),
128 Arc::new(cont_book_clr_price_builder.finish()),
129 Arc::new(auct_interest_clr_price_builder.finish()),
130 Arc::new(paired_qty_builder.finish()),
131 Arc::new(total_imbalance_qty_builder.finish()),
132 Arc::new(side_builder.finish()),
133 Arc::new(significant_imbalance_builder.finish()),
134 Arc::new(ts_event_builder.finish()),
135 Arc::new(ts_recv_builder.finish()),
136 Arc::new(ts_init_builder.finish()),
137 ],
138 )
139 }
140
141 fn metadata(&self) -> HashMap<String, String> {
142 Self::get_metadata(
143 &self.instrument_id,
144 self.ref_price.precision,
145 self.paired_qty.precision,
146 )
147 }
148}
149
150impl DecodeDataFromRecordBatch for DatabentoImbalance {
151 fn decode_data_batch(
152 metadata: &HashMap<String, String>,
153 record_batch: RecordBatch,
154 ) -> Result<Vec<Data>, EncodingError> {
155 let items = decode_imbalance_batch(metadata, &record_batch)?;
156 Ok(items
157 .into_iter()
158 .map(|item| Data::Custom(CustomData::from_arc(Arc::new(item))))
159 .collect())
160 }
161}
162
163pub fn decode_imbalance_batch(
169 metadata: &HashMap<String, String>,
170 record_batch: &RecordBatch,
171) -> Result<Vec<DatabentoImbalance>, EncodingError> {
172 let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
173 let cols = record_batch.columns();
174
175 let ref_price_values = extract_column::<FixedSizeBinaryArray>(
176 cols,
177 "ref_price",
178 0,
179 DataType::FixedSizeBinary(PRECISION_BYTES),
180 )?;
181 let cont_book_clr_price_values = extract_column::<FixedSizeBinaryArray>(
182 cols,
183 "cont_book_clr_price",
184 1,
185 DataType::FixedSizeBinary(PRECISION_BYTES),
186 )?;
187 let auct_interest_clr_price_values = extract_column::<FixedSizeBinaryArray>(
188 cols,
189 "auct_interest_clr_price",
190 2,
191 DataType::FixedSizeBinary(PRECISION_BYTES),
192 )?;
193 let paired_qty_values = extract_column::<FixedSizeBinaryArray>(
194 cols,
195 "paired_qty",
196 3,
197 DataType::FixedSizeBinary(PRECISION_BYTES),
198 )?;
199 let total_imbalance_qty_values = extract_column::<FixedSizeBinaryArray>(
200 cols,
201 "total_imbalance_qty",
202 4,
203 DataType::FixedSizeBinary(PRECISION_BYTES),
204 )?;
205 let side_values = extract_column::<UInt8Array>(cols, "side", 5, DataType::UInt8)?;
206 let significant_imbalance_values =
207 extract_column::<Int8Array>(cols, "significant_imbalance", 6, DataType::Int8)?;
208 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 7, DataType::UInt64)?;
209 let ts_recv_values = extract_column::<UInt64Array>(cols, "ts_recv", 8, DataType::UInt64)?;
210 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 9, DataType::UInt64)?;
211
212 validate_precision_bytes(ref_price_values, "ref_price")?;
213 validate_precision_bytes(cont_book_clr_price_values, "cont_book_clr_price")?;
214 validate_precision_bytes(auct_interest_clr_price_values, "auct_interest_clr_price")?;
215 validate_precision_bytes(paired_qty_values, "paired_qty")?;
216 validate_precision_bytes(total_imbalance_qty_values, "total_imbalance_qty")?;
217
218 (0..record_batch.num_rows())
219 .map(|row| {
220 let ref_price = decode_price(
221 ref_price_values.value(row),
222 price_precision,
223 "ref_price",
224 row,
225 )?;
226 let cont_book_clr_price = decode_price(
227 cont_book_clr_price_values.value(row),
228 price_precision,
229 "cont_book_clr_price",
230 row,
231 )?;
232 let auct_interest_clr_price = decode_price(
233 auct_interest_clr_price_values.value(row),
234 price_precision,
235 "auct_interest_clr_price",
236 row,
237 )?;
238 let paired_qty = decode_quantity(
239 paired_qty_values.value(row),
240 size_precision,
241 "paired_qty",
242 row,
243 )?;
244 let total_imbalance_qty = decode_quantity(
245 total_imbalance_qty_values.value(row),
246 size_precision,
247 "total_imbalance_qty",
248 row,
249 )?;
250 let side_value = side_values.value(row);
251 let side = OrderSide::from_u8(side_value).ok_or_else(|| {
252 EncodingError::ParseError(
253 stringify!(OrderSide),
254 format!("Invalid enum value, was {side_value}"),
255 )
256 })?;
257 let significant_imbalance = significant_imbalance_values.value(row) as std::ffi::c_char;
258
259 Ok(DatabentoImbalance {
260 instrument_id,
261 ref_price,
262 cont_book_clr_price,
263 auct_interest_clr_price,
264 paired_qty,
265 total_imbalance_qty,
266 side,
267 significant_imbalance,
268 ts_event: ts_event_values.value(row).into(),
269 ts_recv: ts_recv_values.value(row).into(),
270 ts_init: ts_init_values.value(row).into(),
271 })
272 })
273 .collect()
274}
275
276pub fn imbalance_to_arrow_record_batch(
283 data: &[DatabentoImbalance],
284) -> Result<RecordBatch, EncodingError> {
285 if data.is_empty() {
286 return Err(EncodingError::EmptyData);
287 }
288
289 let metadata = DatabentoImbalance::chunk_metadata(data);
290 DatabentoImbalance::encode_batch(&metadata, data).map_err(EncodingError::ArrowError)
291}
292
293#[cfg(test)]
294mod tests {
295 use nautilus_model::{
296 enums::OrderSide,
297 identifiers::InstrumentId,
298 types::{Price, Quantity},
299 };
300 use nautilus_serialization::arrow::{
301 ArrowSchemaProvider, EncodeToRecordBatch, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
302 KEY_SIZE_PRECISION,
303 };
304 use rstest::rstest;
305
306 use super::*;
307
308 fn test_metadata() -> HashMap<String, String> {
309 HashMap::from([
310 (KEY_INSTRUMENT_ID.to_string(), "AAPL.XNAS".to_string()),
311 (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
312 (KEY_SIZE_PRECISION.to_string(), "0".to_string()),
313 ])
314 }
315
316 fn test_imbalance(instrument_id: InstrumentId) -> DatabentoImbalance {
317 DatabentoImbalance::new(
318 instrument_id,
319 Price::from("100.50"),
320 Price::from("100.45"),
321 Price::from("100.55"),
322 Quantity::from("1000"),
323 Quantity::from("500"),
324 OrderSide::Buy,
325 b'Y' as std::ffi::c_char,
326 1.into(),
327 2.into(),
328 3.into(),
329 )
330 }
331
332 #[rstest]
333 fn test_get_schema() {
334 let schema = DatabentoImbalance::get_schema(None);
335 assert_eq!(schema.fields().len(), 10);
336 assert_eq!(schema.field(0).name(), "ref_price");
337 assert_eq!(schema.field(5).name(), "side");
338 assert_eq!(schema.field(9).name(), "ts_init");
339 }
340
341 #[rstest]
342 fn test_encode_batch() {
343 let instrument_id = InstrumentId::from("AAPL.XNAS");
344 let metadata = test_metadata();
345 let data = vec![test_imbalance(instrument_id)];
346 let batch = DatabentoImbalance::encode_batch(&metadata, &data).unwrap();
347
348 assert_eq!(batch.num_rows(), 1);
349 assert_eq!(batch.num_columns(), 10);
350 }
351
352 #[rstest]
353 fn test_encode_decode_round_trip() {
354 let instrument_id = InstrumentId::from("AAPL.XNAS");
355 let metadata = test_metadata();
356 let original = vec![test_imbalance(instrument_id)];
357 let batch = DatabentoImbalance::encode_batch(&metadata, &original).unwrap();
358 let decoded = decode_imbalance_batch(&metadata, &batch).unwrap();
359
360 assert_eq!(decoded.len(), 1);
361 assert_eq!(decoded[0].instrument_id, instrument_id);
362 assert_eq!(decoded[0].ref_price, original[0].ref_price);
363 assert_eq!(
364 decoded[0].cont_book_clr_price,
365 original[0].cont_book_clr_price
366 );
367 assert_eq!(
368 decoded[0].auct_interest_clr_price,
369 original[0].auct_interest_clr_price
370 );
371 assert_eq!(decoded[0].paired_qty, original[0].paired_qty);
372 assert_eq!(
373 decoded[0].total_imbalance_qty,
374 original[0].total_imbalance_qty
375 );
376 assert_eq!(decoded[0].side, original[0].side);
377 assert_eq!(
378 decoded[0].significant_imbalance,
379 original[0].significant_imbalance
380 );
381 assert_eq!(decoded[0].ts_event, original[0].ts_event);
382 assert_eq!(decoded[0].ts_recv, original[0].ts_recv);
383 assert_eq!(decoded[0].ts_init, original[0].ts_init);
384 }
385
386 #[rstest]
387 fn test_encode_decode_multiple_rows() {
388 let instrument_id = InstrumentId::from("AAPL.XNAS");
389 let metadata = test_metadata();
390 let imb1 = test_imbalance(instrument_id);
391 let mut imb2 = test_imbalance(instrument_id);
392 imb2.side = OrderSide::Sell;
393 imb2.ref_price = Price::from("101.00");
394 imb2.ts_event = 100.into();
395 let mut imb3 = test_imbalance(instrument_id);
396 imb3.side = OrderSide::NoOrderSide;
397 imb3.significant_imbalance = b'N' as std::ffi::c_char;
398 let original = vec![imb1, imb2, imb3];
399
400 let batch = DatabentoImbalance::encode_batch(&metadata, &original).unwrap();
401 assert_eq!(batch.num_rows(), 3);
402
403 let decoded = decode_imbalance_batch(&metadata, &batch).unwrap();
404 assert_eq!(decoded.len(), 3);
405 for (orig, dec) in original.iter().zip(decoded.iter()) {
406 assert_eq!(dec.instrument_id, orig.instrument_id);
407 assert_eq!(dec.ref_price, orig.ref_price);
408 assert_eq!(dec.side, orig.side);
409 assert_eq!(dec.significant_imbalance, orig.significant_imbalance);
410 assert_eq!(dec.ts_event, orig.ts_event);
411 }
412 }
413
414 #[rstest]
415 fn test_imbalance_to_arrow_record_batch_round_trip() {
416 let instrument_id = InstrumentId::from("AAPL.XNAS");
417 let original = vec![test_imbalance(instrument_id)];
418 let batch = imbalance_to_arrow_record_batch(&original).unwrap();
419 let metadata = batch.schema().metadata().clone();
420 let decoded = decode_imbalance_batch(&metadata, &batch).unwrap();
421
422 assert_eq!(decoded.len(), 1);
423 assert_eq!(decoded[0].ref_price, original[0].ref_price);
424 assert_eq!(decoded[0].paired_qty, original[0].paired_qty);
425 }
426
427 #[rstest]
428 fn test_get_schema_with_metadata() {
429 let metadata = test_metadata();
430 let schema = DatabentoImbalance::get_schema(Some(metadata.clone()));
431 assert_eq!(schema.metadata(), &metadata);
432 assert_eq!(schema.fields().len(), 10);
433 }
434
435 #[rstest]
436 fn test_imbalance_to_arrow_record_batch_empty() {
437 let result = imbalance_to_arrow_record_batch(&[]);
438 assert!(result.is_err());
439 }
440
441 #[rstest]
442 fn test_decode_missing_metadata_returns_error() {
443 let instrument_id = InstrumentId::from("AAPL.XNAS");
444 let metadata = test_metadata();
445 let data = vec![test_imbalance(instrument_id)];
446 let batch = DatabentoImbalance::encode_batch(&metadata, &data).unwrap();
447
448 let empty_metadata = HashMap::new();
449 let result = decode_imbalance_batch(&empty_metadata, &batch);
450 assert!(result.is_err());
451 }
452
453 #[rstest]
454 fn test_decode_data_batch_produces_custom_data() {
455 let instrument_id = InstrumentId::from("AAPL.XNAS");
456 let metadata = test_metadata();
457 let original = vec![test_imbalance(instrument_id)];
458 let batch = DatabentoImbalance::encode_batch(&metadata, &original).unwrap();
459 let data_vec = DatabentoImbalance::decode_data_batch(&metadata, batch).unwrap();
460
461 assert_eq!(data_vec.len(), 1);
462 match &data_vec[0] {
463 Data::Custom(custom) => {
464 assert_eq!(custom.data.type_name(), "DatabentoImbalance");
465 let imbalance = custom
466 .data
467 .as_any()
468 .downcast_ref::<DatabentoImbalance>()
469 .unwrap();
470 assert_eq!(imbalance.instrument_id, instrument_id);
471 assert_eq!(imbalance.ref_price, original[0].ref_price);
472 assert_eq!(imbalance.paired_qty, original[0].paired_qty);
473 assert_eq!(imbalance.side, original[0].side);
474 assert_eq!(imbalance.ts_event, original[0].ts_event);
475 assert_eq!(imbalance.ts_init, original[0].ts_init);
476 }
477 other => panic!("Expected Data::Custom, was {other:?}"),
478 }
479 }
480
481 #[rstest]
482 fn test_decode_data_batch_multiple_rows() {
483 let instrument_id = InstrumentId::from("AAPL.XNAS");
484 let metadata = test_metadata();
485 let mut imb2 = test_imbalance(instrument_id);
486 imb2.side = OrderSide::Sell;
487 imb2.ts_event = 100.into();
488 let original = vec![test_imbalance(instrument_id), imb2];
489 let batch = DatabentoImbalance::encode_batch(&metadata, &original).unwrap();
490 let data_vec = DatabentoImbalance::decode_data_batch(&metadata, batch).unwrap();
491
492 assert_eq!(data_vec.len(), 2);
493 for (i, data) in data_vec.iter().enumerate() {
494 match data {
495 Data::Custom(custom) => {
496 let imbalance = custom
497 .data
498 .as_any()
499 .downcast_ref::<DatabentoImbalance>()
500 .unwrap();
501 assert_eq!(imbalance.instrument_id, original[i].instrument_id);
502 assert_eq!(imbalance.side, original[i].side);
503 assert_eq!(imbalance.ts_event, original[i].ts_event);
504 }
505 other => panic!("Expected Data::Custom, was {other:?}"),
506 }
507 }
508 }
509
510 #[rstest]
511 fn test_ipc_stream_round_trip() {
512 use std::io::Cursor;
513
514 use arrow::ipc::{reader::StreamReader, writer::StreamWriter};
515
516 let instrument_id = InstrumentId::from("AAPL.XNAS");
517 let original = vec![test_imbalance(instrument_id), {
518 let mut imb = test_imbalance(instrument_id);
519 imb.side = OrderSide::Sell;
520 imb.ref_price = Price::from("101.25");
521 imb.ts_event = 100.into();
522 imb
523 }];
524 let batch = imbalance_to_arrow_record_batch(&original).unwrap();
525
526 let mut cursor = Cursor::new(Vec::new());
527 {
528 let mut writer = StreamWriter::try_new(&mut cursor, &batch.schema()).unwrap();
529 writer.write(&batch).unwrap();
530 writer.finish().unwrap();
531 }
532
533 let buffer = cursor.into_inner();
534 let reader = StreamReader::try_new(Cursor::new(buffer), None).unwrap();
535 let mut decoded = Vec::new();
536
537 for batch_result in reader {
538 let batch = batch_result.unwrap();
539 let metadata = batch.schema().metadata().clone();
540 decoded.extend(decode_imbalance_batch(&metadata, &batch).unwrap());
541 }
542
543 assert_eq!(decoded.len(), 2);
544 for (orig, dec) in original.iter().zip(decoded.iter()) {
545 assert_eq!(dec, orig);
546 }
547 }
548}