1use std::{collections::HashMap, str::FromStr, sync::Arc};
17
18use arrow::{
19 array::{FixedSizeBinaryArray, FixedSizeBinaryBuilder, StringBuilder, UInt64Array},
20 datatypes::{DataType, Field, Schema},
21 error::ArrowError,
22 record_batch::RecordBatch,
23};
24use nautilus_model::{
25 data::{Data, bar::BarType, custom::CustomData},
26 types::fixed::PRECISION_BYTES,
27};
28use nautilus_serialization::arrow::{
29 ArrowSchemaProvider, DecodeDataFromRecordBatch, EncodeToRecordBatch, EncodingError,
30 KEY_PRICE_PRECISION, KEY_SIZE_PRECISION, decode_price, decode_quantity, extract_column,
31 extract_column_string, validate_precision_bytes,
32};
33use rust_decimal::Decimal;
34
35use crate::common::bar::BinanceBar;
36
37const KEY_BAR_TYPE: &str = "bar_type";
38
39fn parse_metadata(metadata: &HashMap<String, String>) -> Result<(BarType, u8, u8), EncodingError> {
40 let bar_type_str = metadata
41 .get(KEY_BAR_TYPE)
42 .ok_or_else(|| EncodingError::MissingMetadata(KEY_BAR_TYPE))?;
43 let bar_type = BarType::from_str(bar_type_str)
44 .map_err(|e| EncodingError::ParseError(KEY_BAR_TYPE, e.to_string()))?;
45
46 let price_precision = metadata
47 .get(KEY_PRICE_PRECISION)
48 .ok_or_else(|| EncodingError::MissingMetadata(KEY_PRICE_PRECISION))?
49 .parse::<u8>()
50 .map_err(|e| EncodingError::ParseError(KEY_PRICE_PRECISION, e.to_string()))?;
51
52 let size_precision = metadata
53 .get(KEY_SIZE_PRECISION)
54 .ok_or_else(|| EncodingError::MissingMetadata(KEY_SIZE_PRECISION))?
55 .parse::<u8>()
56 .map_err(|e| EncodingError::ParseError(KEY_SIZE_PRECISION, e.to_string()))?;
57
58 Ok((bar_type, price_precision, size_precision))
59}
60
61impl ArrowSchemaProvider for BinanceBar {
62 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
63 let fields = vec![
66 Field::new("open", DataType::FixedSizeBinary(PRECISION_BYTES), false),
67 Field::new("high", DataType::FixedSizeBinary(PRECISION_BYTES), false),
68 Field::new("low", DataType::FixedSizeBinary(PRECISION_BYTES), false),
69 Field::new("close", DataType::FixedSizeBinary(PRECISION_BYTES), false),
70 Field::new("volume", DataType::FixedSizeBinary(PRECISION_BYTES), false),
71 Field::new("quote_volume", DataType::Utf8, false),
72 Field::new("count", DataType::UInt64, false),
73 Field::new("taker_buy_base_volume", DataType::Utf8, false),
74 Field::new("taker_buy_quote_volume", DataType::Utf8, false),
75 Field::new("ts_event", DataType::UInt64, false),
76 Field::new("ts_init", DataType::UInt64, false),
77 ];
78
79 match metadata {
80 Some(metadata) => Schema::new_with_metadata(fields, metadata),
81 None => Schema::new(fields),
82 }
83 }
84}
85
86impl EncodeToRecordBatch for BinanceBar {
87 fn encode_batch(
88 metadata: &HashMap<String, String>,
89 data: &[Self],
90 ) -> Result<RecordBatch, ArrowError> {
91 let mut open_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
92 let mut high_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
93 let mut low_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
94 let mut close_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
95 let mut volume_builder = FixedSizeBinaryBuilder::with_capacity(data.len(), PRECISION_BYTES);
96 let mut quote_volume_builder = StringBuilder::with_capacity(data.len(), data.len() * 20);
97 let mut count_builder = UInt64Array::builder(data.len());
98 let mut taker_buy_base_volume_builder =
99 StringBuilder::with_capacity(data.len(), data.len() * 20);
100 let mut taker_buy_quote_volume_builder =
101 StringBuilder::with_capacity(data.len(), data.len() * 20);
102 let mut ts_event_builder = UInt64Array::builder(data.len());
103 let mut ts_init_builder = UInt64Array::builder(data.len());
104
105 for bar in data {
106 open_builder
107 .append_value(bar.open.raw.to_le_bytes())
108 .unwrap();
109 high_builder
110 .append_value(bar.high.raw.to_le_bytes())
111 .unwrap();
112 low_builder.append_value(bar.low.raw.to_le_bytes()).unwrap();
113 close_builder
114 .append_value(bar.close.raw.to_le_bytes())
115 .unwrap();
116 volume_builder
117 .append_value(bar.volume.raw.to_le_bytes())
118 .unwrap();
119 quote_volume_builder.append_value(bar.quote_volume.to_string());
120 count_builder.append_value(bar.count);
121 taker_buy_base_volume_builder.append_value(bar.taker_buy_base_volume.to_string());
122 taker_buy_quote_volume_builder.append_value(bar.taker_buy_quote_volume.to_string());
123 ts_event_builder.append_value(bar.ts_event.as_u64());
124 ts_init_builder.append_value(bar.ts_init.as_u64());
125 }
126
127 RecordBatch::try_new(
128 Self::get_schema(Some(metadata.clone())).into(),
129 vec![
130 Arc::new(open_builder.finish()),
131 Arc::new(high_builder.finish()),
132 Arc::new(low_builder.finish()),
133 Arc::new(close_builder.finish()),
134 Arc::new(volume_builder.finish()),
135 Arc::new(quote_volume_builder.finish()),
136 Arc::new(count_builder.finish()),
137 Arc::new(taker_buy_base_volume_builder.finish()),
138 Arc::new(taker_buy_quote_volume_builder.finish()),
139 Arc::new(ts_event_builder.finish()),
140 Arc::new(ts_init_builder.finish()),
141 ],
142 )
143 }
144
145 fn metadata(&self) -> HashMap<String, String> {
146 let mut metadata = Self::get_metadata(&self.bar_type);
147 metadata.insert(
148 KEY_PRICE_PRECISION.to_string(),
149 self.open.precision.to_string(),
150 );
151 metadata.insert(
152 KEY_SIZE_PRECISION.to_string(),
153 self.volume.precision.to_string(),
154 );
155 metadata
156 }
157}
158
159#[expect(clippy::missing_panics_doc)] pub fn binance_bar_to_arrow_record_batch(
166 data: &[BinanceBar],
167) -> Result<RecordBatch, EncodingError> {
168 if data.is_empty() {
169 return Err(EncodingError::EmptyData);
170 }
171
172 let first = data
173 .first()
174 .expect("Chunk should have at least one element to encode");
175 let metadata = first.metadata();
176 BinanceBar::encode_batch(&metadata, data).map_err(EncodingError::ArrowError)
177}
178
179pub fn decode_binance_bar_batch(
185 metadata: &HashMap<String, String>,
186 record_batch: &RecordBatch,
187) -> Result<Vec<BinanceBar>, EncodingError> {
188 let (bar_type, price_precision, size_precision) = parse_metadata(metadata)?;
189 let cols = record_batch.columns();
190
191 let open_values = extract_column::<FixedSizeBinaryArray>(
192 cols,
193 "open",
194 0,
195 DataType::FixedSizeBinary(PRECISION_BYTES),
196 )?;
197 let high_values = extract_column::<FixedSizeBinaryArray>(
198 cols,
199 "high",
200 1,
201 DataType::FixedSizeBinary(PRECISION_BYTES),
202 )?;
203 let low_values = extract_column::<FixedSizeBinaryArray>(
204 cols,
205 "low",
206 2,
207 DataType::FixedSizeBinary(PRECISION_BYTES),
208 )?;
209 let close_values = extract_column::<FixedSizeBinaryArray>(
210 cols,
211 "close",
212 3,
213 DataType::FixedSizeBinary(PRECISION_BYTES),
214 )?;
215 let volume_values = extract_column::<FixedSizeBinaryArray>(
216 cols,
217 "volume",
218 4,
219 DataType::FixedSizeBinary(PRECISION_BYTES),
220 )?;
221 let quote_volume_values = extract_column_string(cols, "quote_volume", 5)?;
222 let count_values = extract_column::<UInt64Array>(cols, "count", 6, DataType::UInt64)?;
223 let taker_buy_base_volume_values = extract_column_string(cols, "taker_buy_base_volume", 7)?;
224 let taker_buy_quote_volume_values = extract_column_string(cols, "taker_buy_quote_volume", 8)?;
225 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 9, DataType::UInt64)?;
226 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 10, DataType::UInt64)?;
227
228 validate_precision_bytes(open_values, "open")?;
229 validate_precision_bytes(high_values, "high")?;
230 validate_precision_bytes(low_values, "low")?;
231 validate_precision_bytes(close_values, "close")?;
232 validate_precision_bytes(volume_values, "volume")?;
233
234 (0..record_batch.num_rows())
235 .map(|row| {
236 let open = decode_price(open_values.value(row), price_precision, "open", row)?;
237 let high = decode_price(high_values.value(row), price_precision, "high", row)?;
238 let low = decode_price(low_values.value(row), price_precision, "low", row)?;
239 let close = decode_price(close_values.value(row), price_precision, "close", row)?;
240 let volume = decode_quantity(volume_values.value(row), size_precision, "volume", row)?;
241
242 let quote_volume = Decimal::from_str(quote_volume_values.value(row))
243 .map_err(|e| EncodingError::ParseError("quote_volume", e.to_string()))?;
244 let taker_buy_base_volume = Decimal::from_str(taker_buy_base_volume_values.value(row))
245 .map_err(|e| EncodingError::ParseError("taker_buy_base_volume", e.to_string()))?;
246 let taker_buy_quote_volume =
247 Decimal::from_str(taker_buy_quote_volume_values.value(row)).map_err(|e| {
248 EncodingError::ParseError("taker_buy_quote_volume", e.to_string())
249 })?;
250
251 Ok(BinanceBar::new(
252 bar_type,
253 open,
254 high,
255 low,
256 close,
257 volume,
258 quote_volume,
259 count_values.value(row),
260 taker_buy_base_volume,
261 taker_buy_quote_volume,
262 ts_event_values.value(row).into(),
263 ts_init_values.value(row).into(),
264 ))
265 })
266 .collect()
267}
268
269impl DecodeDataFromRecordBatch for BinanceBar {
270 fn decode_data_batch(
271 metadata: &HashMap<String, String>,
272 record_batch: RecordBatch,
273 ) -> Result<Vec<Data>, EncodingError> {
274 let items = decode_binance_bar_batch(metadata, &record_batch)?;
275 Ok(items
276 .into_iter()
277 .map(|item| Data::Custom(CustomData::from_arc(Arc::new(item))))
278 .collect())
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use nautilus_model::types::{Price, Quantity};
285 use rstest::rstest;
286 use rust_decimal_macros::dec;
287
288 use super::*;
289
290 fn stub_binance_bar() -> BinanceBar {
291 BinanceBar::new(
292 BarType::from("BTCUSDT.BINANCE-1-MINUTE-LAST-EXTERNAL"),
293 Price::from("0.01634790"),
294 Price::from("0.01640000"),
295 Price::from("0.01575800"),
296 Price::from("0.01577100"),
297 Quantity::from("148976.11427815"),
298 dec!(2434.19055334),
299 100,
300 dec!(1756.87402397),
301 dec!(28.46694368),
302 1_650_000_000_000_000_000u64.into(),
303 1_650_000_000_000_000_000u64.into(),
304 )
305 }
306
307 #[rstest]
308 fn test_get_schema() {
309 let schema = BinanceBar::get_schema(None);
310 assert_eq!(schema.fields().len(), 11);
311 assert_eq!(schema.field(0).name(), "open");
312 assert_eq!(schema.field(5).name(), "quote_volume");
313 assert_eq!(schema.field(5).data_type(), &DataType::Utf8);
314 assert_eq!(schema.field(6).name(), "count");
315 assert_eq!(schema.field(6).data_type(), &DataType::UInt64);
316 }
317
318 #[rstest]
319 fn test_encode_decode_round_trip() {
320 let bar = stub_binance_bar();
321 let metadata = bar.metadata();
322 let data = vec![bar.clone()];
323
324 let record_batch = BinanceBar::encode_batch(&metadata, &data).unwrap();
325 let decoded = decode_binance_bar_batch(&metadata, &record_batch).unwrap();
326
327 assert_eq!(decoded.len(), 1);
328 assert_eq!(decoded[0], bar);
329 }
330
331 #[rstest]
332 fn test_encode_decode_multiple_bars() {
333 let bar1 = stub_binance_bar();
334 let bar2 = BinanceBar::new(
335 BarType::from("BTCUSDT.BINANCE-1-MINUTE-LAST-EXTERNAL"),
336 Price::from("0.01700000"),
337 Price::from("0.01710000"),
338 Price::from("0.01690000"),
339 Price::from("0.01695000"),
340 Quantity::from("50000.00000000"),
341 dec!(1000.00000000),
342 50,
343 dec!(500.00000000),
344 dec!(10.00000000),
345 1_650_000_060_000_000_000u64.into(),
346 1_650_000_060_000_000_000u64.into(),
347 );
348
349 let metadata = bar1.metadata();
350 let data = vec![bar1.clone(), bar2.clone()];
351
352 let record_batch = BinanceBar::encode_batch(&metadata, &data).unwrap();
353 let decoded = decode_binance_bar_batch(&metadata, &record_batch).unwrap();
354
355 assert_eq!(decoded.len(), 2);
356 assert_eq!(decoded[0], bar1);
357 assert_eq!(decoded[1], bar2);
358 }
359
360 #[rstest]
361 fn test_decode_data_batch_returns_custom_data() {
362 let bar = stub_binance_bar();
363 let metadata = bar.metadata();
364 let data = vec![bar];
365
366 let record_batch = BinanceBar::encode_batch(&metadata, &data).unwrap();
367 let decoded = BinanceBar::decode_data_batch(&metadata, record_batch).unwrap();
368
369 assert_eq!(decoded.len(), 1);
370 assert!(matches!(decoded[0], Data::Custom(_)));
371 }
372}