1use std::collections::HashMap;
17
18use arrow::{datatypes::Schema, error::ArrowError, record_batch::RecordBatch};
19use nautilus_model::events::{OrderSnapshot, PositionSnapshot};
20
21use super::{
22 ArrowSchemaProvider, DecodeTypedFromRecordBatch, EncodeToRecordBatch, EncodingError,
23 KEY_INSTRUMENT_ID,
24 json::{JsonFieldSpec, decode_batch, encode_batch, metadata_for_type, schema_for_type},
25};
26
27const ORDER_SNAPSHOT_FIELDS: &[JsonFieldSpec] = &[
28 JsonFieldSpec::utf8("trader_id", false),
29 JsonFieldSpec::utf8("strategy_id", false),
30 JsonFieldSpec::utf8("instrument_id", false),
31 JsonFieldSpec::utf8("client_order_id", false),
32 JsonFieldSpec::utf8("venue_order_id", true),
33 JsonFieldSpec::utf8("position_id", true),
34 JsonFieldSpec::utf8("account_id", true),
35 JsonFieldSpec::utf8("last_trade_id", true),
36 JsonFieldSpec::utf8("order_type", false),
37 JsonFieldSpec::utf8("order_side", false),
38 JsonFieldSpec::utf8("quantity", false),
39 JsonFieldSpec::utf8("price", true),
40 JsonFieldSpec::utf8("trigger_price", true),
41 JsonFieldSpec::utf8("trigger_type", true),
42 JsonFieldSpec::utf8("limit_offset", true),
43 JsonFieldSpec::utf8("trailing_offset", true),
44 JsonFieldSpec::utf8("trailing_offset_type", true),
45 JsonFieldSpec::utf8("time_in_force", false),
46 JsonFieldSpec::u64("expire_time", true),
47 JsonFieldSpec::utf8("filled_qty", false),
48 JsonFieldSpec::utf8("liquidity_side", true),
49 JsonFieldSpec::f64("avg_px", true),
50 JsonFieldSpec::f64("slippage", true),
51 JsonFieldSpec::utf8_json("commissions", false),
52 JsonFieldSpec::utf8("status", false),
53 JsonFieldSpec::boolean("is_post_only", false),
54 JsonFieldSpec::boolean("is_reduce_only", false),
55 JsonFieldSpec::boolean("is_quote_quantity", false),
56 JsonFieldSpec::utf8("display_qty", true),
57 JsonFieldSpec::utf8("emulation_trigger", true),
58 JsonFieldSpec::utf8("trigger_instrument_id", true),
59 JsonFieldSpec::utf8("contingency_type", true),
60 JsonFieldSpec::utf8("order_list_id", true),
61 JsonFieldSpec::utf8_json("linked_order_ids", true),
62 JsonFieldSpec::utf8("parent_order_id", true),
63 JsonFieldSpec::utf8("exec_algorithm_id", true),
64 JsonFieldSpec::utf8_json("exec_algorithm_params", true),
65 JsonFieldSpec::utf8("exec_spawn_id", true),
66 JsonFieldSpec::utf8_json("tags", true),
67 JsonFieldSpec::utf8("init_id", false),
68 JsonFieldSpec::u64("ts_init", false),
69 JsonFieldSpec::u64("ts_last", false),
70];
71
72const POSITION_SNAPSHOT_FIELDS: &[JsonFieldSpec] = &[
73 JsonFieldSpec::utf8("trader_id", false),
74 JsonFieldSpec::utf8("strategy_id", false),
75 JsonFieldSpec::utf8("instrument_id", false),
76 JsonFieldSpec::utf8("position_id", false),
77 JsonFieldSpec::utf8("account_id", false),
78 JsonFieldSpec::utf8("opening_order_id", false),
79 JsonFieldSpec::utf8("closing_order_id", true),
80 JsonFieldSpec::utf8("entry", false),
81 JsonFieldSpec::utf8("side", false),
82 JsonFieldSpec::f64("signed_qty", false),
83 JsonFieldSpec::utf8("quantity", false),
84 JsonFieldSpec::utf8("peak_qty", false),
85 JsonFieldSpec::utf8("quote_currency", false),
86 JsonFieldSpec::utf8("base_currency", true),
87 JsonFieldSpec::utf8("settlement_currency", false),
88 JsonFieldSpec::f64("avg_px_open", false),
89 JsonFieldSpec::f64("avg_px_close", true),
90 JsonFieldSpec::f64("realized_return", true),
91 JsonFieldSpec::utf8("realized_pnl", true),
92 JsonFieldSpec::utf8("unrealized_pnl", true),
93 JsonFieldSpec::utf8_json("commissions", false),
94 JsonFieldSpec::u64("duration_ns", true),
95 JsonFieldSpec::u64("ts_opened", false),
96 JsonFieldSpec::u64("ts_closed", true),
97 JsonFieldSpec::u64("ts_init", false),
98 JsonFieldSpec::u64("ts_last", false),
99];
100
101fn instrument_metadata(type_name: &'static str, instrument_id: &str) -> HashMap<String, String> {
102 let mut metadata = metadata_for_type(type_name);
103 metadata.insert(KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string());
104 metadata
105}
106
107macro_rules! impl_snapshot_arrow {
108 ($type:ty, $type_name:expr, $fields:expr) => {
109 impl ArrowSchemaProvider for $type {
110 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
111 schema_for_type($type_name, metadata, $fields)
112 }
113 }
114
115 impl EncodeToRecordBatch for $type {
116 fn encode_batch(
117 metadata: &HashMap<String, String>,
118 data: &[Self],
119 ) -> Result<RecordBatch, ArrowError> {
120 encode_batch($type_name, metadata, data, $fields)
121 }
122
123 fn metadata(&self) -> HashMap<String, String> {
124 instrument_metadata($type_name, &self.instrument_id.to_string())
125 }
126 }
127
128 impl DecodeTypedFromRecordBatch for $type {
129 fn decode_typed_batch(
130 metadata: &HashMap<String, String>,
131 record_batch: RecordBatch,
132 ) -> Result<Vec<Self>, EncodingError> {
133 decode_batch(metadata, &record_batch, $fields, Some($type_name))
134 }
135 }
136 };
137}
138
139impl_snapshot_arrow!(OrderSnapshot, "OrderSnapshot", ORDER_SNAPSHOT_FIELDS);
140impl_snapshot_arrow!(
141 PositionSnapshot,
142 "PositionSnapshot",
143 POSITION_SNAPSHOT_FIELDS
144);
145
146#[cfg(test)]
147mod tests {
148 use std::str::FromStr;
149
150 use nautilus_core::UnixNanos;
151 use nautilus_model::{
152 enums::{OrderSide, OrderType, PositionSide},
153 identifiers::{AccountId, ClientOrderId, InstrumentId, PositionId, StrategyId, TraderId},
154 orders::OrderTestBuilder,
155 types::{Currency, Money, Price, Quantity},
156 };
157 use rstest::rstest;
158 use rust_decimal::Decimal;
159
160 use super::*;
161
162 #[rstest]
163 fn test_order_snapshot_round_trip_preserves_decimal_precision() {
164 let order = OrderTestBuilder::new(OrderType::TrailingStopLimit)
165 .instrument_id(InstrumentId::from("BTCUSDT.BINANCE"))
166 .side(OrderSide::Buy)
167 .price(Price::from("50000"))
168 .trigger_price(Price::from("50500"))
169 .limit_offset(Decimal::from_str("0.123456789123456789").unwrap())
170 .trailing_offset(Decimal::from_str("0.987654321987654321").unwrap())
171 .quantity(Quantity::from("0.5"))
172 .build();
173 let snapshot = OrderSnapshot::from(order);
174 let metadata = snapshot.metadata();
175 let batch =
176 OrderSnapshot::encode_batch(&metadata, std::slice::from_ref(&snapshot)).unwrap();
177 let decoded = OrderSnapshot::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
178
179 assert_eq!(decoded, vec![snapshot]);
180 }
181
182 fn make_position_snapshot() -> PositionSnapshot {
183 PositionSnapshot {
184 trader_id: TraderId::from("TRADER-001"),
185 strategy_id: StrategyId::from("EMA-CROSS"),
186 instrument_id: InstrumentId::from("EURUSD.SIM"),
187 position_id: PositionId::from("P-001"),
188 account_id: AccountId::from("SIM-001"),
189 opening_order_id: ClientOrderId::from("O-1"),
190 closing_order_id: Some(ClientOrderId::from("O-2")),
191 entry: OrderSide::Buy,
192 side: PositionSide::Long,
193 signed_qty: 100.0,
194 quantity: Quantity::from("100"),
195 peak_qty: Quantity::from("100"),
196 quote_currency: Currency::USD(),
197 base_currency: Some(Currency::EUR()),
198 settlement_currency: Currency::USD(),
199 avg_px_open: 1.0500,
200 avg_px_close: Some(1.0600),
201 realized_return: Some(0.0095),
202 realized_pnl: Some(Money::new(100.0, Currency::USD())),
203 unrealized_pnl: Some(Money::new(50.0, Currency::USD())),
204 commissions: vec![Money::new(2.0, Currency::USD())],
205 duration_ns: Some(3_600_000_000_000),
206 ts_opened: UnixNanos::from(1_000_000_000),
207 ts_closed: Some(UnixNanos::from(4_600_000_000)),
208 ts_init: UnixNanos::from(2_000_000_000),
209 ts_last: UnixNanos::from(4_600_000_000),
210 }
211 }
212
213 #[rstest]
214 fn test_position_snapshot_round_trip() {
215 let snapshot = make_position_snapshot();
216 let metadata = snapshot.metadata();
217 let batch =
218 PositionSnapshot::encode_batch(&metadata, std::slice::from_ref(&snapshot)).unwrap();
219 let decoded =
220 PositionSnapshot::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
221
222 assert_eq!(decoded, vec![snapshot]);
223 }
224
225 #[rstest]
226 fn test_position_snapshot_round_trip_null_optionals() {
227 let mut snapshot = make_position_snapshot();
228 snapshot.closing_order_id = None;
229 snapshot.base_currency = None;
230 snapshot.avg_px_close = None;
231 snapshot.realized_return = None;
232 snapshot.realized_pnl = None;
233 snapshot.unrealized_pnl = None;
234 snapshot.duration_ns = None;
235 snapshot.ts_closed = None;
236
237 let metadata = snapshot.metadata();
238 let batch =
239 PositionSnapshot::encode_batch(&metadata, std::slice::from_ref(&snapshot)).unwrap();
240 let decoded =
241 PositionSnapshot::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
242
243 assert_eq!(decoded, vec![snapshot]);
244 }
245}