nautilus_serialization/arrow/display/
trade.rs1use std::sync::Arc;
19
20use arrow::{
21 array::{Float64Builder, StringBuilder, TimestampNanosecondBuilder},
22 datatypes::Schema,
23 error::ArrowError,
24 record_batch::RecordBatch,
25};
26use nautilus_model::data::TradeTick;
27
28use super::{
29 float64_field, price_to_f64, quantity_to_f64, timestamp_field, unix_nanos_to_i64, utf8_field,
30};
31
32#[must_use]
34pub fn trades_schema() -> Schema {
35 Schema::new(vec![
36 utf8_field("instrument_id", false),
37 float64_field("price", false),
38 float64_field("size", false),
39 utf8_field("aggressor_side", false),
40 utf8_field("trade_id", false),
41 timestamp_field("ts_event", false),
42 timestamp_field("ts_init", false),
43 ])
44}
45
46pub fn encode_trades(data: &[TradeTick]) -> Result<RecordBatch, ArrowError> {
60 let mut instrument_id_builder = StringBuilder::new();
61 let mut price_builder = Float64Builder::with_capacity(data.len());
62 let mut size_builder = Float64Builder::with_capacity(data.len());
63 let mut aggressor_side_builder = StringBuilder::new();
64 let mut trade_id_builder = StringBuilder::new();
65 let mut ts_event_builder = TimestampNanosecondBuilder::with_capacity(data.len());
66 let mut ts_init_builder = TimestampNanosecondBuilder::with_capacity(data.len());
67
68 for trade in data {
69 instrument_id_builder.append_value(trade.instrument_id.to_string());
70 price_builder.append_value(price_to_f64(&trade.price));
71 size_builder.append_value(quantity_to_f64(&trade.size));
72 aggressor_side_builder.append_value(format!("{}", trade.aggressor_side));
73 trade_id_builder.append_value(trade.trade_id.to_string());
74 ts_event_builder.append_value(unix_nanos_to_i64(trade.ts_event.as_u64()));
75 ts_init_builder.append_value(unix_nanos_to_i64(trade.ts_init.as_u64()));
76 }
77
78 RecordBatch::try_new(
79 Arc::new(trades_schema()),
80 vec![
81 Arc::new(instrument_id_builder.finish()),
82 Arc::new(price_builder.finish()),
83 Arc::new(size_builder.finish()),
84 Arc::new(aggressor_side_builder.finish()),
85 Arc::new(trade_id_builder.finish()),
86 Arc::new(ts_event_builder.finish()),
87 Arc::new(ts_init_builder.finish()),
88 ],
89 )
90}
91
92#[cfg(test)]
93mod tests {
94 use arrow::{
95 array::{Array, Float64Array, StringArray, TimestampNanosecondArray},
96 datatypes::{DataType, TimeUnit},
97 };
98 use nautilus_model::{
99 enums::AggressorSide,
100 identifiers::{InstrumentId, TradeId},
101 types::{Price, Quantity},
102 };
103 use rstest::rstest;
104
105 use super::*;
106
107 fn make_trade(
108 instrument_id: &str,
109 price: &str,
110 aggressor_side: AggressorSide,
111 trade_id: &str,
112 ts: u64,
113 ) -> TradeTick {
114 TradeTick {
115 instrument_id: InstrumentId::from(instrument_id),
116 price: Price::from(price),
117 size: Quantity::from(1_000),
118 aggressor_side,
119 trade_id: TradeId::new(trade_id),
120 ts_event: ts.into(),
121 ts_init: (ts + 1).into(),
122 }
123 }
124
125 #[rstest]
126 fn test_encode_trades_schema() {
127 let batch = encode_trades(&[]).unwrap();
128 let fields = batch.schema().fields().clone();
129 assert_eq!(fields.len(), 7);
130 assert_eq!(fields[0].name(), "instrument_id");
131 assert_eq!(fields[0].data_type(), &DataType::Utf8);
132 assert_eq!(fields[1].name(), "price");
133 assert_eq!(fields[1].data_type(), &DataType::Float64);
134 assert_eq!(fields[2].name(), "size");
135 assert_eq!(fields[2].data_type(), &DataType::Float64);
136 assert_eq!(fields[3].name(), "aggressor_side");
137 assert_eq!(fields[3].data_type(), &DataType::Utf8);
138 assert_eq!(fields[4].name(), "trade_id");
139 assert_eq!(fields[4].data_type(), &DataType::Utf8);
140 assert_eq!(fields[5].name(), "ts_event");
141 assert_eq!(
142 fields[5].data_type(),
143 &DataType::Timestamp(TimeUnit::Nanosecond, None)
144 );
145 assert_eq!(fields[6].name(), "ts_init");
146 assert_eq!(
147 fields[6].data_type(),
148 &DataType::Timestamp(TimeUnit::Nanosecond, None)
149 );
150 }
151
152 #[rstest]
153 fn test_encode_trades_values() {
154 let trades = vec![
155 make_trade("AAPL.XNAS", "100.10", AggressorSide::Buyer, "T-1", 1_000),
156 make_trade("AAPL.XNAS", "100.20", AggressorSide::Seller, "T-2", 2_000),
157 ];
158 let batch = encode_trades(&trades).unwrap();
159
160 assert_eq!(batch.num_rows(), 2);
161
162 let instrument_id_col = batch
163 .column(0)
164 .as_any()
165 .downcast_ref::<StringArray>()
166 .unwrap();
167 let price_col = batch
168 .column(1)
169 .as_any()
170 .downcast_ref::<Float64Array>()
171 .unwrap();
172 let size_col = batch
173 .column(2)
174 .as_any()
175 .downcast_ref::<Float64Array>()
176 .unwrap();
177 let aggressor_col = batch
178 .column(3)
179 .as_any()
180 .downcast_ref::<StringArray>()
181 .unwrap();
182 let trade_id_col = batch
183 .column(4)
184 .as_any()
185 .downcast_ref::<StringArray>()
186 .unwrap();
187 let ts_event_col = batch
188 .column(5)
189 .as_any()
190 .downcast_ref::<TimestampNanosecondArray>()
191 .unwrap();
192 let ts_init_col = batch
193 .column(6)
194 .as_any()
195 .downcast_ref::<TimestampNanosecondArray>()
196 .unwrap();
197
198 assert_eq!(instrument_id_col.value(0), "AAPL.XNAS");
199 assert!((price_col.value(0) - 100.10).abs() < 1e-9);
200 assert!((price_col.value(1) - 100.20).abs() < 1e-9);
201 assert!((size_col.value(0) - 1_000.0).abs() < 1e-9);
202 assert_eq!(aggressor_col.value(0), format!("{}", AggressorSide::Buyer));
203 assert_eq!(aggressor_col.value(1), format!("{}", AggressorSide::Seller));
204 assert_eq!(trade_id_col.value(0), "T-1");
205 assert_eq!(trade_id_col.value(1), "T-2");
206 assert_eq!(ts_event_col.value(0), 1_000);
207 assert_eq!(ts_init_col.value(1), 2_001);
208 }
209
210 #[rstest]
211 fn test_encode_trades_empty() {
212 let batch = encode_trades(&[]).unwrap();
213 assert_eq!(batch.num_rows(), 0);
214 }
215
216 #[rstest]
217 fn test_encode_trades_mixed_instruments() {
218 let trades = vec![
219 make_trade("AAPL.XNAS", "100.10", AggressorSide::Buyer, "A-1", 1),
220 make_trade("MSFT.XNAS", "250.00", AggressorSide::Seller, "M-1", 2),
221 ];
222 let batch = encode_trades(&trades).unwrap();
223 let instrument_id_col = batch
224 .column(0)
225 .as_any()
226 .downcast_ref::<StringArray>()
227 .unwrap();
228 assert_eq!(instrument_id_col.value(0), "AAPL.XNAS");
229 assert_eq!(instrument_id_col.value(1), "MSFT.XNAS");
230 }
231}