Skip to main content

nautilus_network/websocket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! WebSocket client implementation with automatic reconnection.
17//!
18//! This module contains the core WebSocket client implementation including:
19//! - Connection management with automatic reconnection.
20//! - Split read/write architecture with separate tasks.
21//! - Unbounded channels on latency-sensitive paths.
22//! - Event-driven state notification via `Notify` for immediate wakeup on transitions.
23//! - Heartbeat support.
24//! - Rate limiting integration.
25
26use std::{
27    collections::VecDeque,
28    fmt::Debug,
29    sync::{
30        Arc, OnceLock,
31        atomic::{AtomicBool, AtomicU8, Ordering},
32    },
33    time::Duration,
34};
35
36use futures_util::{SinkExt, StreamExt};
37use http::HeaderName;
38use nautilus_core::CleanDrop;
39use nautilus_cryptography::providers::install_cryptographic_provider;
40#[cfg(any(feature = "turmoil", feature = "transport-sockudo"))]
41use rustls::ClientConfig;
42#[cfg(feature = "transport-sockudo")]
43use sockudo_ws::{
44    Config as SockudoConfig, Http1, Role, Stream as SockudoStream,
45    WebSocketStream as SockudoWebSocketStream,
46};
47#[cfg(feature = "transport-sockudo")]
48use tokio::io::{AsyncRead, AsyncWrite};
49#[cfg(any(feature = "turmoil", feature = "transport-sockudo"))]
50use tokio_rustls::TlsConnector;
51#[cfg(feature = "turmoil")]
52use tokio_tungstenite::MaybeTlsStream;
53#[cfg(feature = "turmoil")]
54use tokio_tungstenite::client_async;
55#[cfg(not(feature = "turmoil"))]
56use tokio_tungstenite::connect_async_with_config;
57use tokio_tungstenite::tungstenite::{client::IntoClientRequest, http::HeaderValue};
58use ustr::Ustr;
59
60#[cfg(not(feature = "turmoil"))]
61use super::proxy::{ProxiedStream, ProxyKind, WsTarget, tunnel_via_proxy};
62use super::{
63    auth::{AuthState, AuthTracker},
64    config::{TransportBackend, WebSocketConfig},
65    consts::{
66        CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
67        GRACEFUL_SHUTDOWN_TIMEOUT_SECS,
68    },
69    types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
70};
71#[cfg(feature = "turmoil")]
72use crate::net::TcpConnector;
73#[cfg(feature = "transport-sockudo")]
74use crate::net::TcpStream;
75#[cfg(feature = "transport-sockudo")]
76use crate::transport::sockudo::{
77    PrefixedIo, SockudoTransport, client_handshake_with_headers, validate_extra_headers,
78};
79use crate::{
80    RECONNECTED,
81    backoff::ExponentialBackoff,
82    dst,
83    error::SendError,
84    logging::{log_task_aborted, log_task_started, log_task_stopped},
85    mode::ConnectionMode,
86    ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
87    transport::{BoxedWsTransport, Message, TransportError, tungstenite::TungsteniteTransport},
88};
89
90/// `WebSocketClient` connects to a websocket server to read and send messages.
91///
92/// The client is opinionated about how messages are read and written. It
93/// assumes that data can only have one reader but multiple writers.
94///
95/// The client splits the connection into read and write halves. It moves
96/// the read half into a tokio task which keeps receiving messages from the
97/// server and calls a handler - a Python function that takes the data
98/// as its parameter. It stores the write half in the struct wrapped
99/// with an Arc Mutex. This way the client struct can be used to write
100/// data to the server from multiple scopes/tasks.
101///
102/// The client also maintains a heartbeat if given a duration in seconds.
103/// It's preferable to set the duration slightly lower - heartbeat more
104/// frequently - than the required amount.
105pub struct WebSocketClientInner {
106    config: WebSocketConfig,
107    /// The function to handle incoming messages (stored separately from config).
108    message_handler: Option<MessageHandler>,
109    /// The handler for incoming pings (stored separately from config).
110    ping_handler: Option<PingHandler>,
111    read_task: Option<tokio::task::JoinHandle<()>>,
112    write_task: tokio::task::JoinHandle<()>,
113    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
114    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
115    connection_mode: Arc<AtomicU8>,
116    state_notify: Arc<tokio::sync::Notify>,
117    reconnect_timeout: Duration,
118    backoff: ExponentialBackoff,
119    /// True if this is a stream-based client (created via `connect_stream`).
120    /// Stream-based clients disable auto-reconnect because the reader is
121    /// owned by the caller and cannot be replaced during reconnection.
122    is_stream_mode: bool,
123    /// Maximum number of reconnection attempts before giving up (None = unlimited).
124    reconnect_max_attempts: Option<u32>,
125    /// Current count of consecutive reconnection attempts.
126    reconnection_attempt_count: u32,
127    /// Shared auth tracker invalidated on connection drops.
128    auth_tracker: Arc<OnceLock<AuthTracker>>,
129    /// Controls whether buffered replay waits for the next authenticated session.
130    reconnect_buffer_waits_for_auth: Arc<AtomicBool>,
131}
132
133enum ReconnectBufferAction {
134    Drain,
135    Wait,
136    Discard,
137}
138
139impl WebSocketClientInner {
140    /// Create an inner websocket client with an existing writer.
141    ///
142    /// This is used for stream mode where the reader is owned by the caller.
143    ///
144    /// # Errors
145    ///
146    /// Returns an error if the exponential backoff configuration is invalid.
147    #[expect(
148        clippy::unused_async,
149        reason = "async signature for consistency with connect-based constructors"
150    )]
151    pub async fn new_with_writer(
152        config: WebSocketConfig,
153        writer: MessageWriter,
154    ) -> Result<Self, TransportError> {
155        install_cryptographic_provider();
156
157        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
158        let state_notify = Arc::new(tokio::sync::Notify::new());
159
160        // Note: We don't spawn a read task here since the reader is handled externally
161        let read_task = None;
162
163        // Stream mode ignores reconnect settings, use harmless defaults
164        let backoff = ExponentialBackoff::new(
165            Duration::from_secs(2),
166            Duration::from_secs(30),
167            1.5,
168            100,
169            true,
170        )
171        .map_err(|e| {
172            TransportError::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
173        })?;
174
175        let auth_tracker = Arc::new(OnceLock::new());
176        let reconnect_buffer_waits_for_auth = Arc::new(AtomicBool::new(false));
177
178        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
179        let write_task = Self::spawn_write_task(
180            connection_mode.clone(),
181            state_notify.clone(),
182            writer,
183            writer_rx,
184            Arc::clone(&auth_tracker),
185            Arc::clone(&reconnect_buffer_waits_for_auth),
186        );
187
188        let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
189            Some(Self::spawn_heartbeat_task(
190                connection_mode.clone(),
191                heartbeat_interval,
192                config.heartbeat_msg.clone(),
193                writer_tx.clone(),
194            ))
195        } else {
196            None
197        };
198
199        let reconnect_max_attempts = None; // Stream mode does not reconnect
200        let reconnect_timeout = Duration::from_secs(10);
201
202        Ok(Self {
203            config,
204            message_handler: None, // Stream mode has no handler
205            ping_handler: None,
206            writer_tx,
207            connection_mode,
208            state_notify,
209            reconnect_timeout,
210            heartbeat_task,
211            read_task,
212            write_task,
213            backoff,
214            is_stream_mode: true,
215            reconnect_max_attempts,
216            reconnection_attempt_count: 0,
217            auth_tracker,
218            reconnect_buffer_waits_for_auth,
219        })
220    }
221
222    /// Create an inner websocket client.
223    ///
224    /// # Errors
225    ///
226    /// Returns an error if:
227    /// - The connection to the server fails.
228    /// - The exponential backoff configuration is invalid.
229    pub async fn connect_url(
230        config: WebSocketConfig,
231        message_handler: Option<MessageHandler>,
232        ping_handler: Option<PingHandler>,
233    ) -> Result<Self, TransportError> {
234        install_cryptographic_provider();
235
236        if config.heartbeat == Some(0) {
237            return Err(TransportError::Io(std::io::Error::new(
238                std::io::ErrorKind::InvalidInput,
239                "Heartbeat interval cannot be zero",
240            )));
241        }
242
243        if config.idle_timeout_ms == Some(0) {
244            return Err(TransportError::Io(std::io::Error::new(
245                std::io::ErrorKind::InvalidInput,
246                "Idle timeout cannot be zero",
247            )));
248        }
249
250        // Capture whether we're in stream mode before moving config
251        let is_stream_mode = message_handler.is_none();
252        let reconnect_max_attempts = config.reconnect_max_attempts;
253
254        let (writer, reader) = Box::pin(Self::connect_with_server(
255            &config.url,
256            config.headers.clone(),
257            config.backend,
258            config.proxy_url.as_deref(),
259        ))
260        .await?;
261
262        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
263        let state_notify = Arc::new(tokio::sync::Notify::new());
264
265        let read_task = if message_handler.is_some() {
266            Some(Self::spawn_message_handler_task(
267                connection_mode.clone(),
268                state_notify.clone(),
269                reader,
270                message_handler.as_ref(),
271                ping_handler.as_ref(),
272                config.idle_timeout_ms,
273            ))
274        } else {
275            None
276        };
277
278        let auth_tracker = Arc::new(OnceLock::new());
279        let reconnect_buffer_waits_for_auth = Arc::new(AtomicBool::new(false));
280
281        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
282        let write_task = Self::spawn_write_task(
283            connection_mode.clone(),
284            state_notify.clone(),
285            writer,
286            writer_rx,
287            Arc::clone(&auth_tracker),
288            Arc::clone(&reconnect_buffer_waits_for_auth),
289        );
290
291        // Optionally spawn a heartbeat task to periodically ping server
292        let heartbeat_task = config.heartbeat.map(|heartbeat_secs| {
293            Self::spawn_heartbeat_task(
294                connection_mode.clone(),
295                heartbeat_secs,
296                config.heartbeat_msg.clone(),
297                writer_tx.clone(),
298            )
299        });
300
301        let reconnect_timeout =
302            Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10_000));
303        let backoff = ExponentialBackoff::new(
304            Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
305            Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
306            config.reconnect_backoff_factor.unwrap_or(1.5),
307            config.reconnect_jitter_ms.unwrap_or(100),
308            true, // immediate-first
309        )
310        .map_err(|e| {
311            TransportError::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
312        })?;
313
314        Ok(Self {
315            config,
316            message_handler,
317            ping_handler,
318            read_task,
319            write_task,
320            writer_tx,
321            heartbeat_task,
322            connection_mode,
323            state_notify,
324            reconnect_timeout,
325            backoff,
326            // Set stream mode when no message handler (reader not managed by client)
327            is_stream_mode,
328            reconnect_max_attempts,
329            reconnection_attempt_count: 0,
330            auth_tracker,
331            reconnect_buffer_waits_for_auth,
332        })
333    }
334
335    /// Connect to the server and return the split halves of the active transport.
336    ///
337    /// Dispatches on `backend` to the matching backend helper. The
338    /// [`TransportBackend::Tungstenite`] backend is always available; the
339    /// [`TransportBackend::Sockudo`] requires the `transport-sockudo` Cargo
340    /// feature and uses a custom HTTP/1.1 handshake path for upgrade headers.
341    ///
342    /// When `proxy_url` is `Some`, the Tungstenite backend establishes an HTTP
343    /// `CONNECT` tunnel through the proxy before performing the WebSocket
344    /// handshake. The Sockudo backend does not yet support proxying and will
345    /// return an error if a proxy URL is supplied.
346    ///
347    /// # Errors
348    ///
349    /// Returns a [`TransportError`] if the URL is invalid, headers fail to
350    /// parse, the TCP / TLS layer cannot be established, the proxy refuses
351    /// the tunnel, or the WebSocket handshake is rejected by the peer. When
352    /// the Sockudo backend is selected without the `transport-sockudo`
353    /// feature, returns [`TransportError::Other`].
354    #[inline]
355    pub async fn connect_with_server(
356        url: &str,
357        headers: Vec<(String, String)>,
358        backend: TransportBackend,
359        proxy_url: Option<&str>,
360    ) -> Result<(MessageWriter, MessageReader), TransportError> {
361        match backend {
362            TransportBackend::Tungstenite => match proxy_url {
363                Some(proxy) => {
364                    Box::pin(Self::connect_tungstenite_via_proxy(url, headers, proxy)).await
365                }
366                None => Self::connect_tungstenite(url, headers).await,
367            },
368            TransportBackend::Sockudo => {
369                if proxy_url.is_some() {
370                    return Err(TransportError::Other(
371                        "proxy_url is not supported with the Sockudo backend".to_string(),
372                    ));
373                }
374                #[cfg(feature = "transport-sockudo")]
375                {
376                    Self::connect_sockudo(url, headers).await
377                }
378                #[cfg(not(feature = "transport-sockudo"))]
379                {
380                    Err(TransportError::Other(
381                        "sockudo backend selected but the transport-sockudo \
382                         Cargo feature is not enabled"
383                            .to_string(),
384                    ))
385                }
386            }
387        }
388    }
389
390    /// Connects with the server creating a tokio-tungstenite websocket stream.
391    /// Production version that uses `connect_async_with_config` convenience helper.
392    #[inline]
393    #[cfg(not(feature = "turmoil"))]
394    async fn connect_tungstenite(
395        url: &str,
396        headers: Vec<(String, String)>,
397    ) -> Result<(MessageWriter, MessageReader), TransportError> {
398        let mut request = url.into_client_request().map_err(TransportError::from)?;
399        let req_headers = request.headers_mut();
400
401        for (key, val) in headers {
402            let header_value = HeaderValue::from_str(&val)
403                .map_err(|e| TransportError::Handshake(format!("invalid header value: {e}")))?;
404            let header_name: HeaderName = key
405                .parse()
406                .map_err(|e| TransportError::Handshake(format!("invalid header name: {e}")))?;
407            req_headers.insert(header_name, header_value);
408        }
409
410        let (stream, _resp) = connect_async_with_config(request, None, true)
411            .await
412            .map_err(TransportError::from)?;
413        let transport: BoxedWsTransport = Box::pin(TungsteniteTransport::new(stream));
414        Ok(transport.split())
415    }
416
417    /// Connects via an HTTP `CONNECT` proxy and performs the WebSocket
418    /// handshake over the resulting tunnel.
419    ///
420    /// Recognised but unsupported proxy schemes (currently SOCKS) log a
421    /// warning and fall back to a direct connection so existing REST proxy
422    /// configs remain usable. Only available in production builds; the
423    /// turmoil simulator does not model arbitrary outbound TCP via a proxy.
424    #[inline]
425    #[cfg(not(feature = "turmoil"))]
426    async fn connect_tungstenite_via_proxy(
427        url: &str,
428        headers: Vec<(String, String)>,
429        proxy_url: &str,
430    ) -> Result<(MessageWriter, MessageReader), TransportError> {
431        let proxy = match ProxyKind::parse(proxy_url)? {
432            ProxyKind::Http(target) => target,
433            ProxyKind::Unsupported { scheme } => {
434                log::warn!(
435                    "WebSocket proxy_url scheme '{scheme}' is not yet supported; \
436                     connecting without a WebSocket proxy"
437                );
438                return Self::connect_tungstenite(url, headers).await;
439            }
440        };
441
442        let mut request = url.into_client_request().map_err(TransportError::from)?;
443        let req_headers = request.headers_mut();
444
445        for (key, val) in headers {
446            let header_value = HeaderValue::from_str(&val)
447                .map_err(|e| TransportError::Handshake(format!("invalid header value: {e}")))?;
448            let header_name: HeaderName = key
449                .parse()
450                .map_err(|e| TransportError::Handshake(format!("invalid header name: {e}")))?;
451            req_headers.insert(header_name, header_value);
452        }
453
454        let target = WsTarget::parse(url)?;
455        let stream = tunnel_via_proxy(&target, &proxy).await?;
456
457        // Each ProxiedStream variant carries a distinct concrete stream type,
458        // so we monomorphize the handshake through `proxied_ws_handshake`
459        // rather than duplicating the body four times. The arms are
460        // syntactically identical post-deref, but each call instantiates a
461        // different generic; the `match_same_arms` lint is a false positive
462        // here. The futures are boxed because `client_async` produces a
463        // large state machine.
464        #[allow(clippy::match_same_arms)]
465        let transport: BoxedWsTransport = match stream {
466            ProxiedStream::Plain(tcp) => Box::pin(proxied_ws_handshake(request, tcp)).await?,
467            ProxiedStream::PlainOverTlsProxy(s) => {
468                Box::pin(proxied_ws_handshake(request, *s)).await?
469            }
470            ProxiedStream::Tls(s) => Box::pin(proxied_ws_handshake(request, *s)).await?,
471            ProxiedStream::TlsOverTlsProxy(s) => {
472                Box::pin(proxied_ws_handshake(request, *s)).await?
473            }
474        };
475
476        Ok(transport.split())
477    }
478
479    /// Turmoil simulator variant: HTTP `CONNECT` tunneling is not supported
480    /// under the simulator so any proxy URL is rejected up front.
481    #[inline]
482    #[cfg(feature = "turmoil")]
483    #[expect(
484        clippy::unused_async,
485        reason = "signature mirrors the production variant; both are awaited in the dispatcher"
486    )]
487    async fn connect_tungstenite_via_proxy(
488        _url: &str,
489        _headers: Vec<(String, String)>,
490        _proxy_url: &str,
491    ) -> Result<(MessageWriter, MessageReader), TransportError> {
492        Err(TransportError::Other(
493            "proxy_url is not supported under the turmoil simulator".to_string(),
494        ))
495    }
496
497    /// Connects with the server creating a tokio-tungstenite websocket stream.
498    /// Turmoil version that uses the lower-level `client_async` API with injected stream.
499    #[inline]
500    #[cfg(feature = "turmoil")]
501    async fn connect_tungstenite(
502        url: &str,
503        headers: Vec<(String, String)>,
504    ) -> Result<(MessageWriter, MessageReader), TransportError> {
505        let mut request = url.into_client_request().map_err(TransportError::from)?;
506        let req_headers = request.headers_mut();
507
508        for (key, val) in headers {
509            let header_value = HeaderValue::from_str(&val)
510                .map_err(|e| TransportError::Handshake(format!("invalid header value: {e}")))?;
511            let header_name: HeaderName = key
512                .parse()
513                .map_err(|e| TransportError::Handshake(format!("invalid header name: {e}")))?;
514            req_headers.insert(header_name, header_value);
515        }
516
517        let uri = request.uri();
518        let scheme = uri.scheme_str().unwrap_or("ws");
519        let host = uri
520            .host()
521            .ok_or_else(|| TransportError::InvalidUrl("missing hostname".to_string()))?;
522
523        // Determine port: use explicit port if specified, otherwise default based on scheme
524        let port = uri
525            .port_u16()
526            .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
527
528        let addr = format!("{host}:{port}");
529
530        // Use the connector to get a turmoil-compatible stream
531        let connector = crate::net::RealTcpConnector;
532        let tcp_stream = connector.connect(&addr).await?;
533        if let Err(e) = tcp_stream.set_nodelay(true) {
534            log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
535        }
536
537        // Wrap stream appropriately based on scheme
538        let maybe_tls_stream = if scheme == "wss" {
539            // Build TLS config with webpki roots
540            let mut root_store = rustls::RootCertStore::empty();
541            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
542
543            let config = ClientConfig::builder()
544                .with_root_certificates(root_store)
545                .with_no_client_auth();
546
547            let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
548            let domain = rustls::pki_types::ServerName::try_from(host.to_string())
549                .map_err(|e| TransportError::Tls(format!("Invalid DNS name: {e}")))?;
550
551            let tls_stream = tls_connector
552                .connect(domain, tcp_stream)
553                .await
554                .map_err(TransportError::Io)?;
555            MaybeTlsStream::Rustls(tls_stream)
556        } else {
557            MaybeTlsStream::Plain(tcp_stream)
558        };
559
560        // Use client_async with the stream (plain or TLS)
561        let (stream, _resp) = client_async(request, maybe_tls_stream)
562            .await
563            .map_err(TransportError::from)?;
564        let transport: BoxedWsTransport = Box::pin(TungsteniteTransport::new(stream));
565        Ok(transport.split())
566    }
567
568    /// Connects with the server using the sockudo-ws backend.
569    ///
570    /// Uses a local HTTP/1.1 handshake helper so error logging and stream
571    /// construction stay in our hands regardless of header count.
572    ///
573    /// Under the turmoil simulator, only plaintext `ws://` is supported (the
574    /// simulator does not model TLS), so a `wss://` URL returns
575    /// [`TransportError::Tls`] up front.
576    #[inline]
577    #[cfg(feature = "transport-sockudo")]
578    async fn connect_sockudo(
579        url: &str,
580        headers: Vec<(String, String)>,
581    ) -> Result<(MessageWriter, MessageReader), TransportError> {
582        let target = SockudoTarget::parse(url)?;
583        validate_extra_headers(&headers).map_err(TransportError::from)?;
584
585        #[cfg(feature = "turmoil")]
586        if target.is_tls {
587            return Err(TransportError::Tls(
588                "wss:// is not supported under the turmoil simulator; use ws://".to_string(),
589            ));
590        }
591
592        let tcp_stream = TcpStream::connect((target.host.as_str(), target.port))
593            .await
594            .map_err(TransportError::Io)?;
595
596        if let Err(e) = tcp_stream.set_nodelay(true) {
597            log::warn!("Failed to enable TCP_NODELAY for sockudo client: {e:?}");
598        }
599
600        #[cfg(not(feature = "turmoil"))]
601        if target.is_tls {
602            let mut root_store = rustls::RootCertStore::empty();
603            root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
604            let config = ClientConfig::builder()
605                .with_root_certificates(root_store)
606                .with_no_client_auth();
607            let connector = TlsConnector::from(std::sync::Arc::new(config));
608            let domain = rustls::pki_types::ServerName::try_from(target.host.clone())
609                .map_err(|e| TransportError::Tls(format!("Invalid DNS name: {e}")))?;
610            let tls_stream = connector
611                .connect(domain, tcp_stream)
612                .await
613                .map_err(TransportError::Io)?;
614            return Self::finish_sockudo_handshake(tls_stream, &target, &headers).await;
615        }
616
617        Self::finish_sockudo_handshake(tcp_stream, &target, &headers).await
618    }
619
620    #[cfg(feature = "transport-sockudo")]
621    async fn finish_sockudo_handshake<S>(
622        mut stream: S,
623        target: &SockudoTarget,
624        headers: &[(String, String)],
625    ) -> Result<(MessageWriter, MessageReader), TransportError>
626    where
627        S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
628    {
629        // Use our helper for both paths: uniform error logging, and we own
630        // stream construction since sockudo's high-level client drops the
631        // handshake leftover.
632        let handshake = client_handshake_with_headers(
633            &mut stream,
634            &target.host_header,
635            &target.path,
636            None,
637            headers,
638        )
639        .await
640        .map_err(TransportError::from)?;
641
642        // Reading the HTTP 101 may also read the first WebSocket frame prefix;
643        // replay it only when present so the ordinary path stays unwrapped.
644        let stream = match handshake.leftover {
645            Some(prefix) => SockudoStream::<Http1>::new(PrefixedIo::new(stream, prefix)),
646            None => SockudoStream::<Http1>::new(stream),
647        };
648        let ws = SockudoWebSocketStream::from_raw(stream, Role::Client, SockudoConfig::default());
649        let transport: BoxedWsTransport = Box::pin(SockudoTransport::new(ws));
650        Ok(transport.split())
651    }
652}
653
654/// Complete the WebSocket handshake over a stream that has already been
655/// tunneled through an HTTP `CONNECT` proxy. Generic over the concrete
656/// stream type so the four [`super::proxy::ProxiedStream`] variants share
657/// a single body.
658#[cfg(not(feature = "turmoil"))]
659async fn proxied_ws_handshake<S>(
660    request: tokio_tungstenite::tungstenite::handshake::client::Request,
661    stream: S,
662) -> Result<BoxedWsTransport, TransportError>
663where
664    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
665{
666    let (ws, _resp) = tokio_tungstenite::client_async(request, stream)
667        .await
668        .map_err(TransportError::from)?;
669    Ok(Box::pin(TungsteniteTransport::new(ws)))
670}
671
672/// Parsed components of a `ws://` / `wss://` URL needed by the sockudo backend.
673///
674/// Sockudo's HTTP/1.1 client passes the `host` argument verbatim as the
675/// HTTP `Host:` header, so it must include the explicit port when one is
676/// present in the URL (RFC 7230 section 5.4). The DNS / SNI lookup uses the bare
677/// host without the port.
678#[cfg(feature = "transport-sockudo")]
679#[derive(Debug, PartialEq, Eq)]
680struct SockudoTarget {
681    host: String,
682    /// Value to send as the HTTP `Host:` header. Includes `:port` only when
683    /// the URL specifies a non-default port explicitly.
684    host_header: String,
685    port: u16,
686    path: String,
687    is_tls: bool,
688}
689
690#[cfg(feature = "transport-sockudo")]
691impl SockudoTarget {
692    fn parse(url: &str) -> Result<Self, TransportError> {
693        let parsed =
694            url::Url::parse(url).map_err(|e| TransportError::InvalidUrl(format!("{url}: {e}")))?;
695
696        let scheme = parsed.scheme();
697        let is_tls = match scheme {
698            "ws" => false,
699            "wss" => true,
700            other => {
701                return Err(TransportError::InvalidUrl(format!(
702                    "expected ws:// or wss:// scheme, was {other}"
703                )));
704            }
705        };
706
707        let raw_host = parsed
708            .host_str()
709            .ok_or_else(|| TransportError::InvalidUrl("missing hostname".to_string()))?;
710
711        // url::Url stores IPv6 hosts in their bracketed form (e.g. `[::1]`).
712        // Brackets are correct for the HTTP `Host:` header but invalid for
713        // DNS/TCP and TLS SNI, so we keep two representations: a bracketed
714        // `host_header` for the upgrade, and a bare `host` for socket dialing.
715        let is_bracketed = raw_host.starts_with('[') && raw_host.ends_with(']');
716        let host = if is_bracketed {
717            raw_host[1..raw_host.len() - 1].to_string()
718        } else {
719            raw_host.to_string()
720        };
721
722        let explicit_port = parsed.port();
723        let port = explicit_port.unwrap_or(if is_tls { 443 } else { 80 });
724        let host_header = match explicit_port {
725            Some(p) => format!("{raw_host}:{p}"),
726            None => raw_host.to_string(),
727        };
728
729        let path = if parsed.path().is_empty() {
730            "/".to_string()
731        } else {
732            let mut p = parsed.path().to_string();
733            if let Some(query) = parsed.query() {
734                p.push('?');
735                p.push_str(query);
736            }
737            p
738        };
739
740        Ok(Self {
741            host,
742            host_header,
743            port,
744            path,
745            is_tls,
746        })
747    }
748}
749
750impl WebSocketClientInner {
751    /// Reconnect with server.
752    ///
753    /// Make a new connection with server. Use the new read and write halves
754    /// to update self writer and read and heartbeat tasks.
755    ///
756    /// For stream-based clients (created via `connect_stream`), reconnection is disabled
757    /// because the reader is owned by the caller and cannot be replaced. Stream users
758    /// should handle disconnections by creating a new connection.
759    ///
760    /// # Errors
761    ///
762    /// Returns an error if:
763    /// - The reconnection attempt times out.
764    /// - The connection to the server fails.
765    pub async fn reconnect(&mut self) -> Result<(), TransportError> {
766        log::debug!("Reconnecting");
767
768        if self.is_stream_mode {
769            log::warn!(
770                "Auto-reconnect disabled for stream-based WebSocket client; \
771                stream users must manually reconnect by creating a new connection"
772            );
773            // Transition to CLOSED state to stop reconnection attempts
774            self.connection_mode
775                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
776            return Ok(());
777        }
778
779        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
780            log::debug!("Reconnect aborted due to disconnect state");
781            return Ok(());
782        }
783
784        dst::time::timeout(self.reconnect_timeout, async {
785            // Attempt to connect; abort early if a disconnect was requested
786            let (new_writer, reader) = Self::connect_with_server(
787                &self.config.url,
788                self.config.headers.clone(),
789                self.config.backend,
790                self.config.proxy_url.as_deref(),
791            )
792            .await?;
793
794            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
795                log::debug!("Reconnect aborted mid-flight (after connect)");
796                return Ok(());
797            }
798
799            // Use a oneshot channel to synchronize the writer swap before transitioning
800            // back to ACTIVE. Buffered messages stay in the writer task and replay later.
801            let (tx, rx) = tokio::sync::oneshot::channel();
802            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
803                log::error!("{e}");
804                return Err(TransportError::Io(std::io::Error::new(
805                    std::io::ErrorKind::BrokenPipe,
806                    format!("Failed to send update command: {e}"),
807                )));
808            }
809
810            // Wait for writer to confirm it accepted the new socket
811            match rx.await {
812                Ok(true) => log::debug!("Writer confirmed socket update"),
813                Ok(false) => {
814                    log::warn!("Writer rejected socket update, aborting reconnect");
815                    return Err(TransportError::Io(std::io::Error::other(
816                        "Failed to update reconnection writer",
817                    )));
818                }
819                Err(e) => {
820                    log::error!("Writer dropped update channel: {e}");
821                    return Err(TransportError::Io(std::io::Error::new(
822                        std::io::ErrorKind::BrokenPipe,
823                        "Writer task dropped response channel",
824                    )));
825                }
826            }
827
828            // Delay before closing connection
829            dst::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
830
831            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
832                log::debug!("Reconnect aborted mid-flight (after delay)");
833                return Ok(());
834            }
835
836            if let Some(ref read_task) = self.read_task.take()
837                && !read_task.is_finished()
838            {
839                read_task.abort();
840                log_task_aborted("read");
841            }
842
843            // Atomically transition from Reconnect to Active
844            // This prevents race condition where disconnect could be requested between check and store
845            if self
846                .connection_mode
847                .compare_exchange(
848                    ConnectionMode::Reconnect.as_u8(),
849                    ConnectionMode::Active.as_u8(),
850                    Ordering::SeqCst,
851                    Ordering::SeqCst,
852                )
853                .is_err()
854            {
855                log::debug!("Reconnect aborted (state changed during reconnect)");
856                return Ok(());
857            }
858
859            self.read_task = if self.message_handler.is_some() {
860                Some(Self::spawn_message_handler_task(
861                    self.connection_mode.clone(),
862                    self.state_notify.clone(),
863                    reader,
864                    self.message_handler.as_ref(),
865                    self.ping_handler.as_ref(),
866                    self.config.idle_timeout_ms,
867                ))
868            } else {
869                None
870            };
871
872            log::debug!("Reconnect succeeded");
873            Ok(())
874        })
875        .await
876        .map_err(|_| {
877            TransportError::Io(std::io::Error::new(
878                std::io::ErrorKind::TimedOut,
879                format!(
880                    "reconnection timed out after {}s",
881                    self.reconnect_timeout.as_secs_f64()
882                ),
883            ))
884        })?
885    }
886
887    /// Check if the client is still alive.
888    ///
889    /// Returns `true` if both the read and write tasks are still running.
890    /// There may be some delay between the connection closing and the
891    /// client detecting it.
892    #[inline]
893    #[must_use]
894    pub fn is_alive(&self) -> bool {
895        match &self.read_task {
896            Some(read_task) => !read_task.is_finished() && !self.write_task.is_finished(),
897            None => !self.write_task.is_finished(),
898        }
899    }
900
901    fn spawn_message_handler_task(
902        connection_state: Arc<AtomicU8>,
903        state_notify: Arc<tokio::sync::Notify>,
904        mut reader: MessageReader,
905        message_handler: Option<&MessageHandler>,
906        ping_handler: Option<&PingHandler>,
907        idle_timeout_ms: Option<u64>,
908    ) -> tokio::task::JoinHandle<()> {
909        log::debug!("Started message handler task 'read'");
910
911        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
912        let idle_timeout = idle_timeout_ms.map(Duration::from_millis);
913
914        // Clone Arc handlers for the async task
915        let message_handler = message_handler.cloned();
916        let ping_handler = ping_handler.cloned();
917
918        tokio::task::spawn(async move {
919            let mut last_data_time = dst::time::Instant::now();
920
921            loop {
922                if !ConnectionMode::from_atomic(&connection_state).is_active() {
923                    break;
924                }
925
926                match dst::time::timeout(check_interval, reader.next()).await {
927                    Ok(Some(Ok(Message::Binary(data)))) => {
928                        log::trace!("Received message <binary> {} bytes", data.len());
929                        last_data_time = dst::time::Instant::now();
930
931                        if let Some(ref handler) = message_handler {
932                            handler(Message::Binary(data));
933                        }
934                    }
935                    Ok(Some(Ok(Message::Text(data)))) => {
936                        log::trace!("Received message: {data:?}");
937                        last_data_time = dst::time::Instant::now();
938
939                        if let Some(ref handler) = message_handler {
940                            handler(Message::Text(data));
941                        }
942                    }
943                    Ok(Some(Ok(Message::Ping(ping_data)))) => {
944                        log::trace!("Received ping: {ping_data:?}");
945                        // Do not reset last_data_time: pings are keep-alive frames, not application
946                        // data, so a peer that emits only pings must still trip the idle timeout.
947
948                        if let Some(ref handler) = ping_handler {
949                            handler(ping_data.to_vec());
950                        }
951                    }
952                    Ok(Some(Ok(Message::Pong(_)))) => {
953                        log::trace!("Received pong");
954                        // Do not reset last_data_time: pongs are keep-alive replies (not data)
955                    }
956                    Ok(Some(Ok(Message::Close(_)))) => {
957                        log::debug!("Received close message - terminating");
958                        break;
959                    }
960                    Ok(Some(Err(e))) => {
961                        log::error!("Received error message - terminating: {e}");
962                        break;
963                    }
964                    Ok(None) => {
965                        log::debug!("No message received - terminating");
966                        break;
967                    }
968                    Err(_) => {
969                        if let Some(timeout) = idle_timeout {
970                            let idle_duration = last_data_time.elapsed();
971                            if idle_duration >= timeout {
972                                log::warn!(
973                                    "Read idle timeout: no data received for {:.1}s",
974                                    idle_duration.as_secs_f64()
975                                );
976                                break;
977                            }
978                        }
979                    }
980                }
981            }
982
983            // Wake the controller immediately so it detects the dead read task
984            state_notify.notify_one();
985        })
986    }
987
988    /// Attempts to send all buffered messages after reconnection.
989    ///
990    /// Returns `true` if a send error occurred (caller should trigger reconnection).
991    /// Messages remain in buffer if send fails, preserving them for the next reconnection attempt.
992    async fn drain_reconnect_buffer(
993        buffer: &mut VecDeque<Message>,
994        writer: &mut MessageWriter,
995    ) -> bool {
996        if buffer.is_empty() {
997            return false;
998        }
999
1000        let initial_buffer_len = buffer.len();
1001        log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
1002
1003        let mut send_error_occurred = false;
1004
1005        while let Some(buffered_msg) = buffer.front() {
1006            // Clone message before attempting send (to keep in buffer if send fails)
1007            let msg_to_send = buffered_msg.clone();
1008
1009            if let Err(e) = writer.send(msg_to_send).await {
1010                log::error!(
1011                    "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
1012                    buffer.len()
1013                );
1014                send_error_occurred = true;
1015                break; // Stop processing buffer, remaining messages preserved for next reconnection
1016            }
1017
1018            // Only remove from buffer after successful send
1019            buffer.pop_front();
1020        }
1021
1022        if buffer.is_empty() {
1023            log::info!("Successfully sent all {initial_buffer_len} buffered messages");
1024        }
1025
1026        send_error_occurred
1027    }
1028
1029    fn can_drain_reconnect_buffer(
1030        reconnect_buffer_waits_for_auth: &AtomicBool,
1031        auth_tracker: &Arc<OnceLock<AuthTracker>>,
1032    ) -> ReconnectBufferAction {
1033        if !reconnect_buffer_waits_for_auth.load(Ordering::Acquire) {
1034            return ReconnectBufferAction::Drain;
1035        }
1036
1037        match auth_tracker.get().map(AuthTracker::auth_state) {
1038            Some(AuthState::Authenticated) => ReconnectBufferAction::Drain,
1039            Some(AuthState::Failed) => ReconnectBufferAction::Discard,
1040            Some(AuthState::Unauthenticated) | None => ReconnectBufferAction::Wait,
1041        }
1042    }
1043
1044    fn spawn_write_task(
1045        connection_state: Arc<AtomicU8>,
1046        state_notify: Arc<tokio::sync::Notify>,
1047        writer: MessageWriter,
1048        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
1049        auth_tracker: Arc<OnceLock<AuthTracker>>,
1050        reconnect_buffer_waits_for_auth: Arc<AtomicBool>,
1051    ) -> tokio::task::JoinHandle<()> {
1052        log_task_started("write");
1053
1054        // Interval between checking the connection mode
1055        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1056
1057        tokio::task::spawn(async move {
1058            let mut active_writer = writer;
1059            // Buffer for messages received during reconnection
1060            // VecDeque for efficient pop_front() operations
1061            let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
1062
1063            loop {
1064                let mode = ConnectionMode::from_atomic(&connection_state);
1065
1066                match mode {
1067                    ConnectionMode::Disconnect => {
1068                        // Log any buffered messages that will be lost
1069                        if !reconnect_buffer.is_empty() {
1070                            log::warn!(
1071                                "Discarding {} buffered messages due to disconnect",
1072                                reconnect_buffer.len()
1073                            );
1074                            reconnect_buffer.clear();
1075                        }
1076
1077                        // Attempt to close the writer gracefully before exiting,
1078                        // we ignore any error as the writer may already be closed.
1079                        _ = dst::time::timeout(
1080                            Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
1081                            active_writer.close(),
1082                        )
1083                        .await;
1084                        break;
1085                    }
1086                    ConnectionMode::Closed => {
1087                        // Log any buffered messages that will be lost
1088                        if !reconnect_buffer.is_empty() {
1089                            log::warn!(
1090                                "Discarding {} buffered messages due to closed connection",
1091                                reconnect_buffer.len()
1092                            );
1093                            reconnect_buffer.clear();
1094                        }
1095                        break;
1096                    }
1097                    _ => {}
1098                }
1099
1100                if mode.is_active() && !reconnect_buffer.is_empty() {
1101                    match Self::can_drain_reconnect_buffer(
1102                        reconnect_buffer_waits_for_auth.as_ref(),
1103                        &auth_tracker,
1104                    ) {
1105                        ReconnectBufferAction::Drain => {
1106                            let send_error = Self::drain_reconnect_buffer(
1107                                &mut reconnect_buffer,
1108                                &mut active_writer,
1109                            )
1110                            .await;
1111
1112                            if send_error {
1113                                if let Some(tracker) = auth_tracker.get() {
1114                                    tracker.invalidate();
1115                                }
1116                                connection_state
1117                                    .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
1118                                state_notify.notify_one();
1119                            }
1120
1121                            continue;
1122                        }
1123                        ReconnectBufferAction::Discard => {
1124                            log::warn!(
1125                                "Discarding {} buffered messages after authentication failed",
1126                                reconnect_buffer.len()
1127                            );
1128                            reconnect_buffer.clear();
1129                            continue;
1130                        }
1131                        ReconnectBufferAction::Wait => {}
1132                    }
1133                }
1134
1135                match dst::time::timeout(check_interval, writer_rx.recv()).await {
1136                    Ok(Some(msg)) => {
1137                        // Re-check connection mode after receiving a message
1138                        let mode = ConnectionMode::from_atomic(&connection_state);
1139                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
1140                            break;
1141                        }
1142
1143                        match msg {
1144                            WriterCommand::Update(new_writer, tx) => {
1145                                log::debug!("Received new writer");
1146
1147                                // Delay before closing connection
1148                                dst::time::sleep(Duration::from_millis(100)).await;
1149
1150                                // Attempt to close the writer gracefully on update,
1151                                // we ignore any error as the writer may already be closed.
1152                                _ = dst::time::timeout(
1153                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
1154                                    active_writer.close(),
1155                                )
1156                                .await;
1157
1158                                active_writer = new_writer;
1159                                log::debug!("Updated writer");
1160
1161                                if let Err(e) = tx.send(true) {
1162                                    log::error!(
1163                                        "Failed to report writer update to controller: {e:?}"
1164                                    );
1165                                }
1166                            }
1167                            WriterCommand::Send(msg) if mode.is_reconnect() => {
1168                                // Buffer messages during reconnection instead of dropping them
1169                                log::debug!(
1170                                    "Buffering message during reconnection (buffer size: {})",
1171                                    reconnect_buffer.len() + 1
1172                                );
1173                                reconnect_buffer.push_back(msg);
1174                            }
1175                            WriterCommand::Send(msg) => {
1176                                if let Err(e) = active_writer.send(msg.clone()).await {
1177                                    log::error!("Failed to send message: {e}");
1178                                    log::warn!("Writer triggering reconnect");
1179                                    reconnect_buffer.push_back(msg);
1180
1181                                    if let Some(tracker) = auth_tracker.get() {
1182                                        tracker.invalidate();
1183                                    }
1184                                    connection_state
1185                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
1186                                    state_notify.notify_one();
1187                                }
1188                            }
1189                        }
1190                    }
1191                    Ok(None) => {
1192                        // Channel closed - writer task should terminate
1193                        log::debug!("Writer channel closed, terminating writer task");
1194                        break;
1195                    }
1196                    Err(_) => {
1197                        // Timeout - just continue the loop
1198                    }
1199                }
1200            }
1201
1202            // Attempt to close the writer gracefully before exiting,
1203            // we ignore any error as the writer may already be closed.
1204            _ = dst::time::timeout(
1205                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
1206                active_writer.close(),
1207            )
1208            .await;
1209
1210            log_task_stopped("write");
1211        })
1212    }
1213
1214    fn spawn_heartbeat_task(
1215        connection_state: Arc<AtomicU8>,
1216        heartbeat_secs: u64,
1217        message: Option<String>,
1218        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
1219    ) -> tokio::task::JoinHandle<()> {
1220        log_task_started("heartbeat");
1221
1222        tokio::task::spawn(async move {
1223            let interval = Duration::from_secs(heartbeat_secs);
1224
1225            loop {
1226                dst::time::sleep(interval).await;
1227
1228                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
1229                    ConnectionMode::Active => {
1230                        let msg = match &message {
1231                            Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
1232                            None => WriterCommand::Send(Message::Ping(vec![].into())),
1233                        };
1234
1235                        match writer_tx.send(msg) {
1236                            Ok(()) => log::trace!("Sent heartbeat to writer task"),
1237                            Err(e) => {
1238                                log::error!("Failed to send heartbeat to writer task: {e}");
1239                            }
1240                        }
1241                    }
1242                    ConnectionMode::Reconnect => {}
1243                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
1244                }
1245            }
1246
1247            log_task_stopped("heartbeat");
1248        })
1249    }
1250}
1251
1252impl Drop for WebSocketClientInner {
1253    fn drop(&mut self) {
1254        // Delegate to explicit cleanup handler
1255        self.clean_drop();
1256    }
1257}
1258
1259/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
1260impl CleanDrop for WebSocketClientInner {
1261    fn clean_drop(&mut self) {
1262        if let Some(ref read_task) = self.read_task.take()
1263            && !read_task.is_finished()
1264        {
1265            read_task.abort();
1266            log_task_aborted("read");
1267        }
1268
1269        if !self.write_task.is_finished() {
1270            self.write_task.abort();
1271            log_task_aborted("write");
1272        }
1273
1274        if let Some(ref handle) = self.heartbeat_task.take()
1275            && !handle.is_finished()
1276        {
1277            handle.abort();
1278            log_task_aborted("heartbeat");
1279        }
1280
1281        // Clear handlers to break potential reference cycles
1282        self.message_handler = None;
1283        self.ping_handler = None;
1284    }
1285}
1286
1287#[expect(
1288    clippy::missing_fields_in_debug,
1289    reason = "handler closures and internal task handles are intentionally omitted"
1290)]
1291impl Debug for WebSocketClientInner {
1292    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1293        f.debug_struct(stringify!(WebSocketClientInner))
1294            .field("config", &self.config)
1295            .field(
1296                "connection_mode",
1297                &ConnectionMode::from_atomic(&self.connection_mode),
1298            )
1299            .field("reconnect_timeout", &self.reconnect_timeout)
1300            .field("is_stream_mode", &self.is_stream_mode)
1301            .finish()
1302    }
1303}
1304
1305/// WebSocket client with automatic reconnection.
1306///
1307/// Handles connection state, callbacks, and rate limiting.
1308/// See module docs for architecture details.
1309#[cfg_attr(
1310    feature = "python",
1311    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
1312)]
1313#[cfg_attr(
1314    feature = "python",
1315    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.network")
1316)]
1317pub struct WebSocketClient {
1318    pub(crate) controller_task: tokio::task::JoinHandle<()>,
1319    pub(crate) connection_mode: Arc<AtomicU8>,
1320    pub(crate) state_notify: Arc<tokio::sync::Notify>,
1321    pub(crate) reconnect_timeout: Duration,
1322    pub(crate) rate_limiter: Arc<RateLimiter<Ustr, MonotonicClock>>,
1323    pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
1324    auth_tracker: Arc<OnceLock<AuthTracker>>,
1325    reconnect_buffer_waits_for_auth: Arc<AtomicBool>,
1326}
1327
1328impl Debug for WebSocketClient {
1329    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1330        f.debug_struct(stringify!(WebSocketClient)).finish()
1331    }
1332}
1333
1334impl WebSocketClient {
1335    /// Creates a websocket client in **stream mode** that returns a [`MessageReader`].
1336    ///
1337    /// Returns a stream that the caller owns and reads from directly. Automatic reconnection
1338    /// is **disabled** because the reader cannot be replaced internally. On disconnection, the
1339    /// client transitions to CLOSED state and the caller must manually reconnect by calling
1340    /// `connect_stream` again.
1341    ///
1342    /// Use stream mode when you need custom reconnection logic, direct control over message
1343    /// reading, or fine-grained backpressure handling.
1344    ///
1345    /// See [`WebSocketConfig`] documentation for comparison with handler mode.
1346    ///
1347    /// # Errors
1348    ///
1349    /// Returns an error if the connection cannot be established.
1350    pub async fn connect_stream(
1351        config: WebSocketConfig,
1352        keyed_quotas: Vec<(String, Quota)>,
1353        default_quota: Option<Quota>,
1354        post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
1355    ) -> Result<(MessageReader, Self), TransportError> {
1356        install_cryptographic_provider();
1357
1358        // Create a single connection and split it, respecting configured headers
1359        let (writer, reader) = WebSocketClientInner::connect_with_server(
1360            &config.url,
1361            config.headers.clone(),
1362            config.backend,
1363            config.proxy_url.as_deref(),
1364        )
1365        .await?;
1366
1367        // Create inner without connecting (we'll provide the writer)
1368        let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
1369
1370        let connection_mode = inner.connection_mode.clone();
1371        let state_notify = inner.state_notify.clone();
1372        let reconnect_timeout = inner.reconnect_timeout;
1373        let auth_tracker = Arc::clone(&inner.auth_tracker);
1374        let reconnect_buffer_waits_for_auth = Arc::clone(&inner.reconnect_buffer_waits_for_auth);
1375        let keyed_quotas = keyed_quotas
1376            .into_iter()
1377            .map(|(key, quota)| (Ustr::from(&key), quota))
1378            .collect();
1379        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
1380        let writer_tx = inner.writer_tx.clone();
1381
1382        let controller_task = Self::spawn_controller_task(
1383            inner,
1384            connection_mode.clone(),
1385            state_notify.clone(),
1386            post_reconnect,
1387            Arc::clone(&auth_tracker),
1388        );
1389
1390        Ok((
1391            reader,
1392            Self {
1393                controller_task,
1394                connection_mode,
1395                state_notify,
1396                reconnect_timeout,
1397                rate_limiter,
1398                writer_tx,
1399                auth_tracker,
1400                reconnect_buffer_waits_for_auth,
1401            },
1402        ))
1403    }
1404
1405    /// Creates a websocket client in **handler mode** with automatic reconnection.
1406    ///
1407    /// The handler is called for each incoming message on an internal task.
1408    /// Automatic reconnection is **enabled** with exponential backoff. On disconnection,
1409    /// the client automatically attempts to reconnect and replaces the internal reader
1410    /// (the handler continues working seamlessly).
1411    ///
1412    /// Use handler mode for simplified connection management, automatic reconnection, Python
1413    /// bindings, or callback-based message handling.
1414    ///
1415    /// See [`WebSocketConfig`] documentation for comparison with stream mode.
1416    ///
1417    /// # Errors
1418    ///
1419    /// Returns an error if:
1420    /// - The connection cannot be established.
1421    /// - `message_handler` is `None` (use `connect_stream` instead).
1422    pub async fn connect(
1423        config: WebSocketConfig,
1424        message_handler: Option<MessageHandler>,
1425        ping_handler: Option<PingHandler>,
1426        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1427        keyed_quotas: Vec<(String, Quota)>,
1428        default_quota: Option<Quota>,
1429    ) -> Result<Self, TransportError> {
1430        // Validate that handler mode has a message handler
1431        if message_handler.is_none() {
1432            return Err(TransportError::Io(std::io::Error::new(
1433                std::io::ErrorKind::InvalidInput,
1434                "Handler mode requires message_handler to be set. Use connect_stream() for stream mode without a handler.",
1435            )));
1436        }
1437
1438        log::debug!("Connecting");
1439        let inner =
1440            WebSocketClientInner::connect_url(config, message_handler, ping_handler).await?;
1441        let connection_mode = inner.connection_mode.clone();
1442        let state_notify = inner.state_notify.clone();
1443        let writer_tx = inner.writer_tx.clone();
1444        let reconnect_timeout = inner.reconnect_timeout;
1445        let auth_tracker = Arc::clone(&inner.auth_tracker);
1446        let reconnect_buffer_waits_for_auth = Arc::clone(&inner.reconnect_buffer_waits_for_auth);
1447
1448        let controller_task = Self::spawn_controller_task(
1449            inner,
1450            connection_mode.clone(),
1451            state_notify.clone(),
1452            post_reconnection,
1453            Arc::clone(&auth_tracker),
1454        );
1455
1456        let keyed_quotas = keyed_quotas
1457            .into_iter()
1458            .map(|(key, quota)| (Ustr::from(&key), quota))
1459            .collect();
1460        let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
1461
1462        Ok(Self {
1463            controller_task,
1464            connection_mode,
1465            state_notify,
1466            reconnect_timeout,
1467            rate_limiter,
1468            writer_tx,
1469            auth_tracker,
1470            reconnect_buffer_waits_for_auth,
1471        })
1472    }
1473
1474    /// Returns the current connection mode.
1475    #[must_use]
1476    pub fn connection_mode(&self) -> ConnectionMode {
1477        ConnectionMode::from_atomic(&self.connection_mode)
1478    }
1479
1480    /// Returns a clone of the connection mode atomic for external state tracking.
1481    ///
1482    /// This allows adapter clients to track connection state across reconnections
1483    /// without message-passing delays.
1484    #[must_use]
1485    pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
1486        Arc::clone(&self.connection_mode)
1487    }
1488
1489    /// Check if the client connection is active.
1490    ///
1491    /// Returns `true` if the client is connected and has not been signalled to disconnect.
1492    /// The client will automatically retry connection based on its configuration.
1493    #[inline]
1494    #[must_use]
1495    pub fn is_active(&self) -> bool {
1496        self.connection_mode().is_active()
1497    }
1498
1499    /// Check if the client is disconnected.
1500    #[must_use]
1501    pub fn is_disconnected(&self) -> bool {
1502        self.controller_task.is_finished()
1503    }
1504
1505    /// Check if the client is reconnecting.
1506    ///
1507    /// Returns `true` if the client lost connection and is attempting to reestablish it.
1508    /// The client will automatically retry connection based on its configuration.
1509    #[inline]
1510    #[must_use]
1511    pub fn is_reconnecting(&self) -> bool {
1512        self.connection_mode().is_reconnect()
1513    }
1514
1515    /// Registers an [`AuthTracker`] with the client.
1516    ///
1517    /// When the controller detects a dead connection and transitions to
1518    /// `Reconnect`, it calls `invalidate()` on the tracker so that any
1519    /// pending authenticated sends see the state change immediately.
1520    /// Set `reconnect_buffer_waits_for_auth` for clients that must not replay
1521    /// buffered messages until the next session authenticates.
1522    ///
1523    /// Call this once after construction, before any authenticated sends.
1524    pub fn set_auth_tracker(&self, tracker: AuthTracker, reconnect_buffer_waits_for_auth: bool) {
1525        let _ = self.auth_tracker.set(tracker);
1526        self.reconnect_buffer_waits_for_auth
1527            .store(reconnect_buffer_waits_for_auth, Ordering::Release);
1528    }
1529
1530    /// Check if the client is disconnecting.
1531    ///
1532    /// Returns `true` if the client is in disconnect mode.
1533    #[inline]
1534    #[must_use]
1535    pub fn is_disconnecting(&self) -> bool {
1536        self.connection_mode().is_disconnect()
1537    }
1538
1539    /// Check if the client is closed.
1540    ///
1541    /// Returns `true` if the client has been explicitly disconnected or reached
1542    /// maximum reconnection attempts. In this state, the client cannot be reused
1543    /// and a new client must be created for further connections.
1544    #[inline]
1545    #[must_use]
1546    pub fn is_closed(&self) -> bool {
1547        self.connection_mode().is_closed()
1548    }
1549
1550    /// Checks whether the connection is in a terminal state (disconnecting or closed).
1551    ///
1552    /// Single atomic load to fail fast before rate limiting or waiting.
1553    #[inline]
1554    fn check_not_terminal(&self) -> Result<(), SendError> {
1555        match self.connection_mode() {
1556            ConnectionMode::Disconnect | ConnectionMode::Closed => Err(SendError::Closed),
1557            _ => Ok(()),
1558        }
1559    }
1560
1561    /// Waits for rate limiter quota, aborting early if connection enters a terminal state.
1562    async fn await_rate_limit_or_closed(&self, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1563        const CHECK_INTERVAL_MS: u64 = 100;
1564
1565        tokio::select! {
1566            biased;
1567            () = self.rate_limiter.await_keys_ready(keys) => Ok(()),
1568            () = async {
1569                loop {
1570                    let notified = self.state_notify.notified();
1571
1572                    if matches!(self.connection_mode(), ConnectionMode::Disconnect | ConnectionMode::Closed) {
1573                        break;
1574                    }
1575                    tokio::select! {
1576                        biased;
1577                        () = notified => {}
1578                        () = dst::time::sleep(Duration::from_millis(CHECK_INTERVAL_MS)) => {}
1579                    }
1580                }
1581            } => Err(SendError::Closed),
1582        }
1583    }
1584
1585    /// Waits for the client to become active before sending.
1586    ///
1587    /// Uses `state_notify` for event-driven wakeup so sends resume immediately
1588    /// after reconnection completes. A fallback interval guards against missed
1589    /// notifications.
1590    async fn wait_for_active(&self) -> Result<(), SendError> {
1591        const FALLBACK_INTERVAL_MS: u64 = 100;
1592
1593        let mode = self.connection_mode();
1594        if mode.is_active() {
1595            return Ok(());
1596        }
1597
1598        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
1599            return Err(SendError::Closed);
1600        }
1601
1602        log::debug!("Waiting for client to become ACTIVE before sending...");
1603
1604        let fallback_interval = Duration::from_millis(FALLBACK_INTERVAL_MS);
1605
1606        dst::time::timeout(self.reconnect_timeout, async {
1607            loop {
1608                // Register notification interest BEFORE checking state to prevent
1609                // a race where the state changes between our check and the await
1610                let notified = self.state_notify.notified();
1611
1612                let mode = self.connection_mode();
1613                if mode.is_active() {
1614                    return Ok(());
1615                }
1616
1617                if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
1618                    return Err(());
1619                }
1620
1621                tokio::select! {
1622                    biased;
1623                    () = notified => {}
1624                    () = dst::time::sleep(fallback_interval) => {}
1625                }
1626            }
1627        })
1628        .await
1629        .map_err(|_| SendError::Timeout)?
1630        .map_err(|()| SendError::Closed)
1631    }
1632
1633    /// Signals that the caller's reader has observed EOF or a fatal error.
1634    ///
1635    /// In stream mode the controller has no visibility into the caller-owned reader.
1636    /// Call this method when `reader.next().await` returns `None` or an unrecoverable
1637    /// error so the controller transitions to `Closed` and dependent tasks shut down.
1638    ///
1639    /// For peer-initiated close frames (`Message::Close`), use [`disconnect`](Self::disconnect)
1640    /// instead so the writer can send the close reply before shutting down.
1641    ///
1642    /// This is a no-op if the connection is already closed or disconnecting.
1643    pub fn notify_closed(&self) {
1644        let mode = self.connection_mode();
1645        if mode.is_disconnect() || mode.is_closed() {
1646            return;
1647        }
1648
1649        log::debug!("Stream reader signalled EOF, transitioning to CLOSED");
1650
1651        self.connection_mode
1652            .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1653        self.state_notify.notify_waiters();
1654    }
1655
1656    /// Set disconnect mode to true.
1657    ///
1658    /// Controller task will periodically check the disconnect mode
1659    /// and shutdown the client if it is alive
1660    pub async fn disconnect(&self) {
1661        log::debug!("Disconnecting");
1662        self.connection_mode
1663            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1664        self.state_notify.notify_waiters();
1665
1666        if dst::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1667            while !self.is_disconnected() {
1668                dst::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
1669            }
1670
1671            if !self.controller_task.is_finished() {
1672                self.controller_task.abort();
1673                log_task_aborted("controller");
1674            }
1675        })
1676        .await
1677            == Ok(())
1678        {
1679            log::debug!("Controller task finished");
1680        } else {
1681            log::error!("Timeout waiting for controller task to finish");
1682
1683            if !self.controller_task.is_finished() {
1684                self.controller_task.abort();
1685                log_task_aborted("controller");
1686            }
1687            self.connection_mode
1688                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1689        }
1690    }
1691
1692    /// Sends the given text `data` to the server.
1693    ///
1694    /// Returns `Ok(())` when the message is enqueued to the writer channel. This does NOT
1695    /// guarantee delivery: if a disconnect occurs concurrently, the writer task may drop the
1696    /// message. During reconnection, messages are buffered and replayed on the new connection.
1697    ///
1698    /// # Errors
1699    ///
1700    /// Returns a websocket error if unable to send.
1701    #[allow(unused_variables)]
1702    pub async fn send_text(&self, data: String, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1703        self.check_not_terminal()?;
1704
1705        self.await_rate_limit_or_closed(keys).await?;
1706        self.wait_for_active().await?;
1707
1708        log::trace!("Sending text: {data:?}");
1709
1710        let msg = Message::Text(data.into());
1711        self.writer_tx
1712            .send(WriterCommand::Send(msg))
1713            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1714    }
1715
1716    /// Sends a pong frame back to the server.
1717    ///
1718    /// # Errors
1719    ///
1720    /// Returns a websocket error if unable to send.
1721    pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1722        self.wait_for_active().await?;
1723
1724        log::trace!("Sending pong frame ({} bytes)", data.len());
1725
1726        let msg = Message::Pong(data.into());
1727        self.writer_tx
1728            .send(WriterCommand::Send(msg))
1729            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1730    }
1731
1732    /// Sends the given bytes `data` to the server.
1733    ///
1734    /// Returns `Ok(())` when the message is enqueued to the writer channel. This does NOT
1735    /// guarantee delivery: if a disconnect occurs concurrently, the writer task may drop the
1736    /// message. During reconnection, messages are buffered and replayed on the new connection.
1737    ///
1738    /// # Errors
1739    ///
1740    /// Returns a websocket error if unable to send.
1741    #[allow(unused_variables)]
1742    pub async fn send_bytes(&self, data: Vec<u8>, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1743        self.check_not_terminal()?;
1744
1745        self.await_rate_limit_or_closed(keys).await?;
1746        self.wait_for_active().await?;
1747
1748        log::trace!("Sending bytes: {data:?}");
1749
1750        let msg = Message::Binary(data.into());
1751        self.writer_tx
1752            .send(WriterCommand::Send(msg))
1753            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1754    }
1755
1756    /// Sends a close message to the server.
1757    ///
1758    /// # Errors
1759    ///
1760    /// Returns a websocket error if unable to send.
1761    pub async fn send_close_message(&self) -> Result<(), SendError> {
1762        self.wait_for_active().await?;
1763
1764        let msg = Message::Close(None);
1765        self.writer_tx
1766            .send(WriterCommand::Send(msg))
1767            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1768    }
1769
1770    fn spawn_controller_task(
1771        mut inner: WebSocketClientInner,
1772        connection_mode: Arc<AtomicU8>,
1773        state_notify: Arc<tokio::sync::Notify>,
1774        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1775        auth_tracker: Arc<OnceLock<AuthTracker>>,
1776    ) -> tokio::task::JoinHandle<()> {
1777        const CONTROLLER_FALLBACK_INTERVAL_MS: u64 = 100;
1778
1779        tokio::task::spawn(async move {
1780            log_task_started("controller");
1781
1782            let fallback_interval = Duration::from_millis(CONTROLLER_FALLBACK_INTERVAL_MS);
1783
1784            loop {
1785                tokio::select! {
1786                    biased;
1787                    () = state_notify.notified() => {}
1788                    () = dst::time::sleep(fallback_interval) => {}
1789                }
1790
1791                let mut mode = ConnectionMode::from_atomic(&connection_mode);
1792
1793                if mode.is_disconnect() {
1794                    log::debug!("Disconnecting");
1795
1796                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1797                    if dst::time::timeout(timeout, async {
1798                        // Delay awaiting graceful shutdown
1799                        dst::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1800
1801                        if let Some(task) = &inner.read_task
1802                            && !task.is_finished()
1803                        {
1804                            task.abort();
1805                            log_task_aborted("read");
1806                        }
1807
1808                        if let Some(task) = &inner.heartbeat_task
1809                            && !task.is_finished()
1810                        {
1811                            task.abort();
1812                            log_task_aborted("heartbeat");
1813                        }
1814                    })
1815                    .await
1816                    .is_err()
1817                    {
1818                        log::error!("Shutdown timed out after {}s", timeout.as_secs());
1819                    }
1820
1821                    log::debug!("Closed");
1822                    break; // Controller finished
1823                }
1824
1825                if mode.is_closed() {
1826                    log::debug!("Connection closed");
1827                    break;
1828                }
1829
1830                if mode.is_active() && !inner.is_alive() {
1831                    let target = if inner.is_stream_mode {
1832                        ConnectionMode::Closed
1833                    } else {
1834                        ConnectionMode::Reconnect
1835                    };
1836
1837                    if connection_mode
1838                        .compare_exchange(
1839                            ConnectionMode::Active.as_u8(),
1840                            target.as_u8(),
1841                            Ordering::SeqCst,
1842                            Ordering::SeqCst,
1843                        )
1844                        .is_ok()
1845                    {
1846                        if let Some(tracker) = auth_tracker.get() {
1847                            tracker.invalidate();
1848                        }
1849                        log::debug!("Detected dead connection, transitioning to {target:?}");
1850                    }
1851                    mode = ConnectionMode::from_atomic(&connection_mode);
1852                }
1853
1854                if mode.is_reconnect() {
1855                    // Check if max reconnection attempts exceeded
1856                    if let Some(max_attempts) = inner.reconnect_max_attempts
1857                        && inner.reconnection_attempt_count >= max_attempts
1858                    {
1859                        log::error!(
1860                            "Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
1861                        );
1862                        connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1863                        state_notify.notify_waiters();
1864                        break;
1865                    }
1866
1867                    inner.reconnection_attempt_count += 1;
1868                    log::debug!(
1869                        "Reconnection attempt {} of {}",
1870                        inner.reconnection_attempt_count,
1871                        inner
1872                            .reconnect_max_attempts
1873                            .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1874                    );
1875
1876                    // Race reconnect against disconnect notification
1877                    let reconnect_result = tokio::select! {
1878                        biased;
1879                        result = inner.reconnect() => Some(result),
1880                        () = async {
1881                            loop {
1882                                state_notify.notified().await;
1883
1884                                if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
1885                                    break;
1886                                }
1887                            }
1888                        } => None,
1889                    };
1890
1891                    match reconnect_result {
1892                        None => {
1893                            log::debug!("Reconnect interrupted by disconnect");
1894                        }
1895                        Some(Ok(())) => {
1896                            inner.backoff.reset();
1897                            inner.reconnection_attempt_count = 0;
1898
1899                            state_notify.notify_waiters();
1900
1901                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
1902                                if let Some(ref handler) = inner.message_handler {
1903                                    let reconnected_msg =
1904                                        Message::Text(RECONNECTED.to_string().into());
1905                                    handler(reconnected_msg);
1906                                    log::debug!("Sent reconnected message to handler");
1907                                }
1908
1909                                // TODO: Retain this legacy callback for use from Python
1910                                if let Some(ref callback) = post_reconnection {
1911                                    callback();
1912                                    log::debug!("Called `post_reconnection` handler");
1913                                }
1914
1915                                log::debug!("Reconnected successfully");
1916                            } else {
1917                                log::debug!(
1918                                    "Skipping post_reconnection handlers due to disconnect state"
1919                                );
1920                            }
1921                        }
1922                        Some(Err(e)) => {
1923                            let duration = inner.backoff.next_duration();
1924                            log::warn!(
1925                                "Reconnect attempt {} failed: {e}",
1926                                inner.reconnection_attempt_count
1927                            );
1928
1929                            if !duration.is_zero() {
1930                                log::warn!("Backing off for {}s...", duration.as_secs_f64());
1931                                // Race backoff sleep against disconnect
1932                                tokio::select! {
1933                                    biased;
1934                                    () = dst::time::sleep(duration) => {}
1935                                    () = async {
1936                                        loop {
1937                                            state_notify.notified().await;
1938
1939                                            if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
1940                                                break;
1941                                            }
1942                                        }
1943                                    } => {
1944                                        log::debug!("Backoff interrupted by disconnect");
1945                                    }
1946                                }
1947                            }
1948                        }
1949                    }
1950                }
1951            }
1952            inner
1953                .connection_mode
1954                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1955
1956            log_task_stopped("controller");
1957        })
1958    }
1959}
1960
1961// Abort controller task on drop to clean up background tasks
1962impl Drop for WebSocketClient {
1963    fn drop(&mut self) {
1964        if !self.controller_task.is_finished() {
1965            self.controller_task.abort();
1966            log_task_aborted("controller");
1967        }
1968    }
1969}
1970
1971#[cfg(test)]
1972#[cfg(not(feature = "turmoil"))]
1973#[cfg(not(all(feature = "simulation", madsim)))] // transport-layer I/O not simulated
1974#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1975mod tests {
1976    use std::{num::NonZeroU32, sync::Arc};
1977
1978    use futures_util::{SinkExt, StreamExt};
1979    use tokio::{
1980        net::TcpListener,
1981        task::{self, JoinHandle},
1982    };
1983    use tokio_tungstenite::{
1984        accept_hdr_async,
1985        tungstenite::{
1986            Message as WsMessage,
1987            handshake::server::{self, Callback},
1988            http::HeaderValue,
1989        },
1990    };
1991
1992    use crate::{
1993        ratelimiter::quota::Quota,
1994        websocket::{TransportBackend, WebSocketClient, WebSocketConfig},
1995    };
1996
1997    struct TestServer {
1998        task: JoinHandle<()>,
1999        port: u16,
2000    }
2001
2002    #[derive(Debug, Clone)]
2003    struct TestCallback {
2004        key: String,
2005        value: HeaderValue,
2006    }
2007
2008    impl Callback for TestCallback {
2009        #[expect(clippy::panic_in_result_fn)]
2010        fn on_request(
2011            self,
2012            request: &server::Request,
2013            response: server::Response,
2014        ) -> Result<server::Response, server::ErrorResponse> {
2015            let _ = response;
2016            let value = request.headers().get(&self.key);
2017            assert!(value.is_some());
2018
2019            if let Some(value) = request.headers().get(&self.key) {
2020                assert_eq!(value, self.value);
2021            }
2022
2023            Ok(response)
2024        }
2025    }
2026
2027    impl TestServer {
2028        async fn setup() -> Self {
2029            let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
2030            let port = TcpListener::local_addr(&server).unwrap().port();
2031
2032            let header_key = "test".to_string();
2033            let header_value = "test".to_string();
2034
2035            let test_call_back = TestCallback {
2036                key: header_key,
2037                value: HeaderValue::from_str(&header_value).unwrap(),
2038            };
2039
2040            let task = task::spawn(async move {
2041                // Keep accepting connections
2042                loop {
2043                    let (conn, _) = server.accept().await.unwrap();
2044                    let mut websocket = accept_hdr_async(conn, test_call_back.clone())
2045                        .await
2046                        .unwrap();
2047
2048                    task::spawn(async move {
2049                        // Inner if consumes `msg`, cannot hoist into a match guard
2050                        #[expect(clippy::collapsible_match)]
2051                        while let Some(Ok(msg)) = websocket.next().await {
2052                            match msg {
2053                                WsMessage::Text(txt) if txt == "close-now" => {
2054                                    log::debug!("Forcibly closing from server side");
2055                                    // This sends a close frame, then stops reading
2056                                    let _ = websocket.close(None).await;
2057                                    break;
2058                                }
2059                                // Echo text/binary frames
2060                                WsMessage::Text(_) | WsMessage::Binary(_) => {
2061                                    if websocket.send(msg).await.is_err() {
2062                                        break;
2063                                    }
2064                                }
2065                                // If the client closes, we also break
2066                                WsMessage::Close(_frame) => {
2067                                    let _ = websocket.close(None).await;
2068                                    break;
2069                                }
2070                                // Ignore pings/pongs
2071                                _ => {}
2072                            }
2073                        }
2074                    });
2075                }
2076            });
2077
2078            Self { task, port }
2079        }
2080    }
2081
2082    impl Drop for TestServer {
2083        fn drop(&mut self) {
2084            self.task.abort();
2085        }
2086    }
2087
2088    async fn setup_test_client(port: u16) -> WebSocketClient {
2089        let config = WebSocketConfig {
2090            url: format!("ws://127.0.0.1:{port}"),
2091            headers: vec![("test".into(), "test".into())],
2092            heartbeat: None,
2093            heartbeat_msg: None,
2094            reconnect_timeout_ms: None,
2095            reconnect_delay_initial_ms: None,
2096            reconnect_backoff_factor: None,
2097            reconnect_delay_max_ms: None,
2098            reconnect_jitter_ms: None,
2099            reconnect_max_attempts: None,
2100            idle_timeout_ms: None,
2101            backend: TransportBackend::Tungstenite,
2102            proxy_url: None,
2103        };
2104        WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
2105            .await
2106            .expect("Failed to connect")
2107    }
2108
2109    #[tokio::test]
2110    async fn test_websocket_basic() {
2111        let server = TestServer::setup().await;
2112        let client = setup_test_client(server.port).await;
2113
2114        assert!(!client.is_disconnected());
2115
2116        client.disconnect().await;
2117        assert!(client.is_disconnected());
2118    }
2119
2120    #[tokio::test]
2121    async fn test_websocket_heartbeat() {
2122        let server = TestServer::setup().await;
2123        let client = setup_test_client(server.port).await;
2124
2125        // Wait ~3s => server should see multiple "ping"
2126        tokio::time::sleep(std::time::Duration::from_secs(3)).await;
2127
2128        // Cleanup
2129        client.disconnect().await;
2130        assert!(client.is_disconnected());
2131    }
2132
2133    #[tokio::test]
2134    async fn test_websocket_reconnect_exhausted() {
2135        let config = WebSocketConfig {
2136            url: "ws://127.0.0.1:9997".into(), // <-- No server
2137            headers: vec![],
2138            heartbeat: None,
2139            heartbeat_msg: None,
2140            reconnect_timeout_ms: None,
2141            reconnect_delay_initial_ms: None,
2142            reconnect_backoff_factor: None,
2143            reconnect_delay_max_ms: None,
2144            reconnect_jitter_ms: None,
2145            reconnect_max_attempts: None,
2146            idle_timeout_ms: None,
2147            backend: TransportBackend::Tungstenite,
2148            proxy_url: None,
2149        };
2150        let res =
2151            WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
2152                .await;
2153        assert!(res.is_err(), "Should fail quickly with no server");
2154    }
2155
2156    #[tokio::test]
2157    async fn test_websocket_forced_close_reconnect() {
2158        let server = TestServer::setup().await;
2159        let client = setup_test_client(server.port).await;
2160
2161        // 1) Send normal message
2162        client.send_text("Hello".into(), None).await.unwrap();
2163
2164        // 2) Trigger forced close from server
2165        client.send_text("close-now".into(), None).await.unwrap();
2166
2167        // 3) Wait a bit => read loop sees close => reconnect
2168        tokio::time::sleep(std::time::Duration::from_secs(1)).await;
2169
2170        // Confirm not disconnected
2171        assert!(!client.is_disconnected());
2172
2173        // Cleanup
2174        client.disconnect().await;
2175        assert!(client.is_disconnected());
2176    }
2177
2178    #[tokio::test]
2179    async fn test_rate_limiter() {
2180        let server = TestServer::setup().await;
2181        let quota = Quota::per_second(NonZeroU32::new(2).unwrap()).unwrap();
2182
2183        let config = WebSocketConfig {
2184            url: format!("ws://127.0.0.1:{}", server.port),
2185            headers: vec![("test".into(), "test".into())],
2186            heartbeat: None,
2187            heartbeat_msg: None,
2188            reconnect_timeout_ms: None,
2189            reconnect_delay_initial_ms: None,
2190            reconnect_backoff_factor: None,
2191            reconnect_delay_max_ms: None,
2192            reconnect_jitter_ms: None,
2193            reconnect_max_attempts: None,
2194            idle_timeout_ms: None,
2195            backend: TransportBackend::Tungstenite,
2196            proxy_url: None,
2197        };
2198
2199        let client = WebSocketClient::connect(
2200            config,
2201            Some(Arc::new(|_| {})),
2202            None,
2203            None,
2204            vec![("default".into(), quota)],
2205            None,
2206        )
2207        .await
2208        .unwrap();
2209
2210        // First 2 should succeed
2211        client.send_text("test1".into(), None).await.unwrap();
2212        client.send_text("test2".into(), None).await.unwrap();
2213
2214        // Third should error
2215        client.send_text("test3".into(), None).await.unwrap();
2216
2217        // Cleanup
2218        client.disconnect().await;
2219        assert!(client.is_disconnected());
2220    }
2221
2222    #[tokio::test]
2223    async fn test_concurrent_writers() {
2224        let server = TestServer::setup().await;
2225        let client = Arc::new(setup_test_client(server.port).await);
2226
2227        let mut handles = vec![];
2228
2229        for i in 0..10 {
2230            let client = client.clone();
2231            handles.push(task::spawn(async move {
2232                client.send_text(format!("test{i}"), None).await.unwrap();
2233            }));
2234        }
2235
2236        for handle in handles {
2237            handle.await.unwrap();
2238        }
2239
2240        // Cleanup
2241        client.disconnect().await;
2242        assert!(client.is_disconnected());
2243    }
2244}
2245
2246#[cfg(test)]
2247#[cfg(not(feature = "turmoil"))]
2248#[cfg(not(all(feature = "simulation", madsim)))] // transport-layer I/O not simulated
2249mod rust_tests {
2250    use std::sync::{
2251        Arc, OnceLock,
2252        atomic::{AtomicBool, AtomicU8, Ordering},
2253    };
2254
2255    use futures_util::{SinkExt, StreamExt};
2256    use nautilus_common::testing::wait_until_async;
2257    use rstest::rstest;
2258    #[cfg(feature = "transport-sockudo")]
2259    use sockudo_ws::handshake as sockudo_handshake;
2260    #[cfg(feature = "transport-sockudo")]
2261    use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
2262    use tokio::{
2263        net::TcpListener,
2264        task::{self, JoinHandle},
2265        time::{Duration, sleep},
2266    };
2267    use tokio_tungstenite::{accept_async, tungstenite::Message as WsMessage};
2268    #[cfg(feature = "transport-sockudo")]
2269    use tokio_tungstenite::{
2270        accept_hdr_async,
2271        tungstenite::{
2272            handshake::server::{self, Callback},
2273            http::HeaderValue,
2274        },
2275    };
2276
2277    use super::*;
2278    use crate::websocket::types::channel_message_handler;
2279
2280    struct RecordingServer {
2281        task: JoinHandle<()>,
2282        port: u16,
2283        messages: Arc<tokio::sync::Mutex<Vec<String>>>,
2284    }
2285
2286    #[cfg(feature = "transport-sockudo")]
2287    async fn read_http_request<S>(stream: &mut S) -> Vec<u8>
2288    where
2289        S: AsyncRead + Unpin,
2290    {
2291        let mut buf = Vec::new();
2292        let mut chunk = [0u8; 256];
2293
2294        loop {
2295            let n = stream.read(&mut chunk).await.unwrap();
2296            assert!(n > 0, "HTTP request closed before headers completed");
2297            buf.extend_from_slice(&chunk[..n]);
2298            if buf.windows(4).any(|window| window == b"\r\n\r\n") {
2299                return buf;
2300            }
2301        }
2302    }
2303
2304    #[cfg(feature = "transport-sockudo")]
2305    fn extract_header<'a>(request: &'a str, name: &str) -> Option<&'a str> {
2306        request.lines().find_map(|line| {
2307            let (header_name, header_value) = line.split_once(':')?;
2308            if header_name.eq_ignore_ascii_case(name) {
2309                Some(header_value.trim())
2310            } else {
2311                None
2312            }
2313        })
2314    }
2315
2316    #[cfg(feature = "transport-sockudo")]
2317    #[derive(Debug, Clone)]
2318    struct HeaderAssertCallback {
2319        key: String,
2320        value: HeaderValue,
2321    }
2322
2323    #[cfg(feature = "transport-sockudo")]
2324    impl Callback for HeaderAssertCallback {
2325        #[expect(
2326            clippy::panic_in_result_fn,
2327            reason = "assertion failures should fail the test"
2328        )]
2329        fn on_request(
2330            self,
2331            request: &server::Request,
2332            response: server::Response,
2333        ) -> Result<server::Response, server::ErrorResponse> {
2334            assert_eq!(request.headers().get(&self.key), Some(&self.value));
2335            Ok(response)
2336        }
2337    }
2338
2339    impl RecordingServer {
2340        async fn setup() -> Self {
2341            let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2342            let port = listener.local_addr().unwrap().port();
2343            let messages = Arc::new(tokio::sync::Mutex::new(Vec::new()));
2344            let messages_clone = Arc::clone(&messages);
2345
2346            let task = task::spawn(async move {
2347                loop {
2348                    let (stream, _) = listener.accept().await.unwrap();
2349                    let mut websocket = accept_async(stream).await.unwrap();
2350                    let messages = Arc::clone(&messages_clone);
2351
2352                    task::spawn(async move {
2353                        while let Some(Ok(msg)) = websocket.next().await {
2354                            match msg {
2355                                WsMessage::Text(text) => {
2356                                    messages.lock().await.push(text.to_string());
2357                                }
2358                                WsMessage::Close(_) => {
2359                                    let _ = websocket.close(None).await;
2360                                    break;
2361                                }
2362                                _ => {}
2363                            }
2364                        }
2365                    });
2366                }
2367            });
2368
2369            Self {
2370                task,
2371                port,
2372                messages,
2373            }
2374        }
2375
2376        async fn messages(&self) -> Vec<String> {
2377            self.messages.lock().await.clone()
2378        }
2379    }
2380
2381    impl Drop for RecordingServer {
2382        fn drop(&mut self) {
2383            self.task.abort();
2384        }
2385    }
2386
2387    #[rstest]
2388    #[tokio::test]
2389    async fn test_reconnect_then_disconnect() {
2390        // Bind an ephemeral port
2391        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2392        let port = listener.local_addr().unwrap().port();
2393
2394        // Server task: accept one ws connection then close it
2395        let server = task::spawn(async move {
2396            let (stream, _) = listener.accept().await.unwrap();
2397            let ws = accept_async(stream).await.unwrap();
2398            drop(ws);
2399            // Keep alive briefly
2400            sleep(Duration::from_secs(1)).await;
2401        });
2402
2403        // Build a channel-based message handler for incoming messages (unused here)
2404        let (handler, _rx) = channel_message_handler();
2405
2406        // Configure client with short reconnect backoff
2407        let config = WebSocketConfig {
2408            url: format!("ws://127.0.0.1:{port}"),
2409            headers: vec![],
2410            heartbeat: None,
2411            heartbeat_msg: None,
2412            reconnect_timeout_ms: Some(1_000),
2413            reconnect_delay_initial_ms: Some(50),
2414            reconnect_delay_max_ms: Some(100),
2415            reconnect_backoff_factor: Some(1.0),
2416            reconnect_jitter_ms: Some(0),
2417            reconnect_max_attempts: None,
2418            idle_timeout_ms: None,
2419            backend: TransportBackend::Tungstenite,
2420            proxy_url: None,
2421        };
2422
2423        // Connect the client
2424        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2425            .await
2426            .unwrap();
2427
2428        // Allow server to drop connection and client to detect
2429        sleep(Duration::from_millis(100)).await;
2430        // Now immediately disconnect the client
2431        client.disconnect().await;
2432        assert!(client.is_disconnected());
2433        server.abort();
2434    }
2435
2436    #[rstest]
2437    #[tokio::test]
2438    async fn test_reconnect_state_flips_when_reader_stops() {
2439        // Bind an ephemeral port and accept a single websocket connection which we drop.
2440        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2441        let port = listener.local_addr().unwrap().port();
2442
2443        let server = task::spawn(async move {
2444            if let Ok((stream, _)) = listener.accept().await
2445                && let Ok(ws) = accept_async(stream).await
2446            {
2447                drop(ws);
2448            }
2449            sleep(Duration::from_millis(50)).await;
2450        });
2451
2452        let (handler, _rx) = channel_message_handler();
2453
2454        let config = WebSocketConfig {
2455            url: format!("ws://127.0.0.1:{port}"),
2456            headers: vec![],
2457            heartbeat: None,
2458            heartbeat_msg: None,
2459            reconnect_timeout_ms: Some(1_000),
2460            reconnect_delay_initial_ms: Some(50),
2461            reconnect_delay_max_ms: Some(100),
2462            reconnect_backoff_factor: Some(1.0),
2463            reconnect_jitter_ms: Some(0),
2464            reconnect_max_attempts: None,
2465            idle_timeout_ms: None,
2466            backend: TransportBackend::Tungstenite,
2467            proxy_url: None,
2468        };
2469
2470        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2471            .await
2472            .unwrap();
2473
2474        tokio::time::timeout(Duration::from_secs(2), async {
2475            loop {
2476                if client.is_reconnecting() {
2477                    break;
2478                }
2479                tokio::time::sleep(Duration::from_millis(10)).await;
2480            }
2481        })
2482        .await
2483        .expect("client did not enter RECONNECT state");
2484
2485        client.disconnect().await;
2486        server.abort();
2487    }
2488
2489    #[rstest]
2490    #[tokio::test]
2491    async fn test_stream_mode_disables_auto_reconnect() {
2492        // Test that stream-based clients (created via connect_stream) set is_stream_mode flag
2493        // and that reconnect() transitions to CLOSED state for stream mode
2494        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2495        let port = listener.local_addr().unwrap().port();
2496
2497        let server = task::spawn(async move {
2498            if let Ok((stream, _)) = listener.accept().await
2499                && let Ok(_ws) = accept_async(stream).await
2500            {
2501                // Keep connection alive briefly
2502                sleep(Duration::from_millis(100)).await;
2503            }
2504        });
2505
2506        let config = WebSocketConfig {
2507            url: format!("ws://127.0.0.1:{port}"),
2508            headers: vec![],
2509            heartbeat: None,
2510            heartbeat_msg: None,
2511            reconnect_timeout_ms: Some(1_000),
2512            reconnect_delay_initial_ms: Some(50),
2513            reconnect_delay_max_ms: Some(100),
2514            reconnect_backoff_factor: Some(1.0),
2515            reconnect_jitter_ms: Some(0),
2516            reconnect_max_attempts: None,
2517            idle_timeout_ms: None,
2518            backend: TransportBackend::Tungstenite,
2519            proxy_url: None,
2520        };
2521
2522        let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
2523            .await
2524            .unwrap();
2525
2526        // Note: We can't easily test the reconnect behavior from the outside since
2527        // the inner client is private. The key fix is that WebSocketClientInner
2528        // now has is_stream_mode=true for connect_stream, and reconnect() will
2529        // transition to CLOSED state instead of creating a new reader that gets dropped.
2530        // This is tested implicitly by the fact that stream users won't get stuck
2531        // in an infinite reconnect loop.
2532
2533        server.abort();
2534    }
2535
2536    #[rstest]
2537    #[tokio::test]
2538    async fn test_message_handler_mode_allows_auto_reconnect() {
2539        // Test that regular clients (with message handler) can auto-reconnect
2540        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2541        let port = listener.local_addr().unwrap().port();
2542
2543        let server = task::spawn(async move {
2544            // Accept first connection and close it
2545            if let Ok((stream, _)) = listener.accept().await
2546                && let Ok(ws) = accept_async(stream).await
2547            {
2548                drop(ws);
2549            }
2550            sleep(Duration::from_millis(50)).await;
2551        });
2552
2553        let (handler, _rx) = channel_message_handler();
2554
2555        let config = WebSocketConfig {
2556            url: format!("ws://127.0.0.1:{port}"),
2557            headers: vec![],
2558            heartbeat: None,
2559            heartbeat_msg: None,
2560            reconnect_timeout_ms: Some(1_000),
2561            reconnect_delay_initial_ms: Some(50),
2562            reconnect_delay_max_ms: Some(100),
2563            reconnect_backoff_factor: Some(1.0),
2564            reconnect_jitter_ms: Some(0),
2565            reconnect_max_attempts: None,
2566            idle_timeout_ms: None,
2567            backend: TransportBackend::Tungstenite,
2568            proxy_url: None,
2569        };
2570
2571        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2572            .await
2573            .unwrap();
2574
2575        // Wait for the connection to be dropped and reconnection to be attempted
2576        tokio::time::timeout(Duration::from_secs(2), async {
2577            loop {
2578                if client.is_reconnecting() || client.is_closed() {
2579                    break;
2580                }
2581                tokio::time::sleep(Duration::from_millis(10)).await;
2582            }
2583        })
2584        .await
2585        .expect("client should attempt reconnection or close");
2586
2587        // Should either be reconnecting or closed (depending on timing)
2588        // The important thing is it's not staying active forever
2589        assert!(
2590            client.is_reconnecting() || client.is_closed(),
2591            "Client with message handler should attempt reconnection"
2592        );
2593
2594        client.disconnect().await;
2595        server.abort();
2596    }
2597
2598    #[rstest]
2599    #[tokio::test]
2600    async fn test_handler_mode_reconnect_with_new_connection() {
2601        // Test that handler mode successfully reconnects and messages continue flowing
2602        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2603        let port = listener.local_addr().unwrap().port();
2604
2605        let server = task::spawn(async move {
2606            // First connection - accept and immediately close
2607            if let Ok((stream, _)) = listener.accept().await
2608                && let Ok(ws) = accept_async(stream).await
2609            {
2610                drop(ws);
2611            }
2612
2613            // Small delay to let client detect disconnection
2614            sleep(Duration::from_millis(100)).await;
2615
2616            // Second connection - accept, send a message, then keep alive
2617            if let Ok((stream, _)) = listener.accept().await
2618                && let Ok(mut ws) = accept_async(stream).await
2619            {
2620                use futures_util::SinkExt;
2621                let _ = ws
2622                    .send(WsMessage::Text("reconnected".to_string().into()))
2623                    .await;
2624                sleep(Duration::from_secs(1)).await;
2625            }
2626        });
2627
2628        let (handler, mut rx) = channel_message_handler();
2629
2630        let config = WebSocketConfig {
2631            url: format!("ws://127.0.0.1:{port}"),
2632            headers: vec![],
2633            heartbeat: None,
2634            heartbeat_msg: None,
2635            reconnect_timeout_ms: Some(2_000),
2636            reconnect_delay_initial_ms: Some(50),
2637            reconnect_delay_max_ms: Some(200),
2638            reconnect_backoff_factor: Some(1.5),
2639            reconnect_jitter_ms: Some(10),
2640            reconnect_max_attempts: None,
2641            idle_timeout_ms: None,
2642            backend: TransportBackend::Tungstenite,
2643            proxy_url: None,
2644        };
2645
2646        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2647            .await
2648            .unwrap();
2649
2650        // Wait for reconnection to happen and message to arrive
2651        let result = tokio::time::timeout(Duration::from_secs(5), async {
2652            loop {
2653                if let Ok(msg) = rx.try_recv()
2654                    && matches!(msg, WsMessage::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
2655                {
2656                    return true;
2657                }
2658                tokio::time::sleep(Duration::from_millis(10)).await;
2659            }
2660        })
2661        .await;
2662
2663        assert!(
2664            result.is_ok(),
2665            "Should receive message after reconnection within timeout"
2666        );
2667
2668        client.disconnect().await;
2669        server.abort();
2670    }
2671
2672    #[rstest]
2673    #[tokio::test]
2674    async fn test_stream_mode_no_auto_reconnect() {
2675        // Test that stream mode does not automatically reconnect when connection is lost
2676        // The caller owns the reader and is responsible for detecting disconnection
2677        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2678        let port = listener.local_addr().unwrap().port();
2679
2680        let server = task::spawn(async move {
2681            // Accept connection and send one message, then close
2682            if let Ok((stream, _)) = listener.accept().await
2683                && let Ok(mut ws) = accept_async(stream).await
2684            {
2685                use futures_util::SinkExt;
2686                let _ = ws.send(WsMessage::Text("hello".to_string().into())).await;
2687                sleep(Duration::from_millis(50)).await;
2688                // Connection closes when ws is dropped
2689            }
2690        });
2691
2692        let config = WebSocketConfig {
2693            url: format!("ws://127.0.0.1:{port}"),
2694            headers: vec![],
2695            heartbeat: None,
2696            heartbeat_msg: None,
2697            reconnect_timeout_ms: Some(1_000),
2698            reconnect_delay_initial_ms: Some(50),
2699            reconnect_delay_max_ms: Some(100),
2700            reconnect_backoff_factor: Some(1.0),
2701            reconnect_jitter_ms: Some(0),
2702            reconnect_max_attempts: None,
2703            idle_timeout_ms: None,
2704            backend: TransportBackend::Tungstenite,
2705            proxy_url: None,
2706        };
2707
2708        let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
2709            .await
2710            .unwrap();
2711
2712        // Initially active
2713        assert!(client.is_active(), "Client should start as active");
2714
2715        // Read the hello message
2716        let msg = reader.next().await;
2717        assert!(
2718            matches!(&msg, Some(Ok(Message::Text(bytes))) if bytes.as_ref() == b"hello"),
2719            "Should receive initial message"
2720        );
2721
2722        // Read until connection closes (reader will return None or error)
2723        while let Some(msg) = reader.next().await {
2724            if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
2725                break;
2726            }
2727        }
2728
2729        // Controller cannot detect reader EOF (reader is owned by caller),
2730        // so the client stays ACTIVE until the caller signals.
2731        sleep(Duration::from_millis(200)).await;
2732        assert!(
2733            client.is_active(),
2734            "Stream mode client stays ACTIVE before notify_closed()"
2735        );
2736
2737        // Caller signals EOF via notify_closed()
2738        client.notify_closed();
2739
2740        assert!(
2741            client.is_closed(),
2742            "Stream mode client should be CLOSED after notify_closed()"
2743        );
2744        assert!(
2745            !client.is_reconnecting(),
2746            "Stream mode client should never attempt reconnection"
2747        );
2748
2749        client.disconnect().await;
2750        server.abort();
2751    }
2752
2753    #[rstest]
2754    #[tokio::test]
2755    async fn test_send_timeout_uses_configured_reconnect_timeout() {
2756        // Test that send operations respect the configured reconnect_timeout.
2757        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
2758        use nautilus_common::testing::wait_until_async;
2759
2760        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2761        let port = listener.local_addr().unwrap().port();
2762
2763        let server = task::spawn(async move {
2764            // Accept first connection and immediately close it
2765            if let Ok((stream, _)) = listener.accept().await
2766                && let Ok(ws) = accept_async(stream).await
2767            {
2768                drop(ws);
2769            }
2770            // Don't accept second connection - client will be stuck in RECONNECT
2771            sleep(Duration::from_mins(1)).await;
2772        });
2773
2774        let (handler, _rx) = channel_message_handler();
2775
2776        // Configure with SHORT 2s reconnect timeout
2777        let config = WebSocketConfig {
2778            url: format!("ws://127.0.0.1:{port}"),
2779            headers: vec![],
2780            heartbeat: None,
2781            heartbeat_msg: None,
2782            reconnect_timeout_ms: Some(2_000), // 2s timeout
2783            reconnect_delay_initial_ms: Some(50),
2784            reconnect_delay_max_ms: Some(100),
2785            reconnect_backoff_factor: Some(1.0),
2786            reconnect_jitter_ms: Some(0),
2787            reconnect_max_attempts: None,
2788            idle_timeout_ms: None,
2789            backend: TransportBackend::Tungstenite,
2790            proxy_url: None,
2791        };
2792
2793        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2794            .await
2795            .unwrap();
2796
2797        // Wait for client to enter RECONNECT state
2798        wait_until_async(
2799            || async { client.is_reconnecting() },
2800            Duration::from_secs(3),
2801        )
2802        .await;
2803
2804        // Attempt send while stuck in RECONNECT - should timeout after 2s (configured timeout)
2805        let start = std::time::Instant::now();
2806        let send_result = client.send_text("test".to_string(), None).await;
2807        let elapsed = start.elapsed();
2808
2809        assert!(
2810            send_result.is_err(),
2811            "Send should fail when client stuck in RECONNECT"
2812        );
2813        assert!(
2814            matches!(send_result, Err(crate::error::SendError::Timeout)),
2815            "Send should return Timeout error, was: {send_result:?}"
2816        );
2817        // Verify timeout respects configured value (2s), but don't check upper bound
2818        // as CI scheduler jitter can cause legitimate delays beyond the timeout
2819        assert!(
2820            elapsed >= Duration::from_millis(1800),
2821            "Send should timeout after at least 2s (configured timeout), took {elapsed:?}"
2822        );
2823
2824        client.disconnect().await;
2825        server.abort();
2826    }
2827
2828    #[rstest]
2829    #[tokio::test]
2830    async fn test_send_waits_during_reconnection() {
2831        // Test that send operations wait for reconnection to complete (up to timeout)
2832        use nautilus_common::testing::wait_until_async;
2833
2834        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2835        let port = listener.local_addr().unwrap().port();
2836
2837        let server = task::spawn(async move {
2838            // First connection - accept and immediately close
2839            if let Ok((stream, _)) = listener.accept().await
2840                && let Ok(ws) = accept_async(stream).await
2841            {
2842                drop(ws);
2843            }
2844
2845            // Wait a bit before accepting second connection
2846            sleep(Duration::from_millis(500)).await;
2847
2848            // Second connection - accept and keep alive
2849            if let Ok((stream, _)) = listener.accept().await
2850                && let Ok(mut ws) = accept_async(stream).await
2851            {
2852                // Echo messages
2853                while let Some(Ok(msg)) = ws.next().await {
2854                    if ws.send(msg).await.is_err() {
2855                        break;
2856                    }
2857                }
2858            }
2859        });
2860
2861        let (handler, _rx) = channel_message_handler();
2862
2863        let config = WebSocketConfig {
2864            url: format!("ws://127.0.0.1:{port}"),
2865            headers: vec![],
2866            heartbeat: None,
2867            heartbeat_msg: None,
2868            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
2869            reconnect_delay_initial_ms: Some(100),
2870            reconnect_delay_max_ms: Some(200),
2871            reconnect_backoff_factor: Some(1.0),
2872            reconnect_jitter_ms: Some(0),
2873            reconnect_max_attempts: None,
2874            idle_timeout_ms: None,
2875            backend: TransportBackend::Tungstenite,
2876            proxy_url: None,
2877        };
2878
2879        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2880            .await
2881            .unwrap();
2882
2883        // Wait for reconnection to trigger
2884        wait_until_async(
2885            || async { client.is_reconnecting() },
2886            Duration::from_secs(2),
2887        )
2888        .await;
2889
2890        // Try to send while reconnecting - should wait and succeed after reconnect
2891        let send_result = tokio::time::timeout(
2892            Duration::from_secs(3),
2893            client.send_text("test_message".to_string(), None),
2894        )
2895        .await;
2896
2897        assert!(
2898            send_result.is_ok() && send_result.unwrap().is_ok(),
2899            "Send should succeed after waiting for reconnection"
2900        );
2901
2902        client.disconnect().await;
2903        server.abort();
2904    }
2905
2906    #[rstest]
2907    #[tokio::test]
2908    async fn test_rate_limiter_before_active_wait() {
2909        // Test that rate limiting happens BEFORE active state check.
2910        // This prevents race conditions where connection state changes during rate limit wait.
2911        // We verify this by: (1) exhausting rate limit, (2) ensuring client is RECONNECTING,
2912        // (3) sending again and confirming it waits for rate limit THEN reconnection.
2913        use std::{num::NonZeroU32, sync::Arc};
2914
2915        use nautilus_common::testing::wait_until_async;
2916
2917        use crate::ratelimiter::quota::Quota;
2918
2919        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2920        let port = listener.local_addr().unwrap().port();
2921
2922        let server = task::spawn(async move {
2923            // First connection - accept and close after receiving one message
2924            if let Ok((stream, _)) = listener.accept().await
2925                && let Ok(mut ws) = accept_async(stream).await
2926            {
2927                // Receive first message then close
2928                if let Some(Ok(_)) = ws.next().await {
2929                    drop(ws);
2930                }
2931            }
2932
2933            // Wait before accepting reconnection
2934            sleep(Duration::from_millis(500)).await;
2935
2936            // Second connection - accept and keep alive
2937            if let Ok((stream, _)) = listener.accept().await
2938                && let Ok(mut ws) = accept_async(stream).await
2939            {
2940                while let Some(Ok(msg)) = ws.next().await {
2941                    if ws.send(msg).await.is_err() {
2942                        break;
2943                    }
2944                }
2945            }
2946        });
2947
2948        let (handler, _rx) = channel_message_handler();
2949
2950        let config = WebSocketConfig {
2951            url: format!("ws://127.0.0.1:{port}"),
2952            headers: vec![],
2953            heartbeat: None,
2954            heartbeat_msg: None,
2955            reconnect_timeout_ms: Some(5_000),
2956            reconnect_delay_initial_ms: Some(50),
2957            reconnect_delay_max_ms: Some(100),
2958            reconnect_backoff_factor: Some(1.0),
2959            reconnect_jitter_ms: Some(0),
2960            reconnect_max_attempts: None,
2961            idle_timeout_ms: None,
2962            backend: TransportBackend::Tungstenite,
2963            proxy_url: None,
2964        };
2965
2966        // Very restrictive rate limit: 1 request per second, burst of 1
2967        let quota = Quota::per_second(NonZeroU32::new(1).unwrap())
2968            .unwrap()
2969            .allow_burst(NonZeroU32::new(1).unwrap());
2970
2971        let client = Arc::new(
2972            WebSocketClient::connect(
2973                config,
2974                Some(handler),
2975                None,
2976                None,
2977                vec![("test_key".to_string(), quota)],
2978                None,
2979            )
2980            .await
2981            .unwrap(),
2982        );
2983
2984        // First send exhausts burst capacity and triggers connection close
2985        let test_key: [Ustr; 1] = [Ustr::from("test_key")];
2986        client
2987            .send_text("msg1".to_string(), Some(test_key.as_slice()))
2988            .await
2989            .unwrap();
2990
2991        // Wait for client to enter RECONNECT state
2992        wait_until_async(
2993            || async { client.is_reconnecting() },
2994            Duration::from_secs(2),
2995        )
2996        .await;
2997
2998        // Second send: will hit rate limit (~1s) THEN wait for reconnection (~0.5s)
2999        let start = std::time::Instant::now();
3000        let send_result = client
3001            .send_text("msg2".to_string(), Some(test_key.as_slice()))
3002            .await;
3003        let elapsed = start.elapsed();
3004
3005        // Should succeed after both rate limit AND reconnection
3006        assert!(
3007            send_result.is_ok(),
3008            "Send should succeed after rate limit + reconnection, was: {send_result:?}"
3009        );
3010        // Total wait should be at least rate limit time (~1s)
3011        // The reconnection completes while rate limiting or after
3012        // Use 850ms threshold to account for timing jitter in CI
3013        assert!(
3014            elapsed >= Duration::from_millis(850),
3015            "Should wait for rate limit (~1s), waited {elapsed:?}"
3016        );
3017
3018        client.disconnect().await;
3019        server.abort();
3020    }
3021
3022    #[rstest]
3023    #[tokio::test]
3024    async fn test_disconnect_during_reconnect_exits_cleanly() {
3025        // Test CAS race condition: disconnect called during reconnection
3026        // Should exit cleanly without spawning new tasks
3027        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3028        let port = listener.local_addr().unwrap().port();
3029
3030        let server = task::spawn(async move {
3031            // Accept first connection and immediately close
3032            if let Ok((stream, _)) = listener.accept().await
3033                && let Ok(ws) = accept_async(stream).await
3034            {
3035                drop(ws);
3036            }
3037            // Don't accept second connection - let reconnect hang
3038            sleep(Duration::from_mins(1)).await;
3039        });
3040
3041        let (handler, _rx) = channel_message_handler();
3042
3043        let config = WebSocketConfig {
3044            url: format!("ws://127.0.0.1:{port}"),
3045            headers: vec![],
3046            heartbeat: None,
3047            heartbeat_msg: None,
3048            reconnect_timeout_ms: Some(2_000), // 2s timeout - shorter than disconnect timeout
3049            reconnect_delay_initial_ms: Some(100),
3050            reconnect_delay_max_ms: Some(200),
3051            reconnect_backoff_factor: Some(1.0),
3052            reconnect_jitter_ms: Some(0),
3053            reconnect_max_attempts: None,
3054            idle_timeout_ms: None,
3055            backend: TransportBackend::Tungstenite,
3056            proxy_url: None,
3057        };
3058
3059        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3060            .await
3061            .unwrap();
3062
3063        // Wait for reconnection to start
3064        tokio::time::timeout(Duration::from_secs(2), async {
3065            while !client.is_reconnecting() {
3066                sleep(Duration::from_millis(10)).await;
3067            }
3068        })
3069        .await
3070        .expect("Client should enter RECONNECT state");
3071
3072        // Disconnect while reconnecting
3073        client.disconnect().await;
3074
3075        // Should be cleanly closed
3076        assert!(
3077            client.is_disconnected(),
3078            "Client should be cleanly disconnected"
3079        );
3080
3081        server.abort();
3082    }
3083
3084    #[rstest]
3085    #[tokio::test]
3086    async fn test_send_fails_fast_when_closed_before_rate_limit() {
3087        // Test that send operations check connection state BEFORE rate limiting,
3088        // preventing unnecessary delays when the connection is already closed.
3089        use std::{num::NonZeroU32, sync::Arc};
3090
3091        use nautilus_common::testing::wait_until_async;
3092
3093        use crate::ratelimiter::quota::Quota;
3094
3095        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3096        let port = listener.local_addr().unwrap().port();
3097
3098        let server = task::spawn(async move {
3099            // Accept connection and immediately close
3100            if let Ok((stream, _)) = listener.accept().await
3101                && let Ok(ws) = accept_async(stream).await
3102            {
3103                drop(ws);
3104            }
3105            sleep(Duration::from_mins(1)).await;
3106        });
3107
3108        let (handler, _rx) = channel_message_handler();
3109
3110        let config = WebSocketConfig {
3111            url: format!("ws://127.0.0.1:{port}"),
3112            headers: vec![],
3113            heartbeat: None,
3114            heartbeat_msg: None,
3115            reconnect_timeout_ms: Some(5_000),
3116            reconnect_delay_initial_ms: Some(50),
3117            reconnect_delay_max_ms: Some(100),
3118            reconnect_backoff_factor: Some(1.0),
3119            reconnect_jitter_ms: Some(0),
3120            reconnect_max_attempts: None,
3121            idle_timeout_ms: None,
3122            backend: TransportBackend::Tungstenite,
3123            proxy_url: None,
3124        };
3125
3126        // Very restrictive rate limit: 1 request per 10 seconds
3127        // This ensures that if we wait for rate limit, the test will timeout
3128        let quota = Quota::with_period(Duration::from_secs(10))
3129            .unwrap()
3130            .allow_burst(NonZeroU32::new(1).unwrap());
3131
3132        let client = Arc::new(
3133            WebSocketClient::connect(
3134                config,
3135                Some(handler),
3136                None,
3137                None,
3138                vec![("test_key".to_string(), quota)],
3139                None,
3140            )
3141            .await
3142            .unwrap(),
3143        );
3144
3145        // Wait for disconnection
3146        wait_until_async(
3147            || async { client.is_reconnecting() || client.is_closed() },
3148            Duration::from_secs(2),
3149        )
3150        .await;
3151
3152        // Explicitly disconnect to move away from ACTIVE state
3153        client.disconnect().await;
3154        assert!(
3155            !client.is_active(),
3156            "Client should not be active after disconnect"
3157        );
3158
3159        // Attempt send - should fail IMMEDIATELY without waiting for rate limit
3160        let start = std::time::Instant::now();
3161        let test_key: [Ustr; 1] = [Ustr::from("test_key")];
3162        let result = client
3163            .send_text("test".to_string(), Some(test_key.as_slice()))
3164            .await;
3165        let elapsed = start.elapsed();
3166
3167        // Should fail with Closed error
3168        assert!(result.is_err(), "Send should fail when client is closed");
3169        assert!(
3170            matches!(result, Err(crate::error::SendError::Closed)),
3171            "Send should return Closed error, was: {result:?}"
3172        );
3173
3174        // Should fail FAST (< 100ms) without waiting for rate limit (10s)
3175        assert!(
3176            elapsed < Duration::from_millis(100),
3177            "Send should fail fast without rate limiting, took {elapsed:?}"
3178        );
3179
3180        server.abort();
3181    }
3182
3183    #[rstest]
3184    #[tokio::test]
3185    async fn test_connect_rejects_none_message_handler() {
3186        // Test that connect() properly rejects None message_handler
3187        // to prevent zombie connections that appear alive but never detect disconnections
3188
3189        let config = WebSocketConfig {
3190            url: "ws://127.0.0.1:9999".to_string(),
3191            headers: vec![],
3192            heartbeat: None,
3193            heartbeat_msg: None,
3194            reconnect_timeout_ms: Some(1_000),
3195            reconnect_delay_initial_ms: Some(100),
3196            reconnect_delay_max_ms: Some(500),
3197            reconnect_backoff_factor: Some(1.5),
3198            reconnect_jitter_ms: Some(0),
3199            reconnect_max_attempts: None,
3200            idle_timeout_ms: None,
3201            backend: TransportBackend::Tungstenite,
3202            proxy_url: None,
3203        };
3204
3205        // Pass None for message_handler - should be rejected
3206        let result = WebSocketClient::connect(config, None, None, None, vec![], None).await;
3207
3208        assert!(
3209            result.is_err(),
3210            "connect() should reject None message_handler"
3211        );
3212
3213        let err = result.unwrap_err();
3214        let err_msg = err.to_string();
3215        assert!(
3216            err_msg.contains("Handler mode requires message_handler"),
3217            "Error should mention missing message_handler, was: {err_msg}"
3218        );
3219    }
3220
3221    #[rstest]
3222    #[tokio::test]
3223    async fn test_client_without_handler_sets_stream_mode() {
3224        // Test that if a client is created without a handler via connect_url,
3225        // it properly sets is_stream_mode=true to prevent zombie connections
3226
3227        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3228        let port = listener.local_addr().unwrap().port();
3229
3230        let server = task::spawn(async move {
3231            // Accept and immediately close to simulate server disconnect
3232            if let Ok((stream, _)) = listener.accept().await
3233                && let Ok(ws) = accept_async(stream).await
3234            {
3235                drop(ws); // Drop connection immediately
3236            }
3237        });
3238
3239        let config = WebSocketConfig {
3240            url: format!("ws://127.0.0.1:{port}"),
3241            headers: vec![],
3242            heartbeat: None,
3243            heartbeat_msg: None,
3244            reconnect_timeout_ms: Some(1_000),
3245            reconnect_delay_initial_ms: Some(100),
3246            reconnect_delay_max_ms: Some(500),
3247            reconnect_backoff_factor: Some(1.5),
3248            reconnect_jitter_ms: Some(0),
3249            reconnect_max_attempts: None,
3250            idle_timeout_ms: None,
3251            backend: TransportBackend::Tungstenite,
3252            proxy_url: None,
3253        };
3254
3255        // Create client directly via connect_url with no handler (stream mode)
3256        let inner = WebSocketClientInner::connect_url(config, None, None)
3257            .await
3258            .unwrap();
3259
3260        // Verify is_stream_mode is true when no handler
3261        assert!(
3262            inner.is_stream_mode,
3263            "Client without handler should have is_stream_mode=true"
3264        );
3265
3266        // Verify that when stream mode is enabled, reconnection is disabled
3267        // (documented behavior - stream mode clients close instead of reconnecting)
3268
3269        server.abort();
3270    }
3271
3272    #[rstest]
3273    #[tokio::test]
3274    async fn test_idle_timeout_triggers_reconnect() {
3275        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3276        let port = listener.local_addr().unwrap().port();
3277
3278        // Server accepts WS connection but sends nothing (simulates silent death)
3279        let server = task::spawn(async move {
3280            let (stream, _) = listener.accept().await.unwrap();
3281            let _ws = accept_async(stream).await.unwrap();
3282            // Hold connection open but send nothing
3283            sleep(Duration::from_secs(5)).await;
3284        });
3285
3286        let (handler, _rx) = channel_message_handler();
3287
3288        let config = WebSocketConfig {
3289            url: format!("ws://127.0.0.1:{port}"),
3290            headers: vec![],
3291            heartbeat: None,
3292            heartbeat_msg: None,
3293            reconnect_timeout_ms: Some(2_000),
3294            reconnect_delay_initial_ms: Some(50),
3295            reconnect_delay_max_ms: Some(100),
3296            reconnect_backoff_factor: Some(1.0),
3297            reconnect_jitter_ms: Some(0),
3298            reconnect_max_attempts: Some(1),
3299            idle_timeout_ms: Some(500),
3300            backend: TransportBackend::Tungstenite,
3301            proxy_url: None,
3302        };
3303
3304        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3305            .await
3306            .unwrap();
3307
3308        assert!(client.is_active());
3309
3310        // Wait for idle timeout to fire and client to enter reconnect/closed
3311        wait_until_async(
3312            || async { client.is_reconnecting() || client.is_disconnected() },
3313            Duration::from_secs(3),
3314        )
3315        .await;
3316
3317        assert!(
3318            !client.is_active(),
3319            "Client should not be active after idle timeout"
3320        );
3321
3322        client.disconnect().await;
3323        server.abort();
3324    }
3325
3326    #[rstest]
3327    #[tokio::test]
3328    async fn test_idle_timeout_resets_on_data() {
3329        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3330        let port = listener.local_addr().unwrap().port();
3331
3332        // Server sends a message every 200ms (well within 1s idle timeout)
3333        let server = task::spawn(async move {
3334            let (stream, _) = listener.accept().await.unwrap();
3335            let mut ws = accept_async(stream).await.unwrap();
3336
3337            for _ in 0..10 {
3338                sleep(Duration::from_millis(200)).await;
3339
3340                if ws.send(WsMessage::Text("ping".into())).await.is_err() {
3341                    break;
3342                }
3343            }
3344        });
3345
3346        let (handler, _rx) = channel_message_handler();
3347
3348        let config = WebSocketConfig {
3349            url: format!("ws://127.0.0.1:{port}"),
3350            headers: vec![],
3351            heartbeat: None,
3352            heartbeat_msg: None,
3353            reconnect_timeout_ms: Some(2_000),
3354            reconnect_delay_initial_ms: Some(50),
3355            reconnect_delay_max_ms: Some(100),
3356            reconnect_backoff_factor: Some(1.0),
3357            reconnect_jitter_ms: Some(0),
3358            reconnect_max_attempts: Some(1),
3359            idle_timeout_ms: Some(1_000),
3360            backend: TransportBackend::Tungstenite,
3361            proxy_url: None,
3362        };
3363
3364        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3365            .await
3366            .unwrap();
3367
3368        assert!(client.is_active());
3369
3370        // Wait 1.5s - data arrives every 200ms so idle timeout (1s) should NOT fire
3371        sleep(Duration::from_millis(1_500)).await;
3372
3373        assert!(
3374            client.is_active(),
3375            "Client should remain active when data is flowing"
3376        );
3377
3378        client.disconnect().await;
3379        server.abort();
3380    }
3381
3382    #[rstest]
3383    #[tokio::test]
3384    async fn test_idle_timeout_fires_when_only_pings_received() {
3385        // Regression: pings and pongs are keep-alive frames, not application data,
3386        // so a peer that only emits control frames must still trip the idle timeout.
3387        // The peer keeps pinging for well past the observation window so the
3388        // pre-fix behavior (reset-on-ping) would keep the client active; under the
3389        // fix the idle timer never resets and fires after ~500ms.
3390        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3391        let port = listener.local_addr().unwrap().port();
3392
3393        let server = task::spawn(async move {
3394            let (stream, _) = listener.accept().await.unwrap();
3395            let mut ws = accept_async(stream).await.unwrap();
3396
3397            for _ in 0..60 {
3398                sleep(Duration::from_millis(100)).await;
3399
3400                if ws.send(WsMessage::Ping(Vec::new().into())).await.is_err() {
3401                    break;
3402                }
3403            }
3404        });
3405
3406        let (handler, _rx) = channel_message_handler();
3407
3408        let config = WebSocketConfig {
3409            url: format!("ws://127.0.0.1:{port}"),
3410            headers: vec![],
3411            heartbeat: None,
3412            heartbeat_msg: None,
3413            reconnect_timeout_ms: Some(2_000),
3414            reconnect_delay_initial_ms: Some(50),
3415            reconnect_delay_max_ms: Some(100),
3416            reconnect_backoff_factor: Some(1.0),
3417            reconnect_jitter_ms: Some(0),
3418            reconnect_max_attempts: Some(1),
3419            idle_timeout_ms: Some(500),
3420            backend: TransportBackend::Tungstenite,
3421            proxy_url: None,
3422        };
3423
3424        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3425            .await
3426            .unwrap();
3427
3428        assert!(client.is_active());
3429
3430        // Observation window is shorter than the ping stream (6s). If the idle
3431        // timer mistakenly reset on every ping the client would still be active
3432        // here; under the fix it goes inactive at ~500ms.
3433        wait_until_async(
3434            || async { client.is_reconnecting() || client.is_disconnected() },
3435            Duration::from_millis(1_500),
3436        )
3437        .await;
3438
3439        assert!(
3440            !client.is_active(),
3441            "Client should not be active after idle timeout when only pings/pongs flow"
3442        );
3443
3444        client.disconnect().await;
3445        server.abort();
3446    }
3447
3448    #[rstest]
3449    #[tokio::test]
3450    async fn test_idle_timeout_fires_when_only_pongs_received() {
3451        // Regression for the heartbeat-reply path. When the client heartbeat is
3452        // enabled, the peer auto-replies with pongs for every outgoing ping. If
3453        // those pongs refreshed last_data_time the idle timer would never fire on
3454        // a zombie connection (the motivating Polymarket scenario).
3455        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3456        let port = listener.local_addr().unwrap().port();
3457
3458        let server = task::spawn(async move {
3459            let (stream, _) = listener.accept().await.unwrap();
3460            let mut ws = accept_async(stream).await.unwrap();
3461
3462            // Drain incoming frames so tungstenite's internal pong replies are
3463            // actually flushed to the client. Hold the connection open well past
3464            // the observation window.
3465            let deadline = tokio::time::Instant::now() + Duration::from_secs(6);
3466            while tokio::time::Instant::now() < deadline {
3467                if let Ok(Some(Err(_)) | None) =
3468                    tokio::time::timeout(Duration::from_millis(100), ws.next()).await
3469                {
3470                    break;
3471                }
3472            }
3473        });
3474
3475        let (handler, _rx) = channel_message_handler();
3476
3477        let config = WebSocketConfig {
3478            url: format!("ws://127.0.0.1:{port}"),
3479            headers: vec![],
3480            heartbeat: Some(1),
3481            heartbeat_msg: None,
3482            reconnect_timeout_ms: Some(2_000),
3483            reconnect_delay_initial_ms: Some(50),
3484            reconnect_delay_max_ms: Some(100),
3485            reconnect_backoff_factor: Some(1.0),
3486            reconnect_jitter_ms: Some(0),
3487            reconnect_max_attempts: Some(1),
3488            idle_timeout_ms: Some(1_500),
3489            backend: TransportBackend::Tungstenite,
3490            proxy_url: None,
3491        };
3492
3493        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3494            .await
3495            .unwrap();
3496
3497        assert!(client.is_active());
3498
3499        // Heartbeat cadence is 1s; each ping draws a pong reply. Under the fix
3500        // the idle timer ignores those pongs and fires at ~1.5s. Under the bug
3501        // every pong reset the timer and the client would stay active.
3502        wait_until_async(
3503            || async { client.is_reconnecting() || client.is_disconnected() },
3504            Duration::from_millis(2_500),
3505        )
3506        .await;
3507
3508        assert!(
3509            !client.is_active(),
3510            "Client should not be active after idle timeout when only pongs flow"
3511        );
3512
3513        client.disconnect().await;
3514        server.abort();
3515    }
3516
3517    #[rstest]
3518    #[tokio::test]
3519    async fn test_disconnect_during_backoff_exits_promptly() {
3520        // Verify that disconnect interrupts backoff sleep (Finding 1).
3521        // Server accepts then drops, no second listener -> reconnect fails -> enters backoff.
3522        // We disconnect while backing off and assert the client shuts down quickly.
3523        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3524        let port = listener.local_addr().unwrap().port();
3525
3526        let server = task::spawn(async move {
3527            // Accept first connection, close immediately
3528            if let Ok((stream, _)) = listener.accept().await {
3529                let _ = accept_async(stream).await;
3530            }
3531            // Don't accept again so reconnect fails and enters backoff
3532            sleep(Duration::from_mins(1)).await;
3533        });
3534
3535        let (handler, _rx) = channel_message_handler();
3536
3537        let config = WebSocketConfig {
3538            url: format!("ws://127.0.0.1:{port}"),
3539            headers: vec![],
3540            heartbeat: None,
3541            heartbeat_msg: None,
3542            reconnect_timeout_ms: Some(1_000),
3543            reconnect_delay_initial_ms: Some(10_000), // 10s backoff to ensure we're sleeping
3544            reconnect_delay_max_ms: Some(10_000),
3545            reconnect_backoff_factor: Some(1.0),
3546            reconnect_jitter_ms: Some(0),
3547            reconnect_max_attempts: None,
3548            idle_timeout_ms: None,
3549            backend: TransportBackend::Tungstenite,
3550            proxy_url: None,
3551        };
3552
3553        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3554            .await
3555            .unwrap();
3556
3557        // Wait for client to enter reconnect
3558        wait_until_async(
3559            || async { client.is_reconnecting() },
3560            Duration::from_secs(3),
3561        )
3562        .await;
3563
3564        // Wait a bit more for the reconnect attempt to fail and enter backoff sleep
3565        sleep(Duration::from_millis(1_500)).await;
3566
3567        // Disconnect while backing off
3568        let start = std::time::Instant::now();
3569        client.disconnect().await;
3570        let elapsed = start.elapsed();
3571
3572        assert!(client.is_disconnected(), "Client should be disconnected");
3573        // Should exit well before the 10s backoff sleep completes
3574        assert!(
3575            elapsed < Duration::from_secs(2),
3576            "Disconnect should interrupt backoff sleep, took {elapsed:?}"
3577        );
3578
3579        server.abort();
3580    }
3581
3582    #[rstest]
3583    #[tokio::test]
3584    async fn test_rate_limit_cancelled_on_disconnect() {
3585        // Verify that a send blocked on rate limiting returns Closed when
3586        // the client disconnects (Finding 6).
3587        use std::{num::NonZeroU32, sync::Arc};
3588
3589        use crate::ratelimiter::quota::Quota;
3590
3591        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3592        let port = listener.local_addr().unwrap().port();
3593
3594        let server = task::spawn(async move {
3595            if let Ok((stream, _)) = listener.accept().await {
3596                let mut ws = accept_async(stream).await.unwrap();
3597                // Keep alive and echo
3598                while let Some(Ok(msg)) = ws.next().await {
3599                    if ws.send(msg).await.is_err() {
3600                        break;
3601                    }
3602                }
3603            }
3604        });
3605
3606        let (handler, _rx) = channel_message_handler();
3607
3608        let config = WebSocketConfig {
3609            url: format!("ws://127.0.0.1:{port}"),
3610            headers: vec![],
3611            heartbeat: None,
3612            heartbeat_msg: None,
3613            reconnect_timeout_ms: Some(5_000),
3614            reconnect_delay_initial_ms: Some(100),
3615            reconnect_delay_max_ms: Some(500),
3616            reconnect_backoff_factor: Some(1.5),
3617            reconnect_jitter_ms: Some(0),
3618            reconnect_max_attempts: None,
3619            idle_timeout_ms: None,
3620            backend: TransportBackend::Tungstenite,
3621            proxy_url: None,
3622        };
3623
3624        // Very restrictive: 1 req per 60 seconds
3625        let quota = Quota::with_period(Duration::from_mins(1))
3626            .unwrap()
3627            .allow_burst(NonZeroU32::new(1).unwrap());
3628
3629        let client = Arc::new(
3630            WebSocketClient::connect(
3631                config,
3632                Some(handler),
3633                None,
3634                None,
3635                vec![("rate_key".to_string(), quota)],
3636                None,
3637            )
3638            .await
3639            .unwrap(),
3640        );
3641
3642        let test_key: [Ustr; 1] = [Ustr::from("rate_key")];
3643
3644        // Exhaust the burst quota
3645        client
3646            .send_text("exhaust".to_string(), Some(test_key.as_slice()))
3647            .await
3648            .unwrap();
3649
3650        // Spawn a send that will block on rate limiter
3651        let client_clone = client.clone();
3652        let send_handle = task::spawn(async move {
3653            client_clone
3654                .send_text("blocked".to_string(), Some(&[Ustr::from("rate_key")]))
3655                .await
3656        });
3657
3658        // Let the send block on rate limit
3659        sleep(Duration::from_millis(200)).await;
3660
3661        // Disconnect while send is blocked
3662        let start = std::time::Instant::now();
3663        client.disconnect().await;
3664        let elapsed_disconnect = start.elapsed();
3665
3666        // The blocked send should return Closed
3667        let result = tokio::time::timeout(Duration::from_secs(2), send_handle)
3668            .await
3669            .expect("Send task should complete quickly")
3670            .expect("Send task should not panic");
3671
3672        assert!(
3673            matches!(result, Err(crate::error::SendError::Closed)),
3674            "Blocked send should return Closed, was: {result:?}"
3675        );
3676
3677        // Disconnect should be fast, not waiting for the 60s rate limit
3678        assert!(
3679            elapsed_disconnect < Duration::from_secs(3),
3680            "Disconnect should not wait for rate limiter, took {elapsed_disconnect:?}"
3681        );
3682
3683        server.abort();
3684    }
3685
3686    #[rstest]
3687    #[tokio::test]
3688    async fn test_stream_mode_transitions_to_closed_on_dead_write_task() {
3689        // Verify that stream mode transitions to CLOSED (not RECONNECT) when
3690        // the write task dies (Finding 4). We force write failure by sending
3691        // after the server closes the connection.
3692        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3693        let port = listener.local_addr().unwrap().port();
3694
3695        let server = task::spawn(async move {
3696            if let Ok((stream, _)) = listener.accept().await
3697                && let Ok(ws) = accept_async(stream).await
3698            {
3699                // Close immediately to cause write errors
3700                drop(ws);
3701            }
3702        });
3703
3704        let config = WebSocketConfig {
3705            url: format!("ws://127.0.0.1:{port}"),
3706            headers: vec![],
3707            heartbeat: None,
3708            heartbeat_msg: None,
3709            reconnect_timeout_ms: Some(1_000),
3710            reconnect_delay_initial_ms: Some(50),
3711            reconnect_delay_max_ms: Some(100),
3712            reconnect_backoff_factor: Some(1.0),
3713            reconnect_jitter_ms: Some(0),
3714            reconnect_max_attempts: None,
3715            idle_timeout_ms: None,
3716            backend: TransportBackend::Tungstenite,
3717            proxy_url: None,
3718        };
3719
3720        let (_reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
3721            .await
3722            .unwrap();
3723
3724        assert!(client.is_active(), "Client should start active");
3725
3726        // Wait for server to close, then send to trigger write task failure
3727        sleep(Duration::from_millis(100)).await;
3728
3729        // Keep sending until the write task detects the broken connection
3730        for _ in 0..20 {
3731            let _ = client.send_text("ping".to_string(), None).await;
3732            sleep(Duration::from_millis(50)).await;
3733
3734            if !client.is_active() {
3735                break;
3736            }
3737        }
3738
3739        // Wait for controller to process the state change
3740        wait_until_async(|| async { !client.is_active() }, Duration::from_secs(5)).await;
3741
3742        // Stream mode should go to CLOSED, not RECONNECT
3743        assert!(
3744            client.is_closed() || client.is_disconnected(),
3745            "Stream mode should transition to CLOSED, not RECONNECT. \
3746             is_reconnecting={}, is_closed={}, is_disconnected={}",
3747            client.is_reconnecting(),
3748            client.is_closed(),
3749            client.is_disconnected(),
3750        );
3751        assert!(
3752            !client.is_reconnecting(),
3753            "Stream mode should never attempt reconnection"
3754        );
3755
3756        server.abort();
3757    }
3758
3759    #[tokio::test]
3760    async fn test_write_task_waits_for_auth_before_replaying_buffer() {
3761        use nautilus_common::testing::wait_until_async;
3762
3763        let server = RecordingServer::setup().await;
3764        let url = format!("ws://127.0.0.1:{}", server.port);
3765        let (writer, _reader) = WebSocketClientInner::connect_with_server(
3766            &url,
3767            vec![],
3768            TransportBackend::Tungstenite,
3769            None,
3770        )
3771        .await
3772        .unwrap();
3773
3774        let connection_state = Arc::new(AtomicU8::new(ConnectionMode::Reconnect.as_u8()));
3775        let state_notify = Arc::new(tokio::sync::Notify::new());
3776        let auth_tracker = Arc::new(OnceLock::new());
3777        let reconnect_buffer_waits_for_auth = Arc::new(AtomicBool::new(true));
3778        let tracker = AuthTracker::new();
3779        auth_tracker.set(tracker.clone()).unwrap();
3780
3781        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel();
3782        let write_task = WebSocketClientInner::spawn_write_task(
3783            Arc::clone(&connection_state),
3784            Arc::clone(&state_notify),
3785            writer,
3786            writer_rx,
3787            Arc::clone(&auth_tracker),
3788            Arc::clone(&reconnect_buffer_waits_for_auth),
3789        );
3790
3791        writer_tx
3792            .send(WriterCommand::Send(Message::Text("stale".into())))
3793            .unwrap();
3794
3795        let (new_writer, _reader) = WebSocketClientInner::connect_with_server(
3796            &url,
3797            vec![],
3798            TransportBackend::Tungstenite,
3799            None,
3800        )
3801        .await
3802        .unwrap();
3803        let (tx, rx) = tokio::sync::oneshot::channel();
3804        writer_tx
3805            .send(WriterCommand::Update(new_writer, tx))
3806            .unwrap();
3807        assert!(rx.await.unwrap());
3808
3809        connection_state.store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
3810
3811        tokio::time::sleep(Duration::from_millis(300)).await;
3812        assert!(
3813            server.messages().await.is_empty(),
3814            "buffered messages should wait for re-authentication"
3815        );
3816
3817        tracker.succeed();
3818
3819        wait_until_async(
3820            || {
3821                let messages = Arc::clone(&server.messages);
3822                async move { !messages.lock().await.is_empty() }
3823            },
3824            Duration::from_secs(3),
3825        )
3826        .await;
3827
3828        assert_eq!(server.messages().await, vec!["stale".to_string()]);
3829
3830        connection_state.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
3831        state_notify.notify_waiters();
3832        drop(writer_tx);
3833        write_task.abort();
3834    }
3835
3836    #[tokio::test]
3837    async fn test_write_task_discards_buffer_after_auth_failure() {
3838        let server = RecordingServer::setup().await;
3839        let url = format!("ws://127.0.0.1:{}", server.port);
3840        let (writer, _reader) = WebSocketClientInner::connect_with_server(
3841            &url,
3842            vec![],
3843            TransportBackend::Tungstenite,
3844            None,
3845        )
3846        .await
3847        .unwrap();
3848
3849        let connection_state = Arc::new(AtomicU8::new(ConnectionMode::Reconnect.as_u8()));
3850        let state_notify = Arc::new(tokio::sync::Notify::new());
3851        let auth_tracker = Arc::new(OnceLock::new());
3852        let reconnect_buffer_waits_for_auth = Arc::new(AtomicBool::new(true));
3853        let tracker = AuthTracker::new();
3854        auth_tracker.set(tracker.clone()).unwrap();
3855
3856        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel();
3857        let write_task = WebSocketClientInner::spawn_write_task(
3858            Arc::clone(&connection_state),
3859            Arc::clone(&state_notify),
3860            writer,
3861            writer_rx,
3862            Arc::clone(&auth_tracker),
3863            Arc::clone(&reconnect_buffer_waits_for_auth),
3864        );
3865
3866        writer_tx
3867            .send(WriterCommand::Send(Message::Text("stale".into())))
3868            .unwrap();
3869
3870        let (new_writer, _reader) = WebSocketClientInner::connect_with_server(
3871            &url,
3872            vec![],
3873            TransportBackend::Tungstenite,
3874            None,
3875        )
3876        .await
3877        .unwrap();
3878        let (tx, rx) = tokio::sync::oneshot::channel();
3879        writer_tx
3880            .send(WriterCommand::Update(new_writer, tx))
3881            .unwrap();
3882        assert!(rx.await.unwrap());
3883
3884        connection_state.store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
3885        tracker.fail("rejected");
3886        tokio::time::sleep(Duration::from_millis(300)).await;
3887        assert!(
3888            server.messages().await.is_empty(),
3889            "buffered messages should be discarded after authentication failure"
3890        );
3891
3892        let _auth_receiver = tracker.begin();
3893        tracker.succeed();
3894        tokio::time::sleep(Duration::from_millis(300)).await;
3895        assert!(
3896            server.messages().await.is_empty(),
3897            "discarded buffered messages should not replay on a later auth success"
3898        );
3899
3900        connection_state.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
3901        state_notify.notify_waiters();
3902        drop(writer_tx);
3903        write_task.abort();
3904    }
3905
3906    #[rstest]
3907    #[tokio::test]
3908    async fn test_zero_idle_timeout_rejected() {
3909        let (handler, _rx) = channel_message_handler();
3910
3911        let config = WebSocketConfig {
3912            url: "ws://127.0.0.1:9999".to_string(),
3913            headers: vec![],
3914            heartbeat: None,
3915            heartbeat_msg: None,
3916            reconnect_timeout_ms: None,
3917            reconnect_delay_initial_ms: None,
3918            reconnect_delay_max_ms: None,
3919            reconnect_backoff_factor: None,
3920            reconnect_jitter_ms: None,
3921            reconnect_max_attempts: None,
3922            idle_timeout_ms: Some(0),
3923            backend: TransportBackend::Tungstenite,
3924            proxy_url: None,
3925        };
3926
3927        let result =
3928            WebSocketClient::connect(config, Some(handler), None, None, vec![], None).await;
3929
3930        assert!(result.is_err(), "Zero idle timeout should be rejected");
3931        let err_msg = result.unwrap_err().to_string();
3932        assert!(
3933            err_msg.contains("Idle timeout cannot be zero"),
3934            "Error should mention zero idle timeout, was: {err_msg}"
3935        );
3936    }
3937
3938    #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
3939    #[rstest]
3940    #[tokio::test]
3941    async fn test_sockudo_backend_rejects_reserved_headers_before_connect() {
3942        let (handler, _rx) = channel_message_handler();
3943
3944        let config = WebSocketConfig {
3945            url: "ws://127.0.0.1:1".to_string(),
3946            headers: vec![("Host".to_string(), "example.com".to_string())],
3947            heartbeat: None,
3948            heartbeat_msg: None,
3949            reconnect_timeout_ms: None,
3950            reconnect_delay_initial_ms: None,
3951            reconnect_delay_max_ms: None,
3952            reconnect_backoff_factor: None,
3953            reconnect_jitter_ms: None,
3954            reconnect_max_attempts: None,
3955            idle_timeout_ms: None,
3956            backend: TransportBackend::Sockudo,
3957            proxy_url: None,
3958        };
3959
3960        let err = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3961            .await
3962            .expect_err("reserved header should fail before TCP connect");
3963
3964        assert!(
3965            err.to_string()
3966                .contains("reserved upgrade header not allowed in extra_headers"),
3967            "expected reserved-header failure, was: {err}"
3968        );
3969    }
3970
3971    #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
3972    #[rstest]
3973    #[tokio::test]
3974    async fn test_sockudo_backend_replays_leftover_without_custom_headers() {
3975        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3976        let port = listener.local_addr().unwrap().port();
3977
3978        let server = task::spawn(async move {
3979            if let Ok((mut stream, _)) = listener.accept().await {
3980                let request = read_http_request(&mut stream).await;
3981                let request = String::from_utf8(request).unwrap();
3982                let sec_websocket_key = extract_header(&request, "Sec-WebSocket-Key").unwrap();
3983                let accept = sockudo_handshake::generate_accept_key(sec_websocket_key);
3984                let mut response = format!(
3985                    concat!(
3986                        "HTTP/1.1 101 Switching Protocols\r\n",
3987                        "Upgrade: websocket\r\n",
3988                        "Connection: Upgrade\r\n",
3989                        "Sec-WebSocket-Accept: {}\r\n",
3990                        "\r\n",
3991                    ),
3992                    accept
3993                )
3994                .into_bytes();
3995                response.extend_from_slice(b"\x81\x05hello");
3996                stream.write_all(&response).await.unwrap();
3997            }
3998        });
3999
4000        let (handler, mut rx) = channel_message_handler();
4001
4002        let config = WebSocketConfig {
4003            url: format!("ws://127.0.0.1:{port}/ws"),
4004            headers: vec![],
4005            heartbeat: None,
4006            heartbeat_msg: None,
4007            reconnect_timeout_ms: Some(2_000),
4008            reconnect_delay_initial_ms: Some(50),
4009            reconnect_delay_max_ms: Some(100),
4010            reconnect_backoff_factor: Some(1.0),
4011            reconnect_jitter_ms: Some(0),
4012            reconnect_max_attempts: None,
4013            idle_timeout_ms: None,
4014            backend: TransportBackend::Sockudo,
4015            proxy_url: None,
4016        };
4017
4018        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
4019            .await
4020            .expect("sockudo connect without custom headers");
4021
4022        let received = tokio::time::timeout(Duration::from_secs(3), async {
4023            loop {
4024                if let Ok(msg) = rx.try_recv() {
4025                    return msg;
4026                }
4027                tokio::time::sleep(Duration::from_millis(10)).await;
4028            }
4029        })
4030        .await
4031        .expect("did not receive leftover frame before timeout");
4032
4033        match received {
4034            WsMessage::Text(t) => assert_eq!(t.as_str(), "hello"),
4035            other => panic!("expected text, was {other:?}"),
4036        }
4037
4038        client.disconnect().await;
4039        tokio::time::timeout(Duration::from_secs(3), server)
4040            .await
4041            .expect("server did not close before timeout")
4042            .unwrap();
4043    }
4044
4045    #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4046    #[rstest]
4047    #[tokio::test]
4048    async fn test_sockudo_backend_sends_custom_headers() {
4049        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4050        let port = listener.local_addr().unwrap().port();
4051
4052        let server = task::spawn(async move {
4053            if let Ok((stream, _)) = listener.accept().await {
4054                let callback = HeaderAssertCallback {
4055                    key: "X-Test".to_string(),
4056                    value: HeaderValue::from_static("value"),
4057                };
4058
4059                if let Ok(mut ws) = accept_hdr_async(stream, callback).await {
4060                    while let Some(Ok(msg)) = ws.next().await {
4061                        if msg.is_text() || msg.is_binary() {
4062                            if ws.send(msg).await.is_err() {
4063                                break;
4064                            }
4065
4066                            continue;
4067                        }
4068
4069                        if msg.is_close() {
4070                            let _ = ws.close(None).await;
4071                            break;
4072                        }
4073                    }
4074                }
4075            }
4076        });
4077
4078        let (handler, mut rx) = channel_message_handler();
4079
4080        let config = WebSocketConfig {
4081            url: format!("ws://127.0.0.1:{port}"),
4082            headers: vec![("X-Test".to_string(), "value".to_string())],
4083            heartbeat: None,
4084            heartbeat_msg: None,
4085            reconnect_timeout_ms: Some(2_000),
4086            reconnect_delay_initial_ms: Some(50),
4087            reconnect_delay_max_ms: Some(100),
4088            reconnect_backoff_factor: Some(1.0),
4089            reconnect_jitter_ms: Some(0),
4090            reconnect_max_attempts: None,
4091            idle_timeout_ms: None,
4092            backend: TransportBackend::Sockudo,
4093            proxy_url: None,
4094        };
4095
4096        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
4097            .await
4098            .expect("sockudo connect with custom headers");
4099
4100        client.send_text("ping".to_string(), None).await.unwrap();
4101
4102        let received = tokio::time::timeout(Duration::from_secs(3), async {
4103            loop {
4104                if let Ok(msg) = rx.try_recv() {
4105                    return msg;
4106                }
4107                tokio::time::sleep(Duration::from_millis(10)).await;
4108            }
4109        })
4110        .await
4111        .expect("did not receive echo before timeout");
4112
4113        match received {
4114            WsMessage::Text(t) => assert_eq!(t.as_str(), "ping"),
4115            other => panic!("expected text, was {other:?}"),
4116        }
4117
4118        client.disconnect().await;
4119        tokio::time::timeout(Duration::from_secs(3), server)
4120            .await
4121            .expect("server did not close before timeout")
4122            .unwrap();
4123    }
4124
4125    #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4126    #[rstest]
4127    #[tokio::test]
4128    async fn test_sockudo_backend_round_trip_text() {
4129        // tokio-tungstenite test peer paired with a sockudo client.
4130        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4131        let port = listener.local_addr().unwrap().port();
4132
4133        let server = task::spawn(async move {
4134            if let Ok((stream, _)) = listener.accept().await
4135                && let Ok(mut ws) = accept_async(stream).await
4136            {
4137                while let Some(Ok(msg)) = ws.next().await {
4138                    // Inner if consumes `msg`, cannot hoist into a match guard
4139                    #[expect(clippy::collapsible_match)]
4140                    match msg {
4141                        WsMessage::Text(_) | WsMessage::Binary(_) => {
4142                            if ws.send(msg).await.is_err() {
4143                                break;
4144                            }
4145                        }
4146                        WsMessage::Close(_) => {
4147                            let _ = ws.close(None).await;
4148                            break;
4149                        }
4150                        _ => {}
4151                    }
4152                }
4153            }
4154        });
4155
4156        let (handler, mut rx) = channel_message_handler();
4157        let config = WebSocketConfig {
4158            url: format!("ws://127.0.0.1:{port}"),
4159            headers: vec![],
4160            heartbeat: None,
4161            heartbeat_msg: None,
4162            reconnect_timeout_ms: Some(2_000),
4163            reconnect_delay_initial_ms: Some(50),
4164            reconnect_delay_max_ms: Some(100),
4165            reconnect_backoff_factor: Some(1.0),
4166            reconnect_jitter_ms: Some(0),
4167            reconnect_max_attempts: None,
4168            idle_timeout_ms: None,
4169            backend: TransportBackend::Sockudo,
4170            proxy_url: None,
4171        };
4172
4173        let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
4174            .await
4175            .expect("sockudo connect");
4176
4177        client.send_text("ping".to_string(), None).await.unwrap();
4178
4179        let received = tokio::time::timeout(Duration::from_secs(3), async {
4180            loop {
4181                if let Ok(msg) = rx.try_recv() {
4182                    return msg;
4183                }
4184                tokio::time::sleep(Duration::from_millis(10)).await;
4185            }
4186        })
4187        .await
4188        .expect("did not receive echo before timeout");
4189
4190        match received {
4191            WsMessage::Text(t) => assert_eq!(t.as_str(), "ping"),
4192            other => panic!("expected text, was {other:?}"),
4193        }
4194
4195        client.disconnect().await;
4196        server.abort();
4197    }
4198
4199    #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4200    #[rstest]
4201    #[case::ws_default_port("ws://example.com/ws", "example.com", "example.com", 80, "/ws", false)]
4202    #[case::wss_default_port(
4203        "wss://example.com/ws",
4204        "example.com",
4205        "example.com",
4206        443,
4207        "/ws",
4208        true
4209    )]
4210    // url::Url normalises explicit default ports (`:80` for ws, `:443` for wss)
4211    // away, so `parsed.port()` reports `None` here and Host stays unqualified.
4212    #[case::ws_explicit_default(
4213        "ws://example.com:80/ws",
4214        "example.com",
4215        "example.com",
4216        80,
4217        "/ws",
4218        false
4219    )]
4220    #[case::ws_non_default(
4221        "ws://example.com:8443/feed",
4222        "example.com",
4223        "example.com:8443",
4224        8443,
4225        "/feed",
4226        false
4227    )]
4228    #[case::wss_non_default(
4229        "wss://example.com:9443/feed",
4230        "example.com",
4231        "example.com:9443",
4232        9443,
4233        "/feed",
4234        true
4235    )]
4236    #[case::root_path(
4237        "ws://example.com:9000/",
4238        "example.com",
4239        "example.com:9000",
4240        9000,
4241        "/",
4242        false
4243    )]
4244    #[case::query_string(
4245        "ws://example.com/feed?token=abc&channel=trades",
4246        "example.com",
4247        "example.com",
4248        80,
4249        "/feed?token=abc&channel=trades",
4250        false
4251    )]
4252    // IPv6: bare host strips brackets for DNS/TCP/SNI; Host header keeps them.
4253    #[case::ipv6_default("ws://[::1]/feed", "::1", "[::1]", 80, "/feed", false)]
4254    #[case::ipv6_explicit_port("ws://[::1]:9000/feed", "::1", "[::1]:9000", 9000, "/feed", false)]
4255    #[case::ipv6_wss(
4256        "wss://[2001:db8::1]:8443/",
4257        "2001:db8::1",
4258        "[2001:db8::1]:8443",
4259        8443,
4260        "/",
4261        true
4262    )]
4263    fn sockudo_target_parses_url(
4264        #[case] url: &str,
4265        #[case] host: &str,
4266        #[case] host_header: &str,
4267        #[case] port: u16,
4268        #[case] path: &str,
4269        #[case] is_tls: bool,
4270    ) {
4271        let target = super::SockudoTarget::parse(url).expect("parse should succeed");
4272        assert_eq!(target.host, host);
4273        assert_eq!(target.host_header, host_header);
4274        assert_eq!(target.port, port);
4275        assert_eq!(target.path, path);
4276        assert_eq!(target.is_tls, is_tls);
4277    }
4278
4279    #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4280    #[rstest]
4281    fn sockudo_target_rejects_unsupported_scheme() {
4282        let err = super::SockudoTarget::parse("http://example.com/feed").expect_err("not a ws URL");
4283        let msg = err.to_string();
4284        assert!(
4285            msg.contains("expected ws:// or wss://"),
4286            "unexpected error: {msg}"
4287        );
4288    }
4289
4290    #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4291    #[rstest]
4292    fn sockudo_target_rejects_malformed_url() {
4293        let err = super::SockudoTarget::parse("not a url").expect_err("malformed URL");
4294        assert!(
4295            matches!(err, super::TransportError::InvalidUrl(_)),
4296            "expected InvalidUrl, was: {err:?}"
4297        );
4298    }
4299}
4300
4301#[cfg(test)]
4302#[cfg(feature = "turmoil")]
4303mod turmoil_tests {
4304    use std::{sync::Arc, time::Duration};
4305
4306    use futures_util::{SinkExt, StreamExt};
4307    use nautilus_common::testing::wait_until_async;
4308    use rstest::rstest;
4309    use tokio_tungstenite::{accept_async, tungstenite::Message as WsMessage};
4310    use turmoil::{Builder, net};
4311
4312    use super::*;
4313    use crate::websocket::types::channel_message_handler;
4314
4315    #[rstest]
4316    fn test_turmoil_reconnect_buffer_waits_for_auth() {
4317        let mut sim = Builder::new().build();
4318        let messages = Arc::new(tokio::sync::Mutex::new(Vec::new()));
4319        let server_messages = Arc::clone(&messages);
4320
4321        sim.host("server", move || {
4322            let messages = Arc::clone(&server_messages);
4323            auth_buffer_server(messages)
4324        });
4325
4326        sim.client("client", async move {
4327            let tracker = AuthTracker::new();
4328            let (handler, _rx) = channel_message_handler();
4329            let client = WebSocketClient::connect(
4330                turmoil_websocket_config(),
4331                Some(handler),
4332                None,
4333                None,
4334                vec![],
4335                None,
4336            )
4337            .await
4338            .expect("Should connect");
4339
4340            client.set_auth_tracker(tracker.clone(), true);
4341            assert!(client.is_active(), "Client should start active");
4342
4343            wait_until_async(
4344                || async { client.is_reconnecting() },
4345                Duration::from_secs(3),
4346            )
4347            .await;
4348
4349            client
4350                .writer_tx
4351                .send(WriterCommand::Send(Message::Text("stale".into())))
4352                .unwrap();
4353
4354            wait_until_async(|| async { client.is_active() }, Duration::from_secs(3)).await;
4355
4356            let _auth_receiver = tracker.begin();
4357
4358            tokio::time::sleep(Duration::from_millis(300)).await;
4359            assert!(
4360                messages.lock().await.is_empty(),
4361                "buffered messages should wait for auth after reconnect"
4362            );
4363
4364            tracker.succeed();
4365
4366            wait_until_async(
4367                || {
4368                    let messages = Arc::clone(&messages);
4369                    async move { messages.lock().await.as_slice() == ["stale"] }
4370                },
4371                Duration::from_secs(3),
4372            )
4373            .await;
4374
4375            assert_eq!(messages.lock().await.as_slice(), ["stale"]);
4376
4377            client.disconnect().await;
4378            assert!(client.is_disconnected());
4379
4380            Ok(())
4381        });
4382
4383        sim.run().unwrap();
4384    }
4385
4386    #[rstest]
4387    fn test_turmoil_reconnect_buffer_discards_after_auth_failure() {
4388        let mut sim = Builder::new().build();
4389        let messages = Arc::new(tokio::sync::Mutex::new(Vec::new()));
4390        let server_messages = Arc::clone(&messages);
4391
4392        sim.host("server", move || {
4393            let messages = Arc::clone(&server_messages);
4394            auth_buffer_server(messages)
4395        });
4396
4397        sim.client("client", async move {
4398            let tracker = AuthTracker::new();
4399            let (handler, _rx) = channel_message_handler();
4400            let client = WebSocketClient::connect(
4401                turmoil_websocket_config(),
4402                Some(handler),
4403                None,
4404                None,
4405                vec![],
4406                None,
4407            )
4408            .await
4409            .expect("Should connect");
4410
4411            client.set_auth_tracker(tracker.clone(), true);
4412            assert!(client.is_active(), "Client should start active");
4413
4414            wait_until_async(
4415                || async { client.is_reconnecting() },
4416                Duration::from_secs(3),
4417            )
4418            .await;
4419
4420            client
4421                .writer_tx
4422                .send(WriterCommand::Send(Message::Text("stale".into())))
4423                .unwrap();
4424
4425            wait_until_async(|| async { client.is_active() }, Duration::from_secs(3)).await;
4426
4427            let _auth_receiver = tracker.begin();
4428            tracker.fail("rejected");
4429
4430            tokio::time::sleep(Duration::from_millis(300)).await;
4431            assert!(
4432                messages.lock().await.is_empty(),
4433                "buffered messages should be discarded after auth failure"
4434            );
4435
4436            let _retry_auth_receiver = tracker.begin();
4437            tracker.succeed();
4438
4439            tokio::time::sleep(Duration::from_millis(300)).await;
4440            assert!(
4441                messages.lock().await.is_empty(),
4442                "discarded messages should not replay on a later auth success"
4443            );
4444
4445            client.disconnect().await;
4446            assert!(client.is_disconnected());
4447
4448            Ok(())
4449        });
4450
4451        sim.run().unwrap();
4452    }
4453
4454    fn turmoil_websocket_config() -> WebSocketConfig {
4455        WebSocketConfig {
4456            url: "ws://server:8080".to_string(),
4457            headers: vec![],
4458            heartbeat: None,
4459            heartbeat_msg: None,
4460            reconnect_timeout_ms: Some(5_000),
4461            reconnect_delay_initial_ms: Some(50),
4462            reconnect_delay_max_ms: Some(200),
4463            reconnect_backoff_factor: Some(1.0),
4464            reconnect_jitter_ms: Some(0),
4465            reconnect_max_attempts: None,
4466            idle_timeout_ms: None,
4467            backend: TransportBackend::Tungstenite,
4468            proxy_url: None,
4469        }
4470    }
4471
4472    async fn auth_buffer_server(
4473        messages: Arc<tokio::sync::Mutex<Vec<String>>>,
4474    ) -> Result<(), Box<dyn std::error::Error>> {
4475        let listener = net::TcpListener::bind("0.0.0.0:8080").await?;
4476
4477        let (stream, _) = listener.accept().await?;
4478        let mut websocket = accept_async(stream).await?;
4479        let _ = websocket.send(WsMessage::Text("first".into())).await;
4480        drop(websocket);
4481
4482        tokio::time::sleep(Duration::from_millis(200)).await;
4483
4484        let (stream, _) = listener.accept().await?;
4485        let mut websocket = accept_async(stream).await?;
4486
4487        while let Some(msg) = websocket.next().await {
4488            match msg {
4489                Ok(WsMessage::Text(text)) => {
4490                    messages.lock().await.push(text.to_string());
4491                }
4492                Ok(WsMessage::Close(_)) => {
4493                    let _ = websocket.close(None).await;
4494                    break;
4495                }
4496                Ok(_) => {}
4497                Err(_) => break,
4498            }
4499        }
4500
4501        Ok(())
4502    }
4503}