Skip to main content

nautilus_architect_ax/websocket/data/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Market data WebSocket client for Ax.
17
18use std::{
19    fmt::Debug,
20    sync::{
21        Arc, Mutex,
22        atomic::{AtomicBool, AtomicI64, AtomicU8, Ordering},
23    },
24    time::Duration,
25};
26
27use ahash::AHashSet;
28use arc_swap::ArcSwap;
29use nautilus_common::live::get_runtime;
30use nautilus_core::{AtomicMap, consts::NAUTILUS_USER_AGENT};
31use nautilus_network::{
32    backoff::ExponentialBackoff,
33    mode::ConnectionMode,
34    websocket::{
35        PingHandler, SubscriptionState, TransportBackend, WebSocketClient, WebSocketConfig,
36        channel_message_handler,
37    },
38};
39use ustr::Ustr;
40
41use super::handler::{AxMdWsFeedHandler, HandlerCommand};
42use crate::{
43    common::enums::{AxCandleWidth, AxMarketDataLevel},
44    websocket::messages::AxDataWsMessage,
45};
46
47/// Subscription topic delimiter for Ax.
48const AX_TOPIC_DELIMITER: char = ':';
49
50/// Result type for Ax WebSocket operations.
51pub type AxWsResult<T> = Result<T, AxWsClientError>;
52
53/// Error type for the Ax WebSocket client.
54#[derive(Debug, Clone)]
55pub enum AxWsClientError {
56    /// Transport/connection error.
57    Transport(String),
58    /// Channel send error.
59    ChannelError(String),
60}
61
62impl core::fmt::Display for AxWsClientError {
63    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
64        match self {
65            Self::Transport(msg) => write!(f, "Transport error: {msg}"),
66            Self::ChannelError(msg) => write!(f, "Channel error: {msg}"),
67        }
68    }
69}
70
71impl std::error::Error for AxWsClientError {}
72
73#[derive(Debug, Default, Clone)]
74pub struct SymbolDataTypes {
75    pub quotes: bool,
76    pub trades: bool,
77    pub mark_prices: bool,
78    pub instrument_status: bool,
79    pub book_level: Option<AxMarketDataLevel>,
80}
81
82impl SymbolDataTypes {
83    pub fn effective_level(&self) -> Option<AxMarketDataLevel> {
84        if let Some(level) = self.book_level {
85            return Some(level);
86        }
87
88        if self.quotes || self.trades || self.mark_prices || self.instrument_status {
89            return Some(AxMarketDataLevel::Level1);
90        }
91        None
92    }
93
94    fn is_empty(&self) -> bool {
95        !self.quotes
96            && !self.trades
97            && !self.mark_prices
98            && !self.instrument_status
99            && self.book_level.is_none()
100    }
101}
102
103/// Market data WebSocket client for Ax.
104///
105/// Provides streaming market data including tickers, trades, order books, and candles.
106/// Requires Bearer token authentication obtained via the HTTP `/api/authenticate` endpoint.
107pub struct AxMdWebSocketClient {
108    url: String,
109    heartbeat: Option<u64>,
110    auth_token: Option<String>,
111    connection_mode: Arc<ArcSwap<AtomicU8>>,
112    cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
113    out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<AxDataWsMessage>>>,
114    signal: Arc<AtomicBool>,
115    task_handle: Option<tokio::task::JoinHandle<()>>,
116    subscriptions: SubscriptionState,
117    request_id_counter: Arc<AtomicI64>,
118    subscribe_lock: Arc<tokio::sync::Mutex<()>>,
119    symbol_data_types: Arc<AtomicMap<String, SymbolDataTypes>>,
120    status_invalidations: Arc<Mutex<AHashSet<Ustr>>>,
121    transport_backend: TransportBackend,
122    proxy_url: Option<String>,
123}
124
125impl Debug for AxMdWebSocketClient {
126    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
127        f.debug_struct(stringify!(AxMdWebSocketClient))
128            .field("url", &self.url)
129            .field("heartbeat", &self.heartbeat)
130            .field("confirmed_subscriptions", &self.subscriptions.len())
131            .finish()
132    }
133}
134
135impl Clone for AxMdWebSocketClient {
136    fn clone(&self) -> Self {
137        Self {
138            url: self.url.clone(),
139            heartbeat: self.heartbeat,
140            auth_token: self.auth_token.clone(),
141            connection_mode: Arc::clone(&self.connection_mode),
142            cmd_tx: Arc::clone(&self.cmd_tx),
143            out_rx: None,
144            signal: Arc::clone(&self.signal),
145            task_handle: None,
146            subscriptions: self.subscriptions.clone(),
147            subscribe_lock: Arc::clone(&self.subscribe_lock),
148            request_id_counter: Arc::clone(&self.request_id_counter),
149            symbol_data_types: Arc::clone(&self.symbol_data_types),
150            status_invalidations: Arc::clone(&self.status_invalidations),
151            transport_backend: self.transport_backend,
152            proxy_url: self.proxy_url.clone(),
153        }
154    }
155}
156
157impl AxMdWebSocketClient {
158    /// Creates a new Ax market data WebSocket client.
159    ///
160    /// The `auth_token` is a Bearer token obtained from the HTTP `/api/authenticate` endpoint.
161    #[must_use]
162    pub fn new(
163        url: String,
164        auth_token: String,
165        heartbeat: u64,
166        transport_backend: TransportBackend,
167        proxy_url: Option<String>,
168    ) -> Self {
169        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
170
171        let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
172        let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
173
174        Self {
175            url,
176            heartbeat: Some(heartbeat),
177            auth_token: Some(auth_token),
178            connection_mode,
179            cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
180            out_rx: None,
181            signal: Arc::new(AtomicBool::new(false)),
182            task_handle: None,
183            subscriptions: SubscriptionState::new(AX_TOPIC_DELIMITER),
184            request_id_counter: Arc::new(AtomicI64::new(1)),
185            subscribe_lock: Arc::new(tokio::sync::Mutex::new(())),
186            symbol_data_types: Arc::new(AtomicMap::new()),
187            status_invalidations: Arc::new(Mutex::new(AHashSet::new())),
188            transport_backend,
189            proxy_url,
190        }
191    }
192
193    /// Creates a new Ax market data WebSocket client without authentication.
194    ///
195    /// Use [`set_auth_token`](Self::set_auth_token) to set the token before connecting.
196    #[must_use]
197    pub fn without_auth(
198        url: String,
199        heartbeat: u64,
200        transport_backend: TransportBackend,
201        proxy_url: Option<String>,
202    ) -> Self {
203        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
204
205        let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
206        let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
207
208        Self {
209            url,
210            heartbeat: Some(heartbeat),
211            auth_token: None,
212            connection_mode,
213            cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
214            out_rx: None,
215            signal: Arc::new(AtomicBool::new(false)),
216            task_handle: None,
217            subscriptions: SubscriptionState::new(AX_TOPIC_DELIMITER),
218            request_id_counter: Arc::new(AtomicI64::new(1)),
219            subscribe_lock: Arc::new(tokio::sync::Mutex::new(())),
220            symbol_data_types: Arc::new(AtomicMap::new()),
221            status_invalidations: Arc::new(Mutex::new(AHashSet::new())),
222            transport_backend,
223            proxy_url,
224        }
225    }
226
227    /// Returns the WebSocket URL.
228    #[must_use]
229    pub fn url(&self) -> &str {
230        &self.url
231    }
232
233    /// Sets the authentication token for subsequent connections.
234    ///
235    /// This should be called before `connect()` if authentication is required.
236    pub fn set_auth_token(&mut self, token: String) {
237        self.auth_token = Some(token);
238    }
239
240    /// Returns whether the client is currently connected and active.
241    #[must_use]
242    pub fn is_active(&self) -> bool {
243        let connection_mode_arc = self.connection_mode.load();
244        ConnectionMode::from_atomic(&connection_mode_arc).is_active()
245            && !self.signal.load(Ordering::Acquire)
246    }
247
248    /// Returns whether the client is closed.
249    #[must_use]
250    pub fn is_closed(&self) -> bool {
251        let connection_mode_arc = self.connection_mode.load();
252        ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
253            || self.signal.load(Ordering::Acquire)
254    }
255
256    /// Returns the number of confirmed subscriptions.
257    #[must_use]
258    pub fn subscription_count(&self) -> usize {
259        self.subscriptions.len()
260    }
261
262    /// Returns the symbol data types map (shared with handler).
263    #[must_use]
264    pub fn symbol_data_types(&self) -> Arc<AtomicMap<String, SymbolDataTypes>> {
265        Arc::clone(&self.symbol_data_types)
266    }
267
268    /// Returns the shared set of symbols whose instrument status cache has been invalidated.
269    pub fn status_invalidations(&self) -> Arc<Mutex<AHashSet<Ustr>>> {
270        Arc::clone(&self.status_invalidations)
271    }
272
273    fn next_request_id(&self) -> i64 {
274        self.request_id_counter.fetch_add(1, Ordering::Relaxed)
275    }
276
277    fn is_subscribed_topic(&self, topic: &str) -> bool {
278        let (channel, symbol) = topic
279            .split_once(AX_TOPIC_DELIMITER)
280            .map_or((topic, None), |(c, s)| (c, Some(s)));
281        let channel_ustr = Ustr::from(channel);
282        let symbol_ustr = symbol.map_or_else(|| Ustr::from(""), Ustr::from);
283        self.subscriptions
284            .is_subscribed(&channel_ustr, &symbol_ustr)
285    }
286
287    /// Establishes the WebSocket connection.
288    ///
289    /// # Errors
290    ///
291    pub async fn connect(&mut self) -> AxWsResult<()> {
292        const MAX_RETRIES: u32 = 5;
293        const CONNECTION_TIMEOUT_SECS: u64 = 10;
294
295        self.signal.store(false, Ordering::Release);
296
297        let (raw_handler, raw_rx) = channel_message_handler();
298
299        // No-op: ping responses are handled internally by the WebSocketClient
300        let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {});
301
302        let mut headers = vec![("User-Agent".to_string(), NAUTILUS_USER_AGENT.to_string())];
303
304        if let Some(ref token) = self.auth_token {
305            headers.push(("Authorization".to_string(), format!("Bearer {token}")));
306        }
307
308        let config = WebSocketConfig {
309            url: self.url.clone(),
310            headers,
311            heartbeat: self.heartbeat,
312            heartbeat_msg: None, // Ax server sends heartbeats
313            reconnect_timeout_ms: Some(5_000),
314            reconnect_delay_initial_ms: Some(500),
315            reconnect_delay_max_ms: Some(5_000),
316            reconnect_backoff_factor: Some(1.5),
317            reconnect_jitter_ms: Some(250),
318            reconnect_max_attempts: None,
319            idle_timeout_ms: None,
320            backend: self.transport_backend,
321            proxy_url: self.proxy_url.clone(),
322        };
323
324        // Retry initial connection with exponential backoff
325        let mut backoff = ExponentialBackoff::new(
326            Duration::from_millis(500),
327            Duration::from_millis(5000),
328            2.0,
329            250,
330            false,
331        )
332        .map_err(|e| AxWsClientError::Transport(e.to_string()))?;
333
334        let mut last_error: String;
335        let mut attempt = 0;
336
337        let client = loop {
338            attempt += 1;
339
340            match tokio::time::timeout(
341                Duration::from_secs(CONNECTION_TIMEOUT_SECS),
342                WebSocketClient::connect(
343                    config.clone(),
344                    Some(raw_handler.clone()),
345                    Some(ping_handler.clone()),
346                    None,
347                    vec![],
348                    None,
349                ),
350            )
351            .await
352            {
353                Ok(Ok(client)) => {
354                    if attempt > 1 {
355                        log::info!("WebSocket connection established after {attempt} attempts");
356                    }
357                    break client;
358                }
359                Ok(Err(e)) => {
360                    last_error = e.to_string();
361                    log::warn!(
362                        "WebSocket connection attempt failed: attempt={attempt}/{MAX_RETRIES}, url={}, error={last_error}",
363                        self.url
364                    );
365                }
366                Err(_) => {
367                    last_error = format!("Connection timeout after {CONNECTION_TIMEOUT_SECS}s");
368                    log::warn!(
369                        "WebSocket connection attempt timed out: attempt={attempt}/{MAX_RETRIES}, url={}",
370                        self.url
371                    );
372                }
373            }
374
375            if attempt >= MAX_RETRIES {
376                return Err(AxWsClientError::Transport(format!(
377                    "Failed to connect to {} after {MAX_RETRIES} attempts: {}",
378                    self.url,
379                    if last_error.is_empty() {
380                        "unknown error"
381                    } else {
382                        &last_error
383                    }
384                )));
385            }
386
387            let delay = backoff.next_duration();
388            log::debug!(
389                "Retrying in {delay:?} (attempt {}/{MAX_RETRIES})",
390                attempt + 1
391            );
392            tokio::time::sleep(delay).await;
393        };
394
395        self.connection_mode.store(client.connection_mode_atomic());
396
397        let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<AxDataWsMessage>();
398        self.out_rx = Some(Arc::new(out_rx));
399
400        let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
401        *self.cmd_tx.write().await = cmd_tx.clone();
402
403        self.send_cmd(HandlerCommand::SetClient(client)).await?;
404
405        let signal = Arc::clone(&self.signal);
406        let subscriptions = self.subscriptions.clone();
407
408        let stream_handle = get_runtime().spawn(async move {
409            let mut handler =
410                AxMdWsFeedHandler::new(signal.clone(), cmd_rx, raw_rx, subscriptions.clone());
411
412            while let Some(msg) = handler.next().await {
413                if matches!(msg, AxDataWsMessage::Reconnected) {
414                    log::info!("WebSocket reconnected, subscriptions will be replayed");
415                }
416
417                if out_tx.send(msg).is_err() {
418                    log::debug!("Output channel closed");
419                    break;
420                }
421            }
422
423            log::debug!("Handler loop exited");
424        });
425
426        self.task_handle = Some(stream_handle);
427
428        Ok(())
429    }
430
431    /// Subscribes to order book deltas for a symbol.
432    ///
433    /// Uses reference counting so the underlying AX subscription is only
434    /// removed when all data types have been unsubscribed.
435    ///
436    /// # Errors
437    ///
438    /// Returns an error if the subscription command cannot be sent.
439    pub async fn subscribe_book_deltas(
440        &self,
441        symbol: &str,
442        level: AxMarketDataLevel,
443    ) -> AxWsResult<()> {
444        let _guard = self.subscribe_lock.lock().await;
445
446        let current = self
447            .symbol_data_types
448            .load()
449            .get(symbol)
450            .cloned()
451            .unwrap_or_default();
452
453        // AX allows only one subscription per symbol, skip if book already subscribed
454        if current.book_level.is_some() {
455            log::debug!("Book deltas already subscribed for {symbol}, skipping");
456            return Ok(());
457        }
458
459        let old_level = current.effective_level();
460        let mut next = current.clone();
461        next.book_level = Some(level);
462        let new_level = next.effective_level();
463
464        self.update_data_subscription(symbol, old_level, new_level)
465            .await?;
466
467        self.symbol_data_types.rcu(|m| {
468            let entry = m.entry(symbol.to_string()).or_default();
469            entry.book_level = Some(level);
470        });
471
472        Ok(())
473    }
474
475    /// Subscribes to quote data for a symbol.
476    ///
477    /// Uses reference counting so the underlying AX subscription is only
478    /// removed when all data types have been unsubscribed.
479    ///
480    /// # Errors
481    ///
482    /// Returns an error if the subscription command cannot be sent.
483    pub async fn subscribe_quotes(&self, symbol: &str) -> AxWsResult<()> {
484        let _guard = self.subscribe_lock.lock().await;
485
486        let current = self
487            .symbol_data_types
488            .load()
489            .get(symbol)
490            .cloned()
491            .unwrap_or_default();
492        let old_level = current.effective_level();
493        let mut next = current.clone();
494        next.quotes = true;
495        let new_level = next.effective_level();
496
497        self.update_data_subscription(symbol, old_level, new_level)
498            .await?;
499
500        self.symbol_data_types.rcu(|m| {
501            m.entry(symbol.to_string()).or_default().quotes = true;
502        });
503
504        Ok(())
505    }
506
507    /// Subscribes to trade data for a symbol.
508    ///
509    /// Uses reference counting so the underlying AX subscription is only
510    /// removed when all data types have been unsubscribed.
511    ///
512    /// # Errors
513    ///
514    /// Returns an error if the subscription command cannot be sent.
515    pub async fn subscribe_trades(&self, symbol: &str) -> AxWsResult<()> {
516        let _guard = self.subscribe_lock.lock().await;
517
518        let current = self
519            .symbol_data_types
520            .load()
521            .get(symbol)
522            .cloned()
523            .unwrap_or_default();
524        let old_level = current.effective_level();
525        let mut next = current.clone();
526        next.trades = true;
527        let new_level = next.effective_level();
528
529        self.update_data_subscription(symbol, old_level, new_level)
530            .await?;
531
532        self.symbol_data_types.rcu(|m| {
533            m.entry(symbol.to_string()).or_default().trades = true;
534        });
535
536        Ok(())
537    }
538
539    /// Unsubscribes from order book deltas for a symbol.
540    ///
541    /// The underlying AX subscription is only removed when all data types
542    /// (quotes, trades, book) have been unsubscribed.
543    ///
544    /// # Errors
545    ///
546    /// Returns an error if the unsubscribe command cannot be sent.
547    pub async fn unsubscribe_book_deltas(&self, symbol: &str) -> AxWsResult<()> {
548        let _guard = self.subscribe_lock.lock().await;
549
550        let Some(current) = self.symbol_data_types.load().get(symbol).cloned() else {
551            log::debug!("Symbol {symbol} not subscribed, skipping unsubscribe book deltas");
552            return Ok(());
553        };
554        let old_level = current.effective_level();
555        let mut next = current.clone();
556        next.book_level = None;
557        let new_level = next.effective_level();
558
559        self.update_data_subscription(symbol, old_level, new_level)
560            .await?;
561
562        self.symbol_data_types.rcu(|m| {
563            if let Some(entry) = m.get_mut(symbol) {
564                entry.book_level = None;
565                if entry.is_empty() {
566                    m.remove(symbol);
567                }
568            }
569        });
570
571        Ok(())
572    }
573
574    /// Unsubscribes from quote data for a symbol.
575    ///
576    /// The underlying AX subscription is only removed when all data types
577    /// (quotes, trades, book) have been unsubscribed.
578    ///
579    /// # Errors
580    ///
581    /// Returns an error if the unsubscribe command cannot be sent.
582    pub async fn unsubscribe_quotes(&self, symbol: &str) -> AxWsResult<()> {
583        let _guard = self.subscribe_lock.lock().await;
584
585        let Some(current) = self.symbol_data_types.load().get(symbol).cloned() else {
586            log::debug!("Symbol {symbol} not subscribed, skipping unsubscribe quotes");
587            return Ok(());
588        };
589        let old_level = current.effective_level();
590        let mut next = current.clone();
591        next.quotes = false;
592        let new_level = next.effective_level();
593
594        self.update_data_subscription(symbol, old_level, new_level)
595            .await?;
596
597        self.symbol_data_types.rcu(|m| {
598            if let Some(entry) = m.get_mut(symbol) {
599                entry.quotes = false;
600                if entry.is_empty() {
601                    m.remove(symbol);
602                }
603            }
604        });
605
606        Ok(())
607    }
608
609    /// Unsubscribes from trade data for a symbol.
610    ///
611    /// The underlying AX subscription is only removed when all data types
612    /// (quotes, trades, book) have been unsubscribed.
613    ///
614    /// # Errors
615    ///
616    /// Returns an error if the unsubscribe command cannot be sent.
617    pub async fn unsubscribe_trades(&self, symbol: &str) -> AxWsResult<()> {
618        let _guard = self.subscribe_lock.lock().await;
619
620        let Some(current) = self.symbol_data_types.load().get(symbol).cloned() else {
621            log::debug!("Symbol {symbol} not subscribed, skipping unsubscribe trades");
622            return Ok(());
623        };
624        let old_level = current.effective_level();
625        let mut next = current.clone();
626        next.trades = false;
627        let new_level = next.effective_level();
628
629        self.update_data_subscription(symbol, old_level, new_level)
630            .await?;
631
632        self.symbol_data_types.rcu(|m| {
633            if let Some(entry) = m.get_mut(symbol) {
634                entry.trades = false;
635                if entry.is_empty() {
636                    m.remove(symbol);
637                }
638            }
639        });
640
641        Ok(())
642    }
643
644    /// Subscribes to mark prices for a symbol.
645    ///
646    /// Ensures at least an L1 subscription so that ticker messages
647    /// (which carry the mark price field) are received.
648    ///
649    /// # Errors
650    ///
651    /// Returns an error if the subscription command cannot be sent.
652    pub async fn subscribe_mark_prices(&self, symbol: &str) -> AxWsResult<()> {
653        let _guard = self.subscribe_lock.lock().await;
654
655        let current = self
656            .symbol_data_types
657            .load()
658            .get(symbol)
659            .cloned()
660            .unwrap_or_default();
661        let old_level = current.effective_level();
662        let mut next = current.clone();
663        next.mark_prices = true;
664        let new_level = next.effective_level();
665
666        self.update_data_subscription(symbol, old_level, new_level)
667            .await?;
668
669        self.symbol_data_types.rcu(|m| {
670            m.entry(symbol.to_string()).or_default().mark_prices = true;
671        });
672
673        Ok(())
674    }
675
676    /// Unsubscribes from mark prices for a symbol.
677    ///
678    /// The underlying AX subscription is only removed when all data types
679    /// have been unsubscribed.
680    ///
681    /// # Errors
682    ///
683    /// Returns an error if the unsubscribe command cannot be sent.
684    pub async fn unsubscribe_mark_prices(&self, symbol: &str) -> AxWsResult<()> {
685        let _guard = self.subscribe_lock.lock().await;
686
687        let Some(current) = self.symbol_data_types.load().get(symbol).cloned() else {
688            log::debug!("Symbol {symbol} not subscribed, skipping unsubscribe mark prices");
689            return Ok(());
690        };
691        let old_level = current.effective_level();
692        let mut next = current.clone();
693        next.mark_prices = false;
694        let new_level = next.effective_level();
695
696        self.update_data_subscription(symbol, old_level, new_level)
697            .await?;
698
699        self.symbol_data_types.rcu(|m| {
700            if let Some(entry) = m.get_mut(symbol) {
701                entry.mark_prices = false;
702                if entry.is_empty() {
703                    m.remove(symbol);
704                }
705            }
706        });
707
708        Ok(())
709    }
710
711    /// Subscribes to instrument status for a symbol.
712    ///
713    /// Ensures at least an L1 subscription so that ticker messages
714    /// (which carry the instrument state field) are received.
715    ///
716    /// # Errors
717    ///
718    /// Returns an error if the subscription command cannot be sent.
719    pub async fn subscribe_instrument_status(&self, symbol: &str) -> AxWsResult<()> {
720        let _guard = self.subscribe_lock.lock().await;
721
722        let current = self
723            .symbol_data_types
724            .load()
725            .get(symbol)
726            .cloned()
727            .unwrap_or_default();
728        let old_level = current.effective_level();
729        let mut next = current.clone();
730        next.instrument_status = true;
731        let new_level = next.effective_level();
732
733        self.update_data_subscription(symbol, old_level, new_level)
734            .await?;
735
736        self.symbol_data_types.rcu(|m| {
737            m.entry(symbol.to_string()).or_default().instrument_status = true;
738        });
739
740        Ok(())
741    }
742
743    /// Unsubscribes from instrument status for a symbol.
744    ///
745    /// The underlying AX subscription is only removed when all data types
746    /// have been unsubscribed.
747    ///
748    /// # Errors
749    ///
750    /// Returns an error if the unsubscribe command cannot be sent.
751    pub async fn unsubscribe_instrument_status(&self, symbol: &str) -> AxWsResult<()> {
752        let _guard = self.subscribe_lock.lock().await;
753
754        let Some(current) = self.symbol_data_types.load().get(symbol).cloned() else {
755            log::debug!("Symbol {symbol} not subscribed, skipping unsubscribe instrument status");
756            return Ok(());
757        };
758        let old_level = current.effective_level();
759        let mut next = current.clone();
760        next.instrument_status = false;
761        let new_level = next.effective_level();
762
763        self.update_data_subscription(symbol, old_level, new_level)
764            .await?;
765
766        self.symbol_data_types.rcu(|m| {
767            if let Some(entry) = m.get_mut(symbol) {
768                entry.instrument_status = false;
769                if entry.is_empty() {
770                    m.remove(symbol);
771                }
772            }
773        });
774
775        if let Ok(mut invalidations) = self.status_invalidations.lock() {
776            invalidations.insert(Ustr::from(symbol));
777        }
778
779        Ok(())
780    }
781
782    async fn update_data_subscription(
783        &self,
784        symbol: &str,
785        old_level: Option<AxMarketDataLevel>,
786        new_level: Option<AxMarketDataLevel>,
787    ) -> AxWsResult<()> {
788        if old_level == new_level {
789            return Ok(());
790        }
791
792        match (old_level, new_level) {
793            (None, Some(level)) => {
794                log::debug!("Subscribing {symbol} at {level:?}");
795                self.send_subscribe(symbol, level).await
796            }
797            (Some(_), None) => {
798                log::debug!("Unsubscribing {symbol} (no remaining data types)");
799                self.send_unsubscribe(symbol).await
800            }
801            (Some(old), Some(new)) => {
802                log::debug!("Resubscribing {symbol}: {old:?} -> {new:?}");
803                self.send_unsubscribe(symbol).await?;
804                if let Err(e) = self.send_subscribe(symbol, new).await {
805                    log::warn!("Resubscribe failed for {symbol} at {new:?}: {e}");
806                    if let Err(restore_err) = self.send_subscribe(symbol, old).await {
807                        // Channel dead, mark old topic for reconnection replay
808                        log::error!(
809                            "Failed to restore {symbol} at {old:?}: {restore_err}, \
810                             reconnection required"
811                        );
812                        let old_topic = format!("{symbol}:{old:?}");
813                        self.subscriptions.mark_subscribe(&old_topic);
814                    }
815                    return Err(e);
816                }
817                Ok(())
818            }
819            (None, None) => Ok(()),
820        }
821    }
822
823    async fn send_subscribe(&self, symbol: &str, level: AxMarketDataLevel) -> AxWsResult<()> {
824        let topic = format!("{symbol}:{level:?}");
825        let request_id = self.next_request_id();
826
827        self.subscriptions.mark_subscribe(&topic);
828
829        if let Err(e) = self
830            .send_cmd(HandlerCommand::Subscribe {
831                request_id,
832                symbol: Ustr::from(symbol),
833                level,
834            })
835            .await
836        {
837            self.subscriptions.mark_unsubscribe(&topic);
838            return Err(e);
839        }
840
841        Ok(())
842    }
843
844    async fn send_unsubscribe(&self, symbol: &str) -> AxWsResult<()> {
845        let request_id = self.next_request_id();
846
847        self.send_cmd(HandlerCommand::Unsubscribe {
848            request_id,
849            symbol: Ustr::from(symbol),
850        })
851        .await?;
852
853        for level in [
854            AxMarketDataLevel::Level1,
855            AxMarketDataLevel::Level2,
856            AxMarketDataLevel::Level3,
857        ] {
858            let topic = format!("{symbol}:{level:?}");
859            self.subscriptions.mark_unsubscribe(&topic);
860        }
861
862        Ok(())
863    }
864
865    /// Subscribes to candle data for a symbol.
866    ///
867    /// Skips sending if already subscribed or subscription is pending.
868    ///
869    /// # Errors
870    ///
871    /// Returns an error if the subscription command cannot be sent.
872    pub async fn subscribe_candles(&self, symbol: &str, width: AxCandleWidth) -> AxWsResult<()> {
873        let _guard = self.subscribe_lock.lock().await;
874        let topic = format!("candles:{symbol}:{width:?}");
875
876        // Skip if already subscribed or pending
877        if self.is_subscribed_topic(&topic) {
878            log::debug!("Already subscribed to {topic}, skipping");
879            return Ok(());
880        }
881
882        let request_id = self.next_request_id();
883
884        // Mark pending BEFORE sending to prevent race conditions with concurrent subscribes
885        self.subscriptions.mark_subscribe(&topic);
886
887        if let Err(e) = self
888            .send_cmd(HandlerCommand::SubscribeCandles {
889                request_id,
890                symbol: Ustr::from(symbol),
891                width,
892            })
893            .await
894        {
895            // Rollback pending state on send failure
896            self.subscriptions.mark_unsubscribe(&topic);
897            return Err(e);
898        }
899
900        Ok(())
901    }
902
903    /// Unsubscribes from candle data for a symbol.
904    ///
905    /// # Errors
906    ///
907    /// Returns an error if the unsubscribe command cannot be sent.
908    pub async fn unsubscribe_candles(&self, symbol: &str, width: AxCandleWidth) -> AxWsResult<()> {
909        let _guard = self.subscribe_lock.lock().await;
910        let request_id = self.next_request_id();
911        let topic = format!("candles:{symbol}:{width:?}");
912
913        self.subscriptions.mark_unsubscribe(&topic);
914
915        self.send_cmd(HandlerCommand::UnsubscribeCandles {
916            request_id,
917            symbol: Ustr::from(symbol),
918            width,
919        })
920        .await
921    }
922
923    /// Returns a stream of WebSocket messages.
924    ///
925    /// # Panics
926    ///
927    /// Panics if called before `connect()` or if the stream has already been taken.
928    pub fn stream(&mut self) -> impl futures_util::Stream<Item = AxDataWsMessage> + 'static {
929        let rx = self
930            .out_rx
931            .take()
932            .expect("Stream receiver already taken or client not connected - stream() can only be called once");
933        let mut rx = Arc::try_unwrap(rx).expect(
934            "Cannot take ownership of stream - client was cloned and other references exist",
935        );
936        async_stream::stream! {
937            while let Some(msg) = rx.recv().await {
938                yield msg;
939            }
940        }
941    }
942
943    /// Disconnects the WebSocket connection gracefully.
944    pub async fn disconnect(&self) {
945        log::debug!("Disconnecting WebSocket");
946        let _ = self.send_cmd(HandlerCommand::Disconnect).await;
947    }
948
949    /// Closes the WebSocket connection and cleans up resources.
950    pub async fn close(&mut self) {
951        log::debug!("Closing WebSocket client");
952
953        // Send disconnect first to allow graceful cleanup before signal
954        let _ = self.send_cmd(HandlerCommand::Disconnect).await;
955        tokio::time::sleep(Duration::from_millis(50)).await;
956        self.signal.store(true, Ordering::Release);
957
958        if let Some(handle) = self.task_handle.take() {
959            const CLOSE_TIMEOUT: Duration = Duration::from_secs(2);
960            let abort_handle = handle.abort_handle();
961
962            match tokio::time::timeout(CLOSE_TIMEOUT, handle).await {
963                Ok(Ok(())) => log::debug!("Handler task completed gracefully"),
964                Ok(Err(e)) => log::warn!("Handler task panicked: {e}"),
965                Err(_) => {
966                    log::warn!("Handler task did not complete within timeout, aborting");
967                    abort_handle.abort();
968                }
969            }
970        }
971    }
972
973    async fn send_cmd(&self, cmd: HandlerCommand) -> AxWsResult<()> {
974        let guard = self.cmd_tx.read().await;
975        guard
976            .send(cmd)
977            .map_err(|e| AxWsClientError::ChannelError(e.to_string()))
978    }
979}
980
981#[cfg(test)]
982mod tests {
983    use rstest::rstest;
984
985    use super::*;
986
987    #[rstest]
988    fn test_effective_level_empty_returns_none() {
989        let sdt = SymbolDataTypes::default();
990        assert_eq!(sdt.effective_level(), None);
991        assert!(sdt.is_empty());
992    }
993
994    #[rstest]
995    fn test_effective_level_book_level_takes_precedence() {
996        let sdt = SymbolDataTypes {
997            book_level: Some(AxMarketDataLevel::Level2),
998            quotes: true,
999            ..Default::default()
1000        };
1001        assert_eq!(sdt.effective_level(), Some(AxMarketDataLevel::Level2));
1002        assert!(!sdt.is_empty());
1003    }
1004
1005    #[rstest]
1006    #[case(true, false, false, false)]
1007    #[case(false, true, false, false)]
1008    #[case(false, false, true, false)]
1009    #[case(false, false, false, true)]
1010    fn test_effective_level_any_flag_returns_level1(
1011        #[case] quotes: bool,
1012        #[case] trades: bool,
1013        #[case] mark_prices: bool,
1014        #[case] instrument_status: bool,
1015    ) {
1016        let sdt = SymbolDataTypes {
1017            quotes,
1018            trades,
1019            mark_prices,
1020            instrument_status,
1021            book_level: None,
1022        };
1023        assert_eq!(sdt.effective_level(), Some(AxMarketDataLevel::Level1));
1024        assert!(!sdt.is_empty());
1025    }
1026}