1use std::{
27 collections::VecDeque,
28 fmt::Debug,
29 sync::{
30 Arc, OnceLock,
31 atomic::{AtomicBool, AtomicU8, Ordering},
32 },
33 time::Duration,
34};
35
36use futures_util::{SinkExt, StreamExt};
37use http::HeaderName;
38use nautilus_core::CleanDrop;
39use nautilus_cryptography::providers::install_cryptographic_provider;
40#[cfg(any(feature = "turmoil", feature = "transport-sockudo"))]
41use rustls::ClientConfig;
42#[cfg(feature = "transport-sockudo")]
43use sockudo_ws::{
44 Config as SockudoConfig, Http1, Role, Stream as SockudoStream,
45 WebSocketStream as SockudoWebSocketStream,
46};
47#[cfg(feature = "transport-sockudo")]
48use tokio::io::{AsyncRead, AsyncWrite};
49#[cfg(any(feature = "turmoil", feature = "transport-sockudo"))]
50use tokio_rustls::TlsConnector;
51#[cfg(feature = "turmoil")]
52use tokio_tungstenite::MaybeTlsStream;
53#[cfg(feature = "turmoil")]
54use tokio_tungstenite::client_async;
55#[cfg(not(feature = "turmoil"))]
56use tokio_tungstenite::connect_async_with_config;
57use tokio_tungstenite::tungstenite::{client::IntoClientRequest, http::HeaderValue};
58use ustr::Ustr;
59
60#[cfg(not(feature = "turmoil"))]
61use super::proxy::{ProxiedStream, ProxyKind, WsTarget, tunnel_via_proxy};
62use super::{
63 auth::{AuthState, AuthTracker},
64 config::{TransportBackend, WebSocketConfig},
65 consts::{
66 CONNECTION_STATE_CHECK_INTERVAL_MS, GRACEFUL_SHUTDOWN_DELAY_MS,
67 GRACEFUL_SHUTDOWN_TIMEOUT_SECS,
68 },
69 types::{MessageHandler, MessageReader, MessageWriter, PingHandler, WriterCommand},
70};
71#[cfg(feature = "turmoil")]
72use crate::net::TcpConnector;
73#[cfg(feature = "transport-sockudo")]
74use crate::net::TcpStream;
75#[cfg(feature = "transport-sockudo")]
76use crate::transport::sockudo::{
77 PrefixedIo, SockudoTransport, client_handshake_with_headers, validate_extra_headers,
78};
79use crate::{
80 RECONNECTED,
81 backoff::ExponentialBackoff,
82 dst,
83 error::SendError,
84 logging::{log_task_aborted, log_task_started, log_task_stopped},
85 mode::ConnectionMode,
86 ratelimiter::{RateLimiter, clock::MonotonicClock, quota::Quota},
87 transport::{BoxedWsTransport, Message, TransportError, tungstenite::TungsteniteTransport},
88};
89
90pub struct WebSocketClientInner {
106 config: WebSocketConfig,
107 message_handler: Option<MessageHandler>,
109 ping_handler: Option<PingHandler>,
111 read_task: Option<tokio::task::JoinHandle<()>>,
112 write_task: tokio::task::JoinHandle<()>,
113 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
114 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
115 connection_mode: Arc<AtomicU8>,
116 state_notify: Arc<tokio::sync::Notify>,
117 reconnect_timeout: Duration,
118 backoff: ExponentialBackoff,
119 is_stream_mode: bool,
123 reconnect_max_attempts: Option<u32>,
125 reconnection_attempt_count: u32,
127 auth_tracker: Arc<OnceLock<AuthTracker>>,
129 reconnect_buffer_waits_for_auth: Arc<AtomicBool>,
131}
132
133enum ReconnectBufferAction {
134 Drain,
135 Wait,
136 Discard,
137}
138
139impl WebSocketClientInner {
140 #[expect(
148 clippy::unused_async,
149 reason = "async signature for consistency with connect-based constructors"
150 )]
151 pub async fn new_with_writer(
152 config: WebSocketConfig,
153 writer: MessageWriter,
154 ) -> Result<Self, TransportError> {
155 install_cryptographic_provider();
156
157 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
158 let state_notify = Arc::new(tokio::sync::Notify::new());
159
160 let read_task = None;
162
163 let backoff = ExponentialBackoff::new(
165 Duration::from_secs(2),
166 Duration::from_secs(30),
167 1.5,
168 100,
169 true,
170 )
171 .map_err(|e| {
172 TransportError::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
173 })?;
174
175 let auth_tracker = Arc::new(OnceLock::new());
176 let reconnect_buffer_waits_for_auth = Arc::new(AtomicBool::new(false));
177
178 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
179 let write_task = Self::spawn_write_task(
180 connection_mode.clone(),
181 state_notify.clone(),
182 writer,
183 writer_rx,
184 Arc::clone(&auth_tracker),
185 Arc::clone(&reconnect_buffer_waits_for_auth),
186 );
187
188 let heartbeat_task = if let Some(heartbeat_interval) = config.heartbeat {
189 Some(Self::spawn_heartbeat_task(
190 connection_mode.clone(),
191 heartbeat_interval,
192 config.heartbeat_msg.clone(),
193 writer_tx.clone(),
194 ))
195 } else {
196 None
197 };
198
199 let reconnect_max_attempts = None; let reconnect_timeout = Duration::from_secs(10);
201
202 Ok(Self {
203 config,
204 message_handler: None, ping_handler: None,
206 writer_tx,
207 connection_mode,
208 state_notify,
209 reconnect_timeout,
210 heartbeat_task,
211 read_task,
212 write_task,
213 backoff,
214 is_stream_mode: true,
215 reconnect_max_attempts,
216 reconnection_attempt_count: 0,
217 auth_tracker,
218 reconnect_buffer_waits_for_auth,
219 })
220 }
221
222 pub async fn connect_url(
230 config: WebSocketConfig,
231 message_handler: Option<MessageHandler>,
232 ping_handler: Option<PingHandler>,
233 ) -> Result<Self, TransportError> {
234 install_cryptographic_provider();
235
236 if config.heartbeat == Some(0) {
237 return Err(TransportError::Io(std::io::Error::new(
238 std::io::ErrorKind::InvalidInput,
239 "Heartbeat interval cannot be zero",
240 )));
241 }
242
243 if config.idle_timeout_ms == Some(0) {
244 return Err(TransportError::Io(std::io::Error::new(
245 std::io::ErrorKind::InvalidInput,
246 "Idle timeout cannot be zero",
247 )));
248 }
249
250 let is_stream_mode = message_handler.is_none();
252 let reconnect_max_attempts = config.reconnect_max_attempts;
253
254 let (writer, reader) = Box::pin(Self::connect_with_server(
255 &config.url,
256 config.headers.clone(),
257 config.backend,
258 config.proxy_url.as_deref(),
259 ))
260 .await?;
261
262 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
263 let state_notify = Arc::new(tokio::sync::Notify::new());
264
265 let read_task = if message_handler.is_some() {
266 Some(Self::spawn_message_handler_task(
267 connection_mode.clone(),
268 state_notify.clone(),
269 reader,
270 message_handler.as_ref(),
271 ping_handler.as_ref(),
272 config.idle_timeout_ms,
273 ))
274 } else {
275 None
276 };
277
278 let auth_tracker = Arc::new(OnceLock::new());
279 let reconnect_buffer_waits_for_auth = Arc::new(AtomicBool::new(false));
280
281 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
282 let write_task = Self::spawn_write_task(
283 connection_mode.clone(),
284 state_notify.clone(),
285 writer,
286 writer_rx,
287 Arc::clone(&auth_tracker),
288 Arc::clone(&reconnect_buffer_waits_for_auth),
289 );
290
291 let heartbeat_task = config.heartbeat.map(|heartbeat_secs| {
293 Self::spawn_heartbeat_task(
294 connection_mode.clone(),
295 heartbeat_secs,
296 config.heartbeat_msg.clone(),
297 writer_tx.clone(),
298 )
299 });
300
301 let reconnect_timeout =
302 Duration::from_millis(config.reconnect_timeout_ms.unwrap_or(10_000));
303 let backoff = ExponentialBackoff::new(
304 Duration::from_millis(config.reconnect_delay_initial_ms.unwrap_or(2_000)),
305 Duration::from_millis(config.reconnect_delay_max_ms.unwrap_or(30_000)),
306 config.reconnect_backoff_factor.unwrap_or(1.5),
307 config.reconnect_jitter_ms.unwrap_or(100),
308 true, )
310 .map_err(|e| {
311 TransportError::Io(std::io::Error::new(std::io::ErrorKind::InvalidInput, e))
312 })?;
313
314 Ok(Self {
315 config,
316 message_handler,
317 ping_handler,
318 read_task,
319 write_task,
320 writer_tx,
321 heartbeat_task,
322 connection_mode,
323 state_notify,
324 reconnect_timeout,
325 backoff,
326 is_stream_mode,
328 reconnect_max_attempts,
329 reconnection_attempt_count: 0,
330 auth_tracker,
331 reconnect_buffer_waits_for_auth,
332 })
333 }
334
335 #[inline]
355 pub async fn connect_with_server(
356 url: &str,
357 headers: Vec<(String, String)>,
358 backend: TransportBackend,
359 proxy_url: Option<&str>,
360 ) -> Result<(MessageWriter, MessageReader), TransportError> {
361 match backend {
362 TransportBackend::Tungstenite => match proxy_url {
363 Some(proxy) => {
364 Box::pin(Self::connect_tungstenite_via_proxy(url, headers, proxy)).await
365 }
366 None => Self::connect_tungstenite(url, headers).await,
367 },
368 TransportBackend::Sockudo => {
369 if proxy_url.is_some() {
370 return Err(TransportError::Other(
371 "proxy_url is not supported with the Sockudo backend".to_string(),
372 ));
373 }
374 #[cfg(feature = "transport-sockudo")]
375 {
376 Self::connect_sockudo(url, headers).await
377 }
378 #[cfg(not(feature = "transport-sockudo"))]
379 {
380 Err(TransportError::Other(
381 "sockudo backend selected but the transport-sockudo \
382 Cargo feature is not enabled"
383 .to_string(),
384 ))
385 }
386 }
387 }
388 }
389
390 #[inline]
393 #[cfg(not(feature = "turmoil"))]
394 async fn connect_tungstenite(
395 url: &str,
396 headers: Vec<(String, String)>,
397 ) -> Result<(MessageWriter, MessageReader), TransportError> {
398 let mut request = url.into_client_request().map_err(TransportError::from)?;
399 let req_headers = request.headers_mut();
400
401 for (key, val) in headers {
402 let header_value = HeaderValue::from_str(&val)
403 .map_err(|e| TransportError::Handshake(format!("invalid header value: {e}")))?;
404 let header_name: HeaderName = key
405 .parse()
406 .map_err(|e| TransportError::Handshake(format!("invalid header name: {e}")))?;
407 req_headers.insert(header_name, header_value);
408 }
409
410 let (stream, _resp) = connect_async_with_config(request, None, true)
411 .await
412 .map_err(TransportError::from)?;
413 let transport: BoxedWsTransport = Box::pin(TungsteniteTransport::new(stream));
414 Ok(transport.split())
415 }
416
417 #[inline]
425 #[cfg(not(feature = "turmoil"))]
426 async fn connect_tungstenite_via_proxy(
427 url: &str,
428 headers: Vec<(String, String)>,
429 proxy_url: &str,
430 ) -> Result<(MessageWriter, MessageReader), TransportError> {
431 let proxy = match ProxyKind::parse(proxy_url)? {
432 ProxyKind::Http(target) => target,
433 ProxyKind::Unsupported { scheme } => {
434 log::warn!(
435 "WebSocket proxy_url scheme '{scheme}' is not yet supported; \
436 connecting without a WebSocket proxy"
437 );
438 return Self::connect_tungstenite(url, headers).await;
439 }
440 };
441
442 let mut request = url.into_client_request().map_err(TransportError::from)?;
443 let req_headers = request.headers_mut();
444
445 for (key, val) in headers {
446 let header_value = HeaderValue::from_str(&val)
447 .map_err(|e| TransportError::Handshake(format!("invalid header value: {e}")))?;
448 let header_name: HeaderName = key
449 .parse()
450 .map_err(|e| TransportError::Handshake(format!("invalid header name: {e}")))?;
451 req_headers.insert(header_name, header_value);
452 }
453
454 let target = WsTarget::parse(url)?;
455 let stream = tunnel_via_proxy(&target, &proxy).await?;
456
457 #[allow(clippy::match_same_arms)]
465 let transport: BoxedWsTransport = match stream {
466 ProxiedStream::Plain(tcp) => Box::pin(proxied_ws_handshake(request, tcp)).await?,
467 ProxiedStream::PlainOverTlsProxy(s) => {
468 Box::pin(proxied_ws_handshake(request, *s)).await?
469 }
470 ProxiedStream::Tls(s) => Box::pin(proxied_ws_handshake(request, *s)).await?,
471 ProxiedStream::TlsOverTlsProxy(s) => {
472 Box::pin(proxied_ws_handshake(request, *s)).await?
473 }
474 };
475
476 Ok(transport.split())
477 }
478
479 #[inline]
482 #[cfg(feature = "turmoil")]
483 #[expect(
484 clippy::unused_async,
485 reason = "signature mirrors the production variant; both are awaited in the dispatcher"
486 )]
487 async fn connect_tungstenite_via_proxy(
488 _url: &str,
489 _headers: Vec<(String, String)>,
490 _proxy_url: &str,
491 ) -> Result<(MessageWriter, MessageReader), TransportError> {
492 Err(TransportError::Other(
493 "proxy_url is not supported under the turmoil simulator".to_string(),
494 ))
495 }
496
497 #[inline]
500 #[cfg(feature = "turmoil")]
501 async fn connect_tungstenite(
502 url: &str,
503 headers: Vec<(String, String)>,
504 ) -> Result<(MessageWriter, MessageReader), TransportError> {
505 let mut request = url.into_client_request().map_err(TransportError::from)?;
506 let req_headers = request.headers_mut();
507
508 for (key, val) in headers {
509 let header_value = HeaderValue::from_str(&val)
510 .map_err(|e| TransportError::Handshake(format!("invalid header value: {e}")))?;
511 let header_name: HeaderName = key
512 .parse()
513 .map_err(|e| TransportError::Handshake(format!("invalid header name: {e}")))?;
514 req_headers.insert(header_name, header_value);
515 }
516
517 let uri = request.uri();
518 let scheme = uri.scheme_str().unwrap_or("ws");
519 let host = uri
520 .host()
521 .ok_or_else(|| TransportError::InvalidUrl("missing hostname".to_string()))?;
522
523 let port = uri
525 .port_u16()
526 .unwrap_or_else(|| if scheme == "wss" { 443 } else { 80 });
527
528 let addr = format!("{host}:{port}");
529
530 let connector = crate::net::RealTcpConnector;
532 let tcp_stream = connector.connect(&addr).await?;
533 if let Err(e) = tcp_stream.set_nodelay(true) {
534 log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
535 }
536
537 let maybe_tls_stream = if scheme == "wss" {
539 let mut root_store = rustls::RootCertStore::empty();
541 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
542
543 let config = ClientConfig::builder()
544 .with_root_certificates(root_store)
545 .with_no_client_auth();
546
547 let tls_connector = TlsConnector::from(std::sync::Arc::new(config));
548 let domain = rustls::pki_types::ServerName::try_from(host.to_string())
549 .map_err(|e| TransportError::Tls(format!("Invalid DNS name: {e}")))?;
550
551 let tls_stream = tls_connector
552 .connect(domain, tcp_stream)
553 .await
554 .map_err(TransportError::Io)?;
555 MaybeTlsStream::Rustls(tls_stream)
556 } else {
557 MaybeTlsStream::Plain(tcp_stream)
558 };
559
560 let (stream, _resp) = client_async(request, maybe_tls_stream)
562 .await
563 .map_err(TransportError::from)?;
564 let transport: BoxedWsTransport = Box::pin(TungsteniteTransport::new(stream));
565 Ok(transport.split())
566 }
567
568 #[inline]
577 #[cfg(feature = "transport-sockudo")]
578 async fn connect_sockudo(
579 url: &str,
580 headers: Vec<(String, String)>,
581 ) -> Result<(MessageWriter, MessageReader), TransportError> {
582 let target = SockudoTarget::parse(url)?;
583 validate_extra_headers(&headers).map_err(TransportError::from)?;
584
585 #[cfg(feature = "turmoil")]
586 if target.is_tls {
587 return Err(TransportError::Tls(
588 "wss:// is not supported under the turmoil simulator; use ws://".to_string(),
589 ));
590 }
591
592 let tcp_stream = TcpStream::connect((target.host.as_str(), target.port))
593 .await
594 .map_err(TransportError::Io)?;
595
596 if let Err(e) = tcp_stream.set_nodelay(true) {
597 log::warn!("Failed to enable TCP_NODELAY for sockudo client: {e:?}");
598 }
599
600 #[cfg(not(feature = "turmoil"))]
601 if target.is_tls {
602 let mut root_store = rustls::RootCertStore::empty();
603 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
604 let config = ClientConfig::builder()
605 .with_root_certificates(root_store)
606 .with_no_client_auth();
607 let connector = TlsConnector::from(std::sync::Arc::new(config));
608 let domain = rustls::pki_types::ServerName::try_from(target.host.clone())
609 .map_err(|e| TransportError::Tls(format!("Invalid DNS name: {e}")))?;
610 let tls_stream = connector
611 .connect(domain, tcp_stream)
612 .await
613 .map_err(TransportError::Io)?;
614 return Self::finish_sockudo_handshake(tls_stream, &target, &headers).await;
615 }
616
617 Self::finish_sockudo_handshake(tcp_stream, &target, &headers).await
618 }
619
620 #[cfg(feature = "transport-sockudo")]
621 async fn finish_sockudo_handshake<S>(
622 mut stream: S,
623 target: &SockudoTarget,
624 headers: &[(String, String)],
625 ) -> Result<(MessageWriter, MessageReader), TransportError>
626 where
627 S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
628 {
629 let handshake = client_handshake_with_headers(
633 &mut stream,
634 &target.host_header,
635 &target.path,
636 None,
637 headers,
638 )
639 .await
640 .map_err(TransportError::from)?;
641
642 let stream = match handshake.leftover {
645 Some(prefix) => SockudoStream::<Http1>::new(PrefixedIo::new(stream, prefix)),
646 None => SockudoStream::<Http1>::new(stream),
647 };
648 let ws = SockudoWebSocketStream::from_raw(stream, Role::Client, SockudoConfig::default());
649 let transport: BoxedWsTransport = Box::pin(SockudoTransport::new(ws));
650 Ok(transport.split())
651 }
652}
653
654#[cfg(not(feature = "turmoil"))]
659async fn proxied_ws_handshake<S>(
660 request: tokio_tungstenite::tungstenite::handshake::client::Request,
661 stream: S,
662) -> Result<BoxedWsTransport, TransportError>
663where
664 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
665{
666 let (ws, _resp) = tokio_tungstenite::client_async(request, stream)
667 .await
668 .map_err(TransportError::from)?;
669 Ok(Box::pin(TungsteniteTransport::new(ws)))
670}
671
672#[cfg(feature = "transport-sockudo")]
679#[derive(Debug, PartialEq, Eq)]
680struct SockudoTarget {
681 host: String,
682 host_header: String,
685 port: u16,
686 path: String,
687 is_tls: bool,
688}
689
690#[cfg(feature = "transport-sockudo")]
691impl SockudoTarget {
692 fn parse(url: &str) -> Result<Self, TransportError> {
693 let parsed =
694 url::Url::parse(url).map_err(|e| TransportError::InvalidUrl(format!("{url}: {e}")))?;
695
696 let scheme = parsed.scheme();
697 let is_tls = match scheme {
698 "ws" => false,
699 "wss" => true,
700 other => {
701 return Err(TransportError::InvalidUrl(format!(
702 "expected ws:// or wss:// scheme, was {other}"
703 )));
704 }
705 };
706
707 let raw_host = parsed
708 .host_str()
709 .ok_or_else(|| TransportError::InvalidUrl("missing hostname".to_string()))?;
710
711 let is_bracketed = raw_host.starts_with('[') && raw_host.ends_with(']');
716 let host = if is_bracketed {
717 raw_host[1..raw_host.len() - 1].to_string()
718 } else {
719 raw_host.to_string()
720 };
721
722 let explicit_port = parsed.port();
723 let port = explicit_port.unwrap_or(if is_tls { 443 } else { 80 });
724 let host_header = match explicit_port {
725 Some(p) => format!("{raw_host}:{p}"),
726 None => raw_host.to_string(),
727 };
728
729 let path = if parsed.path().is_empty() {
730 "/".to_string()
731 } else {
732 let mut p = parsed.path().to_string();
733 if let Some(query) = parsed.query() {
734 p.push('?');
735 p.push_str(query);
736 }
737 p
738 };
739
740 Ok(Self {
741 host,
742 host_header,
743 port,
744 path,
745 is_tls,
746 })
747 }
748}
749
750impl WebSocketClientInner {
751 pub async fn reconnect(&mut self) -> Result<(), TransportError> {
766 log::debug!("Reconnecting");
767
768 if self.is_stream_mode {
769 log::warn!(
770 "Auto-reconnect disabled for stream-based WebSocket client; \
771 stream users must manually reconnect by creating a new connection"
772 );
773 self.connection_mode
775 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
776 return Ok(());
777 }
778
779 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
780 log::debug!("Reconnect aborted due to disconnect state");
781 return Ok(());
782 }
783
784 dst::time::timeout(self.reconnect_timeout, async {
785 let (new_writer, reader) = Self::connect_with_server(
787 &self.config.url,
788 self.config.headers.clone(),
789 self.config.backend,
790 self.config.proxy_url.as_deref(),
791 )
792 .await?;
793
794 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
795 log::debug!("Reconnect aborted mid-flight (after connect)");
796 return Ok(());
797 }
798
799 let (tx, rx) = tokio::sync::oneshot::channel();
802 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
803 log::error!("{e}");
804 return Err(TransportError::Io(std::io::Error::new(
805 std::io::ErrorKind::BrokenPipe,
806 format!("Failed to send update command: {e}"),
807 )));
808 }
809
810 match rx.await {
812 Ok(true) => log::debug!("Writer confirmed socket update"),
813 Ok(false) => {
814 log::warn!("Writer rejected socket update, aborting reconnect");
815 return Err(TransportError::Io(std::io::Error::other(
816 "Failed to update reconnection writer",
817 )));
818 }
819 Err(e) => {
820 log::error!("Writer dropped update channel: {e}");
821 return Err(TransportError::Io(std::io::Error::new(
822 std::io::ErrorKind::BrokenPipe,
823 "Writer task dropped response channel",
824 )));
825 }
826 }
827
828 dst::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
830
831 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
832 log::debug!("Reconnect aborted mid-flight (after delay)");
833 return Ok(());
834 }
835
836 if let Some(ref read_task) = self.read_task.take()
837 && !read_task.is_finished()
838 {
839 read_task.abort();
840 log_task_aborted("read");
841 }
842
843 if self
846 .connection_mode
847 .compare_exchange(
848 ConnectionMode::Reconnect.as_u8(),
849 ConnectionMode::Active.as_u8(),
850 Ordering::SeqCst,
851 Ordering::SeqCst,
852 )
853 .is_err()
854 {
855 log::debug!("Reconnect aborted (state changed during reconnect)");
856 return Ok(());
857 }
858
859 self.read_task = if self.message_handler.is_some() {
860 Some(Self::spawn_message_handler_task(
861 self.connection_mode.clone(),
862 self.state_notify.clone(),
863 reader,
864 self.message_handler.as_ref(),
865 self.ping_handler.as_ref(),
866 self.config.idle_timeout_ms,
867 ))
868 } else {
869 None
870 };
871
872 log::debug!("Reconnect succeeded");
873 Ok(())
874 })
875 .await
876 .map_err(|_| {
877 TransportError::Io(std::io::Error::new(
878 std::io::ErrorKind::TimedOut,
879 format!(
880 "reconnection timed out after {}s",
881 self.reconnect_timeout.as_secs_f64()
882 ),
883 ))
884 })?
885 }
886
887 #[inline]
893 #[must_use]
894 pub fn is_alive(&self) -> bool {
895 match &self.read_task {
896 Some(read_task) => !read_task.is_finished() && !self.write_task.is_finished(),
897 None => !self.write_task.is_finished(),
898 }
899 }
900
901 fn spawn_message_handler_task(
902 connection_state: Arc<AtomicU8>,
903 state_notify: Arc<tokio::sync::Notify>,
904 mut reader: MessageReader,
905 message_handler: Option<&MessageHandler>,
906 ping_handler: Option<&PingHandler>,
907 idle_timeout_ms: Option<u64>,
908 ) -> tokio::task::JoinHandle<()> {
909 log::debug!("Started message handler task 'read'");
910
911 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
912 let idle_timeout = idle_timeout_ms.map(Duration::from_millis);
913
914 let message_handler = message_handler.cloned();
916 let ping_handler = ping_handler.cloned();
917
918 tokio::task::spawn(async move {
919 let mut last_data_time = dst::time::Instant::now();
920
921 loop {
922 if !ConnectionMode::from_atomic(&connection_state).is_active() {
923 break;
924 }
925
926 match dst::time::timeout(check_interval, reader.next()).await {
927 Ok(Some(Ok(Message::Binary(data)))) => {
928 log::trace!("Received message <binary> {} bytes", data.len());
929 last_data_time = dst::time::Instant::now();
930
931 if let Some(ref handler) = message_handler {
932 handler(Message::Binary(data));
933 }
934 }
935 Ok(Some(Ok(Message::Text(data)))) => {
936 log::trace!("Received message: {data:?}");
937 last_data_time = dst::time::Instant::now();
938
939 if let Some(ref handler) = message_handler {
940 handler(Message::Text(data));
941 }
942 }
943 Ok(Some(Ok(Message::Ping(ping_data)))) => {
944 log::trace!("Received ping: {ping_data:?}");
945 if let Some(ref handler) = ping_handler {
949 handler(ping_data.to_vec());
950 }
951 }
952 Ok(Some(Ok(Message::Pong(_)))) => {
953 log::trace!("Received pong");
954 }
956 Ok(Some(Ok(Message::Close(_)))) => {
957 log::debug!("Received close message - terminating");
958 break;
959 }
960 Ok(Some(Err(e))) => {
961 log::error!("Received error message - terminating: {e}");
962 break;
963 }
964 Ok(None) => {
965 log::debug!("No message received - terminating");
966 break;
967 }
968 Err(_) => {
969 if let Some(timeout) = idle_timeout {
970 let idle_duration = last_data_time.elapsed();
971 if idle_duration >= timeout {
972 log::warn!(
973 "Read idle timeout: no data received for {:.1}s",
974 idle_duration.as_secs_f64()
975 );
976 break;
977 }
978 }
979 }
980 }
981 }
982
983 state_notify.notify_one();
985 })
986 }
987
988 async fn drain_reconnect_buffer(
993 buffer: &mut VecDeque<Message>,
994 writer: &mut MessageWriter,
995 ) -> bool {
996 if buffer.is_empty() {
997 return false;
998 }
999
1000 let initial_buffer_len = buffer.len();
1001 log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
1002
1003 let mut send_error_occurred = false;
1004
1005 while let Some(buffered_msg) = buffer.front() {
1006 let msg_to_send = buffered_msg.clone();
1008
1009 if let Err(e) = writer.send(msg_to_send).await {
1010 log::error!(
1011 "Failed to send buffered message after reconnection: {e}, {} messages remain in buffer",
1012 buffer.len()
1013 );
1014 send_error_occurred = true;
1015 break; }
1017
1018 buffer.pop_front();
1020 }
1021
1022 if buffer.is_empty() {
1023 log::info!("Successfully sent all {initial_buffer_len} buffered messages");
1024 }
1025
1026 send_error_occurred
1027 }
1028
1029 fn can_drain_reconnect_buffer(
1030 reconnect_buffer_waits_for_auth: &AtomicBool,
1031 auth_tracker: &Arc<OnceLock<AuthTracker>>,
1032 ) -> ReconnectBufferAction {
1033 if !reconnect_buffer_waits_for_auth.load(Ordering::Acquire) {
1034 return ReconnectBufferAction::Drain;
1035 }
1036
1037 match auth_tracker.get().map(AuthTracker::auth_state) {
1038 Some(AuthState::Authenticated) => ReconnectBufferAction::Drain,
1039 Some(AuthState::Failed) => ReconnectBufferAction::Discard,
1040 Some(AuthState::Unauthenticated) | None => ReconnectBufferAction::Wait,
1041 }
1042 }
1043
1044 fn spawn_write_task(
1045 connection_state: Arc<AtomicU8>,
1046 state_notify: Arc<tokio::sync::Notify>,
1047 writer: MessageWriter,
1048 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
1049 auth_tracker: Arc<OnceLock<AuthTracker>>,
1050 reconnect_buffer_waits_for_auth: Arc<AtomicBool>,
1051 ) -> tokio::task::JoinHandle<()> {
1052 log_task_started("write");
1053
1054 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
1056
1057 tokio::task::spawn(async move {
1058 let mut active_writer = writer;
1059 let mut reconnect_buffer: VecDeque<Message> = VecDeque::new();
1062
1063 loop {
1064 let mode = ConnectionMode::from_atomic(&connection_state);
1065
1066 match mode {
1067 ConnectionMode::Disconnect => {
1068 if !reconnect_buffer.is_empty() {
1070 log::warn!(
1071 "Discarding {} buffered messages due to disconnect",
1072 reconnect_buffer.len()
1073 );
1074 reconnect_buffer.clear();
1075 }
1076
1077 _ = dst::time::timeout(
1080 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
1081 active_writer.close(),
1082 )
1083 .await;
1084 break;
1085 }
1086 ConnectionMode::Closed => {
1087 if !reconnect_buffer.is_empty() {
1089 log::warn!(
1090 "Discarding {} buffered messages due to closed connection",
1091 reconnect_buffer.len()
1092 );
1093 reconnect_buffer.clear();
1094 }
1095 break;
1096 }
1097 _ => {}
1098 }
1099
1100 if mode.is_active() && !reconnect_buffer.is_empty() {
1101 match Self::can_drain_reconnect_buffer(
1102 reconnect_buffer_waits_for_auth.as_ref(),
1103 &auth_tracker,
1104 ) {
1105 ReconnectBufferAction::Drain => {
1106 let send_error = Self::drain_reconnect_buffer(
1107 &mut reconnect_buffer,
1108 &mut active_writer,
1109 )
1110 .await;
1111
1112 if send_error {
1113 if let Some(tracker) = auth_tracker.get() {
1114 tracker.invalidate();
1115 }
1116 connection_state
1117 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
1118 state_notify.notify_one();
1119 }
1120
1121 continue;
1122 }
1123 ReconnectBufferAction::Discard => {
1124 log::warn!(
1125 "Discarding {} buffered messages after authentication failed",
1126 reconnect_buffer.len()
1127 );
1128 reconnect_buffer.clear();
1129 continue;
1130 }
1131 ReconnectBufferAction::Wait => {}
1132 }
1133 }
1134
1135 match dst::time::timeout(check_interval, writer_rx.recv()).await {
1136 Ok(Some(msg)) => {
1137 let mode = ConnectionMode::from_atomic(&connection_state);
1139 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
1140 break;
1141 }
1142
1143 match msg {
1144 WriterCommand::Update(new_writer, tx) => {
1145 log::debug!("Received new writer");
1146
1147 dst::time::sleep(Duration::from_millis(100)).await;
1149
1150 _ = dst::time::timeout(
1153 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
1154 active_writer.close(),
1155 )
1156 .await;
1157
1158 active_writer = new_writer;
1159 log::debug!("Updated writer");
1160
1161 if let Err(e) = tx.send(true) {
1162 log::error!(
1163 "Failed to report writer update to controller: {e:?}"
1164 );
1165 }
1166 }
1167 WriterCommand::Send(msg) if mode.is_reconnect() => {
1168 log::debug!(
1170 "Buffering message during reconnection (buffer size: {})",
1171 reconnect_buffer.len() + 1
1172 );
1173 reconnect_buffer.push_back(msg);
1174 }
1175 WriterCommand::Send(msg) => {
1176 if let Err(e) = active_writer.send(msg.clone()).await {
1177 log::error!("Failed to send message: {e}");
1178 log::warn!("Writer triggering reconnect");
1179 reconnect_buffer.push_back(msg);
1180
1181 if let Some(tracker) = auth_tracker.get() {
1182 tracker.invalidate();
1183 }
1184 connection_state
1185 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
1186 state_notify.notify_one();
1187 }
1188 }
1189 }
1190 }
1191 Ok(None) => {
1192 log::debug!("Writer channel closed, terminating writer task");
1194 break;
1195 }
1196 Err(_) => {
1197 }
1199 }
1200 }
1201
1202 _ = dst::time::timeout(
1205 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
1206 active_writer.close(),
1207 )
1208 .await;
1209
1210 log_task_stopped("write");
1211 })
1212 }
1213
1214 fn spawn_heartbeat_task(
1215 connection_state: Arc<AtomicU8>,
1216 heartbeat_secs: u64,
1217 message: Option<String>,
1218 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
1219 ) -> tokio::task::JoinHandle<()> {
1220 log_task_started("heartbeat");
1221
1222 tokio::task::spawn(async move {
1223 let interval = Duration::from_secs(heartbeat_secs);
1224
1225 loop {
1226 dst::time::sleep(interval).await;
1227
1228 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
1229 ConnectionMode::Active => {
1230 let msg = match &message {
1231 Some(text) => WriterCommand::Send(Message::Text(text.clone().into())),
1232 None => WriterCommand::Send(Message::Ping(vec![].into())),
1233 };
1234
1235 match writer_tx.send(msg) {
1236 Ok(()) => log::trace!("Sent heartbeat to writer task"),
1237 Err(e) => {
1238 log::error!("Failed to send heartbeat to writer task: {e}");
1239 }
1240 }
1241 }
1242 ConnectionMode::Reconnect => {}
1243 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
1244 }
1245 }
1246
1247 log_task_stopped("heartbeat");
1248 })
1249 }
1250}
1251
1252impl Drop for WebSocketClientInner {
1253 fn drop(&mut self) {
1254 self.clean_drop();
1256 }
1257}
1258
1259impl CleanDrop for WebSocketClientInner {
1261 fn clean_drop(&mut self) {
1262 if let Some(ref read_task) = self.read_task.take()
1263 && !read_task.is_finished()
1264 {
1265 read_task.abort();
1266 log_task_aborted("read");
1267 }
1268
1269 if !self.write_task.is_finished() {
1270 self.write_task.abort();
1271 log_task_aborted("write");
1272 }
1273
1274 if let Some(ref handle) = self.heartbeat_task.take()
1275 && !handle.is_finished()
1276 {
1277 handle.abort();
1278 log_task_aborted("heartbeat");
1279 }
1280
1281 self.message_handler = None;
1283 self.ping_handler = None;
1284 }
1285}
1286
1287#[expect(
1288 clippy::missing_fields_in_debug,
1289 reason = "handler closures and internal task handles are intentionally omitted"
1290)]
1291impl Debug for WebSocketClientInner {
1292 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1293 f.debug_struct(stringify!(WebSocketClientInner))
1294 .field("config", &self.config)
1295 .field(
1296 "connection_mode",
1297 &ConnectionMode::from_atomic(&self.connection_mode),
1298 )
1299 .field("reconnect_timeout", &self.reconnect_timeout)
1300 .field("is_stream_mode", &self.is_stream_mode)
1301 .finish()
1302 }
1303}
1304
1305#[cfg_attr(
1310 feature = "python",
1311 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
1312)]
1313#[cfg_attr(
1314 feature = "python",
1315 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.network")
1316)]
1317pub struct WebSocketClient {
1318 pub(crate) controller_task: tokio::task::JoinHandle<()>,
1319 pub(crate) connection_mode: Arc<AtomicU8>,
1320 pub(crate) state_notify: Arc<tokio::sync::Notify>,
1321 pub(crate) reconnect_timeout: Duration,
1322 pub(crate) rate_limiter: Arc<RateLimiter<Ustr, MonotonicClock>>,
1323 pub(crate) writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
1324 auth_tracker: Arc<OnceLock<AuthTracker>>,
1325 reconnect_buffer_waits_for_auth: Arc<AtomicBool>,
1326}
1327
1328impl Debug for WebSocketClient {
1329 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1330 f.debug_struct(stringify!(WebSocketClient)).finish()
1331 }
1332}
1333
1334impl WebSocketClient {
1335 pub async fn connect_stream(
1351 config: WebSocketConfig,
1352 keyed_quotas: Vec<(String, Quota)>,
1353 default_quota: Option<Quota>,
1354 post_reconnect: Option<Arc<dyn Fn() + Send + Sync>>,
1355 ) -> Result<(MessageReader, Self), TransportError> {
1356 install_cryptographic_provider();
1357
1358 let (writer, reader) = WebSocketClientInner::connect_with_server(
1360 &config.url,
1361 config.headers.clone(),
1362 config.backend,
1363 config.proxy_url.as_deref(),
1364 )
1365 .await?;
1366
1367 let inner = WebSocketClientInner::new_with_writer(config, writer).await?;
1369
1370 let connection_mode = inner.connection_mode.clone();
1371 let state_notify = inner.state_notify.clone();
1372 let reconnect_timeout = inner.reconnect_timeout;
1373 let auth_tracker = Arc::clone(&inner.auth_tracker);
1374 let reconnect_buffer_waits_for_auth = Arc::clone(&inner.reconnect_buffer_waits_for_auth);
1375 let keyed_quotas = keyed_quotas
1376 .into_iter()
1377 .map(|(key, quota)| (Ustr::from(&key), quota))
1378 .collect();
1379 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
1380 let writer_tx = inner.writer_tx.clone();
1381
1382 let controller_task = Self::spawn_controller_task(
1383 inner,
1384 connection_mode.clone(),
1385 state_notify.clone(),
1386 post_reconnect,
1387 Arc::clone(&auth_tracker),
1388 );
1389
1390 Ok((
1391 reader,
1392 Self {
1393 controller_task,
1394 connection_mode,
1395 state_notify,
1396 reconnect_timeout,
1397 rate_limiter,
1398 writer_tx,
1399 auth_tracker,
1400 reconnect_buffer_waits_for_auth,
1401 },
1402 ))
1403 }
1404
1405 pub async fn connect(
1423 config: WebSocketConfig,
1424 message_handler: Option<MessageHandler>,
1425 ping_handler: Option<PingHandler>,
1426 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1427 keyed_quotas: Vec<(String, Quota)>,
1428 default_quota: Option<Quota>,
1429 ) -> Result<Self, TransportError> {
1430 if message_handler.is_none() {
1432 return Err(TransportError::Io(std::io::Error::new(
1433 std::io::ErrorKind::InvalidInput,
1434 "Handler mode requires message_handler to be set. Use connect_stream() for stream mode without a handler.",
1435 )));
1436 }
1437
1438 log::debug!("Connecting");
1439 let inner =
1440 WebSocketClientInner::connect_url(config, message_handler, ping_handler).await?;
1441 let connection_mode = inner.connection_mode.clone();
1442 let state_notify = inner.state_notify.clone();
1443 let writer_tx = inner.writer_tx.clone();
1444 let reconnect_timeout = inner.reconnect_timeout;
1445 let auth_tracker = Arc::clone(&inner.auth_tracker);
1446 let reconnect_buffer_waits_for_auth = Arc::clone(&inner.reconnect_buffer_waits_for_auth);
1447
1448 let controller_task = Self::spawn_controller_task(
1449 inner,
1450 connection_mode.clone(),
1451 state_notify.clone(),
1452 post_reconnection,
1453 Arc::clone(&auth_tracker),
1454 );
1455
1456 let keyed_quotas = keyed_quotas
1457 .into_iter()
1458 .map(|(key, quota)| (Ustr::from(&key), quota))
1459 .collect();
1460 let rate_limiter = Arc::new(RateLimiter::new_with_quota(default_quota, keyed_quotas));
1461
1462 Ok(Self {
1463 controller_task,
1464 connection_mode,
1465 state_notify,
1466 reconnect_timeout,
1467 rate_limiter,
1468 writer_tx,
1469 auth_tracker,
1470 reconnect_buffer_waits_for_auth,
1471 })
1472 }
1473
1474 #[must_use]
1476 pub fn connection_mode(&self) -> ConnectionMode {
1477 ConnectionMode::from_atomic(&self.connection_mode)
1478 }
1479
1480 #[must_use]
1485 pub fn connection_mode_atomic(&self) -> Arc<AtomicU8> {
1486 Arc::clone(&self.connection_mode)
1487 }
1488
1489 #[inline]
1494 #[must_use]
1495 pub fn is_active(&self) -> bool {
1496 self.connection_mode().is_active()
1497 }
1498
1499 #[must_use]
1501 pub fn is_disconnected(&self) -> bool {
1502 self.controller_task.is_finished()
1503 }
1504
1505 #[inline]
1510 #[must_use]
1511 pub fn is_reconnecting(&self) -> bool {
1512 self.connection_mode().is_reconnect()
1513 }
1514
1515 pub fn set_auth_tracker(&self, tracker: AuthTracker, reconnect_buffer_waits_for_auth: bool) {
1525 let _ = self.auth_tracker.set(tracker);
1526 self.reconnect_buffer_waits_for_auth
1527 .store(reconnect_buffer_waits_for_auth, Ordering::Release);
1528 }
1529
1530 #[inline]
1534 #[must_use]
1535 pub fn is_disconnecting(&self) -> bool {
1536 self.connection_mode().is_disconnect()
1537 }
1538
1539 #[inline]
1545 #[must_use]
1546 pub fn is_closed(&self) -> bool {
1547 self.connection_mode().is_closed()
1548 }
1549
1550 #[inline]
1554 fn check_not_terminal(&self) -> Result<(), SendError> {
1555 match self.connection_mode() {
1556 ConnectionMode::Disconnect | ConnectionMode::Closed => Err(SendError::Closed),
1557 _ => Ok(()),
1558 }
1559 }
1560
1561 async fn await_rate_limit_or_closed(&self, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1563 const CHECK_INTERVAL_MS: u64 = 100;
1564
1565 tokio::select! {
1566 biased;
1567 () = self.rate_limiter.await_keys_ready(keys) => Ok(()),
1568 () = async {
1569 loop {
1570 let notified = self.state_notify.notified();
1571
1572 if matches!(self.connection_mode(), ConnectionMode::Disconnect | ConnectionMode::Closed) {
1573 break;
1574 }
1575 tokio::select! {
1576 biased;
1577 () = notified => {}
1578 () = dst::time::sleep(Duration::from_millis(CHECK_INTERVAL_MS)) => {}
1579 }
1580 }
1581 } => Err(SendError::Closed),
1582 }
1583 }
1584
1585 async fn wait_for_active(&self) -> Result<(), SendError> {
1591 const FALLBACK_INTERVAL_MS: u64 = 100;
1592
1593 let mode = self.connection_mode();
1594 if mode.is_active() {
1595 return Ok(());
1596 }
1597
1598 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
1599 return Err(SendError::Closed);
1600 }
1601
1602 log::debug!("Waiting for client to become ACTIVE before sending...");
1603
1604 let fallback_interval = Duration::from_millis(FALLBACK_INTERVAL_MS);
1605
1606 dst::time::timeout(self.reconnect_timeout, async {
1607 loop {
1608 let notified = self.state_notify.notified();
1611
1612 let mode = self.connection_mode();
1613 if mode.is_active() {
1614 return Ok(());
1615 }
1616
1617 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
1618 return Err(());
1619 }
1620
1621 tokio::select! {
1622 biased;
1623 () = notified => {}
1624 () = dst::time::sleep(fallback_interval) => {}
1625 }
1626 }
1627 })
1628 .await
1629 .map_err(|_| SendError::Timeout)?
1630 .map_err(|()| SendError::Closed)
1631 }
1632
1633 pub fn notify_closed(&self) {
1644 let mode = self.connection_mode();
1645 if mode.is_disconnect() || mode.is_closed() {
1646 return;
1647 }
1648
1649 log::debug!("Stream reader signalled EOF, transitioning to CLOSED");
1650
1651 self.connection_mode
1652 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1653 self.state_notify.notify_waiters();
1654 }
1655
1656 pub async fn disconnect(&self) {
1661 log::debug!("Disconnecting");
1662 self.connection_mode
1663 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
1664 self.state_notify.notify_waiters();
1665
1666 if dst::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
1667 while !self.is_disconnected() {
1668 dst::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
1669 }
1670
1671 if !self.controller_task.is_finished() {
1672 self.controller_task.abort();
1673 log_task_aborted("controller");
1674 }
1675 })
1676 .await
1677 == Ok(())
1678 {
1679 log::debug!("Controller task finished");
1680 } else {
1681 log::error!("Timeout waiting for controller task to finish");
1682
1683 if !self.controller_task.is_finished() {
1684 self.controller_task.abort();
1685 log_task_aborted("controller");
1686 }
1687 self.connection_mode
1688 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1689 }
1690 }
1691
1692 #[allow(unused_variables)]
1702 pub async fn send_text(&self, data: String, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1703 self.check_not_terminal()?;
1704
1705 self.await_rate_limit_or_closed(keys).await?;
1706 self.wait_for_active().await?;
1707
1708 log::trace!("Sending text: {data:?}");
1709
1710 let msg = Message::Text(data.into());
1711 self.writer_tx
1712 .send(WriterCommand::Send(msg))
1713 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1714 }
1715
1716 pub async fn send_pong(&self, data: Vec<u8>) -> Result<(), SendError> {
1722 self.wait_for_active().await?;
1723
1724 log::trace!("Sending pong frame ({} bytes)", data.len());
1725
1726 let msg = Message::Pong(data.into());
1727 self.writer_tx
1728 .send(WriterCommand::Send(msg))
1729 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1730 }
1731
1732 #[allow(unused_variables)]
1742 pub async fn send_bytes(&self, data: Vec<u8>, keys: Option<&[Ustr]>) -> Result<(), SendError> {
1743 self.check_not_terminal()?;
1744
1745 self.await_rate_limit_or_closed(keys).await?;
1746 self.wait_for_active().await?;
1747
1748 log::trace!("Sending bytes: {data:?}");
1749
1750 let msg = Message::Binary(data.into());
1751 self.writer_tx
1752 .send(WriterCommand::Send(msg))
1753 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1754 }
1755
1756 pub async fn send_close_message(&self) -> Result<(), SendError> {
1762 self.wait_for_active().await?;
1763
1764 let msg = Message::Close(None);
1765 self.writer_tx
1766 .send(WriterCommand::Send(msg))
1767 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1768 }
1769
1770 fn spawn_controller_task(
1771 mut inner: WebSocketClientInner,
1772 connection_mode: Arc<AtomicU8>,
1773 state_notify: Arc<tokio::sync::Notify>,
1774 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1775 auth_tracker: Arc<OnceLock<AuthTracker>>,
1776 ) -> tokio::task::JoinHandle<()> {
1777 const CONTROLLER_FALLBACK_INTERVAL_MS: u64 = 100;
1778
1779 tokio::task::spawn(async move {
1780 log_task_started("controller");
1781
1782 let fallback_interval = Duration::from_millis(CONTROLLER_FALLBACK_INTERVAL_MS);
1783
1784 loop {
1785 tokio::select! {
1786 biased;
1787 () = state_notify.notified() => {}
1788 () = dst::time::sleep(fallback_interval) => {}
1789 }
1790
1791 let mut mode = ConnectionMode::from_atomic(&connection_mode);
1792
1793 if mode.is_disconnect() {
1794 log::debug!("Disconnecting");
1795
1796 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1797 if dst::time::timeout(timeout, async {
1798 dst::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1800
1801 if let Some(task) = &inner.read_task
1802 && !task.is_finished()
1803 {
1804 task.abort();
1805 log_task_aborted("read");
1806 }
1807
1808 if let Some(task) = &inner.heartbeat_task
1809 && !task.is_finished()
1810 {
1811 task.abort();
1812 log_task_aborted("heartbeat");
1813 }
1814 })
1815 .await
1816 .is_err()
1817 {
1818 log::error!("Shutdown timed out after {}s", timeout.as_secs());
1819 }
1820
1821 log::debug!("Closed");
1822 break; }
1824
1825 if mode.is_closed() {
1826 log::debug!("Connection closed");
1827 break;
1828 }
1829
1830 if mode.is_active() && !inner.is_alive() {
1831 let target = if inner.is_stream_mode {
1832 ConnectionMode::Closed
1833 } else {
1834 ConnectionMode::Reconnect
1835 };
1836
1837 if connection_mode
1838 .compare_exchange(
1839 ConnectionMode::Active.as_u8(),
1840 target.as_u8(),
1841 Ordering::SeqCst,
1842 Ordering::SeqCst,
1843 )
1844 .is_ok()
1845 {
1846 if let Some(tracker) = auth_tracker.get() {
1847 tracker.invalidate();
1848 }
1849 log::debug!("Detected dead connection, transitioning to {target:?}");
1850 }
1851 mode = ConnectionMode::from_atomic(&connection_mode);
1852 }
1853
1854 if mode.is_reconnect() {
1855 if let Some(max_attempts) = inner.reconnect_max_attempts
1857 && inner.reconnection_attempt_count >= max_attempts
1858 {
1859 log::error!(
1860 "Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
1861 );
1862 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1863 state_notify.notify_waiters();
1864 break;
1865 }
1866
1867 inner.reconnection_attempt_count += 1;
1868 log::debug!(
1869 "Reconnection attempt {} of {}",
1870 inner.reconnection_attempt_count,
1871 inner
1872 .reconnect_max_attempts
1873 .map_or_else(|| "unlimited".to_string(), |m| m.to_string())
1874 );
1875
1876 let reconnect_result = tokio::select! {
1878 biased;
1879 result = inner.reconnect() => Some(result),
1880 () = async {
1881 loop {
1882 state_notify.notified().await;
1883
1884 if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
1885 break;
1886 }
1887 }
1888 } => None,
1889 };
1890
1891 match reconnect_result {
1892 None => {
1893 log::debug!("Reconnect interrupted by disconnect");
1894 }
1895 Some(Ok(())) => {
1896 inner.backoff.reset();
1897 inner.reconnection_attempt_count = 0;
1898
1899 state_notify.notify_waiters();
1900
1901 if ConnectionMode::from_atomic(&connection_mode).is_active() {
1902 if let Some(ref handler) = inner.message_handler {
1903 let reconnected_msg =
1904 Message::Text(RECONNECTED.to_string().into());
1905 handler(reconnected_msg);
1906 log::debug!("Sent reconnected message to handler");
1907 }
1908
1909 if let Some(ref callback) = post_reconnection {
1911 callback();
1912 log::debug!("Called `post_reconnection` handler");
1913 }
1914
1915 log::debug!("Reconnected successfully");
1916 } else {
1917 log::debug!(
1918 "Skipping post_reconnection handlers due to disconnect state"
1919 );
1920 }
1921 }
1922 Some(Err(e)) => {
1923 let duration = inner.backoff.next_duration();
1924 log::warn!(
1925 "Reconnect attempt {} failed: {e}",
1926 inner.reconnection_attempt_count
1927 );
1928
1929 if !duration.is_zero() {
1930 log::warn!("Backing off for {}s...", duration.as_secs_f64());
1931 tokio::select! {
1933 biased;
1934 () = dst::time::sleep(duration) => {}
1935 () = async {
1936 loop {
1937 state_notify.notified().await;
1938
1939 if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
1940 break;
1941 }
1942 }
1943 } => {
1944 log::debug!("Backoff interrupted by disconnect");
1945 }
1946 }
1947 }
1948 }
1949 }
1950 }
1951 }
1952 inner
1953 .connection_mode
1954 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1955
1956 log_task_stopped("controller");
1957 })
1958 }
1959}
1960
1961impl Drop for WebSocketClient {
1963 fn drop(&mut self) {
1964 if !self.controller_task.is_finished() {
1965 self.controller_task.abort();
1966 log_task_aborted("controller");
1967 }
1968 }
1969}
1970
1971#[cfg(test)]
1972#[cfg(not(feature = "turmoil"))]
1973#[cfg(not(all(feature = "simulation", madsim)))] #[cfg(target_os = "linux")] mod tests {
1976 use std::{num::NonZeroU32, sync::Arc};
1977
1978 use futures_util::{SinkExt, StreamExt};
1979 use tokio::{
1980 net::TcpListener,
1981 task::{self, JoinHandle},
1982 };
1983 use tokio_tungstenite::{
1984 accept_hdr_async,
1985 tungstenite::{
1986 Message as WsMessage,
1987 handshake::server::{self, Callback},
1988 http::HeaderValue,
1989 },
1990 };
1991
1992 use crate::{
1993 ratelimiter::quota::Quota,
1994 websocket::{TransportBackend, WebSocketClient, WebSocketConfig},
1995 };
1996
1997 struct TestServer {
1998 task: JoinHandle<()>,
1999 port: u16,
2000 }
2001
2002 #[derive(Debug, Clone)]
2003 struct TestCallback {
2004 key: String,
2005 value: HeaderValue,
2006 }
2007
2008 impl Callback for TestCallback {
2009 #[expect(clippy::panic_in_result_fn)]
2010 fn on_request(
2011 self,
2012 request: &server::Request,
2013 response: server::Response,
2014 ) -> Result<server::Response, server::ErrorResponse> {
2015 let _ = response;
2016 let value = request.headers().get(&self.key);
2017 assert!(value.is_some());
2018
2019 if let Some(value) = request.headers().get(&self.key) {
2020 assert_eq!(value, self.value);
2021 }
2022
2023 Ok(response)
2024 }
2025 }
2026
2027 impl TestServer {
2028 async fn setup() -> Self {
2029 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
2030 let port = TcpListener::local_addr(&server).unwrap().port();
2031
2032 let header_key = "test".to_string();
2033 let header_value = "test".to_string();
2034
2035 let test_call_back = TestCallback {
2036 key: header_key,
2037 value: HeaderValue::from_str(&header_value).unwrap(),
2038 };
2039
2040 let task = task::spawn(async move {
2041 loop {
2043 let (conn, _) = server.accept().await.unwrap();
2044 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
2045 .await
2046 .unwrap();
2047
2048 task::spawn(async move {
2049 #[expect(clippy::collapsible_match)]
2051 while let Some(Ok(msg)) = websocket.next().await {
2052 match msg {
2053 WsMessage::Text(txt) if txt == "close-now" => {
2054 log::debug!("Forcibly closing from server side");
2055 let _ = websocket.close(None).await;
2057 break;
2058 }
2059 WsMessage::Text(_) | WsMessage::Binary(_) => {
2061 if websocket.send(msg).await.is_err() {
2062 break;
2063 }
2064 }
2065 WsMessage::Close(_frame) => {
2067 let _ = websocket.close(None).await;
2068 break;
2069 }
2070 _ => {}
2072 }
2073 }
2074 });
2075 }
2076 });
2077
2078 Self { task, port }
2079 }
2080 }
2081
2082 impl Drop for TestServer {
2083 fn drop(&mut self) {
2084 self.task.abort();
2085 }
2086 }
2087
2088 async fn setup_test_client(port: u16) -> WebSocketClient {
2089 let config = WebSocketConfig {
2090 url: format!("ws://127.0.0.1:{port}"),
2091 headers: vec![("test".into(), "test".into())],
2092 heartbeat: None,
2093 heartbeat_msg: None,
2094 reconnect_timeout_ms: None,
2095 reconnect_delay_initial_ms: None,
2096 reconnect_backoff_factor: None,
2097 reconnect_delay_max_ms: None,
2098 reconnect_jitter_ms: None,
2099 reconnect_max_attempts: None,
2100 idle_timeout_ms: None,
2101 backend: TransportBackend::Tungstenite,
2102 proxy_url: None,
2103 };
2104 WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
2105 .await
2106 .expect("Failed to connect")
2107 }
2108
2109 #[tokio::test]
2110 async fn test_websocket_basic() {
2111 let server = TestServer::setup().await;
2112 let client = setup_test_client(server.port).await;
2113
2114 assert!(!client.is_disconnected());
2115
2116 client.disconnect().await;
2117 assert!(client.is_disconnected());
2118 }
2119
2120 #[tokio::test]
2121 async fn test_websocket_heartbeat() {
2122 let server = TestServer::setup().await;
2123 let client = setup_test_client(server.port).await;
2124
2125 tokio::time::sleep(std::time::Duration::from_secs(3)).await;
2127
2128 client.disconnect().await;
2130 assert!(client.is_disconnected());
2131 }
2132
2133 #[tokio::test]
2134 async fn test_websocket_reconnect_exhausted() {
2135 let config = WebSocketConfig {
2136 url: "ws://127.0.0.1:9997".into(), headers: vec![],
2138 heartbeat: None,
2139 heartbeat_msg: None,
2140 reconnect_timeout_ms: None,
2141 reconnect_delay_initial_ms: None,
2142 reconnect_backoff_factor: None,
2143 reconnect_delay_max_ms: None,
2144 reconnect_jitter_ms: None,
2145 reconnect_max_attempts: None,
2146 idle_timeout_ms: None,
2147 backend: TransportBackend::Tungstenite,
2148 proxy_url: None,
2149 };
2150 let res =
2151 WebSocketClient::connect(config, Some(Arc::new(|_| {})), None, None, vec![], None)
2152 .await;
2153 assert!(res.is_err(), "Should fail quickly with no server");
2154 }
2155
2156 #[tokio::test]
2157 async fn test_websocket_forced_close_reconnect() {
2158 let server = TestServer::setup().await;
2159 let client = setup_test_client(server.port).await;
2160
2161 client.send_text("Hello".into(), None).await.unwrap();
2163
2164 client.send_text("close-now".into(), None).await.unwrap();
2166
2167 tokio::time::sleep(std::time::Duration::from_secs(1)).await;
2169
2170 assert!(!client.is_disconnected());
2172
2173 client.disconnect().await;
2175 assert!(client.is_disconnected());
2176 }
2177
2178 #[tokio::test]
2179 async fn test_rate_limiter() {
2180 let server = TestServer::setup().await;
2181 let quota = Quota::per_second(NonZeroU32::new(2).unwrap()).unwrap();
2182
2183 let config = WebSocketConfig {
2184 url: format!("ws://127.0.0.1:{}", server.port),
2185 headers: vec![("test".into(), "test".into())],
2186 heartbeat: None,
2187 heartbeat_msg: None,
2188 reconnect_timeout_ms: None,
2189 reconnect_delay_initial_ms: None,
2190 reconnect_backoff_factor: None,
2191 reconnect_delay_max_ms: None,
2192 reconnect_jitter_ms: None,
2193 reconnect_max_attempts: None,
2194 idle_timeout_ms: None,
2195 backend: TransportBackend::Tungstenite,
2196 proxy_url: None,
2197 };
2198
2199 let client = WebSocketClient::connect(
2200 config,
2201 Some(Arc::new(|_| {})),
2202 None,
2203 None,
2204 vec![("default".into(), quota)],
2205 None,
2206 )
2207 .await
2208 .unwrap();
2209
2210 client.send_text("test1".into(), None).await.unwrap();
2212 client.send_text("test2".into(), None).await.unwrap();
2213
2214 client.send_text("test3".into(), None).await.unwrap();
2216
2217 client.disconnect().await;
2219 assert!(client.is_disconnected());
2220 }
2221
2222 #[tokio::test]
2223 async fn test_concurrent_writers() {
2224 let server = TestServer::setup().await;
2225 let client = Arc::new(setup_test_client(server.port).await);
2226
2227 let mut handles = vec![];
2228
2229 for i in 0..10 {
2230 let client = client.clone();
2231 handles.push(task::spawn(async move {
2232 client.send_text(format!("test{i}"), None).await.unwrap();
2233 }));
2234 }
2235
2236 for handle in handles {
2237 handle.await.unwrap();
2238 }
2239
2240 client.disconnect().await;
2242 assert!(client.is_disconnected());
2243 }
2244}
2245
2246#[cfg(test)]
2247#[cfg(not(feature = "turmoil"))]
2248#[cfg(not(all(feature = "simulation", madsim)))] mod rust_tests {
2250 use std::sync::{
2251 Arc, OnceLock,
2252 atomic::{AtomicBool, AtomicU8, Ordering},
2253 };
2254
2255 use futures_util::{SinkExt, StreamExt};
2256 use nautilus_common::testing::wait_until_async;
2257 use rstest::rstest;
2258 #[cfg(feature = "transport-sockudo")]
2259 use sockudo_ws::handshake as sockudo_handshake;
2260 #[cfg(feature = "transport-sockudo")]
2261 use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
2262 use tokio::{
2263 net::TcpListener,
2264 task::{self, JoinHandle},
2265 time::{Duration, sleep},
2266 };
2267 use tokio_tungstenite::{accept_async, tungstenite::Message as WsMessage};
2268 #[cfg(feature = "transport-sockudo")]
2269 use tokio_tungstenite::{
2270 accept_hdr_async,
2271 tungstenite::{
2272 handshake::server::{self, Callback},
2273 http::HeaderValue,
2274 },
2275 };
2276
2277 use super::*;
2278 use crate::websocket::types::channel_message_handler;
2279
2280 struct RecordingServer {
2281 task: JoinHandle<()>,
2282 port: u16,
2283 messages: Arc<tokio::sync::Mutex<Vec<String>>>,
2284 }
2285
2286 #[cfg(feature = "transport-sockudo")]
2287 async fn read_http_request<S>(stream: &mut S) -> Vec<u8>
2288 where
2289 S: AsyncRead + Unpin,
2290 {
2291 let mut buf = Vec::new();
2292 let mut chunk = [0u8; 256];
2293
2294 loop {
2295 let n = stream.read(&mut chunk).await.unwrap();
2296 assert!(n > 0, "HTTP request closed before headers completed");
2297 buf.extend_from_slice(&chunk[..n]);
2298 if buf.windows(4).any(|window| window == b"\r\n\r\n") {
2299 return buf;
2300 }
2301 }
2302 }
2303
2304 #[cfg(feature = "transport-sockudo")]
2305 fn extract_header<'a>(request: &'a str, name: &str) -> Option<&'a str> {
2306 request.lines().find_map(|line| {
2307 let (header_name, header_value) = line.split_once(':')?;
2308 if header_name.eq_ignore_ascii_case(name) {
2309 Some(header_value.trim())
2310 } else {
2311 None
2312 }
2313 })
2314 }
2315
2316 #[cfg(feature = "transport-sockudo")]
2317 #[derive(Debug, Clone)]
2318 struct HeaderAssertCallback {
2319 key: String,
2320 value: HeaderValue,
2321 }
2322
2323 #[cfg(feature = "transport-sockudo")]
2324 impl Callback for HeaderAssertCallback {
2325 #[expect(
2326 clippy::panic_in_result_fn,
2327 reason = "assertion failures should fail the test"
2328 )]
2329 fn on_request(
2330 self,
2331 request: &server::Request,
2332 response: server::Response,
2333 ) -> Result<server::Response, server::ErrorResponse> {
2334 assert_eq!(request.headers().get(&self.key), Some(&self.value));
2335 Ok(response)
2336 }
2337 }
2338
2339 impl RecordingServer {
2340 async fn setup() -> Self {
2341 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2342 let port = listener.local_addr().unwrap().port();
2343 let messages = Arc::new(tokio::sync::Mutex::new(Vec::new()));
2344 let messages_clone = Arc::clone(&messages);
2345
2346 let task = task::spawn(async move {
2347 loop {
2348 let (stream, _) = listener.accept().await.unwrap();
2349 let mut websocket = accept_async(stream).await.unwrap();
2350 let messages = Arc::clone(&messages_clone);
2351
2352 task::spawn(async move {
2353 while let Some(Ok(msg)) = websocket.next().await {
2354 match msg {
2355 WsMessage::Text(text) => {
2356 messages.lock().await.push(text.to_string());
2357 }
2358 WsMessage::Close(_) => {
2359 let _ = websocket.close(None).await;
2360 break;
2361 }
2362 _ => {}
2363 }
2364 }
2365 });
2366 }
2367 });
2368
2369 Self {
2370 task,
2371 port,
2372 messages,
2373 }
2374 }
2375
2376 async fn messages(&self) -> Vec<String> {
2377 self.messages.lock().await.clone()
2378 }
2379 }
2380
2381 impl Drop for RecordingServer {
2382 fn drop(&mut self) {
2383 self.task.abort();
2384 }
2385 }
2386
2387 #[rstest]
2388 #[tokio::test]
2389 async fn test_reconnect_then_disconnect() {
2390 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2392 let port = listener.local_addr().unwrap().port();
2393
2394 let server = task::spawn(async move {
2396 let (stream, _) = listener.accept().await.unwrap();
2397 let ws = accept_async(stream).await.unwrap();
2398 drop(ws);
2399 sleep(Duration::from_secs(1)).await;
2401 });
2402
2403 let (handler, _rx) = channel_message_handler();
2405
2406 let config = WebSocketConfig {
2408 url: format!("ws://127.0.0.1:{port}"),
2409 headers: vec![],
2410 heartbeat: None,
2411 heartbeat_msg: None,
2412 reconnect_timeout_ms: Some(1_000),
2413 reconnect_delay_initial_ms: Some(50),
2414 reconnect_delay_max_ms: Some(100),
2415 reconnect_backoff_factor: Some(1.0),
2416 reconnect_jitter_ms: Some(0),
2417 reconnect_max_attempts: None,
2418 idle_timeout_ms: None,
2419 backend: TransportBackend::Tungstenite,
2420 proxy_url: None,
2421 };
2422
2423 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2425 .await
2426 .unwrap();
2427
2428 sleep(Duration::from_millis(100)).await;
2430 client.disconnect().await;
2432 assert!(client.is_disconnected());
2433 server.abort();
2434 }
2435
2436 #[rstest]
2437 #[tokio::test]
2438 async fn test_reconnect_state_flips_when_reader_stops() {
2439 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2441 let port = listener.local_addr().unwrap().port();
2442
2443 let server = task::spawn(async move {
2444 if let Ok((stream, _)) = listener.accept().await
2445 && let Ok(ws) = accept_async(stream).await
2446 {
2447 drop(ws);
2448 }
2449 sleep(Duration::from_millis(50)).await;
2450 });
2451
2452 let (handler, _rx) = channel_message_handler();
2453
2454 let config = WebSocketConfig {
2455 url: format!("ws://127.0.0.1:{port}"),
2456 headers: vec![],
2457 heartbeat: None,
2458 heartbeat_msg: None,
2459 reconnect_timeout_ms: Some(1_000),
2460 reconnect_delay_initial_ms: Some(50),
2461 reconnect_delay_max_ms: Some(100),
2462 reconnect_backoff_factor: Some(1.0),
2463 reconnect_jitter_ms: Some(0),
2464 reconnect_max_attempts: None,
2465 idle_timeout_ms: None,
2466 backend: TransportBackend::Tungstenite,
2467 proxy_url: None,
2468 };
2469
2470 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2471 .await
2472 .unwrap();
2473
2474 tokio::time::timeout(Duration::from_secs(2), async {
2475 loop {
2476 if client.is_reconnecting() {
2477 break;
2478 }
2479 tokio::time::sleep(Duration::from_millis(10)).await;
2480 }
2481 })
2482 .await
2483 .expect("client did not enter RECONNECT state");
2484
2485 client.disconnect().await;
2486 server.abort();
2487 }
2488
2489 #[rstest]
2490 #[tokio::test]
2491 async fn test_stream_mode_disables_auto_reconnect() {
2492 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2495 let port = listener.local_addr().unwrap().port();
2496
2497 let server = task::spawn(async move {
2498 if let Ok((stream, _)) = listener.accept().await
2499 && let Ok(_ws) = accept_async(stream).await
2500 {
2501 sleep(Duration::from_millis(100)).await;
2503 }
2504 });
2505
2506 let config = WebSocketConfig {
2507 url: format!("ws://127.0.0.1:{port}"),
2508 headers: vec![],
2509 heartbeat: None,
2510 heartbeat_msg: None,
2511 reconnect_timeout_ms: Some(1_000),
2512 reconnect_delay_initial_ms: Some(50),
2513 reconnect_delay_max_ms: Some(100),
2514 reconnect_backoff_factor: Some(1.0),
2515 reconnect_jitter_ms: Some(0),
2516 reconnect_max_attempts: None,
2517 idle_timeout_ms: None,
2518 backend: TransportBackend::Tungstenite,
2519 proxy_url: None,
2520 };
2521
2522 let (_reader, _client) = WebSocketClient::connect_stream(config, vec![], None, None)
2523 .await
2524 .unwrap();
2525
2526 server.abort();
2534 }
2535
2536 #[rstest]
2537 #[tokio::test]
2538 async fn test_message_handler_mode_allows_auto_reconnect() {
2539 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2541 let port = listener.local_addr().unwrap().port();
2542
2543 let server = task::spawn(async move {
2544 if let Ok((stream, _)) = listener.accept().await
2546 && let Ok(ws) = accept_async(stream).await
2547 {
2548 drop(ws);
2549 }
2550 sleep(Duration::from_millis(50)).await;
2551 });
2552
2553 let (handler, _rx) = channel_message_handler();
2554
2555 let config = WebSocketConfig {
2556 url: format!("ws://127.0.0.1:{port}"),
2557 headers: vec![],
2558 heartbeat: None,
2559 heartbeat_msg: None,
2560 reconnect_timeout_ms: Some(1_000),
2561 reconnect_delay_initial_ms: Some(50),
2562 reconnect_delay_max_ms: Some(100),
2563 reconnect_backoff_factor: Some(1.0),
2564 reconnect_jitter_ms: Some(0),
2565 reconnect_max_attempts: None,
2566 idle_timeout_ms: None,
2567 backend: TransportBackend::Tungstenite,
2568 proxy_url: None,
2569 };
2570
2571 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2572 .await
2573 .unwrap();
2574
2575 tokio::time::timeout(Duration::from_secs(2), async {
2577 loop {
2578 if client.is_reconnecting() || client.is_closed() {
2579 break;
2580 }
2581 tokio::time::sleep(Duration::from_millis(10)).await;
2582 }
2583 })
2584 .await
2585 .expect("client should attempt reconnection or close");
2586
2587 assert!(
2590 client.is_reconnecting() || client.is_closed(),
2591 "Client with message handler should attempt reconnection"
2592 );
2593
2594 client.disconnect().await;
2595 server.abort();
2596 }
2597
2598 #[rstest]
2599 #[tokio::test]
2600 async fn test_handler_mode_reconnect_with_new_connection() {
2601 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2603 let port = listener.local_addr().unwrap().port();
2604
2605 let server = task::spawn(async move {
2606 if let Ok((stream, _)) = listener.accept().await
2608 && let Ok(ws) = accept_async(stream).await
2609 {
2610 drop(ws);
2611 }
2612
2613 sleep(Duration::from_millis(100)).await;
2615
2616 if let Ok((stream, _)) = listener.accept().await
2618 && let Ok(mut ws) = accept_async(stream).await
2619 {
2620 use futures_util::SinkExt;
2621 let _ = ws
2622 .send(WsMessage::Text("reconnected".to_string().into()))
2623 .await;
2624 sleep(Duration::from_secs(1)).await;
2625 }
2626 });
2627
2628 let (handler, mut rx) = channel_message_handler();
2629
2630 let config = WebSocketConfig {
2631 url: format!("ws://127.0.0.1:{port}"),
2632 headers: vec![],
2633 heartbeat: None,
2634 heartbeat_msg: None,
2635 reconnect_timeout_ms: Some(2_000),
2636 reconnect_delay_initial_ms: Some(50),
2637 reconnect_delay_max_ms: Some(200),
2638 reconnect_backoff_factor: Some(1.5),
2639 reconnect_jitter_ms: Some(10),
2640 reconnect_max_attempts: None,
2641 idle_timeout_ms: None,
2642 backend: TransportBackend::Tungstenite,
2643 proxy_url: None,
2644 };
2645
2646 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2647 .await
2648 .unwrap();
2649
2650 let result = tokio::time::timeout(Duration::from_secs(5), async {
2652 loop {
2653 if let Ok(msg) = rx.try_recv()
2654 && matches!(msg, WsMessage::Text(ref text) if AsRef::<str>::as_ref(text) == "reconnected")
2655 {
2656 return true;
2657 }
2658 tokio::time::sleep(Duration::from_millis(10)).await;
2659 }
2660 })
2661 .await;
2662
2663 assert!(
2664 result.is_ok(),
2665 "Should receive message after reconnection within timeout"
2666 );
2667
2668 client.disconnect().await;
2669 server.abort();
2670 }
2671
2672 #[rstest]
2673 #[tokio::test]
2674 async fn test_stream_mode_no_auto_reconnect() {
2675 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2678 let port = listener.local_addr().unwrap().port();
2679
2680 let server = task::spawn(async move {
2681 if let Ok((stream, _)) = listener.accept().await
2683 && let Ok(mut ws) = accept_async(stream).await
2684 {
2685 use futures_util::SinkExt;
2686 let _ = ws.send(WsMessage::Text("hello".to_string().into())).await;
2687 sleep(Duration::from_millis(50)).await;
2688 }
2690 });
2691
2692 let config = WebSocketConfig {
2693 url: format!("ws://127.0.0.1:{port}"),
2694 headers: vec![],
2695 heartbeat: None,
2696 heartbeat_msg: None,
2697 reconnect_timeout_ms: Some(1_000),
2698 reconnect_delay_initial_ms: Some(50),
2699 reconnect_delay_max_ms: Some(100),
2700 reconnect_backoff_factor: Some(1.0),
2701 reconnect_jitter_ms: Some(0),
2702 reconnect_max_attempts: None,
2703 idle_timeout_ms: None,
2704 backend: TransportBackend::Tungstenite,
2705 proxy_url: None,
2706 };
2707
2708 let (mut reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
2709 .await
2710 .unwrap();
2711
2712 assert!(client.is_active(), "Client should start as active");
2714
2715 let msg = reader.next().await;
2717 assert!(
2718 matches!(&msg, Some(Ok(Message::Text(bytes))) if bytes.as_ref() == b"hello"),
2719 "Should receive initial message"
2720 );
2721
2722 while let Some(msg) = reader.next().await {
2724 if msg.is_err() || matches!(msg, Ok(Message::Close(_))) {
2725 break;
2726 }
2727 }
2728
2729 sleep(Duration::from_millis(200)).await;
2732 assert!(
2733 client.is_active(),
2734 "Stream mode client stays ACTIVE before notify_closed()"
2735 );
2736
2737 client.notify_closed();
2739
2740 assert!(
2741 client.is_closed(),
2742 "Stream mode client should be CLOSED after notify_closed()"
2743 );
2744 assert!(
2745 !client.is_reconnecting(),
2746 "Stream mode client should never attempt reconnection"
2747 );
2748
2749 client.disconnect().await;
2750 server.abort();
2751 }
2752
2753 #[rstest]
2754 #[tokio::test]
2755 async fn test_send_timeout_uses_configured_reconnect_timeout() {
2756 use nautilus_common::testing::wait_until_async;
2759
2760 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2761 let port = listener.local_addr().unwrap().port();
2762
2763 let server = task::spawn(async move {
2764 if let Ok((stream, _)) = listener.accept().await
2766 && let Ok(ws) = accept_async(stream).await
2767 {
2768 drop(ws);
2769 }
2770 sleep(Duration::from_mins(1)).await;
2772 });
2773
2774 let (handler, _rx) = channel_message_handler();
2775
2776 let config = WebSocketConfig {
2778 url: format!("ws://127.0.0.1:{port}"),
2779 headers: vec![],
2780 heartbeat: None,
2781 heartbeat_msg: None,
2782 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(50),
2784 reconnect_delay_max_ms: Some(100),
2785 reconnect_backoff_factor: Some(1.0),
2786 reconnect_jitter_ms: Some(0),
2787 reconnect_max_attempts: None,
2788 idle_timeout_ms: None,
2789 backend: TransportBackend::Tungstenite,
2790 proxy_url: None,
2791 };
2792
2793 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2794 .await
2795 .unwrap();
2796
2797 wait_until_async(
2799 || async { client.is_reconnecting() },
2800 Duration::from_secs(3),
2801 )
2802 .await;
2803
2804 let start = std::time::Instant::now();
2806 let send_result = client.send_text("test".to_string(), None).await;
2807 let elapsed = start.elapsed();
2808
2809 assert!(
2810 send_result.is_err(),
2811 "Send should fail when client stuck in RECONNECT"
2812 );
2813 assert!(
2814 matches!(send_result, Err(crate::error::SendError::Timeout)),
2815 "Send should return Timeout error, was: {send_result:?}"
2816 );
2817 assert!(
2820 elapsed >= Duration::from_millis(1800),
2821 "Send should timeout after at least 2s (configured timeout), took {elapsed:?}"
2822 );
2823
2824 client.disconnect().await;
2825 server.abort();
2826 }
2827
2828 #[rstest]
2829 #[tokio::test]
2830 async fn test_send_waits_during_reconnection() {
2831 use nautilus_common::testing::wait_until_async;
2833
2834 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2835 let port = listener.local_addr().unwrap().port();
2836
2837 let server = task::spawn(async move {
2838 if let Ok((stream, _)) = listener.accept().await
2840 && let Ok(ws) = accept_async(stream).await
2841 {
2842 drop(ws);
2843 }
2844
2845 sleep(Duration::from_millis(500)).await;
2847
2848 if let Ok((stream, _)) = listener.accept().await
2850 && let Ok(mut ws) = accept_async(stream).await
2851 {
2852 while let Some(Ok(msg)) = ws.next().await {
2854 if ws.send(msg).await.is_err() {
2855 break;
2856 }
2857 }
2858 }
2859 });
2860
2861 let (handler, _rx) = channel_message_handler();
2862
2863 let config = WebSocketConfig {
2864 url: format!("ws://127.0.0.1:{port}"),
2865 headers: vec![],
2866 heartbeat: None,
2867 heartbeat_msg: None,
2868 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
2870 reconnect_delay_max_ms: Some(200),
2871 reconnect_backoff_factor: Some(1.0),
2872 reconnect_jitter_ms: Some(0),
2873 reconnect_max_attempts: None,
2874 idle_timeout_ms: None,
2875 backend: TransportBackend::Tungstenite,
2876 proxy_url: None,
2877 };
2878
2879 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
2880 .await
2881 .unwrap();
2882
2883 wait_until_async(
2885 || async { client.is_reconnecting() },
2886 Duration::from_secs(2),
2887 )
2888 .await;
2889
2890 let send_result = tokio::time::timeout(
2892 Duration::from_secs(3),
2893 client.send_text("test_message".to_string(), None),
2894 )
2895 .await;
2896
2897 assert!(
2898 send_result.is_ok() && send_result.unwrap().is_ok(),
2899 "Send should succeed after waiting for reconnection"
2900 );
2901
2902 client.disconnect().await;
2903 server.abort();
2904 }
2905
2906 #[rstest]
2907 #[tokio::test]
2908 async fn test_rate_limiter_before_active_wait() {
2909 use std::{num::NonZeroU32, sync::Arc};
2914
2915 use nautilus_common::testing::wait_until_async;
2916
2917 use crate::ratelimiter::quota::Quota;
2918
2919 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2920 let port = listener.local_addr().unwrap().port();
2921
2922 let server = task::spawn(async move {
2923 if let Ok((stream, _)) = listener.accept().await
2925 && let Ok(mut ws) = accept_async(stream).await
2926 {
2927 if let Some(Ok(_)) = ws.next().await {
2929 drop(ws);
2930 }
2931 }
2932
2933 sleep(Duration::from_millis(500)).await;
2935
2936 if let Ok((stream, _)) = listener.accept().await
2938 && let Ok(mut ws) = accept_async(stream).await
2939 {
2940 while let Some(Ok(msg)) = ws.next().await {
2941 if ws.send(msg).await.is_err() {
2942 break;
2943 }
2944 }
2945 }
2946 });
2947
2948 let (handler, _rx) = channel_message_handler();
2949
2950 let config = WebSocketConfig {
2951 url: format!("ws://127.0.0.1:{port}"),
2952 headers: vec![],
2953 heartbeat: None,
2954 heartbeat_msg: None,
2955 reconnect_timeout_ms: Some(5_000),
2956 reconnect_delay_initial_ms: Some(50),
2957 reconnect_delay_max_ms: Some(100),
2958 reconnect_backoff_factor: Some(1.0),
2959 reconnect_jitter_ms: Some(0),
2960 reconnect_max_attempts: None,
2961 idle_timeout_ms: None,
2962 backend: TransportBackend::Tungstenite,
2963 proxy_url: None,
2964 };
2965
2966 let quota = Quota::per_second(NonZeroU32::new(1).unwrap())
2968 .unwrap()
2969 .allow_burst(NonZeroU32::new(1).unwrap());
2970
2971 let client = Arc::new(
2972 WebSocketClient::connect(
2973 config,
2974 Some(handler),
2975 None,
2976 None,
2977 vec![("test_key".to_string(), quota)],
2978 None,
2979 )
2980 .await
2981 .unwrap(),
2982 );
2983
2984 let test_key: [Ustr; 1] = [Ustr::from("test_key")];
2986 client
2987 .send_text("msg1".to_string(), Some(test_key.as_slice()))
2988 .await
2989 .unwrap();
2990
2991 wait_until_async(
2993 || async { client.is_reconnecting() },
2994 Duration::from_secs(2),
2995 )
2996 .await;
2997
2998 let start = std::time::Instant::now();
3000 let send_result = client
3001 .send_text("msg2".to_string(), Some(test_key.as_slice()))
3002 .await;
3003 let elapsed = start.elapsed();
3004
3005 assert!(
3007 send_result.is_ok(),
3008 "Send should succeed after rate limit + reconnection, was: {send_result:?}"
3009 );
3010 assert!(
3014 elapsed >= Duration::from_millis(850),
3015 "Should wait for rate limit (~1s), waited {elapsed:?}"
3016 );
3017
3018 client.disconnect().await;
3019 server.abort();
3020 }
3021
3022 #[rstest]
3023 #[tokio::test]
3024 async fn test_disconnect_during_reconnect_exits_cleanly() {
3025 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3028 let port = listener.local_addr().unwrap().port();
3029
3030 let server = task::spawn(async move {
3031 if let Ok((stream, _)) = listener.accept().await
3033 && let Ok(ws) = accept_async(stream).await
3034 {
3035 drop(ws);
3036 }
3037 sleep(Duration::from_mins(1)).await;
3039 });
3040
3041 let (handler, _rx) = channel_message_handler();
3042
3043 let config = WebSocketConfig {
3044 url: format!("ws://127.0.0.1:{port}"),
3045 headers: vec![],
3046 heartbeat: None,
3047 heartbeat_msg: None,
3048 reconnect_timeout_ms: Some(2_000), reconnect_delay_initial_ms: Some(100),
3050 reconnect_delay_max_ms: Some(200),
3051 reconnect_backoff_factor: Some(1.0),
3052 reconnect_jitter_ms: Some(0),
3053 reconnect_max_attempts: None,
3054 idle_timeout_ms: None,
3055 backend: TransportBackend::Tungstenite,
3056 proxy_url: None,
3057 };
3058
3059 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3060 .await
3061 .unwrap();
3062
3063 tokio::time::timeout(Duration::from_secs(2), async {
3065 while !client.is_reconnecting() {
3066 sleep(Duration::from_millis(10)).await;
3067 }
3068 })
3069 .await
3070 .expect("Client should enter RECONNECT state");
3071
3072 client.disconnect().await;
3074
3075 assert!(
3077 client.is_disconnected(),
3078 "Client should be cleanly disconnected"
3079 );
3080
3081 server.abort();
3082 }
3083
3084 #[rstest]
3085 #[tokio::test]
3086 async fn test_send_fails_fast_when_closed_before_rate_limit() {
3087 use std::{num::NonZeroU32, sync::Arc};
3090
3091 use nautilus_common::testing::wait_until_async;
3092
3093 use crate::ratelimiter::quota::Quota;
3094
3095 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3096 let port = listener.local_addr().unwrap().port();
3097
3098 let server = task::spawn(async move {
3099 if let Ok((stream, _)) = listener.accept().await
3101 && let Ok(ws) = accept_async(stream).await
3102 {
3103 drop(ws);
3104 }
3105 sleep(Duration::from_mins(1)).await;
3106 });
3107
3108 let (handler, _rx) = channel_message_handler();
3109
3110 let config = WebSocketConfig {
3111 url: format!("ws://127.0.0.1:{port}"),
3112 headers: vec![],
3113 heartbeat: None,
3114 heartbeat_msg: None,
3115 reconnect_timeout_ms: Some(5_000),
3116 reconnect_delay_initial_ms: Some(50),
3117 reconnect_delay_max_ms: Some(100),
3118 reconnect_backoff_factor: Some(1.0),
3119 reconnect_jitter_ms: Some(0),
3120 reconnect_max_attempts: None,
3121 idle_timeout_ms: None,
3122 backend: TransportBackend::Tungstenite,
3123 proxy_url: None,
3124 };
3125
3126 let quota = Quota::with_period(Duration::from_secs(10))
3129 .unwrap()
3130 .allow_burst(NonZeroU32::new(1).unwrap());
3131
3132 let client = Arc::new(
3133 WebSocketClient::connect(
3134 config,
3135 Some(handler),
3136 None,
3137 None,
3138 vec![("test_key".to_string(), quota)],
3139 None,
3140 )
3141 .await
3142 .unwrap(),
3143 );
3144
3145 wait_until_async(
3147 || async { client.is_reconnecting() || client.is_closed() },
3148 Duration::from_secs(2),
3149 )
3150 .await;
3151
3152 client.disconnect().await;
3154 assert!(
3155 !client.is_active(),
3156 "Client should not be active after disconnect"
3157 );
3158
3159 let start = std::time::Instant::now();
3161 let test_key: [Ustr; 1] = [Ustr::from("test_key")];
3162 let result = client
3163 .send_text("test".to_string(), Some(test_key.as_slice()))
3164 .await;
3165 let elapsed = start.elapsed();
3166
3167 assert!(result.is_err(), "Send should fail when client is closed");
3169 assert!(
3170 matches!(result, Err(crate::error::SendError::Closed)),
3171 "Send should return Closed error, was: {result:?}"
3172 );
3173
3174 assert!(
3176 elapsed < Duration::from_millis(100),
3177 "Send should fail fast without rate limiting, took {elapsed:?}"
3178 );
3179
3180 server.abort();
3181 }
3182
3183 #[rstest]
3184 #[tokio::test]
3185 async fn test_connect_rejects_none_message_handler() {
3186 let config = WebSocketConfig {
3190 url: "ws://127.0.0.1:9999".to_string(),
3191 headers: vec![],
3192 heartbeat: None,
3193 heartbeat_msg: None,
3194 reconnect_timeout_ms: Some(1_000),
3195 reconnect_delay_initial_ms: Some(100),
3196 reconnect_delay_max_ms: Some(500),
3197 reconnect_backoff_factor: Some(1.5),
3198 reconnect_jitter_ms: Some(0),
3199 reconnect_max_attempts: None,
3200 idle_timeout_ms: None,
3201 backend: TransportBackend::Tungstenite,
3202 proxy_url: None,
3203 };
3204
3205 let result = WebSocketClient::connect(config, None, None, None, vec![], None).await;
3207
3208 assert!(
3209 result.is_err(),
3210 "connect() should reject None message_handler"
3211 );
3212
3213 let err = result.unwrap_err();
3214 let err_msg = err.to_string();
3215 assert!(
3216 err_msg.contains("Handler mode requires message_handler"),
3217 "Error should mention missing message_handler, was: {err_msg}"
3218 );
3219 }
3220
3221 #[rstest]
3222 #[tokio::test]
3223 async fn test_client_without_handler_sets_stream_mode() {
3224 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3228 let port = listener.local_addr().unwrap().port();
3229
3230 let server = task::spawn(async move {
3231 if let Ok((stream, _)) = listener.accept().await
3233 && let Ok(ws) = accept_async(stream).await
3234 {
3235 drop(ws); }
3237 });
3238
3239 let config = WebSocketConfig {
3240 url: format!("ws://127.0.0.1:{port}"),
3241 headers: vec![],
3242 heartbeat: None,
3243 heartbeat_msg: None,
3244 reconnect_timeout_ms: Some(1_000),
3245 reconnect_delay_initial_ms: Some(100),
3246 reconnect_delay_max_ms: Some(500),
3247 reconnect_backoff_factor: Some(1.5),
3248 reconnect_jitter_ms: Some(0),
3249 reconnect_max_attempts: None,
3250 idle_timeout_ms: None,
3251 backend: TransportBackend::Tungstenite,
3252 proxy_url: None,
3253 };
3254
3255 let inner = WebSocketClientInner::connect_url(config, None, None)
3257 .await
3258 .unwrap();
3259
3260 assert!(
3262 inner.is_stream_mode,
3263 "Client without handler should have is_stream_mode=true"
3264 );
3265
3266 server.abort();
3270 }
3271
3272 #[rstest]
3273 #[tokio::test]
3274 async fn test_idle_timeout_triggers_reconnect() {
3275 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3276 let port = listener.local_addr().unwrap().port();
3277
3278 let server = task::spawn(async move {
3280 let (stream, _) = listener.accept().await.unwrap();
3281 let _ws = accept_async(stream).await.unwrap();
3282 sleep(Duration::from_secs(5)).await;
3284 });
3285
3286 let (handler, _rx) = channel_message_handler();
3287
3288 let config = WebSocketConfig {
3289 url: format!("ws://127.0.0.1:{port}"),
3290 headers: vec![],
3291 heartbeat: None,
3292 heartbeat_msg: None,
3293 reconnect_timeout_ms: Some(2_000),
3294 reconnect_delay_initial_ms: Some(50),
3295 reconnect_delay_max_ms: Some(100),
3296 reconnect_backoff_factor: Some(1.0),
3297 reconnect_jitter_ms: Some(0),
3298 reconnect_max_attempts: Some(1),
3299 idle_timeout_ms: Some(500),
3300 backend: TransportBackend::Tungstenite,
3301 proxy_url: None,
3302 };
3303
3304 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3305 .await
3306 .unwrap();
3307
3308 assert!(client.is_active());
3309
3310 wait_until_async(
3312 || async { client.is_reconnecting() || client.is_disconnected() },
3313 Duration::from_secs(3),
3314 )
3315 .await;
3316
3317 assert!(
3318 !client.is_active(),
3319 "Client should not be active after idle timeout"
3320 );
3321
3322 client.disconnect().await;
3323 server.abort();
3324 }
3325
3326 #[rstest]
3327 #[tokio::test]
3328 async fn test_idle_timeout_resets_on_data() {
3329 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3330 let port = listener.local_addr().unwrap().port();
3331
3332 let server = task::spawn(async move {
3334 let (stream, _) = listener.accept().await.unwrap();
3335 let mut ws = accept_async(stream).await.unwrap();
3336
3337 for _ in 0..10 {
3338 sleep(Duration::from_millis(200)).await;
3339
3340 if ws.send(WsMessage::Text("ping".into())).await.is_err() {
3341 break;
3342 }
3343 }
3344 });
3345
3346 let (handler, _rx) = channel_message_handler();
3347
3348 let config = WebSocketConfig {
3349 url: format!("ws://127.0.0.1:{port}"),
3350 headers: vec![],
3351 heartbeat: None,
3352 heartbeat_msg: None,
3353 reconnect_timeout_ms: Some(2_000),
3354 reconnect_delay_initial_ms: Some(50),
3355 reconnect_delay_max_ms: Some(100),
3356 reconnect_backoff_factor: Some(1.0),
3357 reconnect_jitter_ms: Some(0),
3358 reconnect_max_attempts: Some(1),
3359 idle_timeout_ms: Some(1_000),
3360 backend: TransportBackend::Tungstenite,
3361 proxy_url: None,
3362 };
3363
3364 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3365 .await
3366 .unwrap();
3367
3368 assert!(client.is_active());
3369
3370 sleep(Duration::from_millis(1_500)).await;
3372
3373 assert!(
3374 client.is_active(),
3375 "Client should remain active when data is flowing"
3376 );
3377
3378 client.disconnect().await;
3379 server.abort();
3380 }
3381
3382 #[rstest]
3383 #[tokio::test]
3384 async fn test_idle_timeout_fires_when_only_pings_received() {
3385 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3391 let port = listener.local_addr().unwrap().port();
3392
3393 let server = task::spawn(async move {
3394 let (stream, _) = listener.accept().await.unwrap();
3395 let mut ws = accept_async(stream).await.unwrap();
3396
3397 for _ in 0..60 {
3398 sleep(Duration::from_millis(100)).await;
3399
3400 if ws.send(WsMessage::Ping(Vec::new().into())).await.is_err() {
3401 break;
3402 }
3403 }
3404 });
3405
3406 let (handler, _rx) = channel_message_handler();
3407
3408 let config = WebSocketConfig {
3409 url: format!("ws://127.0.0.1:{port}"),
3410 headers: vec![],
3411 heartbeat: None,
3412 heartbeat_msg: None,
3413 reconnect_timeout_ms: Some(2_000),
3414 reconnect_delay_initial_ms: Some(50),
3415 reconnect_delay_max_ms: Some(100),
3416 reconnect_backoff_factor: Some(1.0),
3417 reconnect_jitter_ms: Some(0),
3418 reconnect_max_attempts: Some(1),
3419 idle_timeout_ms: Some(500),
3420 backend: TransportBackend::Tungstenite,
3421 proxy_url: None,
3422 };
3423
3424 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3425 .await
3426 .unwrap();
3427
3428 assert!(client.is_active());
3429
3430 wait_until_async(
3434 || async { client.is_reconnecting() || client.is_disconnected() },
3435 Duration::from_millis(1_500),
3436 )
3437 .await;
3438
3439 assert!(
3440 !client.is_active(),
3441 "Client should not be active after idle timeout when only pings/pongs flow"
3442 );
3443
3444 client.disconnect().await;
3445 server.abort();
3446 }
3447
3448 #[rstest]
3449 #[tokio::test]
3450 async fn test_idle_timeout_fires_when_only_pongs_received() {
3451 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3456 let port = listener.local_addr().unwrap().port();
3457
3458 let server = task::spawn(async move {
3459 let (stream, _) = listener.accept().await.unwrap();
3460 let mut ws = accept_async(stream).await.unwrap();
3461
3462 let deadline = tokio::time::Instant::now() + Duration::from_secs(6);
3466 while tokio::time::Instant::now() < deadline {
3467 if let Ok(Some(Err(_)) | None) =
3468 tokio::time::timeout(Duration::from_millis(100), ws.next()).await
3469 {
3470 break;
3471 }
3472 }
3473 });
3474
3475 let (handler, _rx) = channel_message_handler();
3476
3477 let config = WebSocketConfig {
3478 url: format!("ws://127.0.0.1:{port}"),
3479 headers: vec![],
3480 heartbeat: Some(1),
3481 heartbeat_msg: None,
3482 reconnect_timeout_ms: Some(2_000),
3483 reconnect_delay_initial_ms: Some(50),
3484 reconnect_delay_max_ms: Some(100),
3485 reconnect_backoff_factor: Some(1.0),
3486 reconnect_jitter_ms: Some(0),
3487 reconnect_max_attempts: Some(1),
3488 idle_timeout_ms: Some(1_500),
3489 backend: TransportBackend::Tungstenite,
3490 proxy_url: None,
3491 };
3492
3493 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3494 .await
3495 .unwrap();
3496
3497 assert!(client.is_active());
3498
3499 wait_until_async(
3503 || async { client.is_reconnecting() || client.is_disconnected() },
3504 Duration::from_millis(2_500),
3505 )
3506 .await;
3507
3508 assert!(
3509 !client.is_active(),
3510 "Client should not be active after idle timeout when only pongs flow"
3511 );
3512
3513 client.disconnect().await;
3514 server.abort();
3515 }
3516
3517 #[rstest]
3518 #[tokio::test]
3519 async fn test_disconnect_during_backoff_exits_promptly() {
3520 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3524 let port = listener.local_addr().unwrap().port();
3525
3526 let server = task::spawn(async move {
3527 if let Ok((stream, _)) = listener.accept().await {
3529 let _ = accept_async(stream).await;
3530 }
3531 sleep(Duration::from_mins(1)).await;
3533 });
3534
3535 let (handler, _rx) = channel_message_handler();
3536
3537 let config = WebSocketConfig {
3538 url: format!("ws://127.0.0.1:{port}"),
3539 headers: vec![],
3540 heartbeat: None,
3541 heartbeat_msg: None,
3542 reconnect_timeout_ms: Some(1_000),
3543 reconnect_delay_initial_ms: Some(10_000), reconnect_delay_max_ms: Some(10_000),
3545 reconnect_backoff_factor: Some(1.0),
3546 reconnect_jitter_ms: Some(0),
3547 reconnect_max_attempts: None,
3548 idle_timeout_ms: None,
3549 backend: TransportBackend::Tungstenite,
3550 proxy_url: None,
3551 };
3552
3553 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3554 .await
3555 .unwrap();
3556
3557 wait_until_async(
3559 || async { client.is_reconnecting() },
3560 Duration::from_secs(3),
3561 )
3562 .await;
3563
3564 sleep(Duration::from_millis(1_500)).await;
3566
3567 let start = std::time::Instant::now();
3569 client.disconnect().await;
3570 let elapsed = start.elapsed();
3571
3572 assert!(client.is_disconnected(), "Client should be disconnected");
3573 assert!(
3575 elapsed < Duration::from_secs(2),
3576 "Disconnect should interrupt backoff sleep, took {elapsed:?}"
3577 );
3578
3579 server.abort();
3580 }
3581
3582 #[rstest]
3583 #[tokio::test]
3584 async fn test_rate_limit_cancelled_on_disconnect() {
3585 use std::{num::NonZeroU32, sync::Arc};
3588
3589 use crate::ratelimiter::quota::Quota;
3590
3591 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3592 let port = listener.local_addr().unwrap().port();
3593
3594 let server = task::spawn(async move {
3595 if let Ok((stream, _)) = listener.accept().await {
3596 let mut ws = accept_async(stream).await.unwrap();
3597 while let Some(Ok(msg)) = ws.next().await {
3599 if ws.send(msg).await.is_err() {
3600 break;
3601 }
3602 }
3603 }
3604 });
3605
3606 let (handler, _rx) = channel_message_handler();
3607
3608 let config = WebSocketConfig {
3609 url: format!("ws://127.0.0.1:{port}"),
3610 headers: vec![],
3611 heartbeat: None,
3612 heartbeat_msg: None,
3613 reconnect_timeout_ms: Some(5_000),
3614 reconnect_delay_initial_ms: Some(100),
3615 reconnect_delay_max_ms: Some(500),
3616 reconnect_backoff_factor: Some(1.5),
3617 reconnect_jitter_ms: Some(0),
3618 reconnect_max_attempts: None,
3619 idle_timeout_ms: None,
3620 backend: TransportBackend::Tungstenite,
3621 proxy_url: None,
3622 };
3623
3624 let quota = Quota::with_period(Duration::from_mins(1))
3626 .unwrap()
3627 .allow_burst(NonZeroU32::new(1).unwrap());
3628
3629 let client = Arc::new(
3630 WebSocketClient::connect(
3631 config,
3632 Some(handler),
3633 None,
3634 None,
3635 vec![("rate_key".to_string(), quota)],
3636 None,
3637 )
3638 .await
3639 .unwrap(),
3640 );
3641
3642 let test_key: [Ustr; 1] = [Ustr::from("rate_key")];
3643
3644 client
3646 .send_text("exhaust".to_string(), Some(test_key.as_slice()))
3647 .await
3648 .unwrap();
3649
3650 let client_clone = client.clone();
3652 let send_handle = task::spawn(async move {
3653 client_clone
3654 .send_text("blocked".to_string(), Some(&[Ustr::from("rate_key")]))
3655 .await
3656 });
3657
3658 sleep(Duration::from_millis(200)).await;
3660
3661 let start = std::time::Instant::now();
3663 client.disconnect().await;
3664 let elapsed_disconnect = start.elapsed();
3665
3666 let result = tokio::time::timeout(Duration::from_secs(2), send_handle)
3668 .await
3669 .expect("Send task should complete quickly")
3670 .expect("Send task should not panic");
3671
3672 assert!(
3673 matches!(result, Err(crate::error::SendError::Closed)),
3674 "Blocked send should return Closed, was: {result:?}"
3675 );
3676
3677 assert!(
3679 elapsed_disconnect < Duration::from_secs(3),
3680 "Disconnect should not wait for rate limiter, took {elapsed_disconnect:?}"
3681 );
3682
3683 server.abort();
3684 }
3685
3686 #[rstest]
3687 #[tokio::test]
3688 async fn test_stream_mode_transitions_to_closed_on_dead_write_task() {
3689 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3693 let port = listener.local_addr().unwrap().port();
3694
3695 let server = task::spawn(async move {
3696 if let Ok((stream, _)) = listener.accept().await
3697 && let Ok(ws) = accept_async(stream).await
3698 {
3699 drop(ws);
3701 }
3702 });
3703
3704 let config = WebSocketConfig {
3705 url: format!("ws://127.0.0.1:{port}"),
3706 headers: vec![],
3707 heartbeat: None,
3708 heartbeat_msg: None,
3709 reconnect_timeout_ms: Some(1_000),
3710 reconnect_delay_initial_ms: Some(50),
3711 reconnect_delay_max_ms: Some(100),
3712 reconnect_backoff_factor: Some(1.0),
3713 reconnect_jitter_ms: Some(0),
3714 reconnect_max_attempts: None,
3715 idle_timeout_ms: None,
3716 backend: TransportBackend::Tungstenite,
3717 proxy_url: None,
3718 };
3719
3720 let (_reader, client) = WebSocketClient::connect_stream(config, vec![], None, None)
3721 .await
3722 .unwrap();
3723
3724 assert!(client.is_active(), "Client should start active");
3725
3726 sleep(Duration::from_millis(100)).await;
3728
3729 for _ in 0..20 {
3731 let _ = client.send_text("ping".to_string(), None).await;
3732 sleep(Duration::from_millis(50)).await;
3733
3734 if !client.is_active() {
3735 break;
3736 }
3737 }
3738
3739 wait_until_async(|| async { !client.is_active() }, Duration::from_secs(5)).await;
3741
3742 assert!(
3744 client.is_closed() || client.is_disconnected(),
3745 "Stream mode should transition to CLOSED, not RECONNECT. \
3746 is_reconnecting={}, is_closed={}, is_disconnected={}",
3747 client.is_reconnecting(),
3748 client.is_closed(),
3749 client.is_disconnected(),
3750 );
3751 assert!(
3752 !client.is_reconnecting(),
3753 "Stream mode should never attempt reconnection"
3754 );
3755
3756 server.abort();
3757 }
3758
3759 #[tokio::test]
3760 async fn test_write_task_waits_for_auth_before_replaying_buffer() {
3761 use nautilus_common::testing::wait_until_async;
3762
3763 let server = RecordingServer::setup().await;
3764 let url = format!("ws://127.0.0.1:{}", server.port);
3765 let (writer, _reader) = WebSocketClientInner::connect_with_server(
3766 &url,
3767 vec![],
3768 TransportBackend::Tungstenite,
3769 None,
3770 )
3771 .await
3772 .unwrap();
3773
3774 let connection_state = Arc::new(AtomicU8::new(ConnectionMode::Reconnect.as_u8()));
3775 let state_notify = Arc::new(tokio::sync::Notify::new());
3776 let auth_tracker = Arc::new(OnceLock::new());
3777 let reconnect_buffer_waits_for_auth = Arc::new(AtomicBool::new(true));
3778 let tracker = AuthTracker::new();
3779 auth_tracker.set(tracker.clone()).unwrap();
3780
3781 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel();
3782 let write_task = WebSocketClientInner::spawn_write_task(
3783 Arc::clone(&connection_state),
3784 Arc::clone(&state_notify),
3785 writer,
3786 writer_rx,
3787 Arc::clone(&auth_tracker),
3788 Arc::clone(&reconnect_buffer_waits_for_auth),
3789 );
3790
3791 writer_tx
3792 .send(WriterCommand::Send(Message::Text("stale".into())))
3793 .unwrap();
3794
3795 let (new_writer, _reader) = WebSocketClientInner::connect_with_server(
3796 &url,
3797 vec![],
3798 TransportBackend::Tungstenite,
3799 None,
3800 )
3801 .await
3802 .unwrap();
3803 let (tx, rx) = tokio::sync::oneshot::channel();
3804 writer_tx
3805 .send(WriterCommand::Update(new_writer, tx))
3806 .unwrap();
3807 assert!(rx.await.unwrap());
3808
3809 connection_state.store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
3810
3811 tokio::time::sleep(Duration::from_millis(300)).await;
3812 assert!(
3813 server.messages().await.is_empty(),
3814 "buffered messages should wait for re-authentication"
3815 );
3816
3817 tracker.succeed();
3818
3819 wait_until_async(
3820 || {
3821 let messages = Arc::clone(&server.messages);
3822 async move { !messages.lock().await.is_empty() }
3823 },
3824 Duration::from_secs(3),
3825 )
3826 .await;
3827
3828 assert_eq!(server.messages().await, vec!["stale".to_string()]);
3829
3830 connection_state.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
3831 state_notify.notify_waiters();
3832 drop(writer_tx);
3833 write_task.abort();
3834 }
3835
3836 #[tokio::test]
3837 async fn test_write_task_discards_buffer_after_auth_failure() {
3838 let server = RecordingServer::setup().await;
3839 let url = format!("ws://127.0.0.1:{}", server.port);
3840 let (writer, _reader) = WebSocketClientInner::connect_with_server(
3841 &url,
3842 vec![],
3843 TransportBackend::Tungstenite,
3844 None,
3845 )
3846 .await
3847 .unwrap();
3848
3849 let connection_state = Arc::new(AtomicU8::new(ConnectionMode::Reconnect.as_u8()));
3850 let state_notify = Arc::new(tokio::sync::Notify::new());
3851 let auth_tracker = Arc::new(OnceLock::new());
3852 let reconnect_buffer_waits_for_auth = Arc::new(AtomicBool::new(true));
3853 let tracker = AuthTracker::new();
3854 auth_tracker.set(tracker.clone()).unwrap();
3855
3856 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel();
3857 let write_task = WebSocketClientInner::spawn_write_task(
3858 Arc::clone(&connection_state),
3859 Arc::clone(&state_notify),
3860 writer,
3861 writer_rx,
3862 Arc::clone(&auth_tracker),
3863 Arc::clone(&reconnect_buffer_waits_for_auth),
3864 );
3865
3866 writer_tx
3867 .send(WriterCommand::Send(Message::Text("stale".into())))
3868 .unwrap();
3869
3870 let (new_writer, _reader) = WebSocketClientInner::connect_with_server(
3871 &url,
3872 vec![],
3873 TransportBackend::Tungstenite,
3874 None,
3875 )
3876 .await
3877 .unwrap();
3878 let (tx, rx) = tokio::sync::oneshot::channel();
3879 writer_tx
3880 .send(WriterCommand::Update(new_writer, tx))
3881 .unwrap();
3882 assert!(rx.await.unwrap());
3883
3884 connection_state.store(ConnectionMode::Active.as_u8(), Ordering::SeqCst);
3885 tracker.fail("rejected");
3886 tokio::time::sleep(Duration::from_millis(300)).await;
3887 assert!(
3888 server.messages().await.is_empty(),
3889 "buffered messages should be discarded after authentication failure"
3890 );
3891
3892 let _auth_receiver = tracker.begin();
3893 tracker.succeed();
3894 tokio::time::sleep(Duration::from_millis(300)).await;
3895 assert!(
3896 server.messages().await.is_empty(),
3897 "discarded buffered messages should not replay on a later auth success"
3898 );
3899
3900 connection_state.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
3901 state_notify.notify_waiters();
3902 drop(writer_tx);
3903 write_task.abort();
3904 }
3905
3906 #[rstest]
3907 #[tokio::test]
3908 async fn test_zero_idle_timeout_rejected() {
3909 let (handler, _rx) = channel_message_handler();
3910
3911 let config = WebSocketConfig {
3912 url: "ws://127.0.0.1:9999".to_string(),
3913 headers: vec![],
3914 heartbeat: None,
3915 heartbeat_msg: None,
3916 reconnect_timeout_ms: None,
3917 reconnect_delay_initial_ms: None,
3918 reconnect_delay_max_ms: None,
3919 reconnect_backoff_factor: None,
3920 reconnect_jitter_ms: None,
3921 reconnect_max_attempts: None,
3922 idle_timeout_ms: Some(0),
3923 backend: TransportBackend::Tungstenite,
3924 proxy_url: None,
3925 };
3926
3927 let result =
3928 WebSocketClient::connect(config, Some(handler), None, None, vec![], None).await;
3929
3930 assert!(result.is_err(), "Zero idle timeout should be rejected");
3931 let err_msg = result.unwrap_err().to_string();
3932 assert!(
3933 err_msg.contains("Idle timeout cannot be zero"),
3934 "Error should mention zero idle timeout, was: {err_msg}"
3935 );
3936 }
3937
3938 #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
3939 #[rstest]
3940 #[tokio::test]
3941 async fn test_sockudo_backend_rejects_reserved_headers_before_connect() {
3942 let (handler, _rx) = channel_message_handler();
3943
3944 let config = WebSocketConfig {
3945 url: "ws://127.0.0.1:1".to_string(),
3946 headers: vec![("Host".to_string(), "example.com".to_string())],
3947 heartbeat: None,
3948 heartbeat_msg: None,
3949 reconnect_timeout_ms: None,
3950 reconnect_delay_initial_ms: None,
3951 reconnect_delay_max_ms: None,
3952 reconnect_backoff_factor: None,
3953 reconnect_jitter_ms: None,
3954 reconnect_max_attempts: None,
3955 idle_timeout_ms: None,
3956 backend: TransportBackend::Sockudo,
3957 proxy_url: None,
3958 };
3959
3960 let err = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
3961 .await
3962 .expect_err("reserved header should fail before TCP connect");
3963
3964 assert!(
3965 err.to_string()
3966 .contains("reserved upgrade header not allowed in extra_headers"),
3967 "expected reserved-header failure, was: {err}"
3968 );
3969 }
3970
3971 #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
3972 #[rstest]
3973 #[tokio::test]
3974 async fn test_sockudo_backend_replays_leftover_without_custom_headers() {
3975 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
3976 let port = listener.local_addr().unwrap().port();
3977
3978 let server = task::spawn(async move {
3979 if let Ok((mut stream, _)) = listener.accept().await {
3980 let request = read_http_request(&mut stream).await;
3981 let request = String::from_utf8(request).unwrap();
3982 let sec_websocket_key = extract_header(&request, "Sec-WebSocket-Key").unwrap();
3983 let accept = sockudo_handshake::generate_accept_key(sec_websocket_key);
3984 let mut response = format!(
3985 concat!(
3986 "HTTP/1.1 101 Switching Protocols\r\n",
3987 "Upgrade: websocket\r\n",
3988 "Connection: Upgrade\r\n",
3989 "Sec-WebSocket-Accept: {}\r\n",
3990 "\r\n",
3991 ),
3992 accept
3993 )
3994 .into_bytes();
3995 response.extend_from_slice(b"\x81\x05hello");
3996 stream.write_all(&response).await.unwrap();
3997 }
3998 });
3999
4000 let (handler, mut rx) = channel_message_handler();
4001
4002 let config = WebSocketConfig {
4003 url: format!("ws://127.0.0.1:{port}/ws"),
4004 headers: vec![],
4005 heartbeat: None,
4006 heartbeat_msg: None,
4007 reconnect_timeout_ms: Some(2_000),
4008 reconnect_delay_initial_ms: Some(50),
4009 reconnect_delay_max_ms: Some(100),
4010 reconnect_backoff_factor: Some(1.0),
4011 reconnect_jitter_ms: Some(0),
4012 reconnect_max_attempts: None,
4013 idle_timeout_ms: None,
4014 backend: TransportBackend::Sockudo,
4015 proxy_url: None,
4016 };
4017
4018 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
4019 .await
4020 .expect("sockudo connect without custom headers");
4021
4022 let received = tokio::time::timeout(Duration::from_secs(3), async {
4023 loop {
4024 if let Ok(msg) = rx.try_recv() {
4025 return msg;
4026 }
4027 tokio::time::sleep(Duration::from_millis(10)).await;
4028 }
4029 })
4030 .await
4031 .expect("did not receive leftover frame before timeout");
4032
4033 match received {
4034 WsMessage::Text(t) => assert_eq!(t.as_str(), "hello"),
4035 other => panic!("expected text, was {other:?}"),
4036 }
4037
4038 client.disconnect().await;
4039 tokio::time::timeout(Duration::from_secs(3), server)
4040 .await
4041 .expect("server did not close before timeout")
4042 .unwrap();
4043 }
4044
4045 #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4046 #[rstest]
4047 #[tokio::test]
4048 async fn test_sockudo_backend_sends_custom_headers() {
4049 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4050 let port = listener.local_addr().unwrap().port();
4051
4052 let server = task::spawn(async move {
4053 if let Ok((stream, _)) = listener.accept().await {
4054 let callback = HeaderAssertCallback {
4055 key: "X-Test".to_string(),
4056 value: HeaderValue::from_static("value"),
4057 };
4058
4059 if let Ok(mut ws) = accept_hdr_async(stream, callback).await {
4060 while let Some(Ok(msg)) = ws.next().await {
4061 if msg.is_text() || msg.is_binary() {
4062 if ws.send(msg).await.is_err() {
4063 break;
4064 }
4065
4066 continue;
4067 }
4068
4069 if msg.is_close() {
4070 let _ = ws.close(None).await;
4071 break;
4072 }
4073 }
4074 }
4075 }
4076 });
4077
4078 let (handler, mut rx) = channel_message_handler();
4079
4080 let config = WebSocketConfig {
4081 url: format!("ws://127.0.0.1:{port}"),
4082 headers: vec![("X-Test".to_string(), "value".to_string())],
4083 heartbeat: None,
4084 heartbeat_msg: None,
4085 reconnect_timeout_ms: Some(2_000),
4086 reconnect_delay_initial_ms: Some(50),
4087 reconnect_delay_max_ms: Some(100),
4088 reconnect_backoff_factor: Some(1.0),
4089 reconnect_jitter_ms: Some(0),
4090 reconnect_max_attempts: None,
4091 idle_timeout_ms: None,
4092 backend: TransportBackend::Sockudo,
4093 proxy_url: None,
4094 };
4095
4096 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
4097 .await
4098 .expect("sockudo connect with custom headers");
4099
4100 client.send_text("ping".to_string(), None).await.unwrap();
4101
4102 let received = tokio::time::timeout(Duration::from_secs(3), async {
4103 loop {
4104 if let Ok(msg) = rx.try_recv() {
4105 return msg;
4106 }
4107 tokio::time::sleep(Duration::from_millis(10)).await;
4108 }
4109 })
4110 .await
4111 .expect("did not receive echo before timeout");
4112
4113 match received {
4114 WsMessage::Text(t) => assert_eq!(t.as_str(), "ping"),
4115 other => panic!("expected text, was {other:?}"),
4116 }
4117
4118 client.disconnect().await;
4119 tokio::time::timeout(Duration::from_secs(3), server)
4120 .await
4121 .expect("server did not close before timeout")
4122 .unwrap();
4123 }
4124
4125 #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4126 #[rstest]
4127 #[tokio::test]
4128 async fn test_sockudo_backend_round_trip_text() {
4129 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
4131 let port = listener.local_addr().unwrap().port();
4132
4133 let server = task::spawn(async move {
4134 if let Ok((stream, _)) = listener.accept().await
4135 && let Ok(mut ws) = accept_async(stream).await
4136 {
4137 while let Some(Ok(msg)) = ws.next().await {
4138 #[expect(clippy::collapsible_match)]
4140 match msg {
4141 WsMessage::Text(_) | WsMessage::Binary(_) => {
4142 if ws.send(msg).await.is_err() {
4143 break;
4144 }
4145 }
4146 WsMessage::Close(_) => {
4147 let _ = ws.close(None).await;
4148 break;
4149 }
4150 _ => {}
4151 }
4152 }
4153 }
4154 });
4155
4156 let (handler, mut rx) = channel_message_handler();
4157 let config = WebSocketConfig {
4158 url: format!("ws://127.0.0.1:{port}"),
4159 headers: vec![],
4160 heartbeat: None,
4161 heartbeat_msg: None,
4162 reconnect_timeout_ms: Some(2_000),
4163 reconnect_delay_initial_ms: Some(50),
4164 reconnect_delay_max_ms: Some(100),
4165 reconnect_backoff_factor: Some(1.0),
4166 reconnect_jitter_ms: Some(0),
4167 reconnect_max_attempts: None,
4168 idle_timeout_ms: None,
4169 backend: TransportBackend::Sockudo,
4170 proxy_url: None,
4171 };
4172
4173 let client = WebSocketClient::connect(config, Some(handler), None, None, vec![], None)
4174 .await
4175 .expect("sockudo connect");
4176
4177 client.send_text("ping".to_string(), None).await.unwrap();
4178
4179 let received = tokio::time::timeout(Duration::from_secs(3), async {
4180 loop {
4181 if let Ok(msg) = rx.try_recv() {
4182 return msg;
4183 }
4184 tokio::time::sleep(Duration::from_millis(10)).await;
4185 }
4186 })
4187 .await
4188 .expect("did not receive echo before timeout");
4189
4190 match received {
4191 WsMessage::Text(t) => assert_eq!(t.as_str(), "ping"),
4192 other => panic!("expected text, was {other:?}"),
4193 }
4194
4195 client.disconnect().await;
4196 server.abort();
4197 }
4198
4199 #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4200 #[rstest]
4201 #[case::ws_default_port("ws://example.com/ws", "example.com", "example.com", 80, "/ws", false)]
4202 #[case::wss_default_port(
4203 "wss://example.com/ws",
4204 "example.com",
4205 "example.com",
4206 443,
4207 "/ws",
4208 true
4209 )]
4210 #[case::ws_explicit_default(
4213 "ws://example.com:80/ws",
4214 "example.com",
4215 "example.com",
4216 80,
4217 "/ws",
4218 false
4219 )]
4220 #[case::ws_non_default(
4221 "ws://example.com:8443/feed",
4222 "example.com",
4223 "example.com:8443",
4224 8443,
4225 "/feed",
4226 false
4227 )]
4228 #[case::wss_non_default(
4229 "wss://example.com:9443/feed",
4230 "example.com",
4231 "example.com:9443",
4232 9443,
4233 "/feed",
4234 true
4235 )]
4236 #[case::root_path(
4237 "ws://example.com:9000/",
4238 "example.com",
4239 "example.com:9000",
4240 9000,
4241 "/",
4242 false
4243 )]
4244 #[case::query_string(
4245 "ws://example.com/feed?token=abc&channel=trades",
4246 "example.com",
4247 "example.com",
4248 80,
4249 "/feed?token=abc&channel=trades",
4250 false
4251 )]
4252 #[case::ipv6_default("ws://[::1]/feed", "::1", "[::1]", 80, "/feed", false)]
4254 #[case::ipv6_explicit_port("ws://[::1]:9000/feed", "::1", "[::1]:9000", 9000, "/feed", false)]
4255 #[case::ipv6_wss(
4256 "wss://[2001:db8::1]:8443/",
4257 "2001:db8::1",
4258 "[2001:db8::1]:8443",
4259 8443,
4260 "/",
4261 true
4262 )]
4263 fn sockudo_target_parses_url(
4264 #[case] url: &str,
4265 #[case] host: &str,
4266 #[case] host_header: &str,
4267 #[case] port: u16,
4268 #[case] path: &str,
4269 #[case] is_tls: bool,
4270 ) {
4271 let target = super::SockudoTarget::parse(url).expect("parse should succeed");
4272 assert_eq!(target.host, host);
4273 assert_eq!(target.host_header, host_header);
4274 assert_eq!(target.port, port);
4275 assert_eq!(target.path, path);
4276 assert_eq!(target.is_tls, is_tls);
4277 }
4278
4279 #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4280 #[rstest]
4281 fn sockudo_target_rejects_unsupported_scheme() {
4282 let err = super::SockudoTarget::parse("http://example.com/feed").expect_err("not a ws URL");
4283 let msg = err.to_string();
4284 assert!(
4285 msg.contains("expected ws:// or wss://"),
4286 "unexpected error: {msg}"
4287 );
4288 }
4289
4290 #[cfg(all(feature = "transport-sockudo", not(feature = "turmoil")))]
4291 #[rstest]
4292 fn sockudo_target_rejects_malformed_url() {
4293 let err = super::SockudoTarget::parse("not a url").expect_err("malformed URL");
4294 assert!(
4295 matches!(err, super::TransportError::InvalidUrl(_)),
4296 "expected InvalidUrl, was: {err:?}"
4297 );
4298 }
4299}
4300
4301#[cfg(test)]
4302#[cfg(feature = "turmoil")]
4303mod turmoil_tests {
4304 use std::{sync::Arc, time::Duration};
4305
4306 use futures_util::{SinkExt, StreamExt};
4307 use nautilus_common::testing::wait_until_async;
4308 use rstest::rstest;
4309 use tokio_tungstenite::{accept_async, tungstenite::Message as WsMessage};
4310 use turmoil::{Builder, net};
4311
4312 use super::*;
4313 use crate::websocket::types::channel_message_handler;
4314
4315 #[rstest]
4316 fn test_turmoil_reconnect_buffer_waits_for_auth() {
4317 let mut sim = Builder::new().build();
4318 let messages = Arc::new(tokio::sync::Mutex::new(Vec::new()));
4319 let server_messages = Arc::clone(&messages);
4320
4321 sim.host("server", move || {
4322 let messages = Arc::clone(&server_messages);
4323 auth_buffer_server(messages)
4324 });
4325
4326 sim.client("client", async move {
4327 let tracker = AuthTracker::new();
4328 let (handler, _rx) = channel_message_handler();
4329 let client = WebSocketClient::connect(
4330 turmoil_websocket_config(),
4331 Some(handler),
4332 None,
4333 None,
4334 vec![],
4335 None,
4336 )
4337 .await
4338 .expect("Should connect");
4339
4340 client.set_auth_tracker(tracker.clone(), true);
4341 assert!(client.is_active(), "Client should start active");
4342
4343 wait_until_async(
4344 || async { client.is_reconnecting() },
4345 Duration::from_secs(3),
4346 )
4347 .await;
4348
4349 client
4350 .writer_tx
4351 .send(WriterCommand::Send(Message::Text("stale".into())))
4352 .unwrap();
4353
4354 wait_until_async(|| async { client.is_active() }, Duration::from_secs(3)).await;
4355
4356 let _auth_receiver = tracker.begin();
4357
4358 tokio::time::sleep(Duration::from_millis(300)).await;
4359 assert!(
4360 messages.lock().await.is_empty(),
4361 "buffered messages should wait for auth after reconnect"
4362 );
4363
4364 tracker.succeed();
4365
4366 wait_until_async(
4367 || {
4368 let messages = Arc::clone(&messages);
4369 async move { messages.lock().await.as_slice() == ["stale"] }
4370 },
4371 Duration::from_secs(3),
4372 )
4373 .await;
4374
4375 assert_eq!(messages.lock().await.as_slice(), ["stale"]);
4376
4377 client.disconnect().await;
4378 assert!(client.is_disconnected());
4379
4380 Ok(())
4381 });
4382
4383 sim.run().unwrap();
4384 }
4385
4386 #[rstest]
4387 fn test_turmoil_reconnect_buffer_discards_after_auth_failure() {
4388 let mut sim = Builder::new().build();
4389 let messages = Arc::new(tokio::sync::Mutex::new(Vec::new()));
4390 let server_messages = Arc::clone(&messages);
4391
4392 sim.host("server", move || {
4393 let messages = Arc::clone(&server_messages);
4394 auth_buffer_server(messages)
4395 });
4396
4397 sim.client("client", async move {
4398 let tracker = AuthTracker::new();
4399 let (handler, _rx) = channel_message_handler();
4400 let client = WebSocketClient::connect(
4401 turmoil_websocket_config(),
4402 Some(handler),
4403 None,
4404 None,
4405 vec![],
4406 None,
4407 )
4408 .await
4409 .expect("Should connect");
4410
4411 client.set_auth_tracker(tracker.clone(), true);
4412 assert!(client.is_active(), "Client should start active");
4413
4414 wait_until_async(
4415 || async { client.is_reconnecting() },
4416 Duration::from_secs(3),
4417 )
4418 .await;
4419
4420 client
4421 .writer_tx
4422 .send(WriterCommand::Send(Message::Text("stale".into())))
4423 .unwrap();
4424
4425 wait_until_async(|| async { client.is_active() }, Duration::from_secs(3)).await;
4426
4427 let _auth_receiver = tracker.begin();
4428 tracker.fail("rejected");
4429
4430 tokio::time::sleep(Duration::from_millis(300)).await;
4431 assert!(
4432 messages.lock().await.is_empty(),
4433 "buffered messages should be discarded after auth failure"
4434 );
4435
4436 let _retry_auth_receiver = tracker.begin();
4437 tracker.succeed();
4438
4439 tokio::time::sleep(Duration::from_millis(300)).await;
4440 assert!(
4441 messages.lock().await.is_empty(),
4442 "discarded messages should not replay on a later auth success"
4443 );
4444
4445 client.disconnect().await;
4446 assert!(client.is_disconnected());
4447
4448 Ok(())
4449 });
4450
4451 sim.run().unwrap();
4452 }
4453
4454 fn turmoil_websocket_config() -> WebSocketConfig {
4455 WebSocketConfig {
4456 url: "ws://server:8080".to_string(),
4457 headers: vec![],
4458 heartbeat: None,
4459 heartbeat_msg: None,
4460 reconnect_timeout_ms: Some(5_000),
4461 reconnect_delay_initial_ms: Some(50),
4462 reconnect_delay_max_ms: Some(200),
4463 reconnect_backoff_factor: Some(1.0),
4464 reconnect_jitter_ms: Some(0),
4465 reconnect_max_attempts: None,
4466 idle_timeout_ms: None,
4467 backend: TransportBackend::Tungstenite,
4468 proxy_url: None,
4469 }
4470 }
4471
4472 async fn auth_buffer_server(
4473 messages: Arc<tokio::sync::Mutex<Vec<String>>>,
4474 ) -> Result<(), Box<dyn std::error::Error>> {
4475 let listener = net::TcpListener::bind("0.0.0.0:8080").await?;
4476
4477 let (stream, _) = listener.accept().await?;
4478 let mut websocket = accept_async(stream).await?;
4479 let _ = websocket.send(WsMessage::Text("first".into())).await;
4480 drop(websocket);
4481
4482 tokio::time::sleep(Duration::from_millis(200)).await;
4483
4484 let (stream, _) = listener.accept().await?;
4485 let mut websocket = accept_async(stream).await?;
4486
4487 while let Some(msg) = websocket.next().await {
4488 match msg {
4489 Ok(WsMessage::Text(text)) => {
4490 messages.lock().await.push(text.to_string());
4491 }
4492 Ok(WsMessage::Close(_)) => {
4493 let _ = websocket.close(None).await;
4494 break;
4495 }
4496 Ok(_) => {}
4497 Err(_) => break,
4498 }
4499 }
4500
4501 Ok(())
4502 }
4503}