1use std::collections::HashMap;
17
18use arrow::{datatypes::Schema, error::ArrowError, record_batch::RecordBatch};
19use nautilus_model::events::{PositionAdjusted, PositionChanged, PositionClosed, PositionOpened};
20
21use super::{
22 ArrowSchemaProvider, DecodeTypedFromRecordBatch, EncodeToRecordBatch, EncodingError,
23 KEY_INSTRUMENT_ID,
24 json::{JsonFieldSpec, decode_batch, encode_batch, metadata_for_type, schema_for_type},
25};
26
27const POSITION_OPENED_FIELDS: &[JsonFieldSpec] = &[
28 JsonFieldSpec::utf8("trader_id", false),
29 JsonFieldSpec::utf8("strategy_id", false),
30 JsonFieldSpec::utf8("instrument_id", false),
31 JsonFieldSpec::utf8("position_id", false),
32 JsonFieldSpec::utf8("account_id", false),
33 JsonFieldSpec::utf8("opening_order_id", false),
34 JsonFieldSpec::utf8("entry", false),
35 JsonFieldSpec::utf8("side", false),
36 JsonFieldSpec::f64("signed_qty", false),
37 JsonFieldSpec::utf8("quantity", false),
38 JsonFieldSpec::utf8("last_qty", false),
39 JsonFieldSpec::utf8("last_px", false),
40 JsonFieldSpec::utf8("currency", false),
41 JsonFieldSpec::f64("avg_px_open", false),
42 JsonFieldSpec::utf8("event_id", false),
43 JsonFieldSpec::u64("ts_event", false),
44 JsonFieldSpec::u64("ts_init", false),
45];
46
47const POSITION_CHANGED_FIELDS: &[JsonFieldSpec] = &[
48 JsonFieldSpec::utf8("trader_id", false),
49 JsonFieldSpec::utf8("strategy_id", false),
50 JsonFieldSpec::utf8("instrument_id", false),
51 JsonFieldSpec::utf8("position_id", false),
52 JsonFieldSpec::utf8("account_id", false),
53 JsonFieldSpec::utf8("opening_order_id", false),
54 JsonFieldSpec::utf8("entry", false),
55 JsonFieldSpec::utf8("side", false),
56 JsonFieldSpec::f64("signed_qty", false),
57 JsonFieldSpec::utf8("quantity", false),
58 JsonFieldSpec::utf8("peak_quantity", false),
59 JsonFieldSpec::utf8("last_qty", false),
60 JsonFieldSpec::utf8("last_px", false),
61 JsonFieldSpec::utf8("currency", false),
62 JsonFieldSpec::f64("avg_px_open", false),
63 JsonFieldSpec::f64("avg_px_close", true),
64 JsonFieldSpec::f64("realized_return", false),
65 JsonFieldSpec::utf8("realized_pnl", true),
66 JsonFieldSpec::utf8("unrealized_pnl", false),
67 JsonFieldSpec::utf8("event_id", false),
68 JsonFieldSpec::u64("ts_opened", false),
69 JsonFieldSpec::u64("ts_event", false),
70 JsonFieldSpec::u64("ts_init", false),
71];
72
73const POSITION_CLOSED_FIELDS: &[JsonFieldSpec] = &[
74 JsonFieldSpec::utf8("trader_id", false),
75 JsonFieldSpec::utf8("strategy_id", false),
76 JsonFieldSpec::utf8("instrument_id", false),
77 JsonFieldSpec::utf8("position_id", false),
78 JsonFieldSpec::utf8("account_id", false),
79 JsonFieldSpec::utf8("opening_order_id", false),
80 JsonFieldSpec::utf8("closing_order_id", true),
81 JsonFieldSpec::utf8("entry", false),
82 JsonFieldSpec::utf8("side", false),
83 JsonFieldSpec::f64("signed_qty", false),
84 JsonFieldSpec::utf8("quantity", false),
85 JsonFieldSpec::utf8("peak_quantity", false),
86 JsonFieldSpec::utf8("last_qty", false),
87 JsonFieldSpec::utf8("last_px", false),
88 JsonFieldSpec::utf8("currency", false),
89 JsonFieldSpec::f64("avg_px_open", false),
90 JsonFieldSpec::f64("avg_px_close", true),
91 JsonFieldSpec::f64("realized_return", false),
92 JsonFieldSpec::utf8("realized_pnl", true),
93 JsonFieldSpec::utf8("unrealized_pnl", false),
94 JsonFieldSpec::u64("duration", false),
95 JsonFieldSpec::utf8("event_id", false),
96 JsonFieldSpec::u64("ts_opened", false),
97 JsonFieldSpec::u64("ts_closed", true),
98 JsonFieldSpec::u64("ts_event", false),
99 JsonFieldSpec::u64("ts_init", false),
100];
101
102const POSITION_ADJUSTED_FIELDS: &[JsonFieldSpec] = &[
103 JsonFieldSpec::utf8("trader_id", false),
104 JsonFieldSpec::utf8("strategy_id", false),
105 JsonFieldSpec::utf8("instrument_id", false),
106 JsonFieldSpec::utf8("position_id", false),
107 JsonFieldSpec::utf8("account_id", false),
108 JsonFieldSpec::utf8("adjustment_type", false),
109 JsonFieldSpec::utf8("quantity_change", true),
110 JsonFieldSpec::utf8("pnl_change", true),
111 JsonFieldSpec::utf8("reason", true),
112 JsonFieldSpec::utf8("event_id", false),
113 JsonFieldSpec::u64("ts_event", false),
114 JsonFieldSpec::u64("ts_init", false),
115];
116
117fn instrument_metadata(type_name: &'static str, instrument_id: &str) -> HashMap<String, String> {
118 let mut metadata = metadata_for_type(type_name);
119 metadata.insert(KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string());
120 metadata
121}
122
123macro_rules! impl_position_event_arrow {
124 ($type:ty, $type_name:expr, $fields:expr) => {
125 impl ArrowSchemaProvider for $type {
126 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
127 schema_for_type($type_name, metadata, $fields)
128 }
129 }
130
131 impl EncodeToRecordBatch for $type {
132 fn encode_batch(
133 metadata: &HashMap<String, String>,
134 data: &[Self],
135 ) -> Result<RecordBatch, ArrowError> {
136 encode_batch($type_name, metadata, data, $fields)
137 }
138
139 fn metadata(&self) -> HashMap<String, String> {
140 instrument_metadata($type_name, &self.instrument_id.to_string())
141 }
142 }
143
144 impl DecodeTypedFromRecordBatch for $type {
145 fn decode_typed_batch(
146 metadata: &HashMap<String, String>,
147 record_batch: RecordBatch,
148 ) -> Result<Vec<Self>, EncodingError> {
149 decode_batch(metadata, &record_batch, $fields, Some($type_name))
150 }
151 }
152 };
153}
154
155impl_position_event_arrow!(PositionOpened, "PositionOpened", POSITION_OPENED_FIELDS);
156impl_position_event_arrow!(PositionChanged, "PositionChanged", POSITION_CHANGED_FIELDS);
157impl_position_event_arrow!(PositionClosed, "PositionClosed", POSITION_CLOSED_FIELDS);
158impl_position_event_arrow!(
159 PositionAdjusted,
160 "PositionAdjusted",
161 POSITION_ADJUSTED_FIELDS
162);
163
164#[cfg(test)]
165mod tests {
166 use std::str::FromStr;
167
168 use nautilus_core::{UUID4, UnixNanos};
169 use nautilus_model::{
170 enums::{OrderSide, PositionAdjustmentType, PositionSide},
171 identifiers::{AccountId, ClientOrderId, InstrumentId, PositionId, StrategyId, TraderId},
172 types::{Currency, Money, Price, Quantity},
173 };
174 use rstest::rstest;
175 use rust_decimal::Decimal;
176 use ustr::Ustr;
177
178 use super::*;
179
180 #[rstest]
181 fn test_position_adjusted_round_trip() {
182 let event = PositionAdjusted::new(
183 TraderId::from("TRADER-001"),
184 StrategyId::from("EMA-CROSS"),
185 InstrumentId::from("BTCUSDT.BINANCE"),
186 PositionId::from("P-001"),
187 AccountId::from("BINANCE-001"),
188 PositionAdjustmentType::Funding,
189 Some(Decimal::from_str("-0.123456789123456789").unwrap()),
190 Some(Money::new(-5.50, Currency::USD())),
191 Some(Ustr::from("funding_2024_01_15_08:00")),
192 UUID4::default(),
193 UnixNanos::from(1_000_000_000),
194 UnixNanos::from(2_000_000_000),
195 );
196 let metadata = event.metadata();
197 let batch = PositionAdjusted::encode_batch(&metadata, &[event]).unwrap();
198 let decoded =
199 PositionAdjusted::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
200
201 assert_eq!(decoded, vec![event]);
202 }
203
204 #[rstest]
205 fn test_position_opened_round_trip() {
206 let event = PositionOpened {
207 trader_id: TraderId::from("TRADER-001"),
208 strategy_id: StrategyId::from("EMA-CROSS"),
209 instrument_id: InstrumentId::from("EURUSD.SIM"),
210 position_id: PositionId::from("P-001"),
211 account_id: AccountId::from("SIM-001"),
212 opening_order_id: ClientOrderId::from("O-19700101-000000-001-001-1"),
213 entry: OrderSide::Buy,
214 side: PositionSide::Long,
215 signed_qty: 150.0,
216 quantity: Quantity::from("150"),
217 last_qty: Quantity::from("150"),
218 last_px: Price::from("1.0525"),
219 currency: Currency::USD(),
220 avg_px_open: 1.0525,
221 event_id: UUID4::default(),
222 ts_event: UnixNanos::from(1_000_000_000),
223 ts_init: UnixNanos::from(1_000_000_001),
224 };
225 let metadata = event.metadata();
226 let batch = PositionOpened::encode_batch(&metadata, std::slice::from_ref(&event)).unwrap();
227 let decoded = PositionOpened::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
228
229 assert_eq!(decoded, vec![event]);
230 }
231
232 #[rstest]
233 fn test_position_changed_round_trip() {
234 let event = PositionChanged {
235 trader_id: TraderId::from("TRADER-001"),
236 strategy_id: StrategyId::from("EMA-CROSS"),
237 instrument_id: InstrumentId::from("EURUSD.SIM"),
238 position_id: PositionId::from("P-001"),
239 account_id: AccountId::from("SIM-001"),
240 opening_order_id: ClientOrderId::from("O-19700101-000000-001-001-1"),
241 entry: OrderSide::Buy,
242 side: PositionSide::Long,
243 signed_qty: 300.0,
244 quantity: Quantity::from("300"),
245 peak_quantity: Quantity::from("300"),
246 last_qty: Quantity::from("150"),
247 last_px: Price::from("1.0600"),
248 currency: Currency::USD(),
249 avg_px_open: 1.0562,
250 avg_px_close: None,
251 realized_return: 0.0,
252 realized_pnl: None,
253 unrealized_pnl: Money::new(56.25, Currency::USD()),
254 event_id: UUID4::default(),
255 ts_opened: UnixNanos::from(1_000_000_000),
256 ts_event: UnixNanos::from(2_000_000_000),
257 ts_init: UnixNanos::from(2_000_000_001),
258 };
259 let metadata = event.metadata();
260 let batch = PositionChanged::encode_batch(&metadata, std::slice::from_ref(&event)).unwrap();
261 let decoded =
262 PositionChanged::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
263
264 assert_eq!(decoded, vec![event]);
265 }
266
267 #[rstest]
268 fn test_position_closed_round_trip() {
269 let event = PositionClosed {
270 trader_id: TraderId::from("TRADER-001"),
271 strategy_id: StrategyId::from("EMA-CROSS"),
272 instrument_id: InstrumentId::from("EURUSD.SIM"),
273 position_id: PositionId::from("P-001"),
274 account_id: AccountId::from("SIM-001"),
275 opening_order_id: ClientOrderId::from("O-19700101-000000-001-001-1"),
276 closing_order_id: Some(ClientOrderId::from("O-19700101-000000-001-001-2")),
277 entry: OrderSide::Buy,
278 side: PositionSide::Flat,
279 signed_qty: 0.0,
280 quantity: Quantity::from("0"),
281 peak_quantity: Quantity::from("150"),
282 last_qty: Quantity::from("150"),
283 last_px: Price::from("1.0600"),
284 currency: Currency::USD(),
285 avg_px_open: 1.0525,
286 avg_px_close: Some(1.0600),
287 realized_return: 0.0071,
288 realized_pnl: Some(Money::new(112.50, Currency::USD())),
289 unrealized_pnl: Money::new(0.0, Currency::USD()),
290 duration: 3_600_000_000_000,
291 event_id: UUID4::default(),
292 ts_opened: UnixNanos::from(1_000_000_000),
293 ts_closed: Some(UnixNanos::from(4_600_000_000)),
294 ts_event: UnixNanos::from(4_600_000_000),
295 ts_init: UnixNanos::from(5_000_000_000),
296 };
297 let metadata = event.metadata();
298 let batch = PositionClosed::encode_batch(&metadata, std::slice::from_ref(&event)).unwrap();
299 let decoded = PositionClosed::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
300
301 assert_eq!(decoded, vec![event]);
302 }
303}