nautilus_serialization/arrow/display/
bar.rs1use std::sync::Arc;
19
20use arrow::{
21 array::{Float64Builder, StringBuilder, TimestampNanosecondBuilder},
22 datatypes::Schema,
23 error::ArrowError,
24 record_batch::RecordBatch,
25};
26use nautilus_model::data::Bar;
27
28use super::{
29 float64_field, price_to_f64, quantity_to_f64, timestamp_field, unix_nanos_to_i64, utf8_field,
30};
31
32#[must_use]
34pub fn bars_schema() -> Schema {
35 Schema::new(vec![
36 utf8_field("instrument_id", false),
37 utf8_field("bar_type", false),
38 float64_field("open", false),
39 float64_field("high", false),
40 float64_field("low", false),
41 float64_field("close", false),
42 float64_field("volume", false),
43 timestamp_field("ts_event", false),
44 timestamp_field("ts_init", false),
45 ])
46}
47
48pub fn encode_bars(data: &[Bar]) -> Result<RecordBatch, ArrowError> {
62 let mut instrument_id_builder = StringBuilder::new();
63 let mut bar_type_builder = StringBuilder::new();
64 let mut open_builder = Float64Builder::with_capacity(data.len());
65 let mut high_builder = Float64Builder::with_capacity(data.len());
66 let mut low_builder = Float64Builder::with_capacity(data.len());
67 let mut close_builder = Float64Builder::with_capacity(data.len());
68 let mut volume_builder = Float64Builder::with_capacity(data.len());
69 let mut ts_event_builder = TimestampNanosecondBuilder::with_capacity(data.len());
70 let mut ts_init_builder = TimestampNanosecondBuilder::with_capacity(data.len());
71
72 for bar in data {
73 instrument_id_builder.append_value(bar.instrument_id().to_string());
74 bar_type_builder.append_value(bar.bar_type.to_string());
75 open_builder.append_value(price_to_f64(&bar.open));
76 high_builder.append_value(price_to_f64(&bar.high));
77 low_builder.append_value(price_to_f64(&bar.low));
78 close_builder.append_value(price_to_f64(&bar.close));
79 volume_builder.append_value(quantity_to_f64(&bar.volume));
80 ts_event_builder.append_value(unix_nanos_to_i64(bar.ts_event.as_u64()));
81 ts_init_builder.append_value(unix_nanos_to_i64(bar.ts_init.as_u64()));
82 }
83
84 RecordBatch::try_new(
85 Arc::new(bars_schema()),
86 vec![
87 Arc::new(instrument_id_builder.finish()),
88 Arc::new(bar_type_builder.finish()),
89 Arc::new(open_builder.finish()),
90 Arc::new(high_builder.finish()),
91 Arc::new(low_builder.finish()),
92 Arc::new(close_builder.finish()),
93 Arc::new(volume_builder.finish()),
94 Arc::new(ts_event_builder.finish()),
95 Arc::new(ts_init_builder.finish()),
96 ],
97 )
98}
99
100#[cfg(test)]
101mod tests {
102 use std::str::FromStr;
103
104 use arrow::{
105 array::{Array, Float64Array, StringArray, TimestampNanosecondArray},
106 datatypes::{DataType, TimeUnit},
107 };
108 use nautilus_model::{
109 data::BarType,
110 types::{Price, Quantity},
111 };
112 use rstest::rstest;
113
114 use super::*;
115
116 fn make_bar(
117 bar_type_str: &str,
118 open: &str,
119 high: &str,
120 low: &str,
121 close: &str,
122 ts: u64,
123 ) -> Bar {
124 let bar_type = BarType::from_str(bar_type_str).unwrap();
125 Bar::new(
126 bar_type,
127 Price::from(open),
128 Price::from(high),
129 Price::from(low),
130 Price::from(close),
131 Quantity::from(1_100),
132 ts.into(),
133 (ts + 1).into(),
134 )
135 }
136
137 #[rstest]
138 fn test_encode_bars_schema() {
139 let batch = encode_bars(&[]).unwrap();
140 let fields = batch.schema().fields().clone();
141 assert_eq!(fields.len(), 9);
142 assert_eq!(fields[0].name(), "instrument_id");
143 assert_eq!(fields[0].data_type(), &DataType::Utf8);
144 assert_eq!(fields[1].name(), "bar_type");
145 assert_eq!(fields[1].data_type(), &DataType::Utf8);
146 assert_eq!(fields[2].name(), "open");
147 assert_eq!(fields[2].data_type(), &DataType::Float64);
148 assert_eq!(fields[3].name(), "high");
149 assert_eq!(fields[4].name(), "low");
150 assert_eq!(fields[5].name(), "close");
151 assert_eq!(fields[6].name(), "volume");
152 assert_eq!(fields[6].data_type(), &DataType::Float64);
153 assert_eq!(fields[7].name(), "ts_event");
154 assert_eq!(
155 fields[7].data_type(),
156 &DataType::Timestamp(TimeUnit::Nanosecond, None)
157 );
158 assert_eq!(fields[8].name(), "ts_init");
159 }
160
161 #[rstest]
162 fn test_encode_bars_values() {
163 let bars = vec![
164 make_bar(
165 "AAPL.XNAS-1-MINUTE-LAST-INTERNAL",
166 "100.10",
167 "102.00",
168 "100.00",
169 "101.00",
170 1_000,
171 ),
172 make_bar(
173 "AAPL.XNAS-1-MINUTE-LAST-INTERNAL",
174 "100.20",
175 "102.00",
176 "100.00",
177 "101.00",
178 2_000,
179 ),
180 ];
181 let batch = encode_bars(&bars).unwrap();
182
183 assert_eq!(batch.num_rows(), 2);
184
185 let instrument_id_col = batch
186 .column(0)
187 .as_any()
188 .downcast_ref::<StringArray>()
189 .unwrap();
190 let bar_type_col = batch
191 .column(1)
192 .as_any()
193 .downcast_ref::<StringArray>()
194 .unwrap();
195 let open_col = batch
196 .column(2)
197 .as_any()
198 .downcast_ref::<Float64Array>()
199 .unwrap();
200 let high_col = batch
201 .column(3)
202 .as_any()
203 .downcast_ref::<Float64Array>()
204 .unwrap();
205 let low_col = batch
206 .column(4)
207 .as_any()
208 .downcast_ref::<Float64Array>()
209 .unwrap();
210 let close_col = batch
211 .column(5)
212 .as_any()
213 .downcast_ref::<Float64Array>()
214 .unwrap();
215 let volume_col = batch
216 .column(6)
217 .as_any()
218 .downcast_ref::<Float64Array>()
219 .unwrap();
220 let ts_event_col = batch
221 .column(7)
222 .as_any()
223 .downcast_ref::<TimestampNanosecondArray>()
224 .unwrap();
225 let ts_init_col = batch
226 .column(8)
227 .as_any()
228 .downcast_ref::<TimestampNanosecondArray>()
229 .unwrap();
230
231 assert_eq!(instrument_id_col.value(0), "AAPL.XNAS");
232 assert_eq!(bar_type_col.value(0), "AAPL.XNAS-1-MINUTE-LAST-INTERNAL");
233 assert!((open_col.value(0) - 100.10).abs() < 1e-9);
234 assert!((open_col.value(1) - 100.20).abs() < 1e-9);
235 assert!((high_col.value(0) - 102.00).abs() < 1e-9);
236 assert!((low_col.value(0) - 100.00).abs() < 1e-9);
237 assert!((close_col.value(0) - 101.00).abs() < 1e-9);
238 assert!((volume_col.value(0) - 1_100.0).abs() < 1e-9);
239 assert_eq!(ts_event_col.value(0), 1_000);
240 assert_eq!(ts_init_col.value(1), 2_001);
241 }
242
243 #[rstest]
244 fn test_encode_bars_empty() {
245 let batch = encode_bars(&[]).unwrap();
246 assert_eq!(batch.num_rows(), 0);
247 }
248
249 #[rstest]
250 fn test_encode_bars_mixed_instruments() {
251 let bars = vec![
252 make_bar(
253 "AAPL.XNAS-1-MINUTE-LAST-INTERNAL",
254 "100.10",
255 "102.00",
256 "100.00",
257 "101.00",
258 1,
259 ),
260 make_bar(
261 "MSFT.XNAS-1-MINUTE-LAST-INTERNAL",
262 "250.00",
263 "251.00",
264 "249.00",
265 "250.50",
266 2,
267 ),
268 ];
269 let batch = encode_bars(&bars).unwrap();
270 let instrument_id_col = batch
271 .column(0)
272 .as_any()
273 .downcast_ref::<StringArray>()
274 .unwrap();
275 assert_eq!(instrument_id_col.value(0), "AAPL.XNAS");
276 assert_eq!(instrument_id_col.value(1), "MSFT.XNAS");
277 }
278}