Skip to main content

nautilus_serialization/arrow/
snapshot.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::collections::HashMap;
17
18use arrow::{datatypes::Schema, error::ArrowError, record_batch::RecordBatch};
19use nautilus_model::events::{OrderSnapshot, PositionSnapshot};
20
21use super::{
22    ArrowSchemaProvider, DecodeTypedFromRecordBatch, EncodeToRecordBatch, EncodingError,
23    KEY_INSTRUMENT_ID,
24    json::{JsonFieldSpec, decode_batch, encode_batch, metadata_for_type, schema_for_type},
25};
26
27const ORDER_SNAPSHOT_FIELDS: &[JsonFieldSpec] = &[
28    JsonFieldSpec::utf8("trader_id", false),
29    JsonFieldSpec::utf8("strategy_id", false),
30    JsonFieldSpec::utf8("instrument_id", false),
31    JsonFieldSpec::utf8("client_order_id", false),
32    JsonFieldSpec::utf8("venue_order_id", true),
33    JsonFieldSpec::utf8("position_id", true),
34    JsonFieldSpec::utf8("account_id", true),
35    JsonFieldSpec::utf8("last_trade_id", true),
36    JsonFieldSpec::utf8("order_type", false),
37    JsonFieldSpec::utf8("order_side", false),
38    JsonFieldSpec::utf8("quantity", false),
39    JsonFieldSpec::utf8("price", true),
40    JsonFieldSpec::utf8("trigger_price", true),
41    JsonFieldSpec::utf8("trigger_type", true),
42    JsonFieldSpec::utf8("limit_offset", true),
43    JsonFieldSpec::utf8("trailing_offset", true),
44    JsonFieldSpec::utf8("trailing_offset_type", true),
45    JsonFieldSpec::utf8("time_in_force", false),
46    JsonFieldSpec::u64("expire_time", true),
47    JsonFieldSpec::utf8("filled_qty", false),
48    JsonFieldSpec::utf8("liquidity_side", true),
49    JsonFieldSpec::f64("avg_px", true),
50    JsonFieldSpec::f64("slippage", true),
51    JsonFieldSpec::utf8_json("commissions", false),
52    JsonFieldSpec::utf8("status", false),
53    JsonFieldSpec::boolean("is_post_only", false),
54    JsonFieldSpec::boolean("is_reduce_only", false),
55    JsonFieldSpec::boolean("is_quote_quantity", false),
56    JsonFieldSpec::utf8("display_qty", true),
57    JsonFieldSpec::utf8("emulation_trigger", true),
58    JsonFieldSpec::utf8("trigger_instrument_id", true),
59    JsonFieldSpec::utf8("contingency_type", true),
60    JsonFieldSpec::utf8("order_list_id", true),
61    JsonFieldSpec::utf8_json("linked_order_ids", true),
62    JsonFieldSpec::utf8("parent_order_id", true),
63    JsonFieldSpec::utf8("exec_algorithm_id", true),
64    JsonFieldSpec::utf8_json("exec_algorithm_params", true),
65    JsonFieldSpec::utf8("exec_spawn_id", true),
66    JsonFieldSpec::utf8_json("tags", true),
67    JsonFieldSpec::utf8("init_id", false),
68    JsonFieldSpec::u64("ts_init", false),
69    JsonFieldSpec::u64("ts_last", false),
70];
71
72const POSITION_SNAPSHOT_FIELDS: &[JsonFieldSpec] = &[
73    JsonFieldSpec::utf8("trader_id", false),
74    JsonFieldSpec::utf8("strategy_id", false),
75    JsonFieldSpec::utf8("instrument_id", false),
76    JsonFieldSpec::utf8("position_id", false),
77    JsonFieldSpec::utf8("account_id", false),
78    JsonFieldSpec::utf8("opening_order_id", false),
79    JsonFieldSpec::utf8("closing_order_id", true),
80    JsonFieldSpec::utf8("entry", false),
81    JsonFieldSpec::utf8("side", false),
82    JsonFieldSpec::f64("signed_qty", false),
83    JsonFieldSpec::utf8("quantity", false),
84    JsonFieldSpec::utf8("peak_qty", false),
85    JsonFieldSpec::utf8("quote_currency", false),
86    JsonFieldSpec::utf8("base_currency", true),
87    JsonFieldSpec::utf8("settlement_currency", false),
88    JsonFieldSpec::f64("avg_px_open", false),
89    JsonFieldSpec::f64("avg_px_close", true),
90    JsonFieldSpec::f64("realized_return", true),
91    JsonFieldSpec::utf8("realized_pnl", true),
92    JsonFieldSpec::utf8("unrealized_pnl", true),
93    JsonFieldSpec::utf8_json("commissions", false),
94    JsonFieldSpec::u64("duration_ns", true),
95    JsonFieldSpec::u64("ts_opened", false),
96    JsonFieldSpec::u64("ts_closed", true),
97    JsonFieldSpec::u64("ts_init", false),
98    JsonFieldSpec::u64("ts_last", false),
99];
100
101fn instrument_metadata(type_name: &'static str, instrument_id: &str) -> HashMap<String, String> {
102    let mut metadata = metadata_for_type(type_name);
103    metadata.insert(KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string());
104    metadata
105}
106
107macro_rules! impl_snapshot_arrow {
108    ($type:ty, $type_name:expr, $fields:expr) => {
109        impl ArrowSchemaProvider for $type {
110            fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
111                schema_for_type($type_name, metadata, $fields)
112            }
113        }
114
115        impl EncodeToRecordBatch for $type {
116            fn encode_batch(
117                metadata: &HashMap<String, String>,
118                data: &[Self],
119            ) -> Result<RecordBatch, ArrowError> {
120                encode_batch($type_name, metadata, data, $fields)
121            }
122
123            fn metadata(&self) -> HashMap<String, String> {
124                instrument_metadata($type_name, &self.instrument_id.to_string())
125            }
126        }
127
128        impl DecodeTypedFromRecordBatch for $type {
129            fn decode_typed_batch(
130                metadata: &HashMap<String, String>,
131                record_batch: RecordBatch,
132            ) -> Result<Vec<Self>, EncodingError> {
133                decode_batch(metadata, &record_batch, $fields, Some($type_name))
134            }
135        }
136    };
137}
138
139impl_snapshot_arrow!(OrderSnapshot, "OrderSnapshot", ORDER_SNAPSHOT_FIELDS);
140impl_snapshot_arrow!(
141    PositionSnapshot,
142    "PositionSnapshot",
143    POSITION_SNAPSHOT_FIELDS
144);
145
146#[cfg(test)]
147mod tests {
148    use std::str::FromStr;
149
150    use nautilus_core::UnixNanos;
151    use nautilus_model::{
152        enums::{OrderSide, OrderType, PositionSide},
153        identifiers::{AccountId, ClientOrderId, InstrumentId, PositionId, StrategyId, TraderId},
154        orders::OrderTestBuilder,
155        types::{Currency, Money, Price, Quantity},
156    };
157    use rstest::rstest;
158    use rust_decimal::Decimal;
159
160    use super::*;
161
162    #[rstest]
163    fn test_order_snapshot_round_trip_preserves_decimal_precision() {
164        let order = OrderTestBuilder::new(OrderType::TrailingStopLimit)
165            .instrument_id(InstrumentId::from("BTCUSDT.BINANCE"))
166            .side(OrderSide::Buy)
167            .price(Price::from("50000"))
168            .trigger_price(Price::from("50500"))
169            .limit_offset(Decimal::from_str("0.123456789123456789").unwrap())
170            .trailing_offset(Decimal::from_str("0.987654321987654321").unwrap())
171            .quantity(Quantity::from("0.5"))
172            .build();
173        let snapshot = OrderSnapshot::from(order);
174        let metadata = snapshot.metadata();
175        let batch =
176            OrderSnapshot::encode_batch(&metadata, std::slice::from_ref(&snapshot)).unwrap();
177        let decoded = OrderSnapshot::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
178
179        assert_eq!(decoded, vec![snapshot]);
180    }
181
182    fn make_position_snapshot() -> PositionSnapshot {
183        PositionSnapshot {
184            trader_id: TraderId::from("TRADER-001"),
185            strategy_id: StrategyId::from("EMA-CROSS"),
186            instrument_id: InstrumentId::from("EURUSD.SIM"),
187            position_id: PositionId::from("P-001"),
188            account_id: AccountId::from("SIM-001"),
189            opening_order_id: ClientOrderId::from("O-1"),
190            closing_order_id: Some(ClientOrderId::from("O-2")),
191            entry: OrderSide::Buy,
192            side: PositionSide::Long,
193            signed_qty: 100.0,
194            quantity: Quantity::from("100"),
195            peak_qty: Quantity::from("100"),
196            quote_currency: Currency::USD(),
197            base_currency: Some(Currency::EUR()),
198            settlement_currency: Currency::USD(),
199            avg_px_open: 1.0500,
200            avg_px_close: Some(1.0600),
201            realized_return: Some(0.0095),
202            realized_pnl: Some(Money::new(100.0, Currency::USD())),
203            unrealized_pnl: Some(Money::new(50.0, Currency::USD())),
204            commissions: vec![Money::new(2.0, Currency::USD())],
205            duration_ns: Some(3_600_000_000_000),
206            ts_opened: UnixNanos::from(1_000_000_000),
207            ts_closed: Some(UnixNanos::from(4_600_000_000)),
208            ts_init: UnixNanos::from(2_000_000_000),
209            ts_last: UnixNanos::from(4_600_000_000),
210        }
211    }
212
213    #[rstest]
214    fn test_position_snapshot_round_trip() {
215        let snapshot = make_position_snapshot();
216        let metadata = snapshot.metadata();
217        let batch =
218            PositionSnapshot::encode_batch(&metadata, std::slice::from_ref(&snapshot)).unwrap();
219        let decoded =
220            PositionSnapshot::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
221
222        assert_eq!(decoded, vec![snapshot]);
223    }
224
225    #[rstest]
226    fn test_position_snapshot_round_trip_null_optionals() {
227        let mut snapshot = make_position_snapshot();
228        snapshot.closing_order_id = None;
229        snapshot.base_currency = None;
230        snapshot.avg_px_close = None;
231        snapshot.realized_return = None;
232        snapshot.realized_pnl = None;
233        snapshot.unrealized_pnl = None;
234        snapshot.duration_ns = None;
235        snapshot.ts_closed = None;
236
237        let metadata = snapshot.metadata();
238        let batch =
239            PositionSnapshot::encode_batch(&metadata, std::slice::from_ref(&snapshot)).unwrap();
240        let decoded =
241            PositionSnapshot::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
242
243        assert_eq!(decoded, vec![snapshot]);
244    }
245}