nautilus_serialization/arrow/instrument/
index_instrument.rs1use std::{collections::HashMap, str::FromStr, sync::Arc};
19
20use arrow::{
21 array::{BinaryArray, BinaryBuilder, StringArray, StringBuilder, UInt8Array, UInt64Array},
22 datatypes::{DataType, Field, Schema},
23 error::ArrowError,
24 record_batch::RecordBatch,
25};
26use nautilus_core::{Params, UnixNanos};
27use nautilus_model::{
28 identifiers::{InstrumentId, Symbol},
29 instruments::index_instrument::IndexInstrument,
30 types::{price::Price, quantity::Quantity},
31};
32#[allow(unused)]
33use rust_decimal::Decimal;
34#[allow(unused)]
35use serde_json::Value;
36
37use crate::arrow::{
38 ArrowSchemaProvider, EncodeToRecordBatch, EncodingError, KEY_INSTRUMENT_ID,
39 KEY_PRICE_PRECISION, extract_column,
40};
41
42impl ArrowSchemaProvider for IndexInstrument {
43 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
44 let fields = vec![
45 Field::new("id", DataType::Utf8, false),
46 Field::new("raw_symbol", DataType::Utf8, false),
47 Field::new("currency", DataType::Utf8, false),
48 Field::new("price_precision", DataType::UInt8, false),
49 Field::new("price_increment", DataType::Utf8, false),
50 Field::new("size_precision", DataType::UInt8, false),
51 Field::new("size_increment", DataType::Utf8, false),
52 Field::new("info", DataType::Binary, true), Field::new("ts_event", DataType::UInt64, false),
54 Field::new("ts_init", DataType::UInt64, false),
55 ];
56
57 let mut final_metadata = HashMap::new();
58 final_metadata.insert("class".to_string(), "IndexInstrument".to_string());
59
60 if let Some(meta) = metadata {
61 final_metadata.extend(meta);
62 }
63
64 Schema::new_with_metadata(fields, final_metadata)
65 }
66}
67
68impl EncodeToRecordBatch for IndexInstrument {
69 fn encode_batch(
70 #[allow(unused)] metadata: &HashMap<String, String>,
71 data: &[Self],
72 ) -> Result<RecordBatch, ArrowError> {
73 let mut id_builder = StringBuilder::new();
74 let mut raw_symbol_builder = StringBuilder::new();
75 let mut currency_builder = StringBuilder::new();
76 let mut price_precision_builder = UInt8Array::builder(data.len());
77 let mut size_precision_builder = UInt8Array::builder(data.len());
78 let mut price_increment_builder = StringBuilder::new();
79 let mut size_increment_builder = StringBuilder::new();
80 let mut info_builder = BinaryBuilder::new();
81 let mut ts_event_builder = UInt64Array::builder(data.len());
82 let mut ts_init_builder = UInt64Array::builder(data.len());
83
84 for index in data {
85 id_builder.append_value(index.id.to_string());
86 raw_symbol_builder.append_value(index.raw_symbol);
87 currency_builder.append_value(index.currency.to_string());
88 price_precision_builder.append_value(index.price_precision);
89 price_increment_builder.append_value(index.price_increment.to_string());
90 size_precision_builder.append_value(index.size_precision);
91 size_increment_builder.append_value(index.size_increment.to_string());
92
93 if let Some(ref info) = index.info {
95 match serde_json::to_vec(info) {
96 Ok(json_bytes) => {
97 info_builder.append_value(json_bytes);
98 }
99 Err(e) => {
100 return Err(ArrowError::InvalidArgumentError(format!(
101 "Failed to serialize info dict to JSON: {e}"
102 )));
103 }
104 }
105 } else {
106 info_builder.append_null();
107 }
108
109 ts_event_builder.append_value(index.ts_event.as_u64());
110 ts_init_builder.append_value(index.ts_init.as_u64());
111 }
112
113 let mut final_metadata = metadata.clone();
114 final_metadata.insert("class".to_string(), "IndexInstrument".to_string());
115
116 RecordBatch::try_new(
117 Self::get_schema(Some(final_metadata)).into(),
118 vec![
119 Arc::new(id_builder.finish()),
120 Arc::new(raw_symbol_builder.finish()),
121 Arc::new(currency_builder.finish()),
122 Arc::new(price_precision_builder.finish()),
123 Arc::new(price_increment_builder.finish()),
124 Arc::new(size_precision_builder.finish()),
125 Arc::new(size_increment_builder.finish()),
126 Arc::new(info_builder.finish()),
127 Arc::new(ts_event_builder.finish()),
128 Arc::new(ts_init_builder.finish()),
129 ],
130 )
131 }
132
133 fn metadata(&self) -> HashMap<String, String> {
134 let mut metadata = HashMap::new();
135 metadata.insert(KEY_INSTRUMENT_ID.to_string(), self.id.to_string());
136 metadata.insert(
137 KEY_PRICE_PRECISION.to_string(),
138 self.price_precision.to_string(),
139 );
140 metadata
141 }
142}
143
144pub fn decode_index_instrument_batch(
151 #[allow(unused)] metadata: &HashMap<String, String>,
152 record_batch: &RecordBatch,
153) -> Result<Vec<IndexInstrument>, EncodingError> {
154 let cols = record_batch.columns();
155 let num_rows = record_batch.num_rows();
156
157 let id_values = extract_column::<StringArray>(cols, "id", 0, DataType::Utf8)?;
158 let raw_symbol_values = extract_column::<StringArray>(cols, "raw_symbol", 1, DataType::Utf8)?;
159 let currency_values = extract_column::<StringArray>(cols, "currency", 2, DataType::Utf8)?;
160 let price_precision_values =
161 extract_column::<UInt8Array>(cols, "price_precision", 3, DataType::UInt8)?;
162 let price_increment_values =
163 extract_column::<StringArray>(cols, "price_increment", 4, DataType::Utf8)?;
164 let size_precision_values =
165 extract_column::<UInt8Array>(cols, "size_precision", 5, DataType::UInt8)?;
166 let size_increment_values =
167 extract_column::<StringArray>(cols, "size_increment", 6, DataType::Utf8)?;
168 let info_values = cols
169 .get(7)
170 .ok_or_else(|| EncodingError::MissingColumn("info", 7))?;
171 let ts_event_values = extract_column::<UInt64Array>(cols, "ts_event", 8, DataType::UInt64)?;
172 let ts_init_values = extract_column::<UInt64Array>(cols, "ts_init", 9, DataType::UInt64)?;
173
174 let mut result = Vec::with_capacity(num_rows);
175
176 for i in 0..num_rows {
177 let id = InstrumentId::from_str(id_values.value(i))
178 .map_err(|e| EncodingError::ParseError("id", format!("row {i}: {e}")))?;
179 let raw_symbol = Symbol::from(raw_symbol_values.value(i));
180 let currency = super::decode_currency(
181 currency_values.value(i),
182 "currency",
183 "index_instrument.currency",
184 i,
185 )?;
186 let price_prec = price_precision_values.value(i);
187 let size_prec = size_precision_values.value(i);
188
189 let price_increment = Price::from_str(price_increment_values.value(i))
190 .map_err(|e| EncodingError::ParseError("price_increment", format!("row {i}: {e}")))?;
191 let size_increment = Quantity::from_str(size_increment_values.value(i))
192 .map_err(|e| EncodingError::ParseError("size_increment", format!("row {i}: {e}")))?;
193
194 let info = if info_values.is_null(i) {
196 None
197 } else {
198 let info_bytes = info_values
199 .as_any()
200 .downcast_ref::<BinaryArray>()
201 .ok_or_else(|| EncodingError::ParseError("info", format!("row {i}: invalid type")))?
202 .value(i);
203
204 match serde_json::from_slice::<Params>(info_bytes) {
205 Ok(info_dict) => Some(info_dict),
206 Err(e) => {
207 return Err(EncodingError::ParseError(
208 "info",
209 format!("row {i}: failed to deserialize JSON: {e}"),
210 ));
211 }
212 }
213 };
214
215 let ts_event = UnixNanos::from(ts_event_values.value(i));
216 let ts_init = UnixNanos::from(ts_init_values.value(i));
217
218 let index_instrument = IndexInstrument::new(
219 id,
220 raw_symbol,
221 currency,
222 price_prec,
223 size_prec,
224 price_increment,
225 size_increment,
226 info,
227 ts_event,
228 ts_init,
229 );
230
231 result.push(index_instrument);
232 }
233
234 Ok(result)
235}