Skip to main content

nautilus_databento/arrow/
imbalance.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, sync::Arc};
17
18use arrow::{
19    array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, Int8Array, UInt8Array, UInt64Array},
20    datatypes::{DataType, Field, Schema},
21    error::ArrowError,
22    record_batch::RecordBatch,
23};
24use nautilus_model::{
25    data::{Data, custom::CustomData},
26    enums::{FromU8, OrderSide},
27    types::fixed::PRECISION_BYTES,
28};
29use nautilus_serialization::arrow::{
30    ArrowSchemaProvider, DecodeDataFromRecordBatch, EncodeToRecordBatch, EncodingError,
31    decode_price, decode_quantity, extract_column, validate_precision_bytes,
32};
33
34use super::parse_metadata;
35use crate::types::DatabentoImbalance;
36
37impl ArrowSchemaProvider for DatabentoImbalance {
38    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
39        let fields = vec![
40            Field::new(
41                "ref_price",
42                DataType::FixedSizeBinary(PRECISION_BYTES),
43                false,
44            ),
45            Field::new(
46                "cont_book_clr_price",
47                DataType::FixedSizeBinary(PRECISION_BYTES),
48                false,
49            ),
50            Field::new(
51                "auct_interest_clr_price",
52                DataType::FixedSizeBinary(PRECISION_BYTES),
53                false,
54            ),
55            Field::new(
56                "paired_qty",
57                DataType::FixedSizeBinary(PRECISION_BYTES),
58                false,
59            ),
60            Field::new(
61                "total_imbalance_qty",
62                DataType::FixedSizeBinary(PRECISION_BYTES),
63                false,
64            ),
65            Field::new("side", DataType::UInt8, false),
66            Field::new("significant_imbalance", DataType::Int8, false),
67            Field::new("ts_event", DataType::UInt64, false),
68            Field::new("ts_recv", DataType::UInt64, false),
69            Field::new("ts_init", DataType::UInt64, false),
70        ];
71
72        match metadata {
73            Some(metadata) => Schema::new_with_metadata(fields, metadata),
74            None => Schema::new(fields),
75        }
76    }
77}
78
79impl EncodeToRecordBatch for DatabentoImbalance {
80    #[expect(clippy::unnecessary_cast)] // c_char is u8 on some targets
81    fn encode_batch(
82        metadata: &HashMap<String, String>,
83        data: &[Self],
84    ) -> Result<RecordBatch, ArrowError> {
85        let mut ref_price_builder =
86            FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
87        let mut cont_book_clr_price_builder =
88            FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
89        let mut auct_interest_clr_price_builder =
90            FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
91        let mut paired_qty_builder =
92            FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
93        let mut total_imbalance_qty_builder =
94            FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
95        let mut side_builder = UInt8Array::builder(data.len());
96        let mut significant_imbalance_builder = Int8Array::builder(data.len());
97        let mut ts_event_builder = UInt64Array::builder(data.len());
98        let mut ts_recv_builder = UInt64Array::builder(data.len());
99        let mut ts_init_builder = UInt64Array::builder(data.len());
100
101        for item in data {
102            ref_price_builder
103                .append_value(item.ref_price.raw.to_le_bytes())
104                .unwrap();
105            cont_book_clr_price_builder
106                .append_value(item.cont_book_clr_price.raw.to_le_bytes())
107                .unwrap();
108            auct_interest_clr_price_builder
109                .append_value(item.auct_interest_clr_price.raw.to_le_bytes())
110                .unwrap();
111            paired_qty_builder
112                .append_value(item.paired_qty.raw.to_le_bytes())
113                .unwrap();
114            total_imbalance_qty_builder
115                .append_value(item.total_imbalance_qty.raw.to_le_bytes())
116                .unwrap();
117            side_builder.append_value(item.side as u8);
118            significant_imbalance_builder.append_value(item.significant_imbalance as i8);
119            ts_event_builder.append_value(item.ts_event.as_u64());
120            ts_recv_builder.append_value(item.ts_recv.as_u64());
121            ts_init_builder.append_value(item.ts_init.as_u64());
122        }
123
124        RecordBatch::try_new(
125            Self::get_schema(Some(metadata.clone())).into(),
126            vec![
127                Arc::new(ref_price_builder.finish()),
128                Arc::new(cont_book_clr_price_builder.finish()),
129                Arc::new(auct_interest_clr_price_builder.finish()),
130                Arc::new(paired_qty_builder.finish()),
131                Arc::new(total_imbalance_qty_builder.finish()),
132                Arc::new(side_builder.finish()),
133                Arc::new(significant_imbalance_builder.finish()),
134                Arc::new(ts_event_builder.finish()),
135                Arc::new(ts_recv_builder.finish()),
136                Arc::new(ts_init_builder.finish()),
137            ],
138        )
139    }
140
141    fn metadata(&self) -> HashMap<String, String> {
142        Self::get_metadata(
143            &self.instrument_id,
144            self.ref_price.precision,
145            self.paired_qty.precision,
146        )
147    }
148}
149
150impl DecodeDataFromRecordBatch for DatabentoImbalance {
151    fn decode_data_batch(
152        metadata: &HashMap<String, String>,
153        record_batch: RecordBatch,
154    ) -> Result<Vec<Data>, EncodingError> {
155        let items = decode_imbalance_batch(metadata, &record_batch)?;
156        Ok(items
157            .into_iter()
158            .map(|item| Data::Custom(CustomData::from_arc(Arc::new(item))))
159            .collect())
160    }
161}
162
163/// Decodes a `RecordBatch` into a vector of [`DatabentoImbalance`].
164///
165/// # Errors
166///
167/// Returns an `EncodingError` if decoding fails.
168pub fn decode_imbalance_batch(
169    metadata: &HashMap<String, String>,
170    record_batch: &RecordBatch,
171) -> Result<Vec<DatabentoImbalance>, EncodingError> {
172    let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
173    let cols = record_batch.columns();
174
175    let ref_price_values = extract_column::<FixedSizeBinaryArray>(
176        cols,
177        "ref_price",
178        0,
179        DataType::FixedSizeBinary(PRECISION_BYTES),
180    )?;
181    let cont_book_clr_price_values = extract_column::<FixedSizeBinaryArray>(
182        cols,
183        "cont_book_clr_price",
184        1,
185        DataType::FixedSizeBinary(PRECISION_BYTES),
186    )?;
187    let auct_interest_clr_price_values = extract_column::<FixedSizeBinaryArray>(
188        cols,
189        "auct_interest_clr_price",
190        2,
191        DataType::FixedSizeBinary(PRECISION_BYTES),
192    )?;
193    let paired_qty_values = extract_column::<FixedSizeBinaryArray>(
194        cols,
195        "paired_qty",
196        3,
197        DataType::FixedSizeBinary(PRECISION_BYTES),
198    )?;
199    let total_imbalance_qty_values = extract_column::<FixedSizeBinaryArray>(
200        cols,
201        "total_imbalance_qty",
202        4,
203        DataType::FixedSizeBinary(PRECISION_BYTES),
204    )?;
205    let side_values = extract_column::<UInt8Array>(cols, "side", 5, DataType::UInt8)?;
206    let significant_imbalance_values =
207        extract_column::<Int8Array>(cols, "significant_imbalance", 6, DataType::Int8)?;
208    let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 7, DataType::UInt64)?;
209    let ts_recv_values = extract_column::<UInt64Array>(cols, "ts_recv", 8, DataType::UInt64)?;
210    let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 9, DataType::UInt64)?;
211
212    validate_precision_bytes(ref_price_values, "ref_price")?;
213    validate_precision_bytes(cont_book_clr_price_values, "cont_book_clr_price")?;
214    validate_precision_bytes(auct_interest_clr_price_values, "auct_interest_clr_price")?;
215    validate_precision_bytes(paired_qty_values, "paired_qty")?;
216    validate_precision_bytes(total_imbalance_qty_values, "total_imbalance_qty")?;
217
218    (0..record_batch.num_rows())
219        .map(|row| {
220            let ref_price = decode_price(
221                ref_price_values.value(row),
222                price_precision,
223                "ref_price",
224                row,
225            )?;
226            let cont_book_clr_price = decode_price(
227                cont_book_clr_price_values.value(row),
228                price_precision,
229                "cont_book_clr_price",
230                row,
231            )?;
232            let auct_interest_clr_price = decode_price(
233                auct_interest_clr_price_values.value(row),
234                price_precision,
235                "auct_interest_clr_price",
236                row,
237            )?;
238            let paired_qty = decode_quantity(
239                paired_qty_values.value(row),
240                size_precision,
241                "paired_qty",
242                row,
243            )?;
244            let total_imbalance_qty = decode_quantity(
245                total_imbalance_qty_values.value(row),
246                size_precision,
247                "total_imbalance_qty",
248                row,
249            )?;
250            let side_value = side_values.value(row);
251            let side = OrderSide::from_u8(side_value).ok_or_else(|| {
252                EncodingError::ParseError(
253                    stringify!(OrderSide),
254                    format!("Invalid enum value, was {side_value}"),
255                )
256            })?;
257            let significant_imbalance = significant_imbalance_values.value(row) as std::ffi::c_char;
258
259            Ok(DatabentoImbalance {
260                instrument_id,
261                ref_price,
262                cont_book_clr_price,
263                auct_interest_clr_price,
264                paired_qty,
265                total_imbalance_qty,
266                side,
267                significant_imbalance,
268                ts_event: ts_event_values.value(row).into(),
269                ts_recv: ts_recv_values.value(row).into(),
270                ts_init: ts_init_values.value(row).into(),
271            })
272        })
273        .collect()
274}
275
276/// Encodes a vector of [`DatabentoImbalance`] into an Arrow `RecordBatch`.
277///
278/// # Errors
279///
280/// Returns an error if `data` is empty or encoding fails.
281// Guarded by empty check
282pub fn imbalance_to_arrow_record_batch(
283    data: &[DatabentoImbalance],
284) -> Result<RecordBatch, EncodingError> {
285    if data.is_empty() {
286        return Err(EncodingError::EmptyData);
287    }
288
289    let metadata = DatabentoImbalance::chunk_metadata(data);
290    DatabentoImbalance::encode_batch(&metadata, data).map_err(EncodingError::ArrowError)
291}
292
293#[cfg(test)]
294mod tests {
295    use nautilus_model::{
296        enums::OrderSide,
297        identifiers::InstrumentId,
298        types::{Price, Quantity},
299    };
300    use nautilus_serialization::arrow::{
301        ArrowSchemaProvider, EncodeToRecordBatch, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
302        KEY_SIZE_PRECISION,
303    };
304    use rstest::rstest;
305
306    use super::*;
307
308    fn test_metadata() -> HashMap<String, String> {
309        HashMap::from([
310            (KEY_INSTRUMENT_ID.to_string(), "AAPL.XNAS".to_string()),
311            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
312            (KEY_SIZE_PRECISION.to_string(), "0".to_string()),
313        ])
314    }
315
316    fn test_imbalance(instrument_id: InstrumentId) -> DatabentoImbalance {
317        DatabentoImbalance::new(
318            instrument_id,
319            Price::from("100.50"),
320            Price::from("100.45"),
321            Price::from("100.55"),
322            Quantity::from("1000"),
323            Quantity::from("500"),
324            OrderSide::Buy,
325            b'Y' as std::ffi::c_char,
326            1.into(),
327            2.into(),
328            3.into(),
329        )
330    }
331
332    #[rstest]
333    fn test_get_schema() {
334        let schema = DatabentoImbalance::get_schema(None);
335        assert_eq!(schema.fields().len(), 10);
336        assert_eq!(schema.field(0).name(), "ref_price");
337        assert_eq!(schema.field(5).name(), "side");
338        assert_eq!(schema.field(9).name(), "ts_init");
339    }
340
341    #[rstest]
342    fn test_encode_batch() {
343        let instrument_id = InstrumentId::from("AAPL.XNAS");
344        let metadata = test_metadata();
345        let data = vec![test_imbalance(instrument_id)];
346        let batch = DatabentoImbalance::encode_batch(&metadata, &data).unwrap();
347
348        assert_eq!(batch.num_rows(), 1);
349        assert_eq!(batch.num_columns(), 10);
350    }
351
352    #[rstest]
353    fn test_encode_decode_round_trip() {
354        let instrument_id = InstrumentId::from("AAPL.XNAS");
355        let metadata = test_metadata();
356        let original = vec![test_imbalance(instrument_id)];
357        let batch = DatabentoImbalance::encode_batch(&metadata, &original).unwrap();
358        let decoded = decode_imbalance_batch(&metadata, &batch).unwrap();
359
360        assert_eq!(decoded.len(), 1);
361        assert_eq!(decoded[0].instrument_id, instrument_id);
362        assert_eq!(decoded[0].ref_price, original[0].ref_price);
363        assert_eq!(
364            decoded[0].cont_book_clr_price,
365            original[0].cont_book_clr_price
366        );
367        assert_eq!(
368            decoded[0].auct_interest_clr_price,
369            original[0].auct_interest_clr_price
370        );
371        assert_eq!(decoded[0].paired_qty, original[0].paired_qty);
372        assert_eq!(
373            decoded[0].total_imbalance_qty,
374            original[0].total_imbalance_qty
375        );
376        assert_eq!(decoded[0].side, original[0].side);
377        assert_eq!(
378            decoded[0].significant_imbalance,
379            original[0].significant_imbalance
380        );
381        assert_eq!(decoded[0].ts_event, original[0].ts_event);
382        assert_eq!(decoded[0].ts_recv, original[0].ts_recv);
383        assert_eq!(decoded[0].ts_init, original[0].ts_init);
384    }
385
386    #[rstest]
387    fn test_encode_decode_multiple_rows() {
388        let instrument_id = InstrumentId::from("AAPL.XNAS");
389        let metadata = test_metadata();
390        let imb1 = test_imbalance(instrument_id);
391        let mut imb2 = test_imbalance(instrument_id);
392        imb2.side = OrderSide::Sell;
393        imb2.ref_price = Price::from("101.00");
394        imb2.ts_event = 100.into();
395        let mut imb3 = test_imbalance(instrument_id);
396        imb3.side = OrderSide::NoOrderSide;
397        imb3.significant_imbalance = b'N' as std::ffi::c_char;
398        let original = vec![imb1, imb2, imb3];
399
400        let batch = DatabentoImbalance::encode_batch(&metadata, &original).unwrap();
401        assert_eq!(batch.num_rows(), 3);
402
403        let decoded = decode_imbalance_batch(&metadata, &batch).unwrap();
404        assert_eq!(decoded.len(), 3);
405        for (orig, dec) in original.iter().zip(decoded.iter()) {
406            assert_eq!(dec.instrument_id, orig.instrument_id);
407            assert_eq!(dec.ref_price, orig.ref_price);
408            assert_eq!(dec.side, orig.side);
409            assert_eq!(dec.significant_imbalance, orig.significant_imbalance);
410            assert_eq!(dec.ts_event, orig.ts_event);
411        }
412    }
413
414    #[rstest]
415    fn test_imbalance_to_arrow_record_batch_round_trip() {
416        let instrument_id = InstrumentId::from("AAPL.XNAS");
417        let original = vec![test_imbalance(instrument_id)];
418        let batch = imbalance_to_arrow_record_batch(&original).unwrap();
419        let metadata = batch.schema().metadata().clone();
420        let decoded = decode_imbalance_batch(&metadata, &batch).unwrap();
421
422        assert_eq!(decoded.len(), 1);
423        assert_eq!(decoded[0].ref_price, original[0].ref_price);
424        assert_eq!(decoded[0].paired_qty, original[0].paired_qty);
425    }
426
427    #[rstest]
428    fn test_get_schema_with_metadata() {
429        let metadata = test_metadata();
430        let schema = DatabentoImbalance::get_schema(Some(metadata.clone()));
431        assert_eq!(schema.metadata(), &metadata);
432        assert_eq!(schema.fields().len(), 10);
433    }
434
435    #[rstest]
436    fn test_imbalance_to_arrow_record_batch_empty() {
437        let result = imbalance_to_arrow_record_batch(&[]);
438        assert!(result.is_err());
439    }
440
441    #[rstest]
442    fn test_decode_missing_metadata_returns_error() {
443        let instrument_id = InstrumentId::from("AAPL.XNAS");
444        let metadata = test_metadata();
445        let data = vec![test_imbalance(instrument_id)];
446        let batch = DatabentoImbalance::encode_batch(&metadata, &data).unwrap();
447
448        let empty_metadata = HashMap::new();
449        let result = decode_imbalance_batch(&empty_metadata, &batch);
450        assert!(result.is_err());
451    }
452
453    #[rstest]
454    fn test_decode_data_batch_produces_custom_data() {
455        let instrument_id = InstrumentId::from("AAPL.XNAS");
456        let metadata = test_metadata();
457        let original = vec![test_imbalance(instrument_id)];
458        let batch = DatabentoImbalance::encode_batch(&metadata, &original).unwrap();
459        let data_vec = DatabentoImbalance::decode_data_batch(&metadata, batch).unwrap();
460
461        assert_eq!(data_vec.len(), 1);
462        match &data_vec[0] {
463            Data::Custom(custom) => {
464                assert_eq!(custom.data.type_name(), "DatabentoImbalance");
465                let imbalance = custom
466                    .data
467                    .as_any()
468                    .downcast_ref::<DatabentoImbalance>()
469                    .unwrap();
470                assert_eq!(imbalance.instrument_id, instrument_id);
471                assert_eq!(imbalance.ref_price, original[0].ref_price);
472                assert_eq!(imbalance.paired_qty, original[0].paired_qty);
473                assert_eq!(imbalance.side, original[0].side);
474                assert_eq!(imbalance.ts_event, original[0].ts_event);
475                assert_eq!(imbalance.ts_init, original[0].ts_init);
476            }
477            other => panic!("Expected Data::Custom, was {other:?}"),
478        }
479    }
480
481    #[rstest]
482    fn test_decode_data_batch_multiple_rows() {
483        let instrument_id = InstrumentId::from("AAPL.XNAS");
484        let metadata = test_metadata();
485        let mut imb2 = test_imbalance(instrument_id);
486        imb2.side = OrderSide::Sell;
487        imb2.ts_event = 100.into();
488        let original = vec![test_imbalance(instrument_id), imb2];
489        let batch = DatabentoImbalance::encode_batch(&metadata, &original).unwrap();
490        let data_vec = DatabentoImbalance::decode_data_batch(&metadata, batch).unwrap();
491
492        assert_eq!(data_vec.len(), 2);
493        for (i, data) in data_vec.iter().enumerate() {
494            match data {
495                Data::Custom(custom) => {
496                    let imbalance = custom
497                        .data
498                        .as_any()
499                        .downcast_ref::<DatabentoImbalance>()
500                        .unwrap();
501                    assert_eq!(imbalance.instrument_id, original[i].instrument_id);
502                    assert_eq!(imbalance.side, original[i].side);
503                    assert_eq!(imbalance.ts_event, original[i].ts_event);
504                }
505                other => panic!("Expected Data::Custom, was {other:?}"),
506            }
507        }
508    }
509
510    #[rstest]
511    fn test_ipc_stream_round_trip() {
512        use std::io::Cursor;
513
514        use arrow::ipc::{reader::StreamReader, writer::StreamWriter};
515
516        let instrument_id = InstrumentId::from("AAPL.XNAS");
517        let original = vec![test_imbalance(instrument_id), {
518            let mut imb = test_imbalance(instrument_id);
519            imb.side = OrderSide::Sell;
520            imb.ref_price = Price::from("101.25");
521            imb.ts_event = 100.into();
522            imb
523        }];
524        let batch = imbalance_to_arrow_record_batch(&original).unwrap();
525
526        let mut cursor = Cursor::new(Vec::new());
527        {
528            let mut writer = StreamWriter::try_new(&mut cursor, &batch.schema()).unwrap();
529            writer.write(&batch).unwrap();
530            writer.finish().unwrap();
531        }
532
533        let buffer = cursor.into_inner();
534        let reader = StreamReader::try_new(Cursor::new(buffer), None).unwrap();
535        let mut decoded = Vec::new();
536
537        for batch_result in reader {
538            let batch = batch_result.unwrap();
539            let metadata = batch.schema().metadata().clone();
540            decoded.extend(decode_imbalance_batch(&metadata, &batch).unwrap());
541        }
542
543        assert_eq!(decoded.len(), 2);
544        for (orig, dec) in original.iter().zip(decoded.iter()) {
545            assert_eq!(dec, orig);
546        }
547    }
548}