Skip to main content

nautilus_serialization/arrow/
account_state.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::collections::HashMap;
17
18use arrow::{datatypes::Schema, error::ArrowError, record_batch::RecordBatch};
19use nautilus_model::events::AccountState;
20
21use super::{
22    ArrowSchemaProvider, DecodeTypedFromRecordBatch, EncodeToRecordBatch, EncodingError,
23    json::{JsonFieldSpec, decode_batch, encode_batch, metadata_for_type, schema_for_type},
24};
25
26const ACCOUNT_STATE_FIELDS: &[JsonFieldSpec] = &[
27    JsonFieldSpec::utf8("account_id", false),
28    JsonFieldSpec::utf8("account_type", false),
29    JsonFieldSpec::utf8("base_currency", true),
30    JsonFieldSpec::utf8_json("balances", false),
31    JsonFieldSpec::utf8_json("margins", false),
32    JsonFieldSpec::boolean("is_reported", false),
33    JsonFieldSpec::utf8("event_id", false),
34    JsonFieldSpec::u64("ts_event", false),
35    JsonFieldSpec::u64("ts_init", false),
36];
37
38impl ArrowSchemaProvider for AccountState {
39    fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
40        schema_for_type("AccountState", metadata, ACCOUNT_STATE_FIELDS)
41    }
42}
43
44impl EncodeToRecordBatch for AccountState {
45    fn encode_batch(
46        metadata: &HashMap<String, String>,
47        data: &[Self],
48    ) -> Result<RecordBatch, ArrowError> {
49        encode_batch("AccountState", metadata, data, ACCOUNT_STATE_FIELDS)
50    }
51
52    fn metadata(&self) -> HashMap<String, String> {
53        metadata_for_type("AccountState")
54    }
55}
56
57impl DecodeTypedFromRecordBatch for AccountState {
58    fn decode_typed_batch(
59        metadata: &HashMap<String, String>,
60        record_batch: RecordBatch,
61    ) -> Result<Vec<Self>, EncodingError> {
62        decode_batch(
63            metadata,
64            &record_batch,
65            ACCOUNT_STATE_FIELDS,
66            Some("AccountState"),
67        )
68    }
69}
70
71#[cfg(test)]
72mod tests {
73    use nautilus_model::events::account::stubs::cash_account_state;
74    use rstest::rstest;
75
76    use super::*;
77
78    #[rstest]
79    fn test_account_state_round_trip(cash_account_state: AccountState) {
80        let state = cash_account_state;
81        let metadata = state.metadata();
82        let batch = AccountState::encode_batch(&metadata, std::slice::from_ref(&state)).unwrap();
83        let decoded = AccountState::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
84
85        assert_eq!(decoded.len(), 1);
86        assert_eq!(decoded[0].account_id, state.account_id);
87        assert_eq!(decoded[0].balances, state.balances);
88        assert_eq!(decoded[0].margins, state.margins);
89        assert_eq!(decoded[0].base_currency, state.base_currency);
90    }
91}