1use std::collections::HashMap;
17
18use arrow::{datatypes::Schema, error::ArrowError, record_batch::RecordBatch};
19use nautilus_model::reports::{
20 ExecutionMassStatus, FillReport, OrderStatusReport, PositionStatusReport,
21};
22
23use super::{
24 ArrowSchemaProvider, DecodeTypedFromRecordBatch, EncodeToRecordBatch, EncodingError,
25 KEY_INSTRUMENT_ID,
26 json::{JsonFieldSpec, decode_batch, encode_batch, metadata_for_type, schema_for_type},
27};
28
29const ORDER_STATUS_REPORT_FIELDS: &[JsonFieldSpec] = &[
30 JsonFieldSpec::utf8("account_id", false),
31 JsonFieldSpec::utf8("instrument_id", false),
32 JsonFieldSpec::utf8("client_order_id", true),
33 JsonFieldSpec::utf8("venue_order_id", false),
34 JsonFieldSpec::utf8("order_side", false),
35 JsonFieldSpec::utf8("order_type", false),
36 JsonFieldSpec::utf8("time_in_force", false),
37 JsonFieldSpec::utf8("order_status", false),
38 JsonFieldSpec::utf8("quantity", false),
39 JsonFieldSpec::utf8("filled_qty", false),
40 JsonFieldSpec::utf8("report_id", false),
41 JsonFieldSpec::u64("ts_accepted", false),
42 JsonFieldSpec::u64("ts_last", false),
43 JsonFieldSpec::u64("ts_init", false),
44 JsonFieldSpec::utf8("order_list_id", true),
45 JsonFieldSpec::utf8("venue_position_id", true),
46 JsonFieldSpec::utf8_json("linked_order_ids", true),
47 JsonFieldSpec::utf8("parent_order_id", true),
48 JsonFieldSpec::utf8("contingency_type", false),
49 JsonFieldSpec::u64("expire_time", true),
50 JsonFieldSpec::utf8("price", true),
51 JsonFieldSpec::utf8("trigger_price", true),
52 JsonFieldSpec::utf8("trigger_type", true),
53 JsonFieldSpec::utf8("limit_offset", true),
54 JsonFieldSpec::utf8("trailing_offset", true),
55 JsonFieldSpec::utf8("trailing_offset_type", false),
56 JsonFieldSpec::utf8("avg_px", true),
57 JsonFieldSpec::utf8("display_qty", true),
58 JsonFieldSpec::boolean("post_only", false),
59 JsonFieldSpec::boolean("reduce_only", false),
60 JsonFieldSpec::utf8("cancel_reason", true),
61 JsonFieldSpec::u64("ts_triggered", true),
62];
63
64const FILL_REPORT_FIELDS: &[JsonFieldSpec] = &[
65 JsonFieldSpec::utf8("account_id", false),
66 JsonFieldSpec::utf8("instrument_id", false),
67 JsonFieldSpec::utf8("venue_order_id", false),
68 JsonFieldSpec::utf8("trade_id", false),
69 JsonFieldSpec::utf8("order_side", false),
70 JsonFieldSpec::utf8("last_qty", false),
71 JsonFieldSpec::utf8("last_px", false),
72 JsonFieldSpec::utf8("commission", false),
73 JsonFieldSpec::utf8("liquidity_side", false),
74 JsonFieldSpec::utf8("report_id", false),
75 JsonFieldSpec::u64("ts_event", false),
76 JsonFieldSpec::u64("ts_init", false),
77 JsonFieldSpec::utf8("client_order_id", true),
78 JsonFieldSpec::utf8("venue_position_id", true),
79];
80
81const POSITION_STATUS_REPORT_FIELDS: &[JsonFieldSpec] = &[
82 JsonFieldSpec::utf8("account_id", false),
83 JsonFieldSpec::utf8("instrument_id", false),
84 JsonFieldSpec::utf8("position_side", false),
85 JsonFieldSpec::utf8("quantity", false),
86 JsonFieldSpec::utf8("signed_decimal_qty", false),
87 JsonFieldSpec::utf8("report_id", false),
88 JsonFieldSpec::u64("ts_last", false),
89 JsonFieldSpec::u64("ts_init", false),
90 JsonFieldSpec::utf8("venue_position_id", true),
91 JsonFieldSpec::utf8("avg_px_open", true),
92];
93
94const EXECUTION_MASS_STATUS_FIELDS: &[JsonFieldSpec] = &[
95 JsonFieldSpec::utf8("client_id", false),
96 JsonFieldSpec::utf8("account_id", false),
97 JsonFieldSpec::utf8("venue", false),
98 JsonFieldSpec::utf8("report_id", false),
99 JsonFieldSpec::u64("ts_init", false),
100 JsonFieldSpec::utf8_json("order_reports", false),
101 JsonFieldSpec::utf8_json("fill_reports", false),
102 JsonFieldSpec::utf8_json("position_reports", false),
103];
104
105fn instrument_metadata(type_name: &'static str, instrument_id: &str) -> HashMap<String, String> {
106 let mut metadata = metadata_for_type(type_name);
107 metadata.insert(KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string());
108 metadata
109}
110
111macro_rules! impl_report_arrow {
112 ($type:ty, $type_name:expr, $fields:expr) => {
113 impl ArrowSchemaProvider for $type {
114 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
115 schema_for_type($type_name, metadata, $fields)
116 }
117 }
118
119 impl EncodeToRecordBatch for $type {
120 fn encode_batch(
121 metadata: &HashMap<String, String>,
122 data: &[Self],
123 ) -> Result<RecordBatch, ArrowError> {
124 encode_batch($type_name, metadata, data, $fields)
125 }
126
127 fn metadata(&self) -> HashMap<String, String> {
128 instrument_metadata($type_name, &self.instrument_id.to_string())
129 }
130 }
131
132 impl DecodeTypedFromRecordBatch for $type {
133 fn decode_typed_batch(
134 metadata: &HashMap<String, String>,
135 record_batch: RecordBatch,
136 ) -> Result<Vec<Self>, EncodingError> {
137 decode_batch(metadata, &record_batch, $fields, Some($type_name))
138 }
139 }
140 };
141}
142
143impl_report_arrow!(
144 OrderStatusReport,
145 "OrderStatusReport",
146 ORDER_STATUS_REPORT_FIELDS
147);
148impl_report_arrow!(FillReport, "FillReport", FILL_REPORT_FIELDS);
149impl_report_arrow!(
150 PositionStatusReport,
151 "PositionStatusReport",
152 POSITION_STATUS_REPORT_FIELDS
153);
154
155impl ArrowSchemaProvider for ExecutionMassStatus {
156 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
157 schema_for_type(
158 "ExecutionMassStatus",
159 metadata,
160 EXECUTION_MASS_STATUS_FIELDS,
161 )
162 }
163}
164
165impl EncodeToRecordBatch for ExecutionMassStatus {
166 fn encode_batch(
167 metadata: &HashMap<String, String>,
168 data: &[Self],
169 ) -> Result<RecordBatch, ArrowError> {
170 encode_batch(
171 "ExecutionMassStatus",
172 metadata,
173 data,
174 EXECUTION_MASS_STATUS_FIELDS,
175 )
176 }
177
178 fn metadata(&self) -> HashMap<String, String> {
179 metadata_for_type("ExecutionMassStatus")
180 }
181}
182
183impl DecodeTypedFromRecordBatch for ExecutionMassStatus {
184 fn decode_typed_batch(
185 metadata: &HashMap<String, String>,
186 record_batch: RecordBatch,
187 ) -> Result<Vec<Self>, EncodingError> {
188 decode_batch(
189 metadata,
190 &record_batch,
191 EXECUTION_MASS_STATUS_FIELDS,
192 Some("ExecutionMassStatus"),
193 )
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use std::str::FromStr;
200
201 use nautilus_core::{UUID4, UnixNanos};
202 use nautilus_model::{
203 enums::{OrderSide, OrderStatus, OrderType, PositionSideSpecified, TimeInForce},
204 identifiers::{AccountId, ClientOrderId, InstrumentId, PositionId, VenueOrderId},
205 reports::{OrderStatusReport, PositionStatusReport},
206 types::Quantity,
207 };
208 use rstest::rstest;
209 use rust_decimal::Decimal;
210
211 use super::*;
212
213 #[rstest]
214 fn test_order_status_report_round_trip() {
215 let report = OrderStatusReport::new(
216 AccountId::from("SIM-001"),
217 InstrumentId::from("AUDUSD.SIM"),
218 Some(ClientOrderId::from("O-19700101-000000-001-001-1")),
219 VenueOrderId::from("1"),
220 OrderSide::Buy,
221 OrderType::Limit,
222 TimeInForce::Gtc,
223 OrderStatus::Accepted,
224 Quantity::from("100"),
225 Quantity::from("25"),
226 UnixNanos::from(1_000_000_000),
227 UnixNanos::from(2_000_000_000),
228 UnixNanos::from(3_000_000_000),
229 None,
230 )
231 .with_linked_order_ids([ClientOrderId::from("O-19700101-000000-001-001-2")]);
232 let report = OrderStatusReport {
233 limit_offset: Some(Decimal::from_str("0.123456789123456789").unwrap()),
234 trailing_offset: Some(Decimal::from_str("0.987654321987654321").unwrap()),
235 avg_px: Some(Decimal::from_str("1.23456789123456789").unwrap()),
236 ..report
237 };
238
239 let metadata = report.metadata();
240 let batch =
241 OrderStatusReport::encode_batch(&metadata, std::slice::from_ref(&report)).unwrap();
242 let decoded =
243 OrderStatusReport::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
244
245 assert_eq!(decoded, vec![report]);
246 }
247
248 #[rstest]
249 fn test_position_status_report_round_trip_preserves_decimal_precision() {
250 let report = PositionStatusReport {
251 account_id: AccountId::from("SIM-001"),
252 instrument_id: InstrumentId::from("AUDUSD.SIM"),
253 position_side: PositionSideSpecified::Long,
254 quantity: Quantity::from("100.25"),
255 signed_decimal_qty: Decimal::from_str("100.250000000123456789").unwrap(),
256 report_id: UUID4::default(),
257 ts_last: UnixNanos::from(1_000_000_000),
258 ts_init: UnixNanos::from(2_000_000_000),
259 venue_position_id: Some(PositionId::from("P-001")),
260 avg_px_open: Some(Decimal::from_str("1.23456789123456789").unwrap()),
261 };
262 let metadata = report.metadata();
263 let batch =
264 PositionStatusReport::encode_batch(&metadata, std::slice::from_ref(&report)).unwrap();
265 let decoded =
266 PositionStatusReport::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
267
268 assert_eq!(decoded, vec![report]);
269 }
270}