Skip to main content

nautilus_serialization/arrow/display/
trade.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Display-mode Arrow encoder for [`TradeTick`].
17
18use std::sync::Arc;
19
20use arrow::{
21    array::{Float64Builder, StringBuilder, TimestampNanosecondBuilder},
22    datatypes::Schema,
23    error::ArrowError,
24    record_batch::RecordBatch,
25};
26use nautilus_model::data::TradeTick;
27
28use super::{
29    float64_field, price_to_f64, quantity_to_f64, timestamp_field, unix_nanos_to_i64, utf8_field,
30};
31
32/// Returns the display-mode Arrow schema for [`TradeTick`].
33#[must_use]
34pub fn trades_schema() -> Schema {
35    Schema::new(vec![
36        utf8_field("instrument_id", false),
37        float64_field("price", false),
38        float64_field("size", false),
39        utf8_field("aggressor_side", false),
40        utf8_field("trade_id", false),
41        timestamp_field("ts_event", false),
42        timestamp_field("ts_init", false),
43    ])
44}
45
46/// Encodes trades as a display-friendly Arrow [`RecordBatch`].
47///
48/// Emits `Float64` columns for price and size, `Utf8` columns for the
49/// instrument ID, aggressor side, and trade ID, and `Timestamp(Nanosecond)`
50/// columns for event and init times. Mixed-instrument batches are supported.
51/// Precision is lost on the conversion to `f64`; use
52/// [`crate::arrow::trades_to_arrow_record_batch_bytes`] for catalog storage.
53///
54/// Returns an empty [`RecordBatch`] with the correct schema when `data` is empty.
55///
56/// # Errors
57///
58/// Returns an [`ArrowError`] if the Arrow `RecordBatch` cannot be constructed.
59pub fn encode_trades(data: &[TradeTick]) -> Result<RecordBatch, ArrowError> {
60    let mut instrument_id_builder = StringBuilder::new();
61    let mut price_builder = Float64Builder::with_capacity(data.len());
62    let mut size_builder = Float64Builder::with_capacity(data.len());
63    let mut aggressor_side_builder = StringBuilder::new();
64    let mut trade_id_builder = StringBuilder::new();
65    let mut ts_event_builder = TimestampNanosecondBuilder::with_capacity(data.len());
66    let mut ts_init_builder = TimestampNanosecondBuilder::with_capacity(data.len());
67
68    for trade in data {
69        instrument_id_builder.append_value(trade.instrument_id.to_string());
70        price_builder.append_value(price_to_f64(&trade.price));
71        size_builder.append_value(quantity_to_f64(&trade.size));
72        aggressor_side_builder.append_value(format!("{}", trade.aggressor_side));
73        trade_id_builder.append_value(trade.trade_id.to_string());
74        ts_event_builder.append_value(unix_nanos_to_i64(trade.ts_event.as_u64()));
75        ts_init_builder.append_value(unix_nanos_to_i64(trade.ts_init.as_u64()));
76    }
77
78    RecordBatch::try_new(
79        Arc::new(trades_schema()),
80        vec![
81            Arc::new(instrument_id_builder.finish()),
82            Arc::new(price_builder.finish()),
83            Arc::new(size_builder.finish()),
84            Arc::new(aggressor_side_builder.finish()),
85            Arc::new(trade_id_builder.finish()),
86            Arc::new(ts_event_builder.finish()),
87            Arc::new(ts_init_builder.finish()),
88        ],
89    )
90}
91
92#[cfg(test)]
93mod tests {
94    use arrow::{
95        array::{Array, Float64Array, StringArray, TimestampNanosecondArray},
96        datatypes::{DataType, TimeUnit},
97    };
98    use nautilus_model::{
99        enums::AggressorSide,
100        identifiers::{InstrumentId, TradeId},
101        types::{Price, Quantity},
102    };
103    use rstest::rstest;
104
105    use super::*;
106
107    fn make_trade(
108        instrument_id: &str,
109        price: &str,
110        aggressor_side: AggressorSide,
111        trade_id: &str,
112        ts: u64,
113    ) -> TradeTick {
114        TradeTick {
115            instrument_id: InstrumentId::from(instrument_id),
116            price: Price::from(price),
117            size: Quantity::from(1_000),
118            aggressor_side,
119            trade_id: TradeId::new(trade_id),
120            ts_event: ts.into(),
121            ts_init: (ts + 1).into(),
122        }
123    }
124
125    #[rstest]
126    fn test_encode_trades_schema() {
127        let batch = encode_trades(&[]).unwrap();
128        let fields = batch.schema().fields().clone();
129        assert_eq!(fields.len(), 7);
130        assert_eq!(fields[0].name(), "instrument_id");
131        assert_eq!(fields[0].data_type(), &DataType::Utf8);
132        assert_eq!(fields[1].name(), "price");
133        assert_eq!(fields[1].data_type(), &DataType::Float64);
134        assert_eq!(fields[2].name(), "size");
135        assert_eq!(fields[2].data_type(), &DataType::Float64);
136        assert_eq!(fields[3].name(), "aggressor_side");
137        assert_eq!(fields[3].data_type(), &DataType::Utf8);
138        assert_eq!(fields[4].name(), "trade_id");
139        assert_eq!(fields[4].data_type(), &DataType::Utf8);
140        assert_eq!(fields[5].name(), "ts_event");
141        assert_eq!(
142            fields[5].data_type(),
143            &DataType::Timestamp(TimeUnit::Nanosecond, None)
144        );
145        assert_eq!(fields[6].name(), "ts_init");
146        assert_eq!(
147            fields[6].data_type(),
148            &DataType::Timestamp(TimeUnit::Nanosecond, None)
149        );
150    }
151
152    #[rstest]
153    fn test_encode_trades_values() {
154        let trades = vec![
155            make_trade("AAPL.XNAS", "100.10", AggressorSide::Buyer, "T-1", 1_000),
156            make_trade("AAPL.XNAS", "100.20", AggressorSide::Seller, "T-2", 2_000),
157        ];
158        let batch = encode_trades(&trades).unwrap();
159
160        assert_eq!(batch.num_rows(), 2);
161
162        let instrument_id_col = batch
163            .column(0)
164            .as_any()
165            .downcast_ref::<StringArray>()
166            .unwrap();
167        let price_col = batch
168            .column(1)
169            .as_any()
170            .downcast_ref::<Float64Array>()
171            .unwrap();
172        let size_col = batch
173            .column(2)
174            .as_any()
175            .downcast_ref::<Float64Array>()
176            .unwrap();
177        let aggressor_col = batch
178            .column(3)
179            .as_any()
180            .downcast_ref::<StringArray>()
181            .unwrap();
182        let trade_id_col = batch
183            .column(4)
184            .as_any()
185            .downcast_ref::<StringArray>()
186            .unwrap();
187        let ts_event_col = batch
188            .column(5)
189            .as_any()
190            .downcast_ref::<TimestampNanosecondArray>()
191            .unwrap();
192        let ts_init_col = batch
193            .column(6)
194            .as_any()
195            .downcast_ref::<TimestampNanosecondArray>()
196            .unwrap();
197
198        assert_eq!(instrument_id_col.value(0), "AAPL.XNAS");
199        assert!((price_col.value(0) - 100.10).abs() < 1e-9);
200        assert!((price_col.value(1) - 100.20).abs() < 1e-9);
201        assert!((size_col.value(0) - 1_000.0).abs() < 1e-9);
202        assert_eq!(aggressor_col.value(0), format!("{}", AggressorSide::Buyer));
203        assert_eq!(aggressor_col.value(1), format!("{}", AggressorSide::Seller));
204        assert_eq!(trade_id_col.value(0), "T-1");
205        assert_eq!(trade_id_col.value(1), "T-2");
206        assert_eq!(ts_event_col.value(0), 1_000);
207        assert_eq!(ts_init_col.value(1), 2_001);
208    }
209
210    #[rstest]
211    fn test_encode_trades_empty() {
212        let batch = encode_trades(&[]).unwrap();
213        assert_eq!(batch.num_rows(), 0);
214    }
215
216    #[rstest]
217    fn test_encode_trades_mixed_instruments() {
218        let trades = vec![
219            make_trade("AAPL.XNAS", "100.10", AggressorSide::Buyer, "A-1", 1),
220            make_trade("MSFT.XNAS", "250.00", AggressorSide::Seller, "M-1", 2),
221        ];
222        let batch = encode_trades(&trades).unwrap();
223        let instrument_id_col = batch
224            .column(0)
225            .as_any()
226            .downcast_ref::<StringArray>()
227            .unwrap();
228        assert_eq!(instrument_id_col.value(0), "AAPL.XNAS");
229        assert_eq!(instrument_id_col.value(1), "MSFT.XNAS");
230    }
231}