nautilus_serialization/arrow/display/
account_state.rs1use std::sync::Arc;
19
20use arrow::{
21 array::{BooleanBuilder, StringBuilder, TimestampNanosecondBuilder},
22 datatypes::Schema,
23 error::ArrowError,
24 record_batch::RecordBatch,
25};
26use nautilus_model::events::AccountState;
27
28use super::{bool_field, timestamp_field, unix_nanos_to_i64, utf8_field};
29
30#[must_use]
32pub fn account_state_schema() -> Schema {
33 Schema::new(vec![
34 utf8_field("account_id", false),
35 utf8_field("account_type", false),
36 utf8_field("base_currency", true),
37 utf8_field("balances", false),
38 utf8_field("margins", false),
39 bool_field("is_reported", false),
40 utf8_field("event_id", false),
41 timestamp_field("ts_event", false),
42 timestamp_field("ts_init", false),
43 ])
44}
45
46fn balances_to_json(state: &AccountState) -> String {
47 let entries: Vec<serde_json::Value> = state
48 .balances
49 .iter()
50 .map(|b| {
51 serde_json::json!({
52 "currency": b.currency.to_string(),
53 "total": b.total.as_f64(),
54 "locked": b.locked.as_f64(),
55 "free": b.free.as_f64(),
56 })
57 })
58 .collect();
59 serde_json::to_string(&entries).unwrap_or_default()
60}
61
62fn margins_to_json(state: &AccountState) -> String {
63 let entries: Vec<serde_json::Value> = state
64 .margins
65 .iter()
66 .map(|m| {
67 serde_json::json!({
68 "instrument_id": m.instrument_id.map(|id| id.to_string()),
69 "currency": m.currency.to_string(),
70 "initial": m.initial.as_f64(),
71 "maintenance": m.maintenance.as_f64(),
72 })
73 })
74 .collect();
75 serde_json::to_string(&entries).unwrap_or_default()
76}
77
78pub fn encode_account_states(data: &[AccountState]) -> Result<RecordBatch, ArrowError> {
91 let mut account_id = StringBuilder::new();
92 let mut account_type = StringBuilder::new();
93 let mut base_currency = StringBuilder::new();
94 let mut balances = StringBuilder::new();
95 let mut margins = StringBuilder::new();
96 let mut is_reported = BooleanBuilder::with_capacity(data.len());
97 let mut event_id = StringBuilder::new();
98 let mut ts_event = TimestampNanosecondBuilder::with_capacity(data.len());
99 let mut ts_init = TimestampNanosecondBuilder::with_capacity(data.len());
100
101 for state in data {
102 account_id.append_value(state.account_id);
103 account_type.append_value(format!("{}", state.account_type));
104 base_currency.append_option(state.base_currency.map(|v| v.to_string()));
105 balances.append_value(balances_to_json(state));
106 margins.append_value(margins_to_json(state));
107 is_reported.append_value(state.is_reported);
108 event_id.append_value(state.event_id.to_string());
109 ts_event.append_value(unix_nanos_to_i64(state.ts_event.as_u64()));
110 ts_init.append_value(unix_nanos_to_i64(state.ts_init.as_u64()));
111 }
112
113 RecordBatch::try_new(
114 Arc::new(account_state_schema()),
115 vec![
116 Arc::new(account_id.finish()),
117 Arc::new(account_type.finish()),
118 Arc::new(base_currency.finish()),
119 Arc::new(balances.finish()),
120 Arc::new(margins.finish()),
121 Arc::new(is_reported.finish()),
122 Arc::new(event_id.finish()),
123 Arc::new(ts_event.finish()),
124 Arc::new(ts_init.finish()),
125 ],
126 )
127}
128
129#[cfg(test)]
130mod tests {
131 use arrow::{
132 array::{Array, BooleanArray, StringArray, TimestampNanosecondArray},
133 datatypes::{DataType, TimeUnit},
134 };
135 use nautilus_core::UUID4;
136 use nautilus_model::{
137 enums::AccountType,
138 identifiers::AccountId,
139 types::{AccountBalance, Currency, Money},
140 };
141 use rstest::rstest;
142
143 use super::*;
144
145 fn make_account_state(ts: u64) -> AccountState {
146 let currency = Currency::USD();
147 let balance = AccountBalance::new(
148 Money::new(10_000.0, currency),
149 Money::new(1_000.0, currency),
150 Money::new(9_000.0, currency),
151 );
152 AccountState {
153 account_id: AccountId::from("SIM-001"),
154 account_type: AccountType::Cash,
155 base_currency: Some(currency),
156 balances: vec![balance],
157 margins: vec![],
158 is_reported: false,
159 event_id: UUID4::default(),
160 ts_event: ts.into(),
161 ts_init: (ts + 1).into(),
162 }
163 }
164
165 #[rstest]
166 fn test_encode_account_states_schema() {
167 let batch = encode_account_states(&[]).unwrap();
168 let schema = batch.schema();
169 let fields = schema.fields();
170 assert_eq!(fields.len(), 9);
171 assert_eq!(fields[0].name(), "account_id");
172 assert_eq!(fields[0].data_type(), &DataType::Utf8);
173 assert_eq!(fields[5].name(), "is_reported");
174 assert_eq!(fields[5].data_type(), &DataType::Boolean);
175 assert_eq!(fields[7].name(), "ts_event");
176 assert_eq!(
177 fields[7].data_type(),
178 &DataType::Timestamp(TimeUnit::Nanosecond, None)
179 );
180 }
181
182 #[rstest]
183 fn test_encode_account_states_values() {
184 let states = vec![make_account_state(1_000_000)];
185 let batch = encode_account_states(&states).unwrap();
186
187 assert_eq!(batch.num_rows(), 1);
188
189 let account_id_col = batch
190 .column(0)
191 .as_any()
192 .downcast_ref::<StringArray>()
193 .unwrap();
194 let is_reported_col = batch
195 .column(5)
196 .as_any()
197 .downcast_ref::<BooleanArray>()
198 .unwrap();
199 let ts_event_col = batch
200 .column(7)
201 .as_any()
202 .downcast_ref::<TimestampNanosecondArray>()
203 .unwrap();
204 let balances_col = batch
205 .column(3)
206 .as_any()
207 .downcast_ref::<StringArray>()
208 .unwrap();
209
210 assert_eq!(account_id_col.value(0), "SIM-001");
211 assert!(!is_reported_col.value(0));
212 assert_eq!(ts_event_col.value(0), 1_000_000);
213
214 let balances: Vec<serde_json::Value> = serde_json::from_str(balances_col.value(0)).unwrap();
215 assert_eq!(balances.len(), 1);
216 assert_eq!(balances[0]["currency"], "USD");
217 assert!((balances[0]["total"].as_f64().unwrap() - 10_000.0).abs() < 1e-9);
218 assert!((balances[0]["locked"].as_f64().unwrap() - 1_000.0).abs() < 1e-9);
219 assert!((balances[0]["free"].as_f64().unwrap() - 9_000.0).abs() < 1e-9);
220 }
221
222 #[rstest]
223 fn test_encode_account_states_empty() {
224 let batch = encode_account_states(&[]).unwrap();
225 assert_eq!(batch.num_rows(), 0);
226 assert_eq!(batch.schema().fields().len(), 9);
227 }
228
229 #[rstest]
230 fn test_encode_account_states_null_base_currency() {
231 let mut state = make_account_state(1_000);
232 state.base_currency = None;
233 let batch = encode_account_states(&[state]).unwrap();
234
235 let base_currency_col = batch
236 .column(2)
237 .as_any()
238 .downcast_ref::<StringArray>()
239 .unwrap();
240 assert!(base_currency_col.is_null(0));
241 }
242}