1use std::{collections::HashMap, str::FromStr, sync::Arc};
19
20use arrow::{
21 array::{BinaryArray, BinaryBuilder, StringArray, StringBuilder, UInt8Array, UInt64Array},
22 datatypes::{DataType, Field, Schema},
23 error::ArrowError,
24 record_batch::RecordBatch,
25};
26#[allow(unused_imports)]
27use nautilus_core::Params;
28use nautilus_model::{
29 enums::AssetClass,
30 identifiers::{InstrumentId, Symbol},
31 instruments::binary_option::BinaryOption,
32 types::{price::Price, quantity::Quantity},
33};
34#[allow(unused)]
35use rust_decimal::Decimal;
36#[allow(unused)]
37use serde_json::Value;
38use ustr::Ustr;
39
40use crate::arrow::{
41 ArrowSchemaProvider, EncodeToRecordBatch, EncodingError, KEY_INSTRUMENT_ID,
42 KEY_PRICE_PRECISION, KEY_SIZE_PRECISION, extract_column,
43};
44
45fn asset_class_to_string(ac: AssetClass) -> String {
47 match ac {
48 AssetClass::FX => "FX".to_string(),
49 AssetClass::Equity => "Equity".to_string(),
50 AssetClass::Commodity => "Commodity".to_string(),
51 AssetClass::Debt => "Debt".to_string(),
52 AssetClass::Index => "Index".to_string(),
53 AssetClass::Cryptocurrency => "Cryptocurrency".to_string(),
54 AssetClass::Alternative => "Alternative".to_string(),
55 }
56}
57
58fn asset_class_from_str(s: &str) -> Result<AssetClass, EncodingError> {
60 match s {
61 "FX" => Ok(AssetClass::FX),
62 "Equity" => Ok(AssetClass::Equity),
63 "Commodity" => Ok(AssetClass::Commodity),
64 "Debt" => Ok(AssetClass::Debt),
65 "Index" => Ok(AssetClass::Index),
66 "Cryptocurrency" => Ok(AssetClass::Cryptocurrency),
67 "Alternative" => Ok(AssetClass::Alternative),
68 _ => Err(EncodingError::ParseError(
69 "asset_class",
70 format!("Unknown asset class: {s}"),
71 )),
72 }
73}
74
75impl ArrowSchemaProvider for BinaryOption {
76 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
77 let fields = vec![
78 Field::new("id", DataType::Utf8, false),
79 Field::new("raw_symbol", DataType::Utf8, false),
80 Field::new("asset_class", DataType::Utf8, false),
81 Field::new("currency", DataType::Utf8, false),
82 Field::new("price_precision", DataType::UInt8, false),
83 Field::new("size_precision", DataType::UInt8, false),
84 Field::new("price_increment", DataType::Utf8, false),
85 Field::new("size_increment", DataType::Utf8, false),
86 Field::new("activation_ns", DataType::UInt64, false),
87 Field::new("expiration_ns", DataType::UInt64, false),
88 Field::new("maker_fee", DataType::Utf8, false),
89 Field::new("taker_fee", DataType::Utf8, false),
90 Field::new("max_quantity", DataType::Utf8, true), Field::new("min_quantity", DataType::Utf8, true), Field::new("outcome", DataType::Utf8, true), Field::new("description", DataType::Utf8, true), Field::new("info", DataType::Binary, true), Field::new("ts_event", DataType::UInt64, false),
96 Field::new("ts_init", DataType::UInt64, false),
97 ];
98
99 let mut final_metadata = HashMap::new();
100 final_metadata.insert("class".to_string(), "BinaryOption".to_string());
101
102 if let Some(meta) = metadata {
103 final_metadata.extend(meta);
104 }
105
106 Schema::new_with_metadata(fields, final_metadata)
107 }
108}
109
110impl EncodeToRecordBatch for BinaryOption {
111 fn encode_batch(
112 #[allow(unused)] metadata: &HashMap<String, String>,
113 data: &[Self],
114 ) -> Result<RecordBatch, ArrowError> {
115 let mut id_builder = StringBuilder::new();
116 let mut raw_symbol_builder = StringBuilder::new();
117 let mut asset_class_builder = StringBuilder::new();
118 let mut currency_builder = StringBuilder::new();
119 let mut price_precision_builder = UInt8Array::builder(data.len());
120 let mut size_precision_builder = UInt8Array::builder(data.len());
121 let mut price_increment_builder = StringBuilder::new();
122 let mut size_increment_builder = StringBuilder::new();
123 let mut activation_ns_builder = UInt64Array::builder(data.len());
124 let mut expiration_ns_builder = UInt64Array::builder(data.len());
125 let mut maker_fee_builder = StringBuilder::new();
126 let mut taker_fee_builder = StringBuilder::new();
127 let mut max_quantity_builder = StringBuilder::new();
128 let mut min_quantity_builder = StringBuilder::new();
129 let mut outcome_builder = StringBuilder::new();
130 let mut description_builder = StringBuilder::new();
131 let mut info_builder = BinaryBuilder::new();
132 let mut ts_event_builder = UInt64Array::builder(data.len());
133 let mut ts_init_builder = UInt64Array::builder(data.len());
134
135 for bo in data {
136 id_builder.append_value(bo.id.to_string());
137 raw_symbol_builder.append_value(bo.raw_symbol);
138 asset_class_builder.append_value(asset_class_to_string(bo.asset_class));
139 currency_builder.append_value(bo.currency.to_string());
140 price_precision_builder.append_value(bo.price_precision);
141 size_precision_builder.append_value(bo.size_precision);
142 price_increment_builder.append_value(bo.price_increment.to_string());
143 size_increment_builder.append_value(bo.size_increment.to_string());
144 activation_ns_builder.append_value(bo.activation_ns.as_u64());
145 expiration_ns_builder.append_value(bo.expiration_ns.as_u64());
146 maker_fee_builder.append_value(bo.maker_fee.to_string());
147 taker_fee_builder.append_value(bo.taker_fee.to_string());
148
149 if let Some(max_qty) = bo.max_quantity {
150 max_quantity_builder.append_value(max_qty.to_string());
151 } else {
152 max_quantity_builder.append_null();
153 }
154
155 if let Some(min_qty) = bo.min_quantity {
156 min_quantity_builder.append_value(min_qty.to_string());
157 } else {
158 min_quantity_builder.append_null();
159 }
160
161 if let Some(outcome) = bo.outcome {
162 outcome_builder.append_value(outcome);
163 } else {
164 outcome_builder.append_null();
165 }
166
167 if let Some(desc) = bo.description {
168 description_builder.append_value(desc);
169 } else {
170 description_builder.append_null();
171 }
172
173 if let Some(ref info) = bo.info {
175 match serde_json::to_vec(info) {
176 Ok(json_bytes) => {
177 info_builder.append_value(json_bytes);
178 }
179 Err(e) => {
180 return Err(ArrowError::InvalidArgumentError(format!(
181 "Failed to serialize info dict to JSON: {e}"
182 )));
183 }
184 }
185 } else {
186 info_builder.append_null();
187 }
188
189 ts_event_builder.append_value(bo.ts_event.as_u64());
190 ts_init_builder.append_value(bo.ts_init.as_u64());
191 }
192
193 let mut final_metadata = metadata.clone();
194 final_metadata.insert("class".to_string(), "BinaryOption".to_string());
195
196 RecordBatch::try_new(
197 Self::get_schema(Some(final_metadata)).into(),
198 vec![
199 Arc::new(id_builder.finish()),
200 Arc::new(raw_symbol_builder.finish()),
201 Arc::new(asset_class_builder.finish()),
202 Arc::new(currency_builder.finish()),
203 Arc::new(price_precision_builder.finish()),
204 Arc::new(size_precision_builder.finish()),
205 Arc::new(price_increment_builder.finish()),
206 Arc::new(size_increment_builder.finish()),
207 Arc::new(activation_ns_builder.finish()),
208 Arc::new(expiration_ns_builder.finish()),
209 Arc::new(maker_fee_builder.finish()),
210 Arc::new(taker_fee_builder.finish()),
211 Arc::new(max_quantity_builder.finish()),
212 Arc::new(min_quantity_builder.finish()),
213 Arc::new(outcome_builder.finish()),
214 Arc::new(description_builder.finish()),
215 Arc::new(info_builder.finish()),
216 Arc::new(ts_event_builder.finish()),
217 Arc::new(ts_init_builder.finish()),
218 ],
219 )
220 }
221
222 fn metadata(&self) -> HashMap<String, String> {
223 let mut metadata = HashMap::new();
224 metadata.insert(KEY_INSTRUMENT_ID.to_string(), self.id.to_string());
225 metadata.insert(
226 KEY_PRICE_PRECISION.to_string(),
227 self.price_precision.to_string(),
228 );
229 metadata.insert(
230 KEY_SIZE_PRECISION.to_string(),
231 self.size_precision.to_string(),
232 );
233 metadata
234 }
235}
236
237pub fn decode_binary_option_batch(
244 #[allow(unused)] metadata: &HashMap<String, String>,
245 record_batch: &RecordBatch,
246) -> Result<Vec<BinaryOption>, EncodingError> {
247 let cols = record_batch.columns();
248 let num_rows = record_batch.num_rows();
249
250 let id_values = extract_column::<StringArray>(cols, "id", 0, DataType::Utf8)?;
251 let raw_symbol_values = extract_column::<StringArray>(cols, "raw_symbol", 1, DataType::Utf8)?;
252 let asset_class_values = extract_column::<StringArray>(cols, "asset_class", 2, DataType::Utf8)?;
253 let currency_values = extract_column::<StringArray>(cols, "currency", 3, DataType::Utf8)?;
254 let price_precision_values =
255 extract_column::<UInt8Array>(cols, "price_precision", 4, DataType::UInt8)?;
256 let size_precision_values =
257 extract_column::<UInt8Array>(cols, "size_precision", 5, DataType::UInt8)?;
258 let price_increment_values =
259 extract_column::<StringArray>(cols, "price_increment", 6, DataType::Utf8)?;
260 let size_increment_values =
261 extract_column::<StringArray>(cols, "size_increment", 7, DataType::Utf8)?;
262 let activation_ns_values =
263 extract_column::<UInt64Array>(cols, "activation_ns", 8, DataType::UInt64)?;
264 let expiration_ns_values =
265 extract_column::<UInt64Array>(cols, "expiration_ns", 9, DataType::UInt64)?;
266 let maker_fee_values = extract_column::<StringArray>(cols, "maker_fee", 10, DataType::Utf8)?;
267 let taker_fee_values = extract_column::<StringArray>(cols, "taker_fee", 11, DataType::Utf8)?;
268 let max_quantity_values = cols
269 .get(12)
270 .ok_or_else(|| EncodingError::MissingColumn("max_quantity", 12))?;
271 let min_quantity_values = cols
272 .get(13)
273 .ok_or_else(|| EncodingError::MissingColumn("min_quantity", 13))?;
274 let outcome_values = cols
275 .get(14)
276 .ok_or_else(|| EncodingError::MissingColumn("outcome", 14))?;
277 let description_values = cols
278 .get(15)
279 .ok_or_else(|| EncodingError::MissingColumn("description", 15))?;
280 let info_values = cols
281 .get(16)
282 .ok_or_else(|| EncodingError::MissingColumn("info", 16))?;
283 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 17, DataType::UInt64)?;
284 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 18, DataType::UInt64)?;
285
286 let mut result = Vec::with_capacity(num_rows);
287
288 for i in 0..num_rows {
289 let id = InstrumentId::from_str(id_values.value(i))
290 .map_err(|e| EncodingError::ParseError("id", format!("row {i}: {e}")))?;
291 let raw_symbol = Symbol::from(raw_symbol_values.value(i));
292 let asset_class = asset_class_from_str(asset_class_values.value(i))?;
293 let currency = super::decode_currency(
294 currency_values.value(i),
295 "currency",
296 "binary_option.currency",
297 i,
298 )?;
299 let price_prec = price_precision_values.value(i);
300 let size_prec = size_precision_values.value(i);
301
302 let price_increment = Price::from_str(price_increment_values.value(i))
303 .map_err(|e| EncodingError::ParseError("price_increment", format!("row {i}: {e}")))?;
304 let size_increment = Quantity::from_str(size_increment_values.value(i))
305 .map_err(|e| EncodingError::ParseError("size_increment", format!("row {i}: {e}")))?;
306
307 let activation_ns = nautilus_core::UnixNanos::from(activation_ns_values.value(i));
308 let expiration_ns = nautilus_core::UnixNanos::from(expiration_ns_values.value(i));
309
310 let maker_fee = Decimal::from_str(maker_fee_values.value(i))
311 .map_err(|e| EncodingError::ParseError("maker_fee", format!("row {i}: {e}")))?;
312 let taker_fee = Decimal::from_str(taker_fee_values.value(i))
313 .map_err(|e| EncodingError::ParseError("taker_fee", format!("row {i}: {e}")))?;
314
315 let max_quantity =
316 if max_quantity_values.is_null(i) {
317 None
318 } else {
319 let max_qty_str = max_quantity_values
320 .as_any()
321 .downcast_ref::<StringArray>()
322 .ok_or_else(|| {
323 EncodingError::ParseError("max_quantity", format!("row {i}: invalid type"))
324 })?
325 .value(i);
326 Some(Quantity::from_str(max_qty_str).map_err(|e| {
327 EncodingError::ParseError("max_quantity", format!("row {i}: {e}"))
328 })?)
329 };
330
331 let min_quantity =
332 if min_quantity_values.is_null(i) {
333 None
334 } else {
335 let min_qty_str = min_quantity_values
336 .as_any()
337 .downcast_ref::<StringArray>()
338 .ok_or_else(|| {
339 EncodingError::ParseError("min_quantity", format!("row {i}: invalid type"))
340 })?
341 .value(i);
342 Some(Quantity::from_str(min_qty_str).map_err(|e| {
343 EncodingError::ParseError("min_quantity", format!("row {i}: {e}"))
344 })?)
345 };
346
347 let outcome = if outcome_values.is_null(i) {
348 None
349 } else {
350 let outcome_str = outcome_values
351 .as_any()
352 .downcast_ref::<StringArray>()
353 .ok_or_else(|| {
354 EncodingError::ParseError("outcome", format!("row {i}: invalid type"))
355 })?
356 .value(i);
357 Some(Ustr::from(outcome_str))
358 };
359
360 let description = if description_values.is_null(i) {
361 None
362 } else {
363 let desc_str = description_values
364 .as_any()
365 .downcast_ref::<StringArray>()
366 .ok_or_else(|| {
367 EncodingError::ParseError("description", format!("row {i}: invalid type"))
368 })?
369 .value(i);
370 Some(Ustr::from(desc_str))
371 };
372
373 let info = if info_values.is_null(i) {
375 None
376 } else {
377 let info_bytes = info_values
378 .as_any()
379 .downcast_ref::<BinaryArray>()
380 .ok_or_else(|| EncodingError::ParseError("info", format!("row {i}: invalid type")))?
381 .value(i);
382
383 match serde_json::from_slice::<Params>(info_bytes) {
384 Ok(info_dict) => Some(info_dict),
385 Err(e) => {
386 return Err(EncodingError::ParseError(
387 "info",
388 format!("row {i}: failed to deserialize JSON: {e}"),
389 ));
390 }
391 }
392 };
393
394 let ts_event = nautilus_core::UnixNanos::from(ts_event_values.value(i));
395 let ts_init = nautilus_core::UnixNanos::from(ts_init_values.value(i));
396
397 let binary_option = BinaryOption::new(
398 id,
399 raw_symbol,
400 asset_class,
401 currency,
402 activation_ns,
403 expiration_ns,
404 price_prec,
405 size_prec,
406 price_increment,
407 size_increment,
408 outcome,
409 description,
410 max_quantity,
411 min_quantity,
412 None, None, None, None, None, None, Some(maker_fee),
419 Some(taker_fee),
420 info,
421 ts_event,
422 ts_init,
423 );
424
425 result.push(binary_option);
426 }
427
428 Ok(result)
429}