Skip to main content

nautilus_databento/arrow/
statistics.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, sync::Arc};
17
18use arrow::{
19    array::{
20        FixedSizeBinaryArray, FixedSizeBinaryBuilder, Int32Array, UInt8Array, UInt16Array,
21        UInt32Array, UInt64Array,
22    },
23    datatypes::{DataType, Field, Schema},
24    error::ArrowError,
25    record_batch::RecordBatch,
26};
27use nautilus_model::{
28    data::{Data, custom::CustomData},
29    enums::FromU8,
30    types::{
31        PRICE_UNDEF, QUANTITY_UNDEF,
32        fixed::{FIXED_PRECISION, PRECISION_BYTES},
33    },
34};
35use nautilus_serialization::arrow::{
36    ArrowSchemaProvider, DecodeDataFromRecordBatch, EncodeToRecordBatch, EncodingError,
37    decode_price_with_sentinel, decode_quantity_with_sentinel, extract_column,
38    validate_precision_bytes,
39};
40
41use super::parse_metadata;
42use crate::{
43    enums::{DatabentoStatisticType, DatabentoStatisticUpdateAction},
44    types::DatabentoStatistics,
45};
46
47impl ArrowSchemaProvider for DatabentoStatistics {
48    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
49        let fields = vec![
50            Field::new("stat_type", DataType::UInt8, false),
51            Field::new("update_action", DataType::UInt8, false),
52            Field::new("price", DataType::FixedSizeBinary(PRECISION_BYTES), false),
53            Field::new(
54                "quantity",
55                DataType::FixedSizeBinary(PRECISION_BYTES),
56                false,
57            ),
58            Field::new("channel_id", DataType::UInt16, false),
59            Field::new("stat_flags", DataType::UInt8, false),
60            Field::new("sequence", DataType::UInt32, false),
61            Field::new("ts_ref", DataType::UInt64, false),
62            Field::new("ts_in_delta", DataType::Int32, false),
63            Field::new("ts_event", DataType::UInt64, false),
64            Field::new("ts_recv", DataType::UInt64, false),
65            Field::new("ts_init", DataType::UInt64, false),
66        ];
67
68        match metadata {
69            Some(metadata) => Schema::new_with_metadata(fields, metadata),
70            None => Schema::new(fields),
71        }
72    }
73}
74
75impl EncodeToRecordBatch for DatabentoStatistics {
76    fn encode_batch(
77        metadata: &HashMap<String, String>,
78        data: &[Self],
79    ) -> Result<RecordBatch, ArrowError> {
80        let mut stat_type_builder = UInt8Array::builder(data.len());
81        let mut update_action_builder = UInt8Array::builder(data.len());
82        let mut price_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
83        let mut quantity_builder =
84            FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
85        let mut channel_id_builder = UInt16Array::builder(data.len());
86        let mut stat_flags_builder = UInt8Array::builder(data.len());
87        let mut sequence_builder = UInt32Array::builder(data.len());
88        let mut ts_ref_builder = UInt64Array::builder(data.len());
89        let mut ts_in_delta_builder = Int32Array::builder(data.len());
90        let mut ts_event_builder = UInt64Array::builder(data.len());
91        let mut ts_recv_builder = UInt64Array::builder(data.len());
92        let mut ts_init_builder = UInt64Array::builder(data.len());
93
94        for item in data {
95            stat_type_builder.append_value(item.stat_type as u8);
96            update_action_builder.append_value(item.update_action as u8);
97            let price_raw = item.price.map_or(PRICE_UNDEF, |p| p.raw);
98            price_builder.append_value(price_raw.to_le_bytes()).unwrap();
99            let quantity_raw = item.quantity.map_or(QUANTITY_UNDEF, |q| q.raw);
100            quantity_builder
101                .append_value(quantity_raw.to_le_bytes())
102                .unwrap();
103            channel_id_builder.append_value(item.channel_id);
104            stat_flags_builder.append_value(item.stat_flags);
105            sequence_builder.append_value(item.sequence);
106            ts_ref_builder.append_value(item.ts_ref.as_u64());
107            ts_in_delta_builder.append_value(item.ts_in_delta);
108            ts_event_builder.append_value(item.ts_event.as_u64());
109            ts_recv_builder.append_value(item.ts_recv.as_u64());
110            ts_init_builder.append_value(item.ts_init.as_u64());
111        }
112
113        RecordBatch::try_new(
114            Self::get_schema(Some(metadata.clone())).into(),
115            vec![
116                Arc::new(stat_type_builder.finish()),
117                Arc::new(update_action_builder.finish()),
118                Arc::new(price_builder.finish()),
119                Arc::new(quantity_builder.finish()),
120                Arc::new(channel_id_builder.finish()),
121                Arc::new(stat_flags_builder.finish()),
122                Arc::new(sequence_builder.finish()),
123                Arc::new(ts_ref_builder.finish()),
124                Arc::new(ts_in_delta_builder.finish()),
125                Arc::new(ts_event_builder.finish()),
126                Arc::new(ts_recv_builder.finish()),
127                Arc::new(ts_init_builder.finish()),
128            ],
129        )
130    }
131
132    fn metadata(&self) -> HashMap<String, String> {
133        Self::get_metadata(
134            &self.instrument_id,
135            self.price.map_or(FIXED_PRECISION, |p| p.precision),
136            self.quantity.map_or(FIXED_PRECISION, |q| q.precision),
137        )
138    }
139
140    fn chunk_metadata(chunk: &[Self]) -> HashMap<String, String> {
141        let first = chunk
142            .first()
143            .expect("Chunk should have at least one element to encode");
144
145        let price_precision = chunk
146            .iter()
147            .find_map(|s| s.price.map(|p| p.precision))
148            .unwrap_or(FIXED_PRECISION);
149        let size_precision = chunk
150            .iter()
151            .find_map(|s| s.quantity.map(|q| q.precision))
152            .unwrap_or(FIXED_PRECISION);
153
154        Self::get_metadata(&first.instrument_id, price_precision, size_precision)
155    }
156}
157
158impl DecodeDataFromRecordBatch for DatabentoStatistics {
159    fn decode_data_batch(
160        metadata: &HashMap<String, String>,
161        record_batch: RecordBatch,
162    ) -> Result<Vec<Data>, EncodingError> {
163        let items = decode_statistics_batch(metadata, &record_batch)?;
164        Ok(items
165            .into_iter()
166            .map(|item| Data::Custom(CustomData::from_arc(Arc::new(item))))
167            .collect())
168    }
169}
170
171/// Decodes a `RecordBatch` into a vector of [`DatabentoStatistics`].
172///
173/// # Errors
174///
175/// Returns an `EncodingError` if decoding fails.
176pub fn decode_statistics_batch(
177    metadata: &HashMap<String, String>,
178    record_batch: &RecordBatch,
179) -> Result<Vec<DatabentoStatistics>, EncodingError> {
180    let (instrument_id, price_precision, size_precision) = parse_metadata(metadata)?;
181    let cols = record_batch.columns();
182
183    let stat_type_values = extract_column::<UInt8Array>(cols, "stat_type", 0, DataType::UInt8)?;
184    let update_action_values =
185        extract_column::<UInt8Array>(cols, "update_action", 1, DataType::UInt8)?;
186    let price_values = extract_column::<FixedSizeBinaryArray>(
187        cols,
188        "price",
189        2,
190        DataType::FixedSizeBinary(PRECISION_BYTES),
191    )?;
192    let quantity_values = extract_column::<FixedSizeBinaryArray>(
193        cols,
194        "quantity",
195        3,
196        DataType::FixedSizeBinary(PRECISION_BYTES),
197    )?;
198    let channel_id_values = extract_column::<UInt16Array>(cols, "channel_id", 4, DataType::UInt16)?;
199    let stat_flags_values = extract_column::<UInt8Array>(cols, "stat_flags", 5, DataType::UInt8)?;
200    let sequence_values = extract_column::<UInt32Array>(cols, "sequence", 6, DataType::UInt32)?;
201    let ts_ref_values = extract_column::<UInt64Array>(cols, "ts_ref", 7, DataType::UInt64)?;
202    let ts_in_delta_values = extract_column::<Int32Array>(cols, "ts_in_delta", 8, DataType::Int32)?;
203    let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 9, DataType::UInt64)?;
204    let ts_recv_values = extract_column::<UInt64Array>(cols, "ts_recv", 10, DataType::UInt64)?;
205    let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 11, DataType::UInt64)?;
206
207    validate_precision_bytes(price_values, "price")?;
208    validate_precision_bytes(quantity_values, "quantity")?;
209
210    (0..record_batch.num_rows())
211        .map(|row| {
212            let stat_type_value = stat_type_values.value(row);
213            let stat_type = DatabentoStatisticType::from_u8(stat_type_value).ok_or_else(|| {
214                EncodingError::ParseError(
215                    stringify!(DatabentoStatisticType),
216                    format!("Invalid enum value, was {stat_type_value}"),
217                )
218            })?;
219            let update_action_value = update_action_values.value(row);
220            let update_action = DatabentoStatisticUpdateAction::from_u8(update_action_value)
221                .ok_or_else(|| {
222                    EncodingError::ParseError(
223                        stringify!(DatabentoStatisticUpdateAction),
224                        format!("Invalid enum value, was {update_action_value}"),
225                    )
226                })?;
227
228            let price_decoded =
229                decode_price_with_sentinel(price_values.value(row), price_precision, "price", row)?;
230            let price = if price_decoded.raw == PRICE_UNDEF {
231                None
232            } else {
233                Some(price_decoded)
234            };
235
236            let quantity_decoded = decode_quantity_with_sentinel(
237                quantity_values.value(row),
238                size_precision,
239                "quantity",
240                row,
241            )?;
242            let quantity = if quantity_decoded.raw == QUANTITY_UNDEF {
243                None
244            } else {
245                Some(quantity_decoded)
246            };
247
248            Ok(DatabentoStatistics {
249                instrument_id,
250                stat_type,
251                update_action,
252                price,
253                quantity,
254                channel_id: channel_id_values.value(row),
255                stat_flags: stat_flags_values.value(row),
256                sequence: sequence_values.value(row),
257                ts_ref: ts_ref_values.value(row).into(),
258                ts_in_delta: ts_in_delta_values.value(row),
259                ts_event: ts_event_values.value(row).into(),
260                ts_recv: ts_recv_values.value(row).into(),
261                ts_init: ts_init_values.value(row).into(),
262            })
263        })
264        .collect()
265}
266
267/// Encodes a vector of [`DatabentoStatistics`] into an Arrow `RecordBatch`.
268///
269/// # Errors
270///
271/// Returns an error if `data` is empty or encoding fails.
272// Guarded by empty check
273pub fn statistics_to_arrow_record_batch(
274    data: &[DatabentoStatistics],
275) -> Result<RecordBatch, EncodingError> {
276    if data.is_empty() {
277        return Err(EncodingError::EmptyData);
278    }
279
280    let metadata = DatabentoStatistics::chunk_metadata(data);
281    DatabentoStatistics::encode_batch(&metadata, data).map_err(EncodingError::ArrowError)
282}
283
284#[cfg(test)]
285mod tests {
286    use std::collections::HashMap;
287
288    use nautilus_model::{
289        identifiers::InstrumentId,
290        types::{Price, Quantity},
291    };
292    use nautilus_serialization::arrow::{
293        ArrowSchemaProvider, EncodeToRecordBatch, KEY_INSTRUMENT_ID, KEY_PRICE_PRECISION,
294        KEY_SIZE_PRECISION,
295    };
296    use rstest::rstest;
297
298    use super::*;
299
300    fn test_metadata() -> HashMap<String, String> {
301        HashMap::from([
302            (KEY_INSTRUMENT_ID.to_string(), "ESM4.GLBX".to_string()),
303            (KEY_PRICE_PRECISION.to_string(), "2".to_string()),
304            (KEY_SIZE_PRECISION.to_string(), "0".to_string()),
305        ])
306    }
307
308    fn test_statistics(instrument_id: InstrumentId) -> DatabentoStatistics {
309        DatabentoStatistics::new(
310            instrument_id,
311            DatabentoStatisticType::OpeningPrice,
312            DatabentoStatisticUpdateAction::Added,
313            Some(Price::from("5000.50")),
314            Some(Quantity::from("100")),
315            1,
316            0,
317            42,
318            1_000_000_000.into(),
319            500,
320            2_000_000_000.into(),
321            3_000_000_000.into(),
322            4_000_000_000.into(),
323        )
324    }
325
326    #[rstest]
327    fn test_get_schema() {
328        let schema = DatabentoStatistics::get_schema(None);
329        assert_eq!(schema.fields().len(), 12);
330        assert_eq!(schema.field(0).name(), "stat_type");
331        assert_eq!(schema.field(11).name(), "ts_init");
332    }
333
334    #[rstest]
335    fn test_encode_batch() {
336        let instrument_id = InstrumentId::from("ESM4.GLBX");
337        let metadata = test_metadata();
338        let data = vec![test_statistics(instrument_id)];
339        let batch = DatabentoStatistics::encode_batch(&metadata, &data).unwrap();
340
341        assert_eq!(batch.num_rows(), 1);
342        assert_eq!(batch.num_columns(), 12);
343    }
344
345    #[rstest]
346    fn test_encode_decode_round_trip() {
347        let instrument_id = InstrumentId::from("ESM4.GLBX");
348        let metadata = test_metadata();
349        let original = vec![test_statistics(instrument_id)];
350        let batch = DatabentoStatistics::encode_batch(&metadata, &original).unwrap();
351        let decoded = decode_statistics_batch(&metadata, &batch).unwrap();
352
353        assert_eq!(decoded.len(), 1);
354        assert_eq!(decoded[0].instrument_id, instrument_id);
355        assert_eq!(decoded[0].stat_type, original[0].stat_type);
356        assert_eq!(decoded[0].update_action, original[0].update_action);
357        assert_eq!(decoded[0].price, original[0].price);
358        assert_eq!(decoded[0].quantity, original[0].quantity);
359        assert_eq!(decoded[0].channel_id, original[0].channel_id);
360        assert_eq!(decoded[0].stat_flags, original[0].stat_flags);
361        assert_eq!(decoded[0].sequence, original[0].sequence);
362        assert_eq!(decoded[0].ts_ref, original[0].ts_ref);
363        assert_eq!(decoded[0].ts_in_delta, original[0].ts_in_delta);
364        assert_eq!(decoded[0].ts_event, original[0].ts_event);
365        assert_eq!(decoded[0].ts_recv, original[0].ts_recv);
366        assert_eq!(decoded[0].ts_init, original[0].ts_init);
367    }
368
369    #[rstest]
370    fn test_encode_decode_round_trip_with_none_values() {
371        let instrument_id = InstrumentId::from("ESM4.GLBX");
372        let metadata = test_metadata();
373        let stats = DatabentoStatistics::new(
374            instrument_id,
375            DatabentoStatisticType::ClearedVolume,
376            DatabentoStatisticUpdateAction::Added,
377            None,
378            None,
379            1,
380            0,
381            42,
382            1_000_000_000.into(),
383            500,
384            2_000_000_000.into(),
385            3_000_000_000.into(),
386            4_000_000_000.into(),
387        );
388        let original = vec![stats];
389        let batch = DatabentoStatistics::encode_batch(&metadata, &original).unwrap();
390        let decoded = decode_statistics_batch(&metadata, &batch).unwrap();
391
392        assert_eq!(decoded.len(), 1);
393        assert_eq!(decoded[0].price, None);
394        assert_eq!(decoded[0].quantity, None);
395    }
396
397    #[rstest]
398    fn test_chunk_metadata_uses_first_non_none_precision() {
399        let instrument_id = InstrumentId::from("ESM4.GLBX");
400        let none_stats = DatabentoStatistics::new(
401            instrument_id,
402            DatabentoStatisticType::ClearedVolume,
403            DatabentoStatisticUpdateAction::Added,
404            None,
405            None,
406            1,
407            0,
408            42,
409            1_000_000_000.into(),
410            500,
411            2_000_000_000.into(),
412            3_000_000_000.into(),
413            4_000_000_000.into(),
414        );
415        let some_stats = test_statistics(instrument_id);
416        let data = vec![none_stats, some_stats];
417
418        let batch = statistics_to_arrow_record_batch(&data).unwrap();
419        let metadata = batch.schema().metadata().clone();
420        let decoded = decode_statistics_batch(&metadata, &batch).unwrap();
421
422        assert_eq!(decoded.len(), 2);
423        assert_eq!(decoded[0].price, None);
424        assert_eq!(decoded[0].quantity, None);
425        assert_eq!(decoded[1].price, data[1].price);
426        assert_eq!(decoded[1].quantity, data[1].quantity);
427    }
428
429    #[rstest]
430    fn test_encode_decode_multiple_rows() {
431        let instrument_id = InstrumentId::from("ESM4.GLBX");
432        let metadata = test_metadata();
433        let stats1 = test_statistics(instrument_id);
434        let stats2 = DatabentoStatistics::new(
435            instrument_id,
436            DatabentoStatisticType::ClearedVolume,
437            DatabentoStatisticUpdateAction::Added,
438            Some(Price::from("5100.25")),
439            None,
440            2,
441            1,
442            43,
443            2_000_000_000.into(),
444            600,
445            3_000_000_000.into(),
446            4_000_000_000.into(),
447            5_000_000_000.into(),
448        );
449        let stats3 = DatabentoStatistics::new(
450            instrument_id,
451            DatabentoStatisticType::OpeningPrice,
452            DatabentoStatisticUpdateAction::Added,
453            None,
454            Some(Quantity::from("200")),
455            3,
456            0,
457            44,
458            3_000_000_000.into(),
459            700,
460            4_000_000_000.into(),
461            5_000_000_000.into(),
462            6_000_000_000.into(),
463        );
464        let original = vec![stats1, stats2, stats3];
465
466        let batch = DatabentoStatistics::encode_batch(&metadata, &original).unwrap();
467        assert_eq!(batch.num_rows(), 3);
468
469        let decoded = decode_statistics_batch(&metadata, &batch).unwrap();
470        assert_eq!(decoded.len(), 3);
471        for (orig, dec) in original.iter().zip(decoded.iter()) {
472            assert_eq!(dec.instrument_id, orig.instrument_id);
473            assert_eq!(dec.stat_type, orig.stat_type);
474            assert_eq!(dec.price, orig.price);
475            assert_eq!(dec.quantity, orig.quantity);
476            assert_eq!(dec.channel_id, orig.channel_id);
477            assert_eq!(dec.sequence, orig.sequence);
478        }
479    }
480
481    #[rstest]
482    fn test_statistics_to_arrow_record_batch_round_trip() {
483        let instrument_id = InstrumentId::from("ESM4.GLBX");
484        let original = vec![test_statistics(instrument_id)];
485        let batch = statistics_to_arrow_record_batch(&original).unwrap();
486        let metadata = batch.schema().metadata().clone();
487        let decoded = decode_statistics_batch(&metadata, &batch).unwrap();
488
489        assert_eq!(decoded.len(), 1);
490        assert_eq!(decoded[0].price, original[0].price);
491        assert_eq!(decoded[0].quantity, original[0].quantity);
492    }
493
494    #[rstest]
495    fn test_chunk_metadata_all_none_uses_fixed_precision() {
496        use nautilus_model::types::fixed::FIXED_PRECISION;
497
498        let instrument_id = InstrumentId::from("ESM4.GLBX");
499        let stats = DatabentoStatistics::new(
500            instrument_id,
501            DatabentoStatisticType::ClearedVolume,
502            DatabentoStatisticUpdateAction::Added,
503            None,
504            None,
505            1,
506            0,
507            42,
508            1_000_000_000.into(),
509            500,
510            2_000_000_000.into(),
511            3_000_000_000.into(),
512            4_000_000_000.into(),
513        );
514        let data = vec![stats];
515        let metadata = DatabentoStatistics::chunk_metadata(&data);
516
517        assert_eq!(
518            metadata.get(KEY_PRICE_PRECISION).unwrap(),
519            &FIXED_PRECISION.to_string(),
520        );
521        assert_eq!(
522            metadata.get(KEY_SIZE_PRECISION).unwrap(),
523            &FIXED_PRECISION.to_string(),
524        );
525    }
526
527    #[rstest]
528    fn test_all_none_metadata_decodes_real_prices_correctly() {
529        use nautilus_model::types::fixed::FIXED_PRECISION;
530
531        let instrument_id = InstrumentId::from("ESM4.GLBX");
532        let price = Price::from("5000.50");
533        let quantity = Quantity::from("100");
534        let stats = DatabentoStatistics::new(
535            instrument_id,
536            DatabentoStatisticType::OpeningPrice,
537            DatabentoStatisticUpdateAction::Added,
538            Some(price),
539            Some(quantity),
540            1,
541            0,
542            42,
543            1_000_000_000.into(),
544            500,
545            2_000_000_000.into(),
546            3_000_000_000.into(),
547            4_000_000_000.into(),
548        );
549
550        // Encode with FIXED_PRECISION metadata (as if from an all-None chunk)
551        let metadata = HashMap::from([
552            (KEY_INSTRUMENT_ID.to_string(), "ESM4.GLBX".to_string()),
553            (KEY_PRICE_PRECISION.to_string(), FIXED_PRECISION.to_string()),
554            (KEY_SIZE_PRECISION.to_string(), FIXED_PRECISION.to_string()),
555        ]);
556
557        let batch = DatabentoStatistics::encode_batch(&metadata, &[stats]).unwrap();
558        let decoded = decode_statistics_batch(&metadata, &batch).unwrap();
559
560        assert_eq!(decoded.len(), 1);
561        assert_eq!(decoded[0].price.unwrap().as_f64(), price.as_f64());
562        assert_eq!(decoded[0].quantity.unwrap().as_f64(), quantity.as_f64());
563    }
564
565    #[rstest]
566    fn test_get_schema_with_metadata() {
567        let metadata = test_metadata();
568        let schema = DatabentoStatistics::get_schema(Some(metadata.clone()));
569        assert_eq!(schema.metadata(), &metadata);
570        assert_eq!(schema.fields().len(), 12);
571    }
572
573    #[rstest]
574    fn test_decode_missing_metadata_returns_error() {
575        let instrument_id = InstrumentId::from("ESM4.GLBX");
576        let metadata = test_metadata();
577        let data = vec![test_statistics(instrument_id)];
578        let batch = DatabentoStatistics::encode_batch(&metadata, &data).unwrap();
579
580        let empty_metadata = HashMap::new();
581        let result = decode_statistics_batch(&empty_metadata, &batch);
582        assert!(result.is_err());
583    }
584
585    #[rstest]
586    fn test_statistics_to_arrow_record_batch_empty() {
587        let result = statistics_to_arrow_record_batch(&[]);
588        assert!(result.is_err());
589    }
590
591    #[rstest]
592    fn test_decode_data_batch_produces_custom_data() {
593        let instrument_id = InstrumentId::from("ESM4.GLBX");
594        let metadata = test_metadata();
595        let original = vec![test_statistics(instrument_id)];
596        let batch = DatabentoStatistics::encode_batch(&metadata, &original).unwrap();
597        let data_vec = DatabentoStatistics::decode_data_batch(&metadata, batch).unwrap();
598
599        assert_eq!(data_vec.len(), 1);
600        match &data_vec[0] {
601            Data::Custom(custom) => {
602                assert_eq!(custom.data.type_name(), "DatabentoStatistics");
603                let stats = custom
604                    .data
605                    .as_any()
606                    .downcast_ref::<DatabentoStatistics>()
607                    .unwrap();
608                assert_eq!(stats.instrument_id, instrument_id);
609                assert_eq!(stats.stat_type, original[0].stat_type);
610                assert_eq!(stats.price, original[0].price);
611                assert_eq!(stats.quantity, original[0].quantity);
612                assert_eq!(stats.ts_event, original[0].ts_event);
613                assert_eq!(stats.ts_init, original[0].ts_init);
614            }
615            other => panic!("Expected Data::Custom, was {other:?}"),
616        }
617    }
618
619    #[rstest]
620    fn test_decode_data_batch_multiple_rows() {
621        let instrument_id = InstrumentId::from("ESM4.GLBX");
622        let metadata = test_metadata();
623        let stats2 = DatabentoStatistics::new(
624            instrument_id,
625            DatabentoStatisticType::ClearedVolume,
626            DatabentoStatisticUpdateAction::Added,
627            None,
628            Some(Quantity::from("200")),
629            2,
630            1,
631            43,
632            2_000_000_000.into(),
633            600,
634            3_000_000_000.into(),
635            4_000_000_000.into(),
636            5_000_000_000.into(),
637        );
638        let original = vec![test_statistics(instrument_id), stats2];
639        let batch = DatabentoStatistics::encode_batch(&metadata, &original).unwrap();
640        let data_vec = DatabentoStatistics::decode_data_batch(&metadata, batch).unwrap();
641
642        assert_eq!(data_vec.len(), 2);
643        for (i, data) in data_vec.iter().enumerate() {
644            match data {
645                Data::Custom(custom) => {
646                    let stats = custom
647                        .data
648                        .as_any()
649                        .downcast_ref::<DatabentoStatistics>()
650                        .unwrap();
651                    assert_eq!(stats.instrument_id, original[i].instrument_id);
652                    assert_eq!(stats.stat_type, original[i].stat_type);
653                    assert_eq!(stats.price, original[i].price);
654                    assert_eq!(stats.quantity, original[i].quantity);
655                }
656                other => panic!("Expected Data::Custom, was {other:?}"),
657            }
658        }
659    }
660
661    #[rstest]
662    fn test_ipc_stream_round_trip() {
663        use std::io::Cursor;
664
665        use arrow::ipc::{reader::StreamReader, writer::StreamWriter};
666
667        let instrument_id = InstrumentId::from("ESM4.GLBX");
668        let original = vec![
669            test_statistics(instrument_id),
670            DatabentoStatistics::new(
671                instrument_id,
672                DatabentoStatisticType::ClearedVolume,
673                DatabentoStatisticUpdateAction::Added,
674                None,
675                Some(Quantity::from("200")),
676                2,
677                1,
678                43,
679                2_000_000_000.into(),
680                600,
681                3_000_000_000.into(),
682                4_000_000_000.into(),
683                5_000_000_000.into(),
684            ),
685        ];
686        let batch = statistics_to_arrow_record_batch(&original).unwrap();
687
688        let mut cursor = Cursor::new(Vec::new());
689        {
690            let mut writer = StreamWriter::try_new(&mut cursor, &batch.schema()).unwrap();
691            writer.write(&batch).unwrap();
692            writer.finish().unwrap();
693        }
694
695        let buffer = cursor.into_inner();
696        let reader = StreamReader::try_new(Cursor::new(buffer), None).unwrap();
697        let mut decoded = Vec::new();
698
699        for batch_result in reader {
700            let batch = batch_result.unwrap();
701            let metadata = batch.schema().metadata().clone();
702            decoded.extend(decode_statistics_batch(&metadata, &batch).unwrap());
703        }
704
705        assert_eq!(decoded.len(), 2);
706        for (orig, dec) in original.iter().zip(decoded.iter()) {
707            assert_eq!(dec, orig);
708        }
709    }
710}