1use std::{path::Path, sync::Arc};
17
18use ahash::AHashMap;
19use futures_util::{Stream, StreamExt, pin_mut};
20use nautilus_core::python::{IntoPyObjectNautilusExt, call_python, to_pyruntime_err};
21use nautilus_model::{
22 data::{Bar, Data, funding::FundingRateUpdate},
23 identifiers::InstrumentId,
24 python::data::data_to_pycapsule,
25};
26use pyo3::{prelude::*, types::PyList};
27
28use crate::{
29 config::BookSnapshotOutput,
30 machine::{
31 Error,
32 client::{TardisMachineClient, determine_instrument_info},
33 message::WsMessage,
34 parse::{parse_tardis_ws_message, parse_tardis_ws_message_funding_rate},
35 replay_normalized, stream_normalized,
36 types::{
37 ReplayNormalizedRequestOptions, StreamNormalizedRequestOptions, TardisInstrumentKey,
38 TardisInstrumentMiniInfo,
39 },
40 },
41 replay::run_tardis_machine_replay_from_config,
42};
43
44#[pymethods]
45#[pyo3_stub_gen::derive::gen_stub_pymethods]
46impl ReplayNormalizedRequestOptions {
47 #[staticmethod]
48 #[pyo3(name = "from_json")]
49 fn py_from_json(#[gen_stub(override_type(type_repr = "bytes"))] data: &[u8]) -> Self {
50 serde_json::from_slice(data).expect("Failed to parse JSON")
51 }
52
53 #[pyo3(name = "from_json_array")]
54 #[staticmethod]
55 fn py_from_json_array(
56 #[gen_stub(override_type(type_repr = "bytes"))] data: &[u8],
57 ) -> Vec<Self> {
58 serde_json::from_slice(data).expect("Failed to parse JSON array")
59 }
60}
61
62#[pymethods]
63#[pyo3_stub_gen::derive::gen_stub_pymethods]
64impl StreamNormalizedRequestOptions {
65 #[staticmethod]
66 #[pyo3(name = "from_json")]
67 fn py_from_json(#[gen_stub(override_type(type_repr = "bytes"))] data: &[u8]) -> Self {
68 serde_json::from_slice(data).expect("Failed to parse JSON")
69 }
70
71 #[pyo3(name = "from_json_array")]
72 #[staticmethod]
73 fn py_from_json_array(
74 #[gen_stub(override_type(type_repr = "bytes"))] data: &[u8],
75 ) -> Vec<Self> {
76 serde_json::from_slice(data).expect("Failed to parse JSON array")
77 }
78}
79
80#[pymethods]
81#[pyo3_stub_gen::derive::gen_stub_pymethods]
82impl TardisMachineClient {
83 #[new]
85 #[pyo3(signature = (base_url=None, normalize_symbols=true, book_snapshot_output="deltas"))]
86 fn py_new(
87 base_url: Option<&str>,
88 normalize_symbols: bool,
89 book_snapshot_output: &str,
90 ) -> PyResult<Self> {
91 let output = match book_snapshot_output {
92 "depth10" => BookSnapshotOutput::Depth10,
93 "deltas" => BookSnapshotOutput::Deltas,
94 _ => {
95 return Err(to_pyruntime_err(anyhow::anyhow!(
96 "Invalid book_snapshot_output: '{book_snapshot_output}'. Expected 'depth10' or 'deltas'"
97 )));
98 }
99 };
100 Self::new(base_url, normalize_symbols, output).map_err(to_pyruntime_err)
101 }
102
103 #[pyo3(name = "is_closed")]
108 #[must_use]
109 pub fn py_is_closed(&self) -> bool {
110 self.is_closed()
111 }
112
113 #[pyo3(name = "close")]
114 fn py_close(&mut self) {
115 self.close();
116 }
117
118 #[pyo3(name = "replay")]
120 fn py_replay<'py>(
121 &self,
122 instruments: Vec<TardisInstrumentMiniInfo>,
123 options: Vec<ReplayNormalizedRequestOptions>,
124 callback: Py<PyAny>,
125 py: Python<'py>,
126 ) -> PyResult<Bound<'py, PyAny>> {
127 let map = if instruments.is_empty() {
128 self.instruments.clone()
129 } else {
130 let mut instrument_map: AHashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>> =
131 AHashMap::new();
132
133 for inst in instruments {
134 let key = inst.as_tardis_instrument_key();
135 instrument_map.insert(key, Arc::new(inst.clone()));
136 }
137 instrument_map
138 };
139
140 let base_url = self.base_url.clone();
141 let replay_signal = self.replay_signal.clone();
142 let book_snapshot_output = self.book_snapshot_output.clone();
143
144 pyo3_async_runtimes::tokio::future_into_py(py, async move {
145 let stream = replay_normalized(&base_url, options, replay_signal)
146 .await
147 .map_err(to_pyruntime_err)?;
148
149 handle_python_stream(
152 Box::pin(stream),
153 callback,
154 None,
155 Some(map),
156 book_snapshot_output,
157 )
158 .await;
159 Ok(())
160 })
161 }
162
163 #[pyo3(name = "replay_bars")]
164 fn py_replay_bars<'py>(
165 &self,
166 instruments: Vec<TardisInstrumentMiniInfo>,
167 options: Vec<ReplayNormalizedRequestOptions>,
168 py: Python<'py>,
169 ) -> PyResult<Bound<'py, PyAny>> {
170 let map = if instruments.is_empty() {
171 self.instruments.clone()
172 } else {
173 instruments
174 .into_iter()
175 .map(|inst| (inst.as_tardis_instrument_key(), Arc::new(inst)))
176 .collect()
177 };
178
179 let base_url = self.base_url.clone();
180 let replay_signal = self.replay_signal.clone();
181 let book_snapshot_output = self.book_snapshot_output.clone();
182
183 pyo3_async_runtimes::tokio::future_into_py(py, async move {
184 let stream = replay_normalized(&base_url, options, replay_signal)
185 .await
186 .map_err(to_pyruntime_err)?;
187
188 pin_mut!(stream);
191
192 let mut bars: Vec<Bar> = Vec::new();
193
194 while let Some(result) = stream.next().await {
195 match result {
196 Ok(msg) => {
197 if let Some(Data::Bar(bar)) = determine_instrument_info(&msg, &map)
198 .and_then(|info| {
199 parse_tardis_ws_message(msg, &info, &book_snapshot_output)
200 })
201 {
202 bars.push(bar);
203 }
204 }
205 Err(e) => {
206 log::error!("Error in WebSocket stream: {e:?}");
207 break;
208 }
209 }
210 }
211
212 Python::attach(|py| {
213 let pylist =
214 PyList::new(py, bars.into_iter().map(|bar| bar.into_py_any_unwrap(py)))
215 .expect("Invalid `ExactSizeIterator`");
216 Ok(pylist.into_py_any_unwrap(py))
217 })
218 })
219 }
220
221 #[pyo3(name = "stream")]
223 fn py_stream<'py>(
224 &self,
225 instruments: Vec<TardisInstrumentMiniInfo>,
226 options: Vec<StreamNormalizedRequestOptions>,
227 callback: Py<PyAny>,
228 py: Python<'py>,
229 ) -> PyResult<Bound<'py, PyAny>> {
230 let mut instrument_map: AHashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>> =
231 AHashMap::new();
232
233 for inst in instruments {
234 let key = inst.as_tardis_instrument_key();
235 instrument_map.insert(key, Arc::new(inst.clone()));
236 }
237
238 let base_url = self.base_url.clone();
239 let replay_signal = self.replay_signal.clone();
240 let book_snapshot_output = self.book_snapshot_output.clone();
241
242 pyo3_async_runtimes::tokio::future_into_py(py, async move {
243 let stream = stream_normalized(&base_url, options, replay_signal)
244 .await
245 .map_err(to_pyruntime_err)?;
246
247 handle_python_stream(
250 Box::pin(stream),
251 callback,
252 None,
253 Some(instrument_map),
254 book_snapshot_output,
255 )
256 .await;
257 Ok(())
258 })
259 }
260}
261
262#[pyfunction]
268#[pyo3_stub_gen::derive::gen_stub_pyfunction(module = "nautilus_trader.tardis")]
269#[pyo3(name = "run_tardis_machine_replay")]
270#[pyo3(signature = (config_filepath))]
271pub fn py_run_tardis_machine_replay(
272 py: Python<'_>,
273 config_filepath: String,
274) -> PyResult<Bound<'_, PyAny>> {
275 nautilus_common::logging::ensure_logging_initialized();
276
277 pyo3_async_runtimes::tokio::future_into_py(py, async move {
278 let config_filepath = Path::new(&config_filepath);
279 run_tardis_machine_replay_from_config(config_filepath)
280 .await
281 .map_err(to_pyruntime_err)?;
282 Ok(())
283 })
284}
285
286async fn handle_python_stream<S>(
287 stream: S,
288 callback: Py<PyAny>,
289 instrument: Option<Arc<TardisInstrumentMiniInfo>>,
290 instrument_map: Option<AHashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>>>,
291 book_snapshot_output: BookSnapshotOutput,
292) where
293 S: Stream<Item = Result<WsMessage, Error>> + Unpin,
294{
295 pin_mut!(stream);
296
297 let mut funding_rate_cache: AHashMap<InstrumentId, FundingRateUpdate> = AHashMap::new();
299
300 while let Some(result) = stream.next().await {
301 match result {
302 Ok(msg) => {
303 let info = instrument.clone().or_else(|| {
304 instrument_map
305 .as_ref()
306 .and_then(|map| determine_instrument_info(&msg, map))
307 });
308
309 if let Some(info) = info.clone() {
310 if let Some(data) =
311 parse_tardis_ws_message(msg.clone(), &info, &book_snapshot_output)
312 {
313 Python::attach(|py| {
314 let py_obj = data_to_pycapsule(py, data);
315 call_python(py, &callback, py_obj);
316 });
317 } else if let Some(funding_rate) =
318 parse_tardis_ws_message_funding_rate(msg, &info)
319 {
320 let should_emit = if let Some(cached_rate) =
322 funding_rate_cache.get(&funding_rate.instrument_id)
323 {
324 if cached_rate == &funding_rate {
326 false } else {
328 funding_rate_cache.insert(funding_rate.instrument_id, funding_rate);
329 true
330 }
331 } else {
332 funding_rate_cache.insert(funding_rate.instrument_id, funding_rate);
334 true
335 };
336
337 if should_emit {
338 Python::attach(|py| {
339 let py_obj = funding_rate.into_py_any_unwrap(py);
340 call_python(py, &callback, py_obj);
341 });
342 }
343 }
344 }
345 }
346 Err(e) => {
347 log::error!("Error in WebSocket stream: {e:?}");
348 break;
349 }
350 }
351 }
352}