Skip to main content

nautilus_serialization/arrow/display/
account_state.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Display-mode Arrow encoder for [`AccountState`].
17
18use std::sync::Arc;
19
20use arrow::{
21    array::{BooleanBuilder, StringBuilder, TimestampNanosecondBuilder},
22    datatypes::Schema,
23    error::ArrowError,
24    record_batch::RecordBatch,
25};
26use nautilus_model::events::AccountState;
27
28use super::{bool_field, timestamp_field, unix_nanos_to_i64, utf8_field};
29
30/// Returns the display-mode Arrow schema for [`AccountState`].
31#[must_use]
32pub fn account_state_schema() -> Schema {
33    Schema::new(vec![
34        utf8_field("account_id", false),
35        utf8_field("account_type", false),
36        utf8_field("base_currency", true),
37        utf8_field("balances", false),
38        utf8_field("margins", false),
39        bool_field("is_reported", false),
40        utf8_field("event_id", false),
41        timestamp_field("ts_event", false),
42        timestamp_field("ts_init", false),
43    ])
44}
45
46fn balances_to_json(state: &AccountState) -> String {
47    let entries: Vec<serde_json::Value> = state
48        .balances
49        .iter()
50        .map(|b| {
51            serde_json::json!({
52                "currency": b.currency.to_string(),
53                "total": b.total.as_f64(),
54                "locked": b.locked.as_f64(),
55                "free": b.free.as_f64(),
56            })
57        })
58        .collect();
59    serde_json::to_string(&entries).unwrap_or_default()
60}
61
62fn margins_to_json(state: &AccountState) -> String {
63    let entries: Vec<serde_json::Value> = state
64        .margins
65        .iter()
66        .map(|m| {
67            serde_json::json!({
68                "instrument_id": m.instrument_id.map(|id| id.to_string()),
69                "currency": m.currency.to_string(),
70                "initial": m.initial.as_f64(),
71                "maintenance": m.maintenance.as_f64(),
72            })
73        })
74        .collect();
75    serde_json::to_string(&entries).unwrap_or_default()
76}
77
78/// Encodes account state snapshots as a display-friendly Arrow [`RecordBatch`].
79///
80/// Emits `Utf8` columns for identifiers and JSON-serialized balances/margins,
81/// `Timestamp(Nanosecond)` columns for event and init times, and a `Boolean`
82/// column for `is_reported`. Balances and margins are serialized as JSON arrays
83/// with `f64` amounts for display readability.
84///
85/// Returns an empty [`RecordBatch`] with the correct schema when `data` is empty.
86///
87/// # Errors
88///
89/// Returns an [`ArrowError`] if the Arrow `RecordBatch` cannot be constructed.
90pub fn encode_account_states(data: &[AccountState]) -> Result<RecordBatch, ArrowError> {
91    let mut account_id = StringBuilder::new();
92    let mut account_type = StringBuilder::new();
93    let mut base_currency = StringBuilder::new();
94    let mut balances = StringBuilder::new();
95    let mut margins = StringBuilder::new();
96    let mut is_reported = BooleanBuilder::with_capacity(data.len());
97    let mut event_id = StringBuilder::new();
98    let mut ts_event = TimestampNanosecondBuilder::with_capacity(data.len());
99    let mut ts_init = TimestampNanosecondBuilder::with_capacity(data.len());
100
101    for state in data {
102        account_id.append_value(state.account_id);
103        account_type.append_value(format!("{}", state.account_type));
104        base_currency.append_option(state.base_currency.map(|v| v.to_string()));
105        balances.append_value(balances_to_json(state));
106        margins.append_value(margins_to_json(state));
107        is_reported.append_value(state.is_reported);
108        event_id.append_value(state.event_id.to_string());
109        ts_event.append_value(unix_nanos_to_i64(state.ts_event.as_u64()));
110        ts_init.append_value(unix_nanos_to_i64(state.ts_init.as_u64()));
111    }
112
113    RecordBatch::try_new(
114        Arc::new(account_state_schema()),
115        vec![
116            Arc::new(account_id.finish()),
117            Arc::new(account_type.finish()),
118            Arc::new(base_currency.finish()),
119            Arc::new(balances.finish()),
120            Arc::new(margins.finish()),
121            Arc::new(is_reported.finish()),
122            Arc::new(event_id.finish()),
123            Arc::new(ts_event.finish()),
124            Arc::new(ts_init.finish()),
125        ],
126    )
127}
128
129#[cfg(test)]
130mod tests {
131    use arrow::{
132        array::{Array, BooleanArray, StringArray, TimestampNanosecondArray},
133        datatypes::{DataType, TimeUnit},
134    };
135    use nautilus_core::UUID4;
136    use nautilus_model::{
137        enums::AccountType,
138        identifiers::AccountId,
139        types::{AccountBalance, Currency, Money},
140    };
141    use rstest::rstest;
142
143    use super::*;
144
145    fn make_account_state(ts: u64) -> AccountState {
146        let currency = Currency::USD();
147        let balance = AccountBalance::new(
148            Money::new(10_000.0, currency),
149            Money::new(1_000.0, currency),
150            Money::new(9_000.0, currency),
151        );
152        AccountState {
153            account_id: AccountId::from("SIM-001"),
154            account_type: AccountType::Cash,
155            base_currency: Some(currency),
156            balances: vec![balance],
157            margins: vec![],
158            is_reported: false,
159            event_id: UUID4::default(),
160            ts_event: ts.into(),
161            ts_init: (ts + 1).into(),
162        }
163    }
164
165    #[rstest]
166    fn test_encode_account_states_schema() {
167        let batch = encode_account_states(&[]).unwrap();
168        let schema = batch.schema();
169        let fields = schema.fields();
170        assert_eq!(fields.len(), 9);
171        assert_eq!(fields[0].name(), "account_id");
172        assert_eq!(fields[0].data_type(), &DataType::Utf8);
173        assert_eq!(fields[5].name(), "is_reported");
174        assert_eq!(fields[5].data_type(), &DataType::Boolean);
175        assert_eq!(fields[7].name(), "ts_event");
176        assert_eq!(
177            fields[7].data_type(),
178            &DataType::Timestamp(TimeUnit::Nanosecond, None)
179        );
180    }
181
182    #[rstest]
183    fn test_encode_account_states_values() {
184        let states = vec![make_account_state(1_000_000)];
185        let batch = encode_account_states(&states).unwrap();
186
187        assert_eq!(batch.num_rows(), 1);
188
189        let account_id_col = batch
190            .column(0)
191            .as_any()
192            .downcast_ref::<StringArray>()
193            .unwrap();
194        let is_reported_col = batch
195            .column(5)
196            .as_any()
197            .downcast_ref::<BooleanArray>()
198            .unwrap();
199        let ts_event_col = batch
200            .column(7)
201            .as_any()
202            .downcast_ref::<TimestampNanosecondArray>()
203            .unwrap();
204        let balances_col = batch
205            .column(3)
206            .as_any()
207            .downcast_ref::<StringArray>()
208            .unwrap();
209
210        assert_eq!(account_id_col.value(0), "SIM-001");
211        assert!(!is_reported_col.value(0));
212        assert_eq!(ts_event_col.value(0), 1_000_000);
213
214        let balances: Vec<serde_json::Value> = serde_json::from_str(balances_col.value(0)).unwrap();
215        assert_eq!(balances.len(), 1);
216        assert_eq!(balances[0]["currency"], "USD");
217        assert!((balances[0]["total"].as_f64().unwrap() - 10_000.0).abs() < 1e-9);
218        assert!((balances[0]["locked"].as_f64().unwrap() - 1_000.0).abs() < 1e-9);
219        assert!((balances[0]["free"].as_f64().unwrap() - 9_000.0).abs() < 1e-9);
220    }
221
222    #[rstest]
223    fn test_encode_account_states_empty() {
224        let batch = encode_account_states(&[]).unwrap();
225        assert_eq!(batch.num_rows(), 0);
226        assert_eq!(batch.schema().fields().len(), 9);
227    }
228
229    #[rstest]
230    fn test_encode_account_states_null_base_currency() {
231        let mut state = make_account_state(1_000);
232        state.base_currency = None;
233        let batch = encode_account_states(&[state]).unwrap();
234
235        let base_currency_col = batch
236            .column(2)
237            .as_any()
238            .downcast_ref::<StringArray>()
239            .unwrap();
240        assert!(base_currency_col.is_null(0));
241    }
242}