Skip to main content

nautilus_serialization/arrow/display/
bar.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Display-mode Arrow encoder for [`Bar`].
17
18use std::sync::Arc;
19
20use arrow::{
21    array::{Float64Builder, StringBuilder, TimestampNanosecondBuilder},
22    datatypes::Schema,
23    error::ArrowError,
24    record_batch::RecordBatch,
25};
26use nautilus_model::data::Bar;
27
28use super::{
29    float64_field, price_to_f64, quantity_to_f64, timestamp_field, unix_nanos_to_i64, utf8_field,
30};
31
32/// Returns the display-mode Arrow schema for [`Bar`].
33#[must_use]
34pub fn bars_schema() -> Schema {
35    Schema::new(vec![
36        utf8_field("instrument_id", false),
37        utf8_field("bar_type", false),
38        float64_field("open", false),
39        float64_field("high", false),
40        float64_field("low", false),
41        float64_field("close", false),
42        float64_field("volume", false),
43        timestamp_field("ts_event", false),
44        timestamp_field("ts_init", false),
45    ])
46}
47
48/// Encodes bars as a display-friendly Arrow [`RecordBatch`].
49///
50/// Emits `Float64` columns for OHLCV values, `Utf8` columns for the
51/// instrument ID and bar type, and `Timestamp(Nanosecond)` columns for
52/// event and init times. Mixed-instrument batches are supported. Precision
53/// is lost on the conversion to `f64`; use
54/// [`crate::arrow::bars_to_arrow_record_batch_bytes`] for catalog storage.
55///
56/// Returns an empty [`RecordBatch`] with the correct schema when `data` is empty.
57///
58/// # Errors
59///
60/// Returns an [`ArrowError`] if the Arrow `RecordBatch` cannot be constructed.
61pub fn encode_bars(data: &[Bar]) -> Result<RecordBatch, ArrowError> {
62    let mut instrument_id_builder = StringBuilder::new();
63    let mut bar_type_builder = StringBuilder::new();
64    let mut open_builder = Float64Builder::with_capacity(data.len());
65    let mut high_builder = Float64Builder::with_capacity(data.len());
66    let mut low_builder = Float64Builder::with_capacity(data.len());
67    let mut close_builder = Float64Builder::with_capacity(data.len());
68    let mut volume_builder = Float64Builder::with_capacity(data.len());
69    let mut ts_event_builder = TimestampNanosecondBuilder::with_capacity(data.len());
70    let mut ts_init_builder = TimestampNanosecondBuilder::with_capacity(data.len());
71
72    for bar in data {
73        instrument_id_builder.append_value(bar.instrument_id().to_string());
74        bar_type_builder.append_value(bar.bar_type.to_string());
75        open_builder.append_value(price_to_f64(&bar.open));
76        high_builder.append_value(price_to_f64(&bar.high));
77        low_builder.append_value(price_to_f64(&bar.low));
78        close_builder.append_value(price_to_f64(&bar.close));
79        volume_builder.append_value(quantity_to_f64(&bar.volume));
80        ts_event_builder.append_value(unix_nanos_to_i64(bar.ts_event.as_u64()));
81        ts_init_builder.append_value(unix_nanos_to_i64(bar.ts_init.as_u64()));
82    }
83
84    RecordBatch::try_new(
85        Arc::new(bars_schema()),
86        vec![
87            Arc::new(instrument_id_builder.finish()),
88            Arc::new(bar_type_builder.finish()),
89            Arc::new(open_builder.finish()),
90            Arc::new(high_builder.finish()),
91            Arc::new(low_builder.finish()),
92            Arc::new(close_builder.finish()),
93            Arc::new(volume_builder.finish()),
94            Arc::new(ts_event_builder.finish()),
95            Arc::new(ts_init_builder.finish()),
96        ],
97    )
98}
99
100#[cfg(test)]
101mod tests {
102    use std::str::FromStr;
103
104    use arrow::{
105        array::{Array, Float64Array, StringArray, TimestampNanosecondArray},
106        datatypes::{DataType, TimeUnit},
107    };
108    use nautilus_model::{
109        data::BarType,
110        types::{Price, Quantity},
111    };
112    use rstest::rstest;
113
114    use super::*;
115
116    fn make_bar(
117        bar_type_str: &str,
118        open: &str,
119        high: &str,
120        low: &str,
121        close: &str,
122        ts: u64,
123    ) -> Bar {
124        let bar_type = BarType::from_str(bar_type_str).unwrap();
125        Bar::new(
126            bar_type,
127            Price::from(open),
128            Price::from(high),
129            Price::from(low),
130            Price::from(close),
131            Quantity::from(1_100),
132            ts.into(),
133            (ts + 1).into(),
134        )
135    }
136
137    #[rstest]
138    fn test_encode_bars_schema() {
139        let batch = encode_bars(&[]).unwrap();
140        let fields = batch.schema().fields().clone();
141        assert_eq!(fields.len(), 9);
142        assert_eq!(fields[0].name(), "instrument_id");
143        assert_eq!(fields[0].data_type(), &DataType::Utf8);
144        assert_eq!(fields[1].name(), "bar_type");
145        assert_eq!(fields[1].data_type(), &DataType::Utf8);
146        assert_eq!(fields[2].name(), "open");
147        assert_eq!(fields[2].data_type(), &DataType::Float64);
148        assert_eq!(fields[3].name(), "high");
149        assert_eq!(fields[4].name(), "low");
150        assert_eq!(fields[5].name(), "close");
151        assert_eq!(fields[6].name(), "volume");
152        assert_eq!(fields[6].data_type(), &DataType::Float64);
153        assert_eq!(fields[7].name(), "ts_event");
154        assert_eq!(
155            fields[7].data_type(),
156            &DataType::Timestamp(TimeUnit::Nanosecond, None)
157        );
158        assert_eq!(fields[8].name(), "ts_init");
159    }
160
161    #[rstest]
162    fn test_encode_bars_values() {
163        let bars = vec![
164            make_bar(
165                "AAPL.XNAS-1-MINUTE-LAST-INTERNAL",
166                "100.10",
167                "102.00",
168                "100.00",
169                "101.00",
170                1_000,
171            ),
172            make_bar(
173                "AAPL.XNAS-1-MINUTE-LAST-INTERNAL",
174                "100.20",
175                "102.00",
176                "100.00",
177                "101.00",
178                2_000,
179            ),
180        ];
181        let batch = encode_bars(&bars).unwrap();
182
183        assert_eq!(batch.num_rows(), 2);
184
185        let instrument_id_col = batch
186            .column(0)
187            .as_any()
188            .downcast_ref::<StringArray>()
189            .unwrap();
190        let bar_type_col = batch
191            .column(1)
192            .as_any()
193            .downcast_ref::<StringArray>()
194            .unwrap();
195        let open_col = batch
196            .column(2)
197            .as_any()
198            .downcast_ref::<Float64Array>()
199            .unwrap();
200        let high_col = batch
201            .column(3)
202            .as_any()
203            .downcast_ref::<Float64Array>()
204            .unwrap();
205        let low_col = batch
206            .column(4)
207            .as_any()
208            .downcast_ref::<Float64Array>()
209            .unwrap();
210        let close_col = batch
211            .column(5)
212            .as_any()
213            .downcast_ref::<Float64Array>()
214            .unwrap();
215        let volume_col = batch
216            .column(6)
217            .as_any()
218            .downcast_ref::<Float64Array>()
219            .unwrap();
220        let ts_event_col = batch
221            .column(7)
222            .as_any()
223            .downcast_ref::<TimestampNanosecondArray>()
224            .unwrap();
225        let ts_init_col = batch
226            .column(8)
227            .as_any()
228            .downcast_ref::<TimestampNanosecondArray>()
229            .unwrap();
230
231        assert_eq!(instrument_id_col.value(0), "AAPL.XNAS");
232        assert_eq!(bar_type_col.value(0), "AAPL.XNAS-1-MINUTE-LAST-INTERNAL");
233        assert!((open_col.value(0) - 100.10).abs() < 1e-9);
234        assert!((open_col.value(1) - 100.20).abs() < 1e-9);
235        assert!((high_col.value(0) - 102.00).abs() < 1e-9);
236        assert!((low_col.value(0) - 100.00).abs() < 1e-9);
237        assert!((close_col.value(0) - 101.00).abs() < 1e-9);
238        assert!((volume_col.value(0) - 1_100.0).abs() < 1e-9);
239        assert_eq!(ts_event_col.value(0), 1_000);
240        assert_eq!(ts_init_col.value(1), 2_001);
241    }
242
243    #[rstest]
244    fn test_encode_bars_empty() {
245        let batch = encode_bars(&[]).unwrap();
246        assert_eq!(batch.num_rows(), 0);
247    }
248
249    #[rstest]
250    fn test_encode_bars_mixed_instruments() {
251        let bars = vec![
252            make_bar(
253                "AAPL.XNAS-1-MINUTE-LAST-INTERNAL",
254                "100.10",
255                "102.00",
256                "100.00",
257                "101.00",
258                1,
259            ),
260            make_bar(
261                "MSFT.XNAS-1-MINUTE-LAST-INTERNAL",
262                "250.00",
263                "251.00",
264                "249.00",
265                "250.50",
266                2,
267            ),
268        ];
269        let batch = encode_bars(&bars).unwrap();
270        let instrument_id_col = batch
271            .column(0)
272            .as_any()
273            .downcast_ref::<StringArray>()
274            .unwrap();
275        assert_eq!(instrument_id_col.value(0), "AAPL.XNAS");
276        assert_eq!(instrument_id_col.value(1), "MSFT.XNAS");
277    }
278}