1use std::{collections::HashMap, str::FromStr, sync::Arc};
19
20use arrow::{
21 array::{BinaryArray, BinaryBuilder, StringArray, StringBuilder, UInt8Array, UInt64Array},
22 datatypes::{DataType, Field, Schema},
23 error::ArrowError,
24 record_batch::RecordBatch,
25};
26use nautilus_core::Params;
27use nautilus_model::{
28 identifiers::{InstrumentId, Symbol},
29 instruments::equity::Equity,
30 types::{price::Price, quantity::Quantity},
31};
32#[allow(unused)]
33use rust_decimal::Decimal;
34#[allow(unused)]
35use serde_json::Value;
36use ustr::Ustr;
37
38use crate::arrow::{
39 ArrowSchemaProvider, EncodeToRecordBatch, EncodingError, KEY_INSTRUMENT_ID,
40 KEY_PRICE_PRECISION, extract_column,
41};
42
43impl ArrowSchemaProvider for Equity {
44 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
45 let fields = vec![
46 Field::new("id", DataType::Utf8, false),
47 Field::new("raw_symbol", DataType::Utf8, false),
48 Field::new("currency", DataType::Utf8, false),
49 Field::new("price_precision", DataType::UInt8, false),
50 Field::new("price_increment", DataType::Utf8, false),
51 Field::new("lot_size", DataType::Utf8, true), Field::new("isin", DataType::Utf8, true), Field::new("max_quantity", DataType::Utf8, true), Field::new("min_quantity", DataType::Utf8, true), Field::new("max_price", DataType::Utf8, true), Field::new("min_price", DataType::Utf8, true), Field::new("margin_init", DataType::Utf8, false),
58 Field::new("margin_maint", DataType::Utf8, false),
59 Field::new("maker_fee", DataType::Utf8, false),
60 Field::new("taker_fee", DataType::Utf8, false),
61 Field::new("info", DataType::Binary, true), Field::new("ts_event", DataType::UInt64, false),
63 Field::new("ts_init", DataType::UInt64, false),
64 ];
65
66 let mut final_metadata = HashMap::new();
67 final_metadata.insert("class".to_string(), "Equity".to_string());
68
69 if let Some(meta) = metadata {
70 final_metadata.extend(meta);
71 }
72
73 Schema::new_with_metadata(fields, final_metadata)
74 }
75}
76
77impl EncodeToRecordBatch for Equity {
78 fn encode_batch(
79 #[allow(unused)] metadata: &HashMap<String, String>,
80 data: &[Self],
81 ) -> Result<RecordBatch, ArrowError> {
82 let mut id_builder = StringBuilder::new();
83 let mut raw_symbol_builder = StringBuilder::new();
84 let mut currency_builder = StringBuilder::new();
85 let mut price_precision_builder = UInt8Array::builder(data.len());
86 let mut price_increment_builder = StringBuilder::new();
87 let mut lot_size_builder = StringBuilder::new();
88 let mut isin_builder = StringBuilder::new();
89 let mut max_quantity_builder = StringBuilder::new();
90 let mut min_quantity_builder = StringBuilder::new();
91 let mut max_price_builder = StringBuilder::new();
92 let mut min_price_builder = StringBuilder::new();
93 let mut margin_init_builder = StringBuilder::new();
94 let mut margin_maint_builder = StringBuilder::new();
95 let mut maker_fee_builder = StringBuilder::new();
96 let mut taker_fee_builder = StringBuilder::new();
97 let mut info_builder = BinaryBuilder::new();
98 let mut ts_event_builder = UInt64Array::builder(data.len());
99 let mut ts_init_builder = UInt64Array::builder(data.len());
100
101 for equity in data {
102 id_builder.append_value(equity.id.to_string());
103 raw_symbol_builder.append_value(equity.raw_symbol);
104 currency_builder.append_value(equity.currency.to_string());
105 price_precision_builder.append_value(equity.price_precision);
106 price_increment_builder.append_value(equity.price_increment.to_string());
107
108 if let Some(lot_size) = equity.lot_size {
109 lot_size_builder.append_value(lot_size.to_string());
110 } else {
111 lot_size_builder.append_null();
112 }
113
114 if let Some(isin) = equity.isin {
115 isin_builder.append_value(isin);
116 } else {
117 isin_builder.append_null();
118 }
119
120 if let Some(max_qty) = equity.max_quantity {
121 max_quantity_builder.append_value(max_qty.to_string());
122 } else {
123 max_quantity_builder.append_null();
124 }
125
126 if let Some(min_qty) = equity.min_quantity {
127 min_quantity_builder.append_value(min_qty.to_string());
128 } else {
129 min_quantity_builder.append_null();
130 }
131
132 if let Some(max_p) = equity.max_price {
133 max_price_builder.append_value(max_p.to_string());
134 } else {
135 max_price_builder.append_null();
136 }
137
138 if let Some(min_p) = equity.min_price {
139 min_price_builder.append_value(min_p.to_string());
140 } else {
141 min_price_builder.append_null();
142 }
143
144 margin_init_builder.append_value(equity.margin_init.to_string());
145 margin_maint_builder.append_value(equity.margin_maint.to_string());
146 maker_fee_builder.append_value(equity.maker_fee.to_string());
147 taker_fee_builder.append_value(equity.taker_fee.to_string());
148
149 if let Some(ref info) = equity.info {
151 match serde_json::to_vec(info) {
152 Ok(json_bytes) => {
153 info_builder.append_value(json_bytes);
154 }
155 Err(e) => {
156 return Err(ArrowError::InvalidArgumentError(format!(
157 "Failed to serialize info dict to JSON: {e}"
158 )));
159 }
160 }
161 } else {
162 info_builder.append_null();
163 }
164
165 ts_event_builder.append_value(equity.ts_event.as_u64());
166 ts_init_builder.append_value(equity.ts_init.as_u64());
167 }
168
169 let mut final_metadata = metadata.clone();
170 final_metadata.insert("class".to_string(), "Equity".to_string());
171
172 RecordBatch::try_new(
173 Self::get_schema(Some(final_metadata)).into(),
174 vec![
175 Arc::new(id_builder.finish()),
176 Arc::new(raw_symbol_builder.finish()),
177 Arc::new(currency_builder.finish()),
178 Arc::new(price_precision_builder.finish()),
179 Arc::new(price_increment_builder.finish()),
180 Arc::new(lot_size_builder.finish()),
181 Arc::new(isin_builder.finish()),
182 Arc::new(max_quantity_builder.finish()),
183 Arc::new(min_quantity_builder.finish()),
184 Arc::new(max_price_builder.finish()),
185 Arc::new(min_price_builder.finish()),
186 Arc::new(margin_init_builder.finish()),
187 Arc::new(margin_maint_builder.finish()),
188 Arc::new(maker_fee_builder.finish()),
189 Arc::new(taker_fee_builder.finish()),
190 Arc::new(info_builder.finish()),
191 Arc::new(ts_event_builder.finish()),
192 Arc::new(ts_init_builder.finish()),
193 ],
194 )
195 }
196
197 fn metadata(&self) -> HashMap<String, String> {
198 let mut metadata = HashMap::new();
199 metadata.insert(KEY_INSTRUMENT_ID.to_string(), self.id.to_string());
200 metadata.insert(
201 KEY_PRICE_PRECISION.to_string(),
202 self.price_precision.to_string(),
203 );
204 metadata
205 }
206}
207
208pub fn decode_equity_batch(
215 #[allow(unused)] metadata: &HashMap<String, String>,
216 record_batch: &RecordBatch,
217) -> Result<Vec<Equity>, EncodingError> {
218 let cols = record_batch.columns();
219 let num_rows = record_batch.num_rows();
220
221 let id_values = extract_column::<StringArray>(cols, "id", 0, DataType::Utf8)?;
223 let raw_symbol_values = extract_column::<StringArray>(cols, "raw_symbol", 1, DataType::Utf8)?;
224 let currency_values = extract_column::<StringArray>(cols, "currency", 2, DataType::Utf8)?;
225 let price_precision_values =
226 extract_column::<UInt8Array>(cols, "price_precision", 3, DataType::UInt8)?;
227 let price_increment_values =
228 extract_column::<StringArray>(cols, "price_increment", 4, DataType::Utf8)?;
229 let lot_size_values = cols
230 .get(5)
231 .ok_or_else(|| EncodingError::MissingColumn("lot_size", 5))?;
232 let isin_values = cols
233 .get(6)
234 .ok_or_else(|| EncodingError::MissingColumn("isin", 6))?;
235 let max_quantity_values = cols
236 .get(7)
237 .ok_or_else(|| EncodingError::MissingColumn("max_quantity", 7))?;
238 let min_quantity_values = cols
239 .get(8)
240 .ok_or_else(|| EncodingError::MissingColumn("min_quantity", 8))?;
241 let max_price_values = cols
242 .get(9)
243 .ok_or_else(|| EncodingError::MissingColumn("max_price", 9))?;
244 let min_price_values = cols
245 .get(10)
246 .ok_or_else(|| EncodingError::MissingColumn("min_price", 10))?;
247 let margin_init_values =
248 extract_column::<StringArray>(cols, "margin_init", 11, DataType::Utf8)?;
249 let margin_maint_values =
250 extract_column::<StringArray>(cols, "margin_maint", 12, DataType::Utf8)?;
251 let maker_fee_values = extract_column::<StringArray>(cols, "maker_fee", 13, DataType::Utf8)?;
252 let taker_fee_values = extract_column::<StringArray>(cols, "taker_fee", 14, DataType::Utf8)?;
253 let info_values = cols
254 .get(15)
255 .ok_or_else(|| EncodingError::MissingColumn("info", 15))?;
256 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 16, DataType::UInt64)?;
257 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 17, DataType::UInt64)?;
258
259 let mut result = Vec::with_capacity(num_rows);
260
261 for i in 0..num_rows {
262 let id = InstrumentId::from_str(id_values.value(i))
263 .map_err(|e| EncodingError::ParseError("id", format!("row {i}: {e}")))?;
264 let raw_symbol = Symbol::from(raw_symbol_values.value(i));
265 let currency =
266 super::decode_currency(currency_values.value(i), "currency", "equity.currency", i)?;
267 let price_prec = price_precision_values.value(i);
268
269 let price_increment = Price::from_str(price_increment_values.value(i))
270 .map_err(|e| EncodingError::ParseError("price_increment", format!("row {i}: {e}")))?;
271
272 let lot_size = if lot_size_values.is_null(i) {
273 None
274 } else {
275 let lot_size_str = lot_size_values
276 .as_any()
277 .downcast_ref::<StringArray>()
278 .ok_or_else(|| {
279 EncodingError::ParseError("lot_size", format!("row {i}: invalid type"))
280 })?
281 .value(i);
282 Some(
283 Quantity::from_str(lot_size_str)
284 .map_err(|e| EncodingError::ParseError("lot_size", format!("row {i}: {e}")))?,
285 )
286 };
287
288 let isin = if isin_values.is_null(i) {
289 None
290 } else {
291 let isin_str = isin_values
292 .as_any()
293 .downcast_ref::<StringArray>()
294 .ok_or_else(|| EncodingError::ParseError("isin", format!("row {i}: invalid type")))?
295 .value(i);
296 Some(Ustr::from(isin_str))
297 };
298
299 let max_quantity =
300 if max_quantity_values.is_null(i) {
301 None
302 } else {
303 let max_qty_str = max_quantity_values
304 .as_any()
305 .downcast_ref::<StringArray>()
306 .ok_or_else(|| {
307 EncodingError::ParseError("max_quantity", format!("row {i}: invalid type"))
308 })?
309 .value(i);
310 Some(Quantity::from_str(max_qty_str).map_err(|e| {
311 EncodingError::ParseError("max_quantity", format!("row {i}: {e}"))
312 })?)
313 };
314
315 let min_quantity =
316 if min_quantity_values.is_null(i) {
317 None
318 } else {
319 let min_qty_str = min_quantity_values
320 .as_any()
321 .downcast_ref::<StringArray>()
322 .ok_or_else(|| {
323 EncodingError::ParseError("min_quantity", format!("row {i}: invalid type"))
324 })?
325 .value(i);
326 Some(Quantity::from_str(min_qty_str).map_err(|e| {
327 EncodingError::ParseError("min_quantity", format!("row {i}: {e}"))
328 })?)
329 };
330
331 let max_price = if max_price_values.is_null(i) {
332 None
333 } else {
334 let max_p_str = max_price_values
335 .as_any()
336 .downcast_ref::<StringArray>()
337 .ok_or_else(|| {
338 EncodingError::ParseError("max_price", format!("row {i}: invalid type"))
339 })?
340 .value(i);
341 Some(
342 Price::from_str(max_p_str)
343 .map_err(|e| EncodingError::ParseError("max_price", format!("row {i}: {e}")))?,
344 )
345 };
346
347 let min_price = if min_price_values.is_null(i) {
348 None
349 } else {
350 let min_p_str = min_price_values
351 .as_any()
352 .downcast_ref::<StringArray>()
353 .ok_or_else(|| {
354 EncodingError::ParseError("min_price", format!("row {i}: invalid type"))
355 })?
356 .value(i);
357 Some(
358 Price::from_str(min_p_str)
359 .map_err(|e| EncodingError::ParseError("min_price", format!("row {i}: {e}")))?,
360 )
361 };
362
363 let margin_init = Decimal::from_str(margin_init_values.value(i))
364 .map_err(|e| EncodingError::ParseError("margin_init", format!("row {i}: {e}")))?;
365 let margin_maint = Decimal::from_str(margin_maint_values.value(i))
366 .map_err(|e| EncodingError::ParseError("margin_maint", format!("row {i}: {e}")))?;
367 let maker_fee = Decimal::from_str(maker_fee_values.value(i))
368 .map_err(|e| EncodingError::ParseError("maker_fee", format!("row {i}: {e}")))?;
369 let taker_fee = Decimal::from_str(taker_fee_values.value(i))
370 .map_err(|e| EncodingError::ParseError("taker_fee", format!("row {i}: {e}")))?;
371
372 let info = if info_values.is_null(i) {
374 None
375 } else {
376 let info_bytes = info_values
377 .as_any()
378 .downcast_ref::<BinaryArray>()
379 .ok_or_else(|| EncodingError::ParseError("info", format!("row {i}: invalid type")))?
380 .value(i);
381
382 match serde_json::from_slice::<Params>(info_bytes) {
383 Ok(info_dict) => Some(info_dict),
384 Err(e) => {
385 return Err(EncodingError::ParseError(
386 "info",
387 format!("row {i}: failed to deserialize JSON: {e}"),
388 ));
389 }
390 }
391 };
392
393 let ts_event = nautilus_core::UnixNanos::from(ts_event_values.value(i));
394 let ts_init = nautilus_core::UnixNanos::from(ts_init_values.value(i));
395
396 let equity = Equity::new(
397 id,
398 raw_symbol,
399 isin,
400 currency,
401 price_prec,
402 price_increment,
403 lot_size,
404 max_quantity,
405 min_quantity,
406 max_price,
407 min_price,
408 Some(margin_init),
409 Some(margin_maint),
410 Some(maker_fee),
411 Some(taker_fee),
412 info,
413 ts_event,
414 ts_init,
415 );
416
417 result.push(equity);
418 }
419
420 Ok(result)
421}