Skip to main content

nautilus_serialization/arrow/
position_event.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::collections::HashMap;
17
18use arrow::{datatypes::Schema, error::ArrowError, record_batch::RecordBatch};
19use nautilus_model::events::{PositionAdjusted, PositionChanged, PositionClosed, PositionOpened};
20
21use super::{
22    ArrowSchemaProvider, DecodeTypedFromRecordBatch, EncodeToRecordBatch, EncodingError,
23    KEY_INSTRUMENT_ID,
24    json::{JsonFieldSpec, decode_batch, encode_batch, metadata_for_type, schema_for_type},
25};
26
27const POSITION_OPENED_FIELDS: &[JsonFieldSpec] = &[
28    JsonFieldSpec::utf8("trader_id", false),
29    JsonFieldSpec::utf8("strategy_id", false),
30    JsonFieldSpec::utf8("instrument_id", false),
31    JsonFieldSpec::utf8("position_id", false),
32    JsonFieldSpec::utf8("account_id", false),
33    JsonFieldSpec::utf8("opening_order_id", false),
34    JsonFieldSpec::utf8("entry", false),
35    JsonFieldSpec::utf8("side", false),
36    JsonFieldSpec::f64("signed_qty", false),
37    JsonFieldSpec::utf8("quantity", false),
38    JsonFieldSpec::utf8("last_qty", false),
39    JsonFieldSpec::utf8("last_px", false),
40    JsonFieldSpec::utf8("currency", false),
41    JsonFieldSpec::f64("avg_px_open", false),
42    JsonFieldSpec::utf8("event_id", false),
43    JsonFieldSpec::u64("ts_event", false),
44    JsonFieldSpec::u64("ts_init", false),
45];
46
47const POSITION_CHANGED_FIELDS: &[JsonFieldSpec] = &[
48    JsonFieldSpec::utf8("trader_id", false),
49    JsonFieldSpec::utf8("strategy_id", false),
50    JsonFieldSpec::utf8("instrument_id", false),
51    JsonFieldSpec::utf8("position_id", false),
52    JsonFieldSpec::utf8("account_id", false),
53    JsonFieldSpec::utf8("opening_order_id", false),
54    JsonFieldSpec::utf8("entry", false),
55    JsonFieldSpec::utf8("side", false),
56    JsonFieldSpec::f64("signed_qty", false),
57    JsonFieldSpec::utf8("quantity", false),
58    JsonFieldSpec::utf8("peak_quantity", false),
59    JsonFieldSpec::utf8("last_qty", false),
60    JsonFieldSpec::utf8("last_px", false),
61    JsonFieldSpec::utf8("currency", false),
62    JsonFieldSpec::f64("avg_px_open", false),
63    JsonFieldSpec::f64("avg_px_close", true),
64    JsonFieldSpec::f64("realized_return", false),
65    JsonFieldSpec::utf8("realized_pnl", true),
66    JsonFieldSpec::utf8("unrealized_pnl", false),
67    JsonFieldSpec::utf8("event_id", false),
68    JsonFieldSpec::u64("ts_opened", false),
69    JsonFieldSpec::u64("ts_event", false),
70    JsonFieldSpec::u64("ts_init", false),
71];
72
73const POSITION_CLOSED_FIELDS: &[JsonFieldSpec] = &[
74    JsonFieldSpec::utf8("trader_id", false),
75    JsonFieldSpec::utf8("strategy_id", false),
76    JsonFieldSpec::utf8("instrument_id", false),
77    JsonFieldSpec::utf8("position_id", false),
78    JsonFieldSpec::utf8("account_id", false),
79    JsonFieldSpec::utf8("opening_order_id", false),
80    JsonFieldSpec::utf8("closing_order_id", true),
81    JsonFieldSpec::utf8("entry", false),
82    JsonFieldSpec::utf8("side", false),
83    JsonFieldSpec::f64("signed_qty", false),
84    JsonFieldSpec::utf8("quantity", false),
85    JsonFieldSpec::utf8("peak_quantity", false),
86    JsonFieldSpec::utf8("last_qty", false),
87    JsonFieldSpec::utf8("last_px", false),
88    JsonFieldSpec::utf8("currency", false),
89    JsonFieldSpec::f64("avg_px_open", false),
90    JsonFieldSpec::f64("avg_px_close", true),
91    JsonFieldSpec::f64("realized_return", false),
92    JsonFieldSpec::utf8("realized_pnl", true),
93    JsonFieldSpec::utf8("unrealized_pnl", false),
94    JsonFieldSpec::u64("duration", false),
95    JsonFieldSpec::utf8("event_id", false),
96    JsonFieldSpec::u64("ts_opened", false),
97    JsonFieldSpec::u64("ts_closed", true),
98    JsonFieldSpec::u64("ts_event", false),
99    JsonFieldSpec::u64("ts_init", false),
100];
101
102const POSITION_ADJUSTED_FIELDS: &[JsonFieldSpec] = &[
103    JsonFieldSpec::utf8("trader_id", false),
104    JsonFieldSpec::utf8("strategy_id", false),
105    JsonFieldSpec::utf8("instrument_id", false),
106    JsonFieldSpec::utf8("position_id", false),
107    JsonFieldSpec::utf8("account_id", false),
108    JsonFieldSpec::utf8("adjustment_type", false),
109    JsonFieldSpec::utf8("quantity_change", true),
110    JsonFieldSpec::utf8("pnl_change", true),
111    JsonFieldSpec::utf8("reason", true),
112    JsonFieldSpec::utf8("event_id", false),
113    JsonFieldSpec::u64("ts_event", false),
114    JsonFieldSpec::u64("ts_init", false),
115];
116
117fn instrument_metadata(type_name: &'static str, instrument_id: &str) -> HashMap<String, String> {
118    let mut metadata = metadata_for_type(type_name);
119    metadata.insert(KEY_INSTRUMENT_ID.to_string(), instrument_id.to_string());
120    metadata
121}
122
123macro_rules! impl_position_event_arrow {
124    ($type:ty, $type_name:expr, $fields:expr) => {
125        impl ArrowSchemaProvider for $type {
126            fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
127                schema_for_type($type_name, metadata, $fields)
128            }
129        }
130
131        impl EncodeToRecordBatch for $type {
132            fn encode_batch(
133                metadata: &HashMap<String, String>,
134                data: &[Self],
135            ) -> Result<RecordBatch, ArrowError> {
136                encode_batch($type_name, metadata, data, $fields)
137            }
138
139            fn metadata(&self) -> HashMap<String, String> {
140                instrument_metadata($type_name, &self.instrument_id.to_string())
141            }
142        }
143
144        impl DecodeTypedFromRecordBatch for $type {
145            fn decode_typed_batch(
146                metadata: &HashMap<String, String>,
147                record_batch: RecordBatch,
148            ) -> Result<Vec<Self>, EncodingError> {
149                decode_batch(metadata, &record_batch, $fields, Some($type_name))
150            }
151        }
152    };
153}
154
155impl_position_event_arrow!(PositionOpened, "PositionOpened", POSITION_OPENED_FIELDS);
156impl_position_event_arrow!(PositionChanged, "PositionChanged", POSITION_CHANGED_FIELDS);
157impl_position_event_arrow!(PositionClosed, "PositionClosed", POSITION_CLOSED_FIELDS);
158impl_position_event_arrow!(
159    PositionAdjusted,
160    "PositionAdjusted",
161    POSITION_ADJUSTED_FIELDS
162);
163
164#[cfg(test)]
165mod tests {
166    use std::str::FromStr;
167
168    use nautilus_core::{UUID4, UnixNanos};
169    use nautilus_model::{
170        enums::{OrderSide, PositionAdjustmentType, PositionSide},
171        identifiers::{AccountId, ClientOrderId, InstrumentId, PositionId, StrategyId, TraderId},
172        types::{Currency, Money, Price, Quantity},
173    };
174    use rstest::rstest;
175    use rust_decimal::Decimal;
176    use ustr::Ustr;
177
178    use super::*;
179
180    #[rstest]
181    fn test_position_adjusted_round_trip() {
182        let event = PositionAdjusted::new(
183            TraderId::from("TRADER-001"),
184            StrategyId::from("EMA-CROSS"),
185            InstrumentId::from("BTCUSDT.BINANCE"),
186            PositionId::from("P-001"),
187            AccountId::from("BINANCE-001"),
188            PositionAdjustmentType::Funding,
189            Some(Decimal::from_str("-0.123456789123456789").unwrap()),
190            Some(Money::new(-5.50, Currency::USD())),
191            Some(Ustr::from("funding_2024_01_15_08:00")),
192            UUID4::default(),
193            UnixNanos::from(1_000_000_000),
194            UnixNanos::from(2_000_000_000),
195        );
196        let metadata = event.metadata();
197        let batch = PositionAdjusted::encode_batch(&metadata, &[event]).unwrap();
198        let decoded =
199            PositionAdjusted::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
200
201        assert_eq!(decoded, vec![event]);
202    }
203
204    #[rstest]
205    fn test_position_opened_round_trip() {
206        let event = PositionOpened {
207            trader_id: TraderId::from("TRADER-001"),
208            strategy_id: StrategyId::from("EMA-CROSS"),
209            instrument_id: InstrumentId::from("EURUSD.SIM"),
210            position_id: PositionId::from("P-001"),
211            account_id: AccountId::from("SIM-001"),
212            opening_order_id: ClientOrderId::from("O-19700101-000000-001-001-1"),
213            entry: OrderSide::Buy,
214            side: PositionSide::Long,
215            signed_qty: 150.0,
216            quantity: Quantity::from("150"),
217            last_qty: Quantity::from("150"),
218            last_px: Price::from("1.0525"),
219            currency: Currency::USD(),
220            avg_px_open: 1.0525,
221            event_id: UUID4::default(),
222            ts_event: UnixNanos::from(1_000_000_000),
223            ts_init: UnixNanos::from(1_000_000_001),
224        };
225        let metadata = event.metadata();
226        let batch = PositionOpened::encode_batch(&metadata, std::slice::from_ref(&event)).unwrap();
227        let decoded = PositionOpened::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
228
229        assert_eq!(decoded, vec![event]);
230    }
231
232    #[rstest]
233    fn test_position_changed_round_trip() {
234        let event = PositionChanged {
235            trader_id: TraderId::from("TRADER-001"),
236            strategy_id: StrategyId::from("EMA-CROSS"),
237            instrument_id: InstrumentId::from("EURUSD.SIM"),
238            position_id: PositionId::from("P-001"),
239            account_id: AccountId::from("SIM-001"),
240            opening_order_id: ClientOrderId::from("O-19700101-000000-001-001-1"),
241            entry: OrderSide::Buy,
242            side: PositionSide::Long,
243            signed_qty: 300.0,
244            quantity: Quantity::from("300"),
245            peak_quantity: Quantity::from("300"),
246            last_qty: Quantity::from("150"),
247            last_px: Price::from("1.0600"),
248            currency: Currency::USD(),
249            avg_px_open: 1.0562,
250            avg_px_close: None,
251            realized_return: 0.0,
252            realized_pnl: None,
253            unrealized_pnl: Money::new(56.25, Currency::USD()),
254            event_id: UUID4::default(),
255            ts_opened: UnixNanos::from(1_000_000_000),
256            ts_event: UnixNanos::from(2_000_000_000),
257            ts_init: UnixNanos::from(2_000_000_001),
258        };
259        let metadata = event.metadata();
260        let batch = PositionChanged::encode_batch(&metadata, std::slice::from_ref(&event)).unwrap();
261        let decoded =
262            PositionChanged::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
263
264        assert_eq!(decoded, vec![event]);
265    }
266
267    #[rstest]
268    fn test_position_closed_round_trip() {
269        let event = PositionClosed {
270            trader_id: TraderId::from("TRADER-001"),
271            strategy_id: StrategyId::from("EMA-CROSS"),
272            instrument_id: InstrumentId::from("EURUSD.SIM"),
273            position_id: PositionId::from("P-001"),
274            account_id: AccountId::from("SIM-001"),
275            opening_order_id: ClientOrderId::from("O-19700101-000000-001-001-1"),
276            closing_order_id: Some(ClientOrderId::from("O-19700101-000000-001-001-2")),
277            entry: OrderSide::Buy,
278            side: PositionSide::Flat,
279            signed_qty: 0.0,
280            quantity: Quantity::from("0"),
281            peak_quantity: Quantity::from("150"),
282            last_qty: Quantity::from("150"),
283            last_px: Price::from("1.0600"),
284            currency: Currency::USD(),
285            avg_px_open: 1.0525,
286            avg_px_close: Some(1.0600),
287            realized_return: 0.0071,
288            realized_pnl: Some(Money::new(112.50, Currency::USD())),
289            unrealized_pnl: Money::new(0.0, Currency::USD()),
290            duration: 3_600_000_000_000,
291            event_id: UUID4::default(),
292            ts_opened: UnixNanos::from(1_000_000_000),
293            ts_closed: Some(UnixNanos::from(4_600_000_000)),
294            ts_event: UnixNanos::from(4_600_000_000),
295            ts_init: UnixNanos::from(5_000_000_000),
296        };
297        let metadata = event.metadata();
298        let batch = PositionClosed::encode_batch(&metadata, std::slice::from_ref(&event)).unwrap();
299        let decoded = PositionClosed::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
300
301        assert_eq!(decoded, vec![event]);
302    }
303}