nautilus_serialization/arrow/
account_state.rs1use std::collections::HashMap;
17
18use arrow::{datatypes::Schema, error::ArrowError, record_batch::RecordBatch};
19use nautilus_model::events::AccountState;
20
21use super::{
22 ArrowSchemaProvider, DecodeTypedFromRecordBatch, EncodeToRecordBatch, EncodingError,
23 json::{JsonFieldSpec, decode_batch, encode_batch, metadata_for_type, schema_for_type},
24};
25
26const ACCOUNT_STATE_FIELDS: &[JsonFieldSpec] = &[
27 JsonFieldSpec::utf8("account_id", false),
28 JsonFieldSpec::utf8("account_type", false),
29 JsonFieldSpec::utf8("base_currency", true),
30 JsonFieldSpec::utf8_json("balances", false),
31 JsonFieldSpec::utf8_json("margins", false),
32 JsonFieldSpec::boolean("is_reported", false),
33 JsonFieldSpec::utf8("event_id", false),
34 JsonFieldSpec::u64("ts_event", false),
35 JsonFieldSpec::u64("ts_init", false),
36];
37
38impl ArrowSchemaProvider for AccountState {
39 fn get_schema(metadata: Option<HashMap<String, String>>) -> Schema {
40 schema_for_type("AccountState", metadata, ACCOUNT_STATE_FIELDS)
41 }
42}
43
44impl EncodeToRecordBatch for AccountState {
45 fn encode_batch(
46 metadata: &HashMap<String, String>,
47 data: &[Self],
48 ) -> Result<RecordBatch, ArrowError> {
49 encode_batch("AccountState", metadata, data, ACCOUNT_STATE_FIELDS)
50 }
51
52 fn metadata(&self) -> HashMap<String, String> {
53 metadata_for_type("AccountState")
54 }
55}
56
57impl DecodeTypedFromRecordBatch for AccountState {
58 fn decode_typed_batch(
59 metadata: &HashMap<String, String>,
60 record_batch: RecordBatch,
61 ) -> Result<Vec<Self>, EncodingError> {
62 decode_batch(
63 metadata,
64 &record_batch,
65 ACCOUNT_STATE_FIELDS,
66 Some("AccountState"),
67 )
68 }
69}
70
71#[cfg(test)]
72mod tests {
73 use nautilus_model::events::account::stubs::cash_account_state;
74 use rstest::rstest;
75
76 use super::*;
77
78 #[rstest]
79 fn test_account_state_round_trip(cash_account_state: AccountState) {
80 let state = cash_account_state;
81 let metadata = state.metadata();
82 let batch = AccountState::encode_batch(&metadata, std::slice::from_ref(&state)).unwrap();
83 let decoded = AccountState::decode_typed_batch(batch.schema().metadata(), batch).unwrap();
84
85 assert_eq!(decoded.len(), 1);
86 assert_eq!(decoded[0].account_id, state.account_id);
87 assert_eq!(decoded[0].balances, state.balances);
88 assert_eq!(decoded[0].margins, state.margins);
89 assert_eq!(decoded[0].base_currency, state.base_currency);
90 }
91}