1use std::sync::{
17 Arc,
18 atomic::{AtomicBool, Ordering},
19};
20
21use ahash::AHashMap;
22use futures_util::{Stream, StreamExt, pin_mut};
23use nautilus_model::data::Data;
24use ustr::Ustr;
25
26use super::{
27 Error,
28 message::WsMessage,
29 replay_normalized, stream_normalized,
30 types::{
31 ReplayNormalizedRequestOptions, StreamNormalizedRequestOptions, TardisInstrumentKey,
32 TardisInstrumentMiniInfo,
33 },
34};
35use crate::{
36 common::urls::resolve_ws_base_url, config::BookSnapshotOutput,
37 machine::parse::parse_tardis_ws_message,
38};
39
40#[cfg_attr(
42 feature = "python",
43 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.tardis", from_py_object)
44)]
45#[cfg_attr(
46 feature = "python",
47 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.tardis")
48)]
49#[derive(Debug, Clone)]
50pub struct TardisMachineClient {
51 pub base_url: String,
52 pub replay_signal: Arc<AtomicBool>,
53 pub stream_signal: Arc<AtomicBool>,
54 pub instruments: AHashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>>,
55 pub normalize_symbols: bool,
56 pub book_snapshot_output: BookSnapshotOutput,
57}
58
59impl TardisMachineClient {
60 pub fn new(
66 base_url: Option<&str>,
67 normalize_symbols: bool,
68 book_snapshot_output: BookSnapshotOutput,
69 ) -> anyhow::Result<Self> {
70 let base_url = resolve_ws_base_url(base_url)?;
71
72 Ok(Self {
73 base_url,
74 replay_signal: Arc::new(AtomicBool::new(false)),
75 stream_signal: Arc::new(AtomicBool::new(false)),
76 instruments: AHashMap::new(),
77 normalize_symbols,
78 book_snapshot_output,
79 })
80 }
81
82 pub fn add_instrument_info(&mut self, info: TardisInstrumentMiniInfo) {
83 let key = info.as_tardis_instrument_key();
84 self.instruments.insert(key, Arc::new(info));
85 }
86
87 #[must_use]
92 pub fn is_closed(&self) -> bool {
93 self.replay_signal.load(Ordering::Acquire) && self.stream_signal.load(Ordering::Acquire)
95 }
96
97 pub fn close(&mut self) {
98 log::debug!("Closing");
99
100 self.replay_signal.store(true, Ordering::Release);
102 self.stream_signal.store(true, Ordering::Release);
103
104 log::debug!("Closed");
105 }
106
107 pub async fn replay(
113 &self,
114 options: Vec<ReplayNormalizedRequestOptions>,
115 ) -> Result<impl Stream<Item = Result<Data, Error>>, Error> {
116 let stream = replay_normalized(&self.base_url, options, self.replay_signal.clone()).await?;
117
118 Ok(handle_ws_stream(
121 Box::pin(stream),
122 None,
123 Some(self.instruments.clone()),
124 self.book_snapshot_output.clone(),
125 ))
126 }
127
128 pub async fn stream(
134 &self,
135 instrument: TardisInstrumentMiniInfo,
136 options: Vec<StreamNormalizedRequestOptions>,
137 ) -> Result<impl Stream<Item = Result<Data, Error>>, Error> {
138 let stream = stream_normalized(&self.base_url, options, self.stream_signal.clone()).await?;
139
140 Ok(handle_ws_stream(
143 Box::pin(stream),
144 Some(Arc::new(instrument)),
145 None,
146 self.book_snapshot_output.clone(),
147 ))
148 }
149}
150
151fn handle_ws_stream<S>(
152 stream: S,
153 instrument: Option<Arc<TardisInstrumentMiniInfo>>,
154 instrument_map: Option<AHashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>>>,
155 book_snapshot_output: BookSnapshotOutput,
156) -> impl Stream<Item = Result<Data, Error>>
157where
158 S: Stream<Item = Result<WsMessage, Error>> + Unpin,
159{
160 assert!(
161 instrument.is_some() || instrument_map.is_some(),
162 "Either `instrument` or `instrument_map` must be provided"
163 );
164
165 async_stream::stream! {
166 pin_mut!(stream);
167
168 while let Some(result) = stream.next().await {
169 match result {
170 Ok(msg) => {
171 if matches!(msg, WsMessage::Disconnect(_)) {
172 log::debug!("Received disconnect message: {msg:?}");
173 continue;
174 }
175
176 let info = instrument.clone().or_else(|| {
177 instrument_map
178 .as_ref()
179 .and_then(|map| determine_instrument_info(&msg, map))
180 });
181
182 if let Some(info) = info {
183 if let Some(data) = parse_tardis_ws_message(msg, &info, &book_snapshot_output) {
184 yield Ok(data);
185 }
186 } else {
187 log::error!("Missing instrument info for message: {msg:?}");
188 yield Err(Error::ConnectionClosed {
189 reason: "Missing instrument definition info".to_string()
190 });
191 break;
192 }
193 }
194 Err(e) => {
195 log::error!("Error in WebSocket stream: {e:?}");
196 yield Err(e);
197 break;
198 }
199 }
200 }
201 }
202}
203
204pub fn determine_instrument_info(
205 msg: &WsMessage,
206 instrument_map: &AHashMap<TardisInstrumentKey, Arc<TardisInstrumentMiniInfo>>,
207) -> Option<Arc<TardisInstrumentMiniInfo>> {
208 let key = match msg {
209 WsMessage::BookChange(msg) => {
210 TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange)
211 }
212 WsMessage::BookSnapshot(msg) => {
213 TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange)
214 }
215 WsMessage::Trade(msg) => TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange),
216 WsMessage::TradeBar(msg) => TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange),
217 WsMessage::DerivativeTicker(msg) => {
218 TardisInstrumentKey::new(Ustr::from(&msg.symbol), msg.exchange)
219 }
220 WsMessage::Disconnect(_) => return None,
221 };
222
223 if let Some(inst) = instrument_map.get(&key) {
224 Some(inst.clone())
225 } else {
226 log::error!("Instrument definition info not available for {key:?}");
227 None
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use rstest::rstest;
234
235 use super::*;
236
237 #[rstest]
238 fn test_is_closed_initial_state() {
239 let client = TardisMachineClient::new(
240 Some("ws://localhost:8001"),
241 false,
242 BookSnapshotOutput::Deltas,
243 )
244 .unwrap();
245 assert!(!client.is_closed());
247 }
248
249 #[rstest]
250 fn test_is_closed_after_close() {
251 let mut client = TardisMachineClient::new(
252 Some("ws://localhost:8001"),
253 false,
254 BookSnapshotOutput::Deltas,
255 )
256 .unwrap();
257 client.close();
258 assert!(client.is_closed());
260 }
261
262 #[rstest]
263 fn test_is_closed_partial_signal() {
264 let client = TardisMachineClient::new(
265 Some("ws://localhost:8001"),
266 false,
267 BookSnapshotOutput::Deltas,
268 )
269 .unwrap();
270 client.replay_signal.store(true, Ordering::Release);
273 assert!(!client.is_closed());
274
275 client.stream_signal.store(true, Ordering::Release);
276 assert!(client.is_closed());
278 }
279}