1use std::{collections::HashMap, str::FromStr, sync::Arc};
19
20#[allow(unused_imports)]
21use arrow::{
22 array::{
23 BinaryArray, BinaryBuilder, Float64Array, Float64Builder, Int64Array, Int64Builder,
24 StringArray, StringBuilder, UInt8Array, UInt64Array,
25 },
26 datatypes::{DataType, Field, Schema},
27 error::ArrowError,
28 record_batch::RecordBatch,
29};
30#[allow(unused_imports)]
31use nautilus_core::Params;
32use nautilus_model::{
33 identifiers::InstrumentId,
34 instruments::betting::BettingInstrument,
35 types::{price::Price, quantity::Quantity},
36};
37#[allow(unused)]
38use rust_decimal::Decimal;
39#[allow(unused)]
40use serde_json::Value;
41use ustr::Ustr;
42
43use crate::arrow::{
44 ArrowSchemaProvider, EncodeToRecordBatch, EncodingError, KEY_INSTRUMENT_ID,
45 KEY_PRICE_PRECISION, KEY_SIZE_PRECISION, extract_column,
46};
47
48impl ArrowSchemaProvider for BettingInstrument {
49 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
50 let fields = vec![
51 Field::new("id", DataType::Utf8, false),
52 Field::new("venue_name", DataType::Utf8, false),
53 Field::new("currency", DataType::Utf8, false),
54 Field::new("event_type_id", DataType::UInt64, false),
55 Field::new("event_type_name", DataType::Utf8, false),
56 Field::new("competition_id", DataType::UInt64, false),
57 Field::new("competition_name", DataType::Utf8, false),
58 Field::new("event_id", DataType::UInt64, false),
59 Field::new("event_name", DataType::Utf8, false),
60 Field::new("event_country_code", DataType::Utf8, false),
61 Field::new("event_open_date", DataType::UInt64, false),
62 Field::new("betting_type", DataType::Utf8, false),
63 Field::new("market_id", DataType::Utf8, false),
64 Field::new("market_name", DataType::Utf8, false),
65 Field::new("market_type", DataType::Utf8, false),
66 Field::new("market_start_time", DataType::UInt64, false),
67 Field::new("selection_id", DataType::UInt64, false),
68 Field::new("selection_name", DataType::Utf8, false),
69 Field::new("selection_handicap", DataType::Float64, false),
70 Field::new("price_precision", DataType::UInt8, false),
71 Field::new("size_precision", DataType::UInt8, false),
72 Field::new("info", DataType::Binary, true), Field::new("ts_event", DataType::UInt64, false),
74 Field::new("ts_init", DataType::UInt64, false),
75 ];
76
77 let mut final_metadata = HashMap::new();
78 final_metadata.insert("class".to_string(), "BettingInstrument".to_string());
79
80 if let Some(meta) = metadata {
81 final_metadata.extend(meta);
82 }
83
84 Schema::new_with_metadata(fields, final_metadata)
85 }
86}
87
88impl EncodeToRecordBatch for BettingInstrument {
89 fn encode_batch(
90 #[allow(unused)] metadata: &HashMap<String, String>,
91 data: &[Self],
92 ) -> Result<RecordBatch, ArrowError> {
93 let mut id_builder = StringBuilder::new();
94 let mut venue_name_builder = StringBuilder::new();
95 let mut currency_builder = StringBuilder::new();
96 let mut event_type_id_builder = UInt64Array::builder(data.len());
97 let mut event_type_name_builder = StringBuilder::new();
98 let mut competition_id_builder = UInt64Array::builder(data.len());
99 let mut competition_name_builder = StringBuilder::new();
100 let mut event_id_builder = UInt64Array::builder(data.len());
101 let mut event_name_builder = StringBuilder::new();
102 let mut event_country_code_builder = StringBuilder::new();
103 let mut event_open_date_builder = UInt64Array::builder(data.len());
104 let mut betting_type_builder = StringBuilder::new();
105 let mut market_id_builder = StringBuilder::new();
106 let mut market_name_builder = StringBuilder::new();
107 let mut market_type_builder = StringBuilder::new();
108 let mut market_start_time_builder = UInt64Array::builder(data.len());
109 let mut selection_id_builder = UInt64Array::builder(data.len());
110 let mut selection_name_builder = StringBuilder::new();
111 let mut selection_handicap_builder = Float64Array::builder(data.len());
112 let mut price_precision_builder = UInt8Array::builder(data.len());
113 let mut size_precision_builder = UInt8Array::builder(data.len());
114 let mut info_builder = BinaryBuilder::new();
115 let mut ts_event_builder = UInt64Array::builder(data.len());
116 let mut ts_init_builder = UInt64Array::builder(data.len());
117
118 for bi in data {
119 id_builder.append_value(bi.id.to_string());
120 let venue_name = bi.id.venue.to_string();
122 venue_name_builder.append_value(venue_name);
123 currency_builder.append_value(bi.currency.to_string());
124 event_type_id_builder.append_value(bi.event_type_id);
125 event_type_name_builder.append_value(bi.event_type_name);
126 competition_id_builder.append_value(bi.competition_id);
127 competition_name_builder.append_value(bi.competition_name);
128 event_id_builder.append_value(bi.event_id);
129 event_name_builder.append_value(bi.event_name);
130 event_country_code_builder.append_value(bi.event_country_code);
131 event_open_date_builder.append_value(bi.event_open_date.as_u64());
132 betting_type_builder.append_value(bi.betting_type);
133 market_id_builder.append_value(bi.market_id);
134 market_name_builder.append_value(bi.market_name);
135 market_type_builder.append_value(bi.market_type);
136 market_start_time_builder.append_value(bi.market_start_time.as_u64());
137 selection_id_builder.append_value(bi.selection_id);
138 selection_name_builder.append_value(bi.selection_name);
139 selection_handicap_builder.append_value(bi.selection_handicap);
140 price_precision_builder.append_value(bi.price_precision);
141 size_precision_builder.append_value(bi.size_precision);
142
143 if let Some(ref info) = bi.info {
145 match serde_json::to_vec(info) {
146 Ok(json_bytes) => {
147 info_builder.append_value(json_bytes);
148 }
149 Err(e) => {
150 return Err(ArrowError::InvalidArgumentError(format!(
151 "Failed to serialize info dict to JSON: {e}"
152 )));
153 }
154 }
155 } else {
156 info_builder.append_null();
157 }
158
159 ts_event_builder.append_value(bi.ts_event.as_u64());
160 ts_init_builder.append_value(bi.ts_init.as_u64());
161 }
162
163 let mut final_metadata = metadata.clone();
164 final_metadata.insert("class".to_string(), "BettingInstrument".to_string());
165
166 RecordBatch::try_new(
167 Self::get_schema(Some(final_metadata)).into(),
168 vec![
169 Arc::new(id_builder.finish()),
170 Arc::new(venue_name_builder.finish()),
171 Arc::new(currency_builder.finish()),
172 Arc::new(event_type_id_builder.finish()),
173 Arc::new(event_type_name_builder.finish()),
174 Arc::new(competition_id_builder.finish()),
175 Arc::new(competition_name_builder.finish()),
176 Arc::new(event_id_builder.finish()),
177 Arc::new(event_name_builder.finish()),
178 Arc::new(event_country_code_builder.finish()),
179 Arc::new(event_open_date_builder.finish()),
180 Arc::new(betting_type_builder.finish()),
181 Arc::new(market_id_builder.finish()),
182 Arc::new(market_name_builder.finish()),
183 Arc::new(market_type_builder.finish()),
184 Arc::new(market_start_time_builder.finish()),
185 Arc::new(selection_id_builder.finish()),
186 Arc::new(selection_name_builder.finish()),
187 Arc::new(selection_handicap_builder.finish()),
188 Arc::new(price_precision_builder.finish()),
189 Arc::new(size_precision_builder.finish()),
190 Arc::new(info_builder.finish()),
191 Arc::new(ts_event_builder.finish()),
192 Arc::new(ts_init_builder.finish()),
193 ],
194 )
195 }
196
197 fn metadata(&self) -> HashMap<String, String> {
198 let mut metadata = HashMap::new();
199 metadata.insert(KEY_INSTRUMENT_ID.to_string(), self.id.to_string());
200 metadata.insert(
201 KEY_PRICE_PRECISION.to_string(),
202 self.price_precision.to_string(),
203 );
204 metadata.insert(
205 KEY_SIZE_PRECISION.to_string(),
206 self.size_precision.to_string(),
207 );
208 metadata
209 }
210}
211
212pub fn decode_betting_instrument_batch(
219 #[allow(unused)] metadata: &HashMap<String, String>,
220 record_batch: &RecordBatch,
221) -> Result<Vec<BettingInstrument>, EncodingError> {
222 let cols = record_batch.columns();
223 let num_rows = record_batch.num_rows();
224
225 let id_values = extract_column::<StringArray>(cols, "id", 0, DataType::Utf8)?;
226 let _venue_name_values = extract_column::<StringArray>(cols, "venue_name", 1, DataType::Utf8)?; let currency_values = extract_column::<StringArray>(cols, "currency", 2, DataType::Utf8)?;
228 let event_type_id_values =
229 extract_column::<UInt64Array>(cols, "event_type_id", 3, DataType::UInt64)?;
230 let event_type_name_values =
231 extract_column::<StringArray>(cols, "event_type_name", 4, DataType::Utf8)?;
232 let competition_id_values =
233 extract_column::<UInt64Array>(cols, "competition_id", 5, DataType::UInt64)?;
234 let competition_name_values =
235 extract_column::<StringArray>(cols, "competition_name", 6, DataType::Utf8)?;
236 let event_id_values = extract_column::<UInt64Array>(cols, "event_id", 7, DataType::UInt64)?;
237 let event_name_values = extract_column::<StringArray>(cols, "event_name", 8, DataType::Utf8)?;
238 let event_country_code_values =
239 extract_column::<StringArray>(cols, "event_country_code", 9, DataType::Utf8)?;
240 let event_open_date_values =
241 extract_column::<UInt64Array>(cols, "event_open_date", 10, DataType::UInt64)?;
242 let betting_type_values =
243 extract_column::<StringArray>(cols, "betting_type", 11, DataType::Utf8)?;
244 let market_id_values = extract_column::<StringArray>(cols, "market_id", 12, DataType::Utf8)?;
245 let market_name_values =
246 extract_column::<StringArray>(cols, "market_name", 13, DataType::Utf8)?;
247 let market_type_values =
248 extract_column::<StringArray>(cols, "market_type", 14, DataType::Utf8)?;
249 let market_start_time_values =
250 extract_column::<UInt64Array>(cols, "market_start_time", 15, DataType::UInt64)?;
251 let selection_id_values =
252 extract_column::<UInt64Array>(cols, "selection_id", 16, DataType::UInt64)?;
253 let selection_name_values =
254 extract_column::<StringArray>(cols, "selection_name", 17, DataType::Utf8)?;
255 let selection_handicap_values =
256 extract_column::<Float64Array>(cols, "selection_handicap", 18, DataType::Float64)?;
257 let price_precision_values =
258 extract_column::<UInt8Array>(cols, "price_precision", 19, DataType::UInt8)?;
259 let size_precision_values =
260 extract_column::<UInt8Array>(cols, "size_precision", 20, DataType::UInt8)?;
261 let info_values = cols
262 .get(21)
263 .ok_or_else(|| EncodingError::MissingColumn("info", 21))?;
264 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 22, DataType::UInt64)?;
265 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 23, DataType::UInt64)?;
266
267 let mut result = Vec::with_capacity(num_rows);
268
269 for i in 0..num_rows {
270 let id = InstrumentId::from_str(id_values.value(i))
271 .map_err(|e| EncodingError::ParseError("id", format!("row {i}: {e}")))?;
272 let currency = super::decode_currency(
273 currency_values.value(i),
274 "currency",
275 "betting_instrument.currency",
276 i,
277 )?;
278 let event_type_id = event_type_id_values.value(i);
279 let event_type_name = Ustr::from(event_type_name_values.value(i));
280 let competition_id = competition_id_values.value(i);
281 let competition_name = Ustr::from(competition_name_values.value(i));
282 let event_id = event_id_values.value(i);
283 let event_name = Ustr::from(event_name_values.value(i));
284 let event_country_code = Ustr::from(event_country_code_values.value(i));
285 let event_open_date = nautilus_core::UnixNanos::from(event_open_date_values.value(i));
286 let betting_type = Ustr::from(betting_type_values.value(i));
287 let market_id = Ustr::from(market_id_values.value(i));
288 let market_name = Ustr::from(market_name_values.value(i));
289 let market_type = Ustr::from(market_type_values.value(i));
290 let market_start_time = nautilus_core::UnixNanos::from(market_start_time_values.value(i));
291 let selection_id = selection_id_values.value(i);
292 let selection_name = Ustr::from(selection_name_values.value(i));
293 let selection_handicap = selection_handicap_values.value(i);
294 let price_prec = price_precision_values.value(i);
295 let size_prec = size_precision_values.value(i);
296
297 let info = if info_values.is_null(i) {
299 None
300 } else {
301 let info_bytes = info_values
302 .as_any()
303 .downcast_ref::<BinaryArray>()
304 .ok_or_else(|| EncodingError::ParseError("info", format!("row {i}: invalid type")))?
305 .value(i);
306
307 match serde_json::from_slice::<Params>(info_bytes) {
308 Ok(info_dict) => Some(info_dict),
309 Err(e) => {
310 return Err(EncodingError::ParseError(
311 "info",
312 format!("row {i}: failed to deserialize JSON: {e}"),
313 ));
314 }
315 }
316 };
317
318 let ts_event = nautilus_core::UnixNanos::from(ts_event_values.value(i));
319 let ts_init = nautilus_core::UnixNanos::from(ts_init_values.value(i));
320
321 let price_increment = Price::new(0.01, price_prec);
325 let size_increment = Quantity::new(1.0, size_prec);
326
327 let raw_symbol = id.symbol;
329
330 let betting_instrument = BettingInstrument::new(
331 id,
332 raw_symbol,
333 event_type_id,
334 event_type_name,
335 competition_id,
336 competition_name,
337 event_id,
338 event_name,
339 event_country_code,
340 event_open_date,
341 betting_type,
342 market_id,
343 market_name,
344 market_type,
345 market_start_time,
346 selection_id,
347 selection_name,
348 selection_handicap,
349 currency,
350 price_prec,
351 size_prec,
352 price_increment,
353 size_increment,
354 None, None, None, None, None, None, None, None, None, None, info,
365 ts_event,
366 ts_init,
367 );
368
369 result.push(betting_instrument);
370 }
371
372 Ok(result)
373}