1use std::{
19 fmt::Debug,
20 sync::{
21 Arc, Mutex,
22 atomic::{AtomicBool, AtomicI64, AtomicU8, Ordering},
23 },
24 time::Duration,
25};
26
27use ahash::AHashSet;
28use arc_swap::ArcSwap;
29use nautilus_common::live::get_runtime;
30use nautilus_core::{AtomicMap, consts::NAUTILUS_USER_AGENT};
31use nautilus_network::{
32 backoff::ExponentialBackoff,
33 mode::ConnectionMode,
34 websocket::{
35 PingHandler, SubscriptionState, TransportBackend, WebSocketClient, WebSocketConfig,
36 channel_message_handler,
37 },
38};
39use ustr::Ustr;
40
41use super::handler::{AxMdWsFeedHandler, HandlerCommand};
42use crate::{
43 common::enums::{AxCandleWidth, AxMarketDataLevel},
44 websocket::messages::AxDataWsMessage,
45};
46
47const AX_TOPIC_DELIMITER: char = ':';
49
50pub type AxWsResult<T> = Result<T, AxWsClientError>;
52
53#[derive(Debug, Clone)]
55pub enum AxWsClientError {
56 Transport(String),
58 ChannelError(String),
60}
61
62impl core::fmt::Display for AxWsClientError {
63 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
64 match self {
65 Self::Transport(msg) => write!(f, "Transport error: {msg}"),
66 Self::ChannelError(msg) => write!(f, "Channel error: {msg}"),
67 }
68 }
69}
70
71impl std::error::Error for AxWsClientError {}
72
73#[derive(Debug, Default, Clone)]
74pub struct SymbolDataTypes {
75 pub quotes: bool,
76 pub trades: bool,
77 pub mark_prices: bool,
78 pub instrument_status: bool,
79 pub book_level: Option<AxMarketDataLevel>,
80}
81
82impl SymbolDataTypes {
83 pub fn effective_level(&self) -> Option<AxMarketDataLevel> {
84 if let Some(level) = self.book_level {
85 return Some(level);
86 }
87
88 if self.quotes || self.trades || self.mark_prices || self.instrument_status {
89 return Some(AxMarketDataLevel::Level1);
90 }
91 None
92 }
93
94 fn is_empty(&self) -> bool {
95 !self.quotes
96 && !self.trades
97 && !self.mark_prices
98 && !self.instrument_status
99 && self.book_level.is_none()
100 }
101}
102
103pub struct AxMdWebSocketClient {
108 url: String,
109 heartbeat: Option<u64>,
110 auth_token: Option<String>,
111 connection_mode: Arc<ArcSwap<AtomicU8>>,
112 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
113 out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<AxDataWsMessage>>>,
114 signal: Arc<AtomicBool>,
115 task_handle: Option<tokio::task::JoinHandle<()>>,
116 subscriptions: SubscriptionState,
117 request_id_counter: Arc<AtomicI64>,
118 subscribe_lock: Arc<tokio::sync::Mutex<()>>,
119 symbol_data_types: Arc<AtomicMap<String, SymbolDataTypes>>,
120 status_invalidations: Arc<Mutex<AHashSet<Ustr>>>,
121 transport_backend: TransportBackend,
122 proxy_url: Option<String>,
123}
124
125impl Debug for AxMdWebSocketClient {
126 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
127 f.debug_struct(stringify!(AxMdWebSocketClient))
128 .field("url", &self.url)
129 .field("heartbeat", &self.heartbeat)
130 .field("confirmed_subscriptions", &self.subscriptions.len())
131 .finish()
132 }
133}
134
135impl Clone for AxMdWebSocketClient {
136 fn clone(&self) -> Self {
137 Self {
138 url: self.url.clone(),
139 heartbeat: self.heartbeat,
140 auth_token: self.auth_token.clone(),
141 connection_mode: Arc::clone(&self.connection_mode),
142 cmd_tx: Arc::clone(&self.cmd_tx),
143 out_rx: None,
144 signal: Arc::clone(&self.signal),
145 task_handle: None,
146 subscriptions: self.subscriptions.clone(),
147 subscribe_lock: Arc::clone(&self.subscribe_lock),
148 request_id_counter: Arc::clone(&self.request_id_counter),
149 symbol_data_types: Arc::clone(&self.symbol_data_types),
150 status_invalidations: Arc::clone(&self.status_invalidations),
151 transport_backend: self.transport_backend,
152 proxy_url: self.proxy_url.clone(),
153 }
154 }
155}
156
157impl AxMdWebSocketClient {
158 #[must_use]
162 pub fn new(
163 url: String,
164 auth_token: String,
165 heartbeat: u64,
166 transport_backend: TransportBackend,
167 proxy_url: Option<String>,
168 ) -> Self {
169 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
170
171 let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
172 let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
173
174 Self {
175 url,
176 heartbeat: Some(heartbeat),
177 auth_token: Some(auth_token),
178 connection_mode,
179 cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
180 out_rx: None,
181 signal: Arc::new(AtomicBool::new(false)),
182 task_handle: None,
183 subscriptions: SubscriptionState::new(AX_TOPIC_DELIMITER),
184 request_id_counter: Arc::new(AtomicI64::new(1)),
185 subscribe_lock: Arc::new(tokio::sync::Mutex::new(())),
186 symbol_data_types: Arc::new(AtomicMap::new()),
187 status_invalidations: Arc::new(Mutex::new(AHashSet::new())),
188 transport_backend,
189 proxy_url,
190 }
191 }
192
193 #[must_use]
197 pub fn without_auth(
198 url: String,
199 heartbeat: u64,
200 transport_backend: TransportBackend,
201 proxy_url: Option<String>,
202 ) -> Self {
203 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
204
205 let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
206 let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
207
208 Self {
209 url,
210 heartbeat: Some(heartbeat),
211 auth_token: None,
212 connection_mode,
213 cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
214 out_rx: None,
215 signal: Arc::new(AtomicBool::new(false)),
216 task_handle: None,
217 subscriptions: SubscriptionState::new(AX_TOPIC_DELIMITER),
218 request_id_counter: Arc::new(AtomicI64::new(1)),
219 subscribe_lock: Arc::new(tokio::sync::Mutex::new(())),
220 symbol_data_types: Arc::new(AtomicMap::new()),
221 status_invalidations: Arc::new(Mutex::new(AHashSet::new())),
222 transport_backend,
223 proxy_url,
224 }
225 }
226
227 #[must_use]
229 pub fn url(&self) -> &str {
230 &self.url
231 }
232
233 pub fn set_auth_token(&mut self, token: String) {
237 self.auth_token = Some(token);
238 }
239
240 #[must_use]
242 pub fn is_active(&self) -> bool {
243 let connection_mode_arc = self.connection_mode.load();
244 ConnectionMode::from_atomic(&connection_mode_arc).is_active()
245 && !self.signal.load(Ordering::Acquire)
246 }
247
248 #[must_use]
250 pub fn is_closed(&self) -> bool {
251 let connection_mode_arc = self.connection_mode.load();
252 ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
253 || self.signal.load(Ordering::Acquire)
254 }
255
256 #[must_use]
258 pub fn subscription_count(&self) -> usize {
259 self.subscriptions.len()
260 }
261
262 #[must_use]
264 pub fn symbol_data_types(&self) -> Arc<AtomicMap<String, SymbolDataTypes>> {
265 Arc::clone(&self.symbol_data_types)
266 }
267
268 pub fn status_invalidations(&self) -> Arc<Mutex<AHashSet<Ustr>>> {
270 Arc::clone(&self.status_invalidations)
271 }
272
273 fn next_request_id(&self) -> i64 {
274 self.request_id_counter.fetch_add(1, Ordering::Relaxed)
275 }
276
277 fn is_subscribed_topic(&self, topic: &str) -> bool {
278 let (channel, symbol) = topic
279 .split_once(AX_TOPIC_DELIMITER)
280 .map_or((topic, None), |(c, s)| (c, Some(s)));
281 let channel_ustr = Ustr::from(channel);
282 let symbol_ustr = symbol.map_or_else(|| Ustr::from(""), Ustr::from);
283 self.subscriptions
284 .is_subscribed(&channel_ustr, &symbol_ustr)
285 }
286
287 pub async fn connect(&mut self) -> AxWsResult<()> {
292 const MAX_RETRIES: u32 = 5;
293 const CONNECTION_TIMEOUT_SECS: u64 = 10;
294
295 self.signal.store(false, Ordering::Release);
296
297 let (raw_handler, raw_rx) = channel_message_handler();
298
299 let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {});
301
302 let mut headers = vec![("User-Agent".to_string(), NAUTILUS_USER_AGENT.to_string())];
303
304 if let Some(ref token) = self.auth_token {
305 headers.push(("Authorization".to_string(), format!("Bearer {token}")));
306 }
307
308 let config = WebSocketConfig {
309 url: self.url.clone(),
310 headers,
311 heartbeat: self.heartbeat,
312 heartbeat_msg: None, reconnect_timeout_ms: Some(5_000),
314 reconnect_delay_initial_ms: Some(500),
315 reconnect_delay_max_ms: Some(5_000),
316 reconnect_backoff_factor: Some(1.5),
317 reconnect_jitter_ms: Some(250),
318 reconnect_max_attempts: None,
319 idle_timeout_ms: None,
320 backend: self.transport_backend,
321 proxy_url: self.proxy_url.clone(),
322 };
323
324 let mut backoff = ExponentialBackoff::new(
326 Duration::from_millis(500),
327 Duration::from_millis(5000),
328 2.0,
329 250,
330 false,
331 )
332 .map_err(|e| AxWsClientError::Transport(e.to_string()))?;
333
334 let mut last_error: String;
335 let mut attempt = 0;
336
337 let client = loop {
338 attempt += 1;
339
340 match tokio::time::timeout(
341 Duration::from_secs(CONNECTION_TIMEOUT_SECS),
342 WebSocketClient::connect(
343 config.clone(),
344 Some(raw_handler.clone()),
345 Some(ping_handler.clone()),
346 None,
347 vec![],
348 None,
349 ),
350 )
351 .await
352 {
353 Ok(Ok(client)) => {
354 if attempt > 1 {
355 log::info!("WebSocket connection established after {attempt} attempts");
356 }
357 break client;
358 }
359 Ok(Err(e)) => {
360 last_error = e.to_string();
361 log::warn!(
362 "WebSocket connection attempt failed: attempt={attempt}/{MAX_RETRIES}, url={}, error={last_error}",
363 self.url
364 );
365 }
366 Err(_) => {
367 last_error = format!("Connection timeout after {CONNECTION_TIMEOUT_SECS}s");
368 log::warn!(
369 "WebSocket connection attempt timed out: attempt={attempt}/{MAX_RETRIES}, url={}",
370 self.url
371 );
372 }
373 }
374
375 if attempt >= MAX_RETRIES {
376 return Err(AxWsClientError::Transport(format!(
377 "Failed to connect to {} after {MAX_RETRIES} attempts: {}",
378 self.url,
379 if last_error.is_empty() {
380 "unknown error"
381 } else {
382 &last_error
383 }
384 )));
385 }
386
387 let delay = backoff.next_duration();
388 log::debug!(
389 "Retrying in {delay:?} (attempt {}/{MAX_RETRIES})",
390 attempt + 1
391 );
392 tokio::time::sleep(delay).await;
393 };
394
395 self.connection_mode.store(client.connection_mode_atomic());
396
397 let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<AxDataWsMessage>();
398 self.out_rx = Some(Arc::new(out_rx));
399
400 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
401 *self.cmd_tx.write().await = cmd_tx.clone();
402
403 self.send_cmd(HandlerCommand::SetClient(client)).await?;
404
405 let signal = Arc::clone(&self.signal);
406 let subscriptions = self.subscriptions.clone();
407
408 let stream_handle = get_runtime().spawn(async move {
409 let mut handler =
410 AxMdWsFeedHandler::new(signal.clone(), cmd_rx, raw_rx, subscriptions.clone());
411
412 while let Some(msg) = handler.next().await {
413 if matches!(msg, AxDataWsMessage::Reconnected) {
414 log::info!("WebSocket reconnected, subscriptions will be replayed");
415 }
416
417 if out_tx.send(msg).is_err() {
418 log::debug!("Output channel closed");
419 break;
420 }
421 }
422
423 log::debug!("Handler loop exited");
424 });
425
426 self.task_handle = Some(stream_handle);
427
428 Ok(())
429 }
430
431 pub async fn subscribe_book_deltas(
440 &self,
441 symbol: &str,
442 level: AxMarketDataLevel,
443 ) -> AxWsResult<()> {
444 let _guard = self.subscribe_lock.lock().await;
445
446 let current = self
447 .symbol_data_types
448 .load()
449 .get(symbol)
450 .cloned()
451 .unwrap_or_default();
452
453 if current.book_level.is_some() {
455 log::debug!("Book deltas already subscribed for {symbol}, skipping");
456 return Ok(());
457 }
458
459 let old_level = current.effective_level();
460 let mut next = current.clone();
461 next.book_level = Some(level);
462 let new_level = next.effective_level();
463
464 self.update_data_subscription(symbol, old_level, new_level)
465 .await?;
466
467 self.symbol_data_types.rcu(|m| {
468 let entry = m.entry(symbol.to_string()).or_default();
469 entry.book_level = Some(level);
470 });
471
472 Ok(())
473 }
474
475 pub async fn subscribe_quotes(&self, symbol: &str) -> AxWsResult<()> {
484 let _guard = self.subscribe_lock.lock().await;
485
486 let current = self
487 .symbol_data_types
488 .load()
489 .get(symbol)
490 .cloned()
491 .unwrap_or_default();
492 let old_level = current.effective_level();
493 let mut next = current.clone();
494 next.quotes = true;
495 let new_level = next.effective_level();
496
497 self.update_data_subscription(symbol, old_level, new_level)
498 .await?;
499
500 self.symbol_data_types.rcu(|m| {
501 m.entry(symbol.to_string()).or_default().quotes = true;
502 });
503
504 Ok(())
505 }
506
507 pub async fn subscribe_trades(&self, symbol: &str) -> AxWsResult<()> {
516 let _guard = self.subscribe_lock.lock().await;
517
518 let current = self
519 .symbol_data_types
520 .load()
521 .get(symbol)
522 .cloned()
523 .unwrap_or_default();
524 let old_level = current.effective_level();
525 let mut next = current.clone();
526 next.trades = true;
527 let new_level = next.effective_level();
528
529 self.update_data_subscription(symbol, old_level, new_level)
530 .await?;
531
532 self.symbol_data_types.rcu(|m| {
533 m.entry(symbol.to_string()).or_default().trades = true;
534 });
535
536 Ok(())
537 }
538
539 pub async fn unsubscribe_book_deltas(&self, symbol: &str) -> AxWsResult<()> {
548 let _guard = self.subscribe_lock.lock().await;
549
550 let Some(current) = self.symbol_data_types.load().get(symbol).cloned() else {
551 log::debug!("Symbol {symbol} not subscribed, skipping unsubscribe book deltas");
552 return Ok(());
553 };
554 let old_level = current.effective_level();
555 let mut next = current.clone();
556 next.book_level = None;
557 let new_level = next.effective_level();
558
559 self.update_data_subscription(symbol, old_level, new_level)
560 .await?;
561
562 self.symbol_data_types.rcu(|m| {
563 if let Some(entry) = m.get_mut(symbol) {
564 entry.book_level = None;
565 if entry.is_empty() {
566 m.remove(symbol);
567 }
568 }
569 });
570
571 Ok(())
572 }
573
574 pub async fn unsubscribe_quotes(&self, symbol: &str) -> AxWsResult<()> {
583 let _guard = self.subscribe_lock.lock().await;
584
585 let Some(current) = self.symbol_data_types.load().get(symbol).cloned() else {
586 log::debug!("Symbol {symbol} not subscribed, skipping unsubscribe quotes");
587 return Ok(());
588 };
589 let old_level = current.effective_level();
590 let mut next = current.clone();
591 next.quotes = false;
592 let new_level = next.effective_level();
593
594 self.update_data_subscription(symbol, old_level, new_level)
595 .await?;
596
597 self.symbol_data_types.rcu(|m| {
598 if let Some(entry) = m.get_mut(symbol) {
599 entry.quotes = false;
600 if entry.is_empty() {
601 m.remove(symbol);
602 }
603 }
604 });
605
606 Ok(())
607 }
608
609 pub async fn unsubscribe_trades(&self, symbol: &str) -> AxWsResult<()> {
618 let _guard = self.subscribe_lock.lock().await;
619
620 let Some(current) = self.symbol_data_types.load().get(symbol).cloned() else {
621 log::debug!("Symbol {symbol} not subscribed, skipping unsubscribe trades");
622 return Ok(());
623 };
624 let old_level = current.effective_level();
625 let mut next = current.clone();
626 next.trades = false;
627 let new_level = next.effective_level();
628
629 self.update_data_subscription(symbol, old_level, new_level)
630 .await?;
631
632 self.symbol_data_types.rcu(|m| {
633 if let Some(entry) = m.get_mut(symbol) {
634 entry.trades = false;
635 if entry.is_empty() {
636 m.remove(symbol);
637 }
638 }
639 });
640
641 Ok(())
642 }
643
644 pub async fn subscribe_mark_prices(&self, symbol: &str) -> AxWsResult<()> {
653 let _guard = self.subscribe_lock.lock().await;
654
655 let current = self
656 .symbol_data_types
657 .load()
658 .get(symbol)
659 .cloned()
660 .unwrap_or_default();
661 let old_level = current.effective_level();
662 let mut next = current.clone();
663 next.mark_prices = true;
664 let new_level = next.effective_level();
665
666 self.update_data_subscription(symbol, old_level, new_level)
667 .await?;
668
669 self.symbol_data_types.rcu(|m| {
670 m.entry(symbol.to_string()).or_default().mark_prices = true;
671 });
672
673 Ok(())
674 }
675
676 pub async fn unsubscribe_mark_prices(&self, symbol: &str) -> AxWsResult<()> {
685 let _guard = self.subscribe_lock.lock().await;
686
687 let Some(current) = self.symbol_data_types.load().get(symbol).cloned() else {
688 log::debug!("Symbol {symbol} not subscribed, skipping unsubscribe mark prices");
689 return Ok(());
690 };
691 let old_level = current.effective_level();
692 let mut next = current.clone();
693 next.mark_prices = false;
694 let new_level = next.effective_level();
695
696 self.update_data_subscription(symbol, old_level, new_level)
697 .await?;
698
699 self.symbol_data_types.rcu(|m| {
700 if let Some(entry) = m.get_mut(symbol) {
701 entry.mark_prices = false;
702 if entry.is_empty() {
703 m.remove(symbol);
704 }
705 }
706 });
707
708 Ok(())
709 }
710
711 pub async fn subscribe_instrument_status(&self, symbol: &str) -> AxWsResult<()> {
720 let _guard = self.subscribe_lock.lock().await;
721
722 let current = self
723 .symbol_data_types
724 .load()
725 .get(symbol)
726 .cloned()
727 .unwrap_or_default();
728 let old_level = current.effective_level();
729 let mut next = current.clone();
730 next.instrument_status = true;
731 let new_level = next.effective_level();
732
733 self.update_data_subscription(symbol, old_level, new_level)
734 .await?;
735
736 self.symbol_data_types.rcu(|m| {
737 m.entry(symbol.to_string()).or_default().instrument_status = true;
738 });
739
740 Ok(())
741 }
742
743 pub async fn unsubscribe_instrument_status(&self, symbol: &str) -> AxWsResult<()> {
752 let _guard = self.subscribe_lock.lock().await;
753
754 let Some(current) = self.symbol_data_types.load().get(symbol).cloned() else {
755 log::debug!("Symbol {symbol} not subscribed, skipping unsubscribe instrument status");
756 return Ok(());
757 };
758 let old_level = current.effective_level();
759 let mut next = current.clone();
760 next.instrument_status = false;
761 let new_level = next.effective_level();
762
763 self.update_data_subscription(symbol, old_level, new_level)
764 .await?;
765
766 self.symbol_data_types.rcu(|m| {
767 if let Some(entry) = m.get_mut(symbol) {
768 entry.instrument_status = false;
769 if entry.is_empty() {
770 m.remove(symbol);
771 }
772 }
773 });
774
775 if let Ok(mut invalidations) = self.status_invalidations.lock() {
776 invalidations.insert(Ustr::from(symbol));
777 }
778
779 Ok(())
780 }
781
782 async fn update_data_subscription(
783 &self,
784 symbol: &str,
785 old_level: Option<AxMarketDataLevel>,
786 new_level: Option<AxMarketDataLevel>,
787 ) -> AxWsResult<()> {
788 if old_level == new_level {
789 return Ok(());
790 }
791
792 match (old_level, new_level) {
793 (None, Some(level)) => {
794 log::debug!("Subscribing {symbol} at {level:?}");
795 self.send_subscribe(symbol, level).await
796 }
797 (Some(_), None) => {
798 log::debug!("Unsubscribing {symbol} (no remaining data types)");
799 self.send_unsubscribe(symbol).await
800 }
801 (Some(old), Some(new)) => {
802 log::debug!("Resubscribing {symbol}: {old:?} -> {new:?}");
803 self.send_unsubscribe(symbol).await?;
804 if let Err(e) = self.send_subscribe(symbol, new).await {
805 log::warn!("Resubscribe failed for {symbol} at {new:?}: {e}");
806 if let Err(restore_err) = self.send_subscribe(symbol, old).await {
807 log::error!(
809 "Failed to restore {symbol} at {old:?}: {restore_err}, \
810 reconnection required"
811 );
812 let old_topic = format!("{symbol}:{old:?}");
813 self.subscriptions.mark_subscribe(&old_topic);
814 }
815 return Err(e);
816 }
817 Ok(())
818 }
819 (None, None) => Ok(()),
820 }
821 }
822
823 async fn send_subscribe(&self, symbol: &str, level: AxMarketDataLevel) -> AxWsResult<()> {
824 let topic = format!("{symbol}:{level:?}");
825 let request_id = self.next_request_id();
826
827 self.subscriptions.mark_subscribe(&topic);
828
829 if let Err(e) = self
830 .send_cmd(HandlerCommand::Subscribe {
831 request_id,
832 symbol: Ustr::from(symbol),
833 level,
834 })
835 .await
836 {
837 self.subscriptions.mark_unsubscribe(&topic);
838 return Err(e);
839 }
840
841 Ok(())
842 }
843
844 async fn send_unsubscribe(&self, symbol: &str) -> AxWsResult<()> {
845 let request_id = self.next_request_id();
846
847 self.send_cmd(HandlerCommand::Unsubscribe {
848 request_id,
849 symbol: Ustr::from(symbol),
850 })
851 .await?;
852
853 for level in [
854 AxMarketDataLevel::Level1,
855 AxMarketDataLevel::Level2,
856 AxMarketDataLevel::Level3,
857 ] {
858 let topic = format!("{symbol}:{level:?}");
859 self.subscriptions.mark_unsubscribe(&topic);
860 }
861
862 Ok(())
863 }
864
865 pub async fn subscribe_candles(&self, symbol: &str, width: AxCandleWidth) -> AxWsResult<()> {
873 let _guard = self.subscribe_lock.lock().await;
874 let topic = format!("candles:{symbol}:{width:?}");
875
876 if self.is_subscribed_topic(&topic) {
878 log::debug!("Already subscribed to {topic}, skipping");
879 return Ok(());
880 }
881
882 let request_id = self.next_request_id();
883
884 self.subscriptions.mark_subscribe(&topic);
886
887 if let Err(e) = self
888 .send_cmd(HandlerCommand::SubscribeCandles {
889 request_id,
890 symbol: Ustr::from(symbol),
891 width,
892 })
893 .await
894 {
895 self.subscriptions.mark_unsubscribe(&topic);
897 return Err(e);
898 }
899
900 Ok(())
901 }
902
903 pub async fn unsubscribe_candles(&self, symbol: &str, width: AxCandleWidth) -> AxWsResult<()> {
909 let _guard = self.subscribe_lock.lock().await;
910 let request_id = self.next_request_id();
911 let topic = format!("candles:{symbol}:{width:?}");
912
913 self.subscriptions.mark_unsubscribe(&topic);
914
915 self.send_cmd(HandlerCommand::UnsubscribeCandles {
916 request_id,
917 symbol: Ustr::from(symbol),
918 width,
919 })
920 .await
921 }
922
923 pub fn stream(&mut self) -> impl futures_util::Stream<Item = AxDataWsMessage> + 'static {
929 let rx = self
930 .out_rx
931 .take()
932 .expect("Stream receiver already taken or client not connected - stream() can only be called once");
933 let mut rx = Arc::try_unwrap(rx).expect(
934 "Cannot take ownership of stream - client was cloned and other references exist",
935 );
936 async_stream::stream! {
937 while let Some(msg) = rx.recv().await {
938 yield msg;
939 }
940 }
941 }
942
943 pub async fn disconnect(&self) {
945 log::debug!("Disconnecting WebSocket");
946 let _ = self.send_cmd(HandlerCommand::Disconnect).await;
947 }
948
949 pub async fn close(&mut self) {
951 log::debug!("Closing WebSocket client");
952
953 let _ = self.send_cmd(HandlerCommand::Disconnect).await;
955 tokio::time::sleep(Duration::from_millis(50)).await;
956 self.signal.store(true, Ordering::Release);
957
958 if let Some(handle) = self.task_handle.take() {
959 const CLOSE_TIMEOUT: Duration = Duration::from_secs(2);
960 let abort_handle = handle.abort_handle();
961
962 match tokio::time::timeout(CLOSE_TIMEOUT, handle).await {
963 Ok(Ok(())) => log::debug!("Handler task completed gracefully"),
964 Ok(Err(e)) => log::warn!("Handler task panicked: {e}"),
965 Err(_) => {
966 log::warn!("Handler task did not complete within timeout, aborting");
967 abort_handle.abort();
968 }
969 }
970 }
971 }
972
973 async fn send_cmd(&self, cmd: HandlerCommand) -> AxWsResult<()> {
974 let guard = self.cmd_tx.read().await;
975 guard
976 .send(cmd)
977 .map_err(|e| AxWsClientError::ChannelError(e.to_string()))
978 }
979}
980
981#[cfg(test)]
982mod tests {
983 use rstest::rstest;
984
985 use super::*;
986
987 #[rstest]
988 fn test_effective_level_empty_returns_none() {
989 let sdt = SymbolDataTypes::default();
990 assert_eq!(sdt.effective_level(), None);
991 assert!(sdt.is_empty());
992 }
993
994 #[rstest]
995 fn test_effective_level_book_level_takes_precedence() {
996 let sdt = SymbolDataTypes {
997 book_level: Some(AxMarketDataLevel::Level2),
998 quotes: true,
999 ..Default::default()
1000 };
1001 assert_eq!(sdt.effective_level(), Some(AxMarketDataLevel::Level2));
1002 assert!(!sdt.is_empty());
1003 }
1004
1005 #[rstest]
1006 #[case(true, false, false, false)]
1007 #[case(false, true, false, false)]
1008 #[case(false, false, true, false)]
1009 #[case(false, false, false, true)]
1010 fn test_effective_level_any_flag_returns_level1(
1011 #[case] quotes: bool,
1012 #[case] trades: bool,
1013 #[case] mark_prices: bool,
1014 #[case] instrument_status: bool,
1015 ) {
1016 let sdt = SymbolDataTypes {
1017 quotes,
1018 trades,
1019 mark_prices,
1020 instrument_status,
1021 book_level: None,
1022 };
1023 assert_eq!(sdt.effective_level(), Some(AxMarketDataLevel::Level1));
1024 assert!(!sdt.is_empty());
1025 }
1026}