1use std::{
33 collections::VecDeque,
34 fmt::Debug,
35 path::Path,
36 sync::{
37 Arc,
38 atomic::{AtomicU8, Ordering},
39 },
40 time::Duration,
41};
42
43use bytes::Bytes;
44use nautilus_core::CleanDrop;
45use nautilus_cryptography::providers::install_cryptographic_provider;
46use tokio::io::{AsyncReadExt, AsyncWriteExt};
47use tokio_tungstenite::tungstenite::{Error, client::IntoClientRequest, stream::Mode};
48
49use super::{SocketConfig, TcpMessageHandler, TcpReader, TcpWriter, WriterCommand};
50use crate::{
51 backoff::ExponentialBackoff,
52 dst,
53 error::SendError,
54 logging::{log_task_aborted, log_task_started, log_task_stopped},
55 mode::ConnectionMode,
56 net::TcpStream,
57 tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
58};
59
60const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
62const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
63const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
64
65const MAX_READ_BUFFER_BYTES: usize = 10 * 1024 * 1024;
67
68#[cfg_attr(
84 feature = "python",
85 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
86)]
87struct SocketClientInner {
88 config: SocketConfig,
89 connector: Option<Connector>,
90 read_task: Arc<tokio::task::JoinHandle<()>>,
91 write_task: tokio::task::JoinHandle<()>,
92 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
93 heartbeat_task: Option<tokio::task::JoinHandle<()>>,
94 connection_mode: Arc<AtomicU8>,
95 state_notify: Arc<tokio::sync::Notify>,
96 reconnect_timeout: Duration,
97 backoff: ExponentialBackoff,
98 handler: Option<TcpMessageHandler>,
99 reconnect_max_attempts: Option<u32>,
100 reconnect_attempt_count: u32,
101}
102
103impl SocketClientInner {
104 pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
110 const CONNECTION_TIMEOUT_SECS: u64 = 10;
111
112 install_cryptographic_provider();
113
114 if config.suffix.is_empty() {
116 anyhow::bail!("Socket suffix cannot be empty: suffix is required for message framing");
117 }
118
119 if let Some((interval_secs, _)) = &config.heartbeat
120 && *interval_secs == 0
121 {
122 anyhow::bail!("Heartbeat interval cannot be zero");
123 }
124
125 if config.idle_timeout_ms == Some(0) {
126 anyhow::bail!("Idle timeout cannot be zero");
127 }
128
129 let SocketConfig {
130 url,
131 mode,
132 heartbeat,
133 suffix,
134 message_handler,
135 reconnect_timeout_ms,
136 reconnect_delay_initial_ms,
137 reconnect_delay_max_ms,
138 reconnect_backoff_factor,
139 reconnect_jitter_ms,
140 connection_max_retries,
141 reconnect_max_attempts,
142 idle_timeout_ms,
143 certs_dir,
144 } = &config.clone();
145 let connector = if let Some(dir) = certs_dir {
146 let config = create_tls_config_from_certs_dir(Path::new(dir), false)?;
147 Some(Connector::Rustls(Arc::new(config)))
148 } else {
149 None
150 };
151
152 let max_retries = connection_max_retries.unwrap_or(5);
154
155 let mut backoff = ExponentialBackoff::new(
156 Duration::from_millis(500),
157 Duration::from_secs(5),
158 2.0,
159 250,
160 false,
161 )?;
162
163 #[allow(unused_assignments)]
164 let mut last_error = String::new();
165 let mut attempt = 0;
166 let (reader, writer) = loop {
167 attempt += 1;
168
169 match dst::time::timeout(
170 Duration::from_secs(CONNECTION_TIMEOUT_SECS),
171 Self::tls_connect_with_server(url, *mode, connector.clone()),
172 )
173 .await
174 {
175 Ok(Ok(result)) => {
176 if attempt > 1 {
177 log::info!("Socket connection established after {attempt} attempts");
178 }
179 break result;
180 }
181 Ok(Err(e)) => {
182 last_error = e.to_string();
183 log::warn!(
184 "Socket connection attempt {attempt}/{max_retries} to {url} failed: {last_error}"
185 );
186 }
187 Err(_) => {
188 last_error = format!(
189 "Connection timeout after {CONNECTION_TIMEOUT_SECS}s (possible DNS resolution failure)"
190 );
191 log::warn!(
192 "Socket connection attempt {attempt}/{max_retries} to {url} timed out"
193 );
194 }
195 }
196
197 if attempt >= max_retries {
198 anyhow::bail!(
199 "Failed to connect to {} after {} attempts: {}. \
200 If this is a DNS error, check your network configuration and DNS settings.",
201 url,
202 max_retries,
203 if last_error.is_empty() {
204 "unknown error"
205 } else {
206 &last_error
207 }
208 );
209 }
210
211 let delay = backoff.next_duration();
212 log::debug!(
213 "Retrying in {delay:?} (attempt {}/{})",
214 attempt + 1,
215 max_retries
216 );
217 dst::time::sleep(delay).await;
218 };
219
220 log::debug!("Connected");
221
222 let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
223 let state_notify = Arc::new(tokio::sync::Notify::new());
224
225 let read_task = Arc::new(Self::spawn_read_task(
226 connection_mode.clone(),
227 reader,
228 message_handler.clone(),
229 suffix.clone(),
230 *idle_timeout_ms,
231 ));
232
233 let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
234
235 let write_task = Self::spawn_write_task(
236 connection_mode.clone(),
237 state_notify.clone(),
238 writer,
239 writer_rx,
240 suffix.clone(),
241 );
242
243 let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
245 Self::spawn_heartbeat_task(
246 connection_mode.clone(),
247 heartbeat.clone(),
248 writer_tx.clone(),
249 )
250 });
251
252 let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
253 let backoff = ExponentialBackoff::new(
254 Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
255 Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
256 reconnect_backoff_factor.unwrap_or(1.5),
257 reconnect_jitter_ms.unwrap_or(100),
258 true, )?;
260
261 Ok(Self {
262 config,
263 connector,
264 read_task,
265 write_task,
266 writer_tx,
267 heartbeat_task,
268 connection_mode,
269 state_notify,
270 reconnect_timeout,
271 backoff,
272 handler: message_handler.clone(),
273 reconnect_max_attempts: *reconnect_max_attempts,
274 reconnect_attempt_count: 0,
275 })
276 }
277
278 fn parse_socket_url(url: &str, mode: Mode) -> Result<(String, String), Error> {
288 if url.contains("://") {
289 let parsed = url.parse::<http::Uri>().map_err(|e| {
291 Error::Io(std::io::Error::new(
292 std::io::ErrorKind::InvalidInput,
293 format!("Invalid URL: {e}"),
294 ))
295 })?;
296
297 let host = parsed.host().ok_or_else(|| {
298 Error::Io(std::io::Error::new(
299 std::io::ErrorKind::InvalidInput,
300 "URL missing host",
301 ))
302 })?;
303
304 let port = parsed
305 .port_u16()
306 .unwrap_or_else(|| match parsed.scheme_str() {
307 Some("wss" | "https") => 443,
308 Some("ws" | "http") => 80,
309 _ => match mode {
310 Mode::Tls => 443,
311 Mode::Plain => 80,
312 },
313 });
314
315 Ok((format!("{host}:{port}"), url.to_string()))
316 } else {
317 let scheme = match mode {
320 Mode::Tls => "wss",
321 Mode::Plain => "ws",
322 };
323 Ok((url.to_string(), format!("{scheme}://{url}")))
324 }
325 }
326
327 pub async fn tls_connect_with_server(
337 url: &str,
338 mode: Mode,
339 connector: Option<Connector>,
340 ) -> Result<(TcpReader, TcpWriter), Error> {
341 log::debug!("Connecting to {url}");
342
343 let (socket_addr, request_url) = Self::parse_socket_url(url, mode)?;
344 let tcp_result = TcpStream::connect(&socket_addr).await;
345
346 match tcp_result {
347 Ok(stream) => {
348 log::debug!("TCP connection established to {socket_addr}, proceeding with TLS");
349
350 if let Err(e) = stream.set_nodelay(true) {
351 log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
352 }
353 let request = request_url.into_client_request()?;
354 tcp_tls(&request, mode, stream, connector)
355 .await
356 .map(tokio::io::split)
357 }
358 Err(e) => {
359 log::error!("TCP connection failed to {socket_addr}: {e:?}");
360 Err(Error::Io(e))
361 }
362 }
363 }
364
365 async fn reconnect(&mut self) -> Result<(), Error> {
370 log::debug!("Reconnecting");
371
372 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
373 log::debug!("Reconnect aborted due to disconnect state");
374 return Ok(());
375 }
376
377 dst::time::timeout(self.reconnect_timeout, async {
378 let SocketConfig {
379 url,
380 mode,
381 heartbeat: _,
382 suffix,
383 message_handler: _,
384 reconnect_timeout_ms: _,
385 reconnect_delay_initial_ms: _,
386 reconnect_backoff_factor: _,
387 reconnect_delay_max_ms: _,
388 reconnect_jitter_ms: _,
389 connection_max_retries: _,
390 reconnect_max_attempts: _,
391 idle_timeout_ms,
392 certs_dir: _,
393 } = &self.config;
394 let connector = self.connector.clone();
396 let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
398
399 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
400 log::debug!("Reconnect aborted mid-flight (after connect)");
401 return Ok(());
402 }
403 log::debug!("Connected");
404
405 let (tx, rx) = tokio::sync::oneshot::channel();
409 if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
410 log::error!("{e}");
411 return Err(Error::Io(std::io::Error::new(
412 std::io::ErrorKind::BrokenPipe,
413 format!("Failed to send update command: {e}"),
414 )));
415 }
416
417 match rx.await {
419 Ok(true) => log::debug!("Writer confirmed buffer drain success"),
420 Ok(false) => {
421 log::warn!("Writer failed to drain buffer, aborting reconnect");
422 return Err(Error::Io(std::io::Error::other(
424 "Failed to drain reconnection buffer",
425 )));
426 }
427 Err(e) => {
428 log::error!("Writer dropped update channel: {e}");
429 return Err(Error::Io(std::io::Error::new(
430 std::io::ErrorKind::BrokenPipe,
431 "Writer task dropped response channel",
432 )));
433 }
434 }
435
436 dst::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
438
439 if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
440 log::debug!("Reconnect aborted mid-flight (after delay)");
441 return Ok(());
442 }
443
444 if !self.read_task.is_finished() {
445 self.read_task.abort();
446 log_task_aborted("read");
447 }
448
449 if self
452 .connection_mode
453 .compare_exchange(
454 ConnectionMode::Reconnect.as_u8(),
455 ConnectionMode::Active.as_u8(),
456 Ordering::SeqCst,
457 Ordering::SeqCst,
458 )
459 .is_err()
460 {
461 log::debug!("Reconnect aborted (state changed during reconnect)");
462 return Ok(());
463 }
464
465 self.read_task = Arc::new(Self::spawn_read_task(
467 self.connection_mode.clone(),
468 reader,
469 self.handler.clone(),
470 suffix.clone(),
471 *idle_timeout_ms,
472 ));
473
474 log::debug!("Reconnect succeeded");
475 Ok(())
476 })
477 .await
478 .map_err(|_| {
479 Error::Io(std::io::Error::new(
480 std::io::ErrorKind::TimedOut,
481 format!(
482 "reconnection timed out after {}s",
483 self.reconnect_timeout.as_secs_f64()
484 ),
485 ))
486 })?
487 }
488
489 #[inline]
495 #[must_use]
496 pub fn is_alive(&self) -> bool {
497 !self.read_task.is_finished() && !self.write_task.is_finished()
498 }
499
500 #[must_use]
501 fn spawn_read_task(
502 connection_state: Arc<AtomicU8>,
503 mut reader: TcpReader,
504 handler: Option<TcpMessageHandler>,
505 suffix: Vec<u8>,
506 idle_timeout_ms: Option<u64>,
507 ) -> tokio::task::JoinHandle<()> {
508 log_task_started("read");
509
510 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
512 let idle_timeout = idle_timeout_ms.map(Duration::from_millis);
513
514 tokio::task::spawn(async move {
515 let mut buf = Vec::new();
516 let mut last_data_time = dst::time::Instant::now();
517
518 loop {
519 if !ConnectionMode::from_atomic(&connection_state).is_active() {
520 break;
521 }
522
523 match dst::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
524 Ok(Ok(0)) => {
526 log::debug!("Connection closed by server");
527 break;
528 }
529 Ok(Err(e)) => {
530 log::debug!("Connection ended: {e}");
531 break;
532 }
533 Ok(Ok(bytes)) => {
535 log::trace!("Received <binary> {bytes} bytes");
536 last_data_time = dst::time::Instant::now();
537
538 while let Some((i, _)) = &buf
539 .windows(suffix.len())
540 .enumerate()
541 .find(|(_, pair)| pair.eq(&suffix))
542 {
543 let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
544 data.truncate(data.len() - suffix.len());
545
546 if let Some(ref handler) = handler {
547 handler(&data);
548 }
549 }
550
551 if buf.len() > MAX_READ_BUFFER_BYTES {
552 log::error!(
553 "Read buffer exceeded maximum size ({MAX_READ_BUFFER_BYTES} bytes), closing connection"
554 );
555 break;
556 }
557 }
558 Err(_) => {
559 if let Some(timeout) = idle_timeout {
560 let idle_duration = last_data_time.elapsed();
561 if idle_duration >= timeout {
562 log::warn!(
563 "Read idle timeout: no data received for {:.1}s",
564 idle_duration.as_secs_f64()
565 );
566 break;
567 }
568 }
569 }
570 }
571 }
572
573 log_task_stopped("read");
574 })
575 }
576
577 async fn drain_reconnect_buffer(
587 buffer: &mut VecDeque<Bytes>,
588 writer: &mut TcpWriter,
589 suffix: &[u8],
590 ) -> bool {
591 if buffer.is_empty() {
592 return false;
593 }
594
595 let initial_buffer_len = buffer.len();
596 log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
597
598 let mut send_error_occurred = false;
599
600 while let Some(buffered_msg) = buffer.front() {
601 let mut combined_msg = Vec::with_capacity(buffered_msg.len() + suffix.len());
602 combined_msg.extend_from_slice(buffered_msg);
603 combined_msg.extend_from_slice(suffix);
604
605 if let Err(e) = writer.write_all(&combined_msg).await {
606 log::error!(
607 "Failed to send buffered message with suffix after reconnection: {e}, {} messages remain in buffer",
608 buffer.len()
609 );
610 send_error_occurred = true;
611 break;
612 }
613
614 buffer.pop_front();
615 }
616
617 if buffer.is_empty() {
618 log::info!("Successfully sent all {initial_buffer_len} buffered messages");
619 }
620
621 send_error_occurred
622 }
623
624 fn spawn_write_task(
625 connection_state: Arc<AtomicU8>,
626 state_notify: Arc<tokio::sync::Notify>,
627 writer: TcpWriter,
628 mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
629 suffix: Vec<u8>,
630 ) -> tokio::task::JoinHandle<()> {
631 log_task_started("write");
632
633 let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
635
636 tokio::task::spawn(async move {
637 let mut active_writer = writer;
638 let mut reconnect_buffer: VecDeque<Bytes> = VecDeque::new();
639 let mut write_buf: Vec<u8> = Vec::new();
640
641 loop {
642 if matches!(
643 ConnectionMode::from_atomic(&connection_state),
644 ConnectionMode::Disconnect | ConnectionMode::Closed
645 ) {
646 break;
647 }
648
649 match dst::time::timeout(check_interval, writer_rx.recv()).await {
650 Ok(Some(msg)) => {
651 let mode = ConnectionMode::from_atomic(&connection_state);
653 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
654 break;
655 }
656
657 match msg {
658 WriterCommand::Update(new_writer, tx) => {
659 log::debug!("Received new writer");
660
661 dst::time::sleep(Duration::from_millis(100)).await;
663
664 _ = dst::time::timeout(
667 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
668 active_writer.shutdown(),
669 )
670 .await;
671
672 active_writer = new_writer;
673 log::debug!("Updated writer");
674
675 let send_error = Self::drain_reconnect_buffer(
676 &mut reconnect_buffer,
677 &mut active_writer,
678 &suffix,
679 )
680 .await;
681
682 if let Err(e) = tx.send(!send_error) {
683 log::error!(
684 "Failed to report drain status to controller: {e:?}"
685 );
686 }
687 }
688 _ if mode.is_reconnect() => {
689 if let WriterCommand::Send(data) = msg {
690 log::debug!(
691 "Buffering message while reconnecting ({} bytes)",
692 data.len()
693 );
694 reconnect_buffer.push_back(data);
695 }
696 }
697 WriterCommand::Send(msg) => {
698 write_buf.clear();
699 write_buf.extend_from_slice(&msg);
700 write_buf.extend_from_slice(&suffix);
701
702 if let Err(e) = active_writer.write_all(&write_buf).await {
703 log::error!("Failed to send message: {e}");
704 log::warn!("Writer triggering reconnect");
705
706 reconnect_buffer.push_back(msg);
707 connection_state
708 .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
709 state_notify.notify_one();
710 }
711 }
712 }
713 }
714 Ok(None) => {
715 log::debug!("Writer channel closed, terminating writer task");
717 break;
718 }
719 Err(_) => {
720 }
722 }
723 }
724
725 _ = dst::time::timeout(
728 Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
729 active_writer.shutdown(),
730 )
731 .await;
732
733 log_task_stopped("write");
734 })
735 }
736
737 fn spawn_heartbeat_task(
738 connection_state: Arc<AtomicU8>,
739 heartbeat: (u64, Vec<u8>),
740 writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
741 ) -> tokio::task::JoinHandle<()> {
742 log_task_started("heartbeat");
743 let (interval_secs, message) = heartbeat;
744
745 tokio::task::spawn(async move {
746 let interval = Duration::from_secs(interval_secs);
747
748 loop {
749 dst::time::sleep(interval).await;
750
751 match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
752 ConnectionMode::Active => {
753 let msg = WriterCommand::Send(message.clone().into());
754
755 match writer_tx.send(msg) {
756 Ok(()) => log::trace!("Sent heartbeat to writer task"),
757 Err(e) => {
758 log::error!("Failed to send heartbeat to writer task: {e}");
759 }
760 }
761 }
762 ConnectionMode::Reconnect => {}
763 ConnectionMode::Disconnect | ConnectionMode::Closed => break,
764 }
765 }
766
767 log_task_stopped("heartbeat");
768 })
769 }
770}
771
772impl Drop for SocketClientInner {
773 fn drop(&mut self) {
774 self.clean_drop();
776 }
777}
778
779impl CleanDrop for SocketClientInner {
781 fn clean_drop(&mut self) {
782 if !self.read_task.is_finished() {
783 self.read_task.abort();
784 log_task_aborted("read");
785 }
786
787 if !self.write_task.is_finished() {
788 self.write_task.abort();
789 log_task_aborted("write");
790 }
791
792 if let Some(ref handle) = self.heartbeat_task.take()
793 && !handle.is_finished()
794 {
795 handle.abort();
796 log_task_aborted("heartbeat");
797 }
798
799 #[cfg(feature = "python")]
800 {
801 self.config.message_handler = None;
803 }
804 }
805}
806
807#[cfg_attr(
808 feature = "python",
809 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
810)]
811#[cfg_attr(
812 feature = "python",
813 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.network")
814)]
815pub struct SocketClient {
816 pub(crate) controller_task: tokio::task::JoinHandle<()>,
817 pub(crate) connection_mode: Arc<AtomicU8>,
818 pub(crate) state_notify: Arc<tokio::sync::Notify>,
819 pub(crate) reconnect_timeout: Duration,
820 pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
821}
822
823impl Debug for SocketClient {
824 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
825 f.debug_struct(stringify!(SocketClient)).finish()
826 }
827}
828
829impl SocketClient {
830 pub async fn connect(
836 config: SocketConfig,
837 post_connection: Option<Arc<dyn Fn() + Send + Sync>>,
838 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
839 post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
840 ) -> anyhow::Result<Self> {
841 let inner = SocketClientInner::connect_url(config).await?;
842 let writer_tx = inner.writer_tx.clone();
843 let connection_mode = inner.connection_mode.clone();
844 let state_notify = inner.state_notify.clone();
845 let reconnect_timeout = inner.reconnect_timeout;
846
847 let controller_task = Self::spawn_controller_task(
848 inner,
849 connection_mode.clone(),
850 state_notify.clone(),
851 post_reconnection,
852 post_disconnection,
853 );
854
855 if let Some(handler) = post_connection {
856 handler();
857 log::debug!("Called `post_connection` handler");
858 }
859
860 Ok(Self {
861 controller_task,
862 connection_mode,
863 state_notify,
864 reconnect_timeout,
865 writer_tx,
866 })
867 }
868
869 #[must_use]
871 pub fn connection_mode(&self) -> ConnectionMode {
872 ConnectionMode::from_atomic(&self.connection_mode)
873 }
874
875 #[inline]
880 #[must_use]
881 pub fn is_active(&self) -> bool {
882 self.connection_mode().is_active()
883 }
884
885 #[inline]
890 #[must_use]
891 pub fn is_reconnecting(&self) -> bool {
892 self.connection_mode().is_reconnect()
893 }
894
895 #[inline]
899 #[must_use]
900 pub fn is_disconnecting(&self) -> bool {
901 self.connection_mode().is_disconnect()
902 }
903
904 #[inline]
910 #[must_use]
911 pub fn is_closed(&self) -> bool {
912 self.connection_mode().is_closed()
913 }
914
915 pub async fn close(&self) {
920 self.connection_mode
921 .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
922 self.state_notify.notify_waiters();
923
924 if dst::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
925 while !self.is_closed() {
926 dst::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
927 }
928
929 if !self.controller_task.is_finished() {
930 self.controller_task.abort();
931 log_task_aborted("controller");
932 }
933 })
934 .await
935 == Ok(())
936 {
937 log_task_stopped("controller");
938 } else {
939 log::error!("Timeout waiting for controller task to finish");
940
941 if !self.controller_task.is_finished() {
942 self.controller_task.abort();
943 log_task_aborted("controller");
944 }
945 self.connection_mode
946 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
947 }
948 }
949
950 #[inline]
954 fn check_not_terminal(&self) -> Result<(), SendError> {
955 match self.connection_mode() {
956 ConnectionMode::Disconnect | ConnectionMode::Closed => Err(SendError::Closed),
957 _ => Ok(()),
958 }
959 }
960
961 async fn wait_for_active(&self) -> Result<(), SendError> {
967 const FALLBACK_INTERVAL_MS: u64 = 100;
968
969 let mode = self.connection_mode();
970 if mode.is_active() {
971 return Ok(());
972 }
973
974 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
975 return Err(SendError::Closed);
976 }
977
978 log::debug!("Waiting for client to become ACTIVE before sending...");
979
980 let fallback_interval = Duration::from_millis(FALLBACK_INTERVAL_MS);
981
982 dst::time::timeout(self.reconnect_timeout, async {
983 loop {
984 let notified = self.state_notify.notified();
985
986 let mode = self.connection_mode();
987 if mode.is_active() {
988 return Ok(());
989 }
990
991 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
992 return Err(());
993 }
994
995 tokio::select! {
996 biased;
997 () = notified => {}
998 () = dst::time::sleep(fallback_interval) => {}
999 }
1000 }
1001 })
1002 .await
1003 .map_err(|_| SendError::Timeout)?
1004 .map_err(|()| SendError::Closed)
1005 }
1006
1007 pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
1017 self.check_not_terminal()?;
1018 self.wait_for_active().await?;
1019
1020 let msg = WriterCommand::Send(data.into());
1021 self.writer_tx
1022 .send(msg)
1023 .map_err(|e| SendError::BrokenPipe(e.to_string()))
1024 }
1025
1026 fn spawn_controller_task(
1027 mut inner: SocketClientInner,
1028 connection_mode: Arc<AtomicU8>,
1029 state_notify: Arc<tokio::sync::Notify>,
1030 post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1031 post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1032 ) -> tokio::task::JoinHandle<()> {
1033 const CONTROLLER_FALLBACK_INTERVAL_MS: u64 = 100;
1034
1035 tokio::task::spawn(async move {
1036 log_task_started("controller");
1037
1038 let fallback_interval = Duration::from_millis(CONTROLLER_FALLBACK_INTERVAL_MS);
1039
1040 loop {
1041 tokio::select! {
1042 biased;
1043 () = state_notify.notified() => {}
1044 () = dst::time::sleep(fallback_interval) => {}
1045 }
1046
1047 let mut mode = ConnectionMode::from_atomic(&connection_mode);
1048
1049 if mode.is_disconnect() {
1050 log::debug!("Disconnecting");
1051
1052 let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1053 if dst::time::timeout(timeout, async {
1054 dst::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1056
1057 if !inner.read_task.is_finished() {
1058 inner.read_task.abort();
1059 log_task_aborted("read");
1060 }
1061
1062 if let Some(task) = &inner.heartbeat_task
1063 && !task.is_finished()
1064 {
1065 task.abort();
1066 log_task_aborted("heartbeat");
1067 }
1068 })
1069 .await
1070 .is_err()
1071 {
1072 log::error!("Shutdown timed out after {}s", timeout.as_secs());
1073 }
1074
1075 log::debug!("Closed");
1076
1077 if let Some(ref handler) = post_disconnection {
1078 handler();
1079 log::debug!("Called `post_disconnection` handler");
1080 }
1081 break; }
1083
1084 if mode.is_closed() {
1085 log::debug!("Connection closed");
1086 break;
1087 }
1088
1089 if mode.is_active() && !inner.is_alive() {
1090 if connection_mode
1091 .compare_exchange(
1092 ConnectionMode::Active.as_u8(),
1093 ConnectionMode::Reconnect.as_u8(),
1094 Ordering::SeqCst,
1095 Ordering::SeqCst,
1096 )
1097 .is_ok()
1098 {
1099 log::debug!("Detected dead read task, transitioning to RECONNECT");
1100 }
1101 mode = ConnectionMode::from_atomic(&connection_mode);
1102 }
1103
1104 if mode.is_reconnect() {
1105 if let Some(max_attempts) = inner.reconnect_max_attempts
1107 && inner.reconnect_attempt_count >= max_attempts
1108 {
1109 log::error!(
1110 "Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
1111 );
1112 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1113 state_notify.notify_waiters();
1114 break;
1115 }
1116
1117 inner.reconnect_attempt_count += 1;
1118
1119 let reconnect_result = tokio::select! {
1121 biased;
1122 result = inner.reconnect() => Some(result),
1123 () = async {
1124 loop {
1125 state_notify.notified().await;
1126
1127 if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
1128 break;
1129 }
1130 }
1131 } => None,
1132 };
1133
1134 match reconnect_result {
1135 None => {
1136 log::debug!("Reconnect interrupted by disconnect");
1137 }
1138 Some(Ok(())) => {
1139 log::debug!("Reconnected successfully");
1140 inner.backoff.reset();
1141 inner.reconnect_attempt_count = 0;
1142
1143 state_notify.notify_waiters();
1144
1145 if ConnectionMode::from_atomic(&connection_mode).is_active() {
1146 if let Some(ref handler) = post_reconnection {
1147 handler();
1148 log::debug!("Called `post_reconnection` handler");
1149 }
1150 } else {
1151 log::debug!(
1152 "Skipping post_reconnection handlers due to disconnect state"
1153 );
1154 }
1155 }
1156 Some(Err(e)) => {
1157 let duration = inner.backoff.next_duration();
1158 log::warn!(
1159 "Reconnect attempt {} failed: {e}",
1160 inner.reconnect_attempt_count
1161 );
1162
1163 if !duration.is_zero() {
1164 log::warn!("Backing off for {}s...", duration.as_secs_f64());
1165 tokio::select! {
1167 biased;
1168 () = dst::time::sleep(duration) => {}
1169 () = async {
1170 loop {
1171 state_notify.notified().await;
1172
1173 if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
1174 break;
1175 }
1176 }
1177 } => {
1178 log::debug!("Backoff interrupted by disconnect");
1179 }
1180 }
1181 }
1182 }
1183 }
1184 }
1185 }
1186 inner
1187 .connection_mode
1188 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1189
1190 log_task_stopped("controller");
1191 })
1192 }
1193}
1194
1195impl Drop for SocketClient {
1197 fn drop(&mut self) {
1198 if !self.controller_task.is_finished() {
1199 self.controller_task.abort();
1200 log_task_aborted("controller");
1201 }
1202 }
1203}
1204
1205#[cfg(test)]
1206#[cfg(feature = "python")]
1207#[cfg(not(all(feature = "simulation", madsim)))] #[cfg(target_os = "linux")] mod tests {
1210 use nautilus_common::testing::wait_until_async;
1211 use pyo3::Python;
1212 use tokio::{
1213 io::{AsyncReadExt, AsyncWriteExt},
1214 net::{TcpListener, TcpStream},
1215 sync::Mutex,
1216 task,
1217 time::{Duration, sleep},
1218 };
1219
1220 use super::*;
1221
1222 async fn bind_test_server() -> (u16, TcpListener) {
1223 let listener = TcpListener::bind("127.0.0.1:0")
1224 .await
1225 .expect("Failed to bind ephemeral port");
1226 let port = listener.local_addr().unwrap().port();
1227 (port, listener)
1228 }
1229
1230 async fn run_echo_server(mut socket: TcpStream) {
1231 let mut buf = Vec::new();
1232 loop {
1233 match socket.read_buf(&mut buf).await {
1234 Ok(0) => {
1235 break;
1236 }
1237 Ok(_n) => {
1238 while let Some(idx) = buf.array_windows().position(|w| w == b"\r\n") {
1239 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1240 line.truncate(line.len() - 2);
1242
1243 if line == b"close" {
1244 let _ = socket.shutdown().await;
1245 return;
1246 }
1247
1248 let mut echo_data = line;
1249 echo_data.extend_from_slice(b"\r\n");
1250 if socket.write_all(&echo_data).await.is_err() {
1251 break;
1252 }
1253 }
1254 }
1255 Err(e) => {
1256 eprintln!("Server read error: {e}");
1257 break;
1258 }
1259 }
1260 }
1261 }
1262
1263 #[tokio::test]
1264 async fn test_basic_send_receive() {
1265 Python::initialize();
1266
1267 let (port, listener) = bind_test_server().await;
1268 let server_task = task::spawn(async move {
1269 let (socket, _) = listener.accept().await.unwrap();
1270 run_echo_server(socket).await;
1271 });
1272
1273 let config = SocketConfig {
1274 url: format!("127.0.0.1:{port}"),
1275 mode: Mode::Plain,
1276 suffix: b"\r\n".to_vec(),
1277 message_handler: None,
1278 heartbeat: None,
1279 reconnect_timeout_ms: None,
1280 reconnect_delay_initial_ms: None,
1281 reconnect_backoff_factor: None,
1282 reconnect_delay_max_ms: None,
1283 reconnect_jitter_ms: None,
1284 reconnect_max_attempts: None,
1285 connection_max_retries: None,
1286 idle_timeout_ms: None,
1287 certs_dir: None,
1288 };
1289
1290 let client = SocketClient::connect(config, None, None, None)
1291 .await
1292 .expect("Client connect failed unexpectedly");
1293
1294 client.send_bytes(b"Hello".into()).await.unwrap();
1295 client.send_bytes(b"World".into()).await.unwrap();
1296
1297 sleep(Duration::from_millis(100)).await;
1299
1300 client.send_bytes(b"close".into()).await.unwrap();
1301 server_task.await.unwrap();
1302 assert!(!client.is_closed());
1303 }
1304
1305 #[tokio::test]
1306 async fn test_reconnect_fail_exhausted() {
1307 Python::initialize();
1308
1309 let (port, listener) = bind_test_server().await;
1310 drop(listener); wait_until_async(
1314 || async {
1315 TcpStream::connect(format!("127.0.0.1:{port}"))
1316 .await
1317 .is_err()
1318 },
1319 Duration::from_secs(2),
1320 )
1321 .await;
1322
1323 let config = SocketConfig {
1324 url: format!("127.0.0.1:{port}"),
1325 mode: Mode::Plain,
1326 suffix: b"\r\n".to_vec(),
1327 message_handler: None,
1328 heartbeat: None,
1329 reconnect_timeout_ms: Some(100),
1330 reconnect_delay_initial_ms: Some(50),
1331 reconnect_backoff_factor: Some(1.0),
1332 reconnect_delay_max_ms: Some(50),
1333 reconnect_jitter_ms: Some(0),
1334 connection_max_retries: Some(1),
1335 reconnect_max_attempts: None,
1336 idle_timeout_ms: None,
1337 certs_dir: None,
1338 };
1339
1340 let client_res = SocketClient::connect(config, None, None, None).await;
1341 assert!(
1342 client_res.is_err(),
1343 "Should fail quickly with no server listening"
1344 );
1345 }
1346
1347 #[tokio::test]
1348 async fn test_user_disconnect() {
1349 Python::initialize();
1350
1351 let (port, listener) = bind_test_server().await;
1352 let server_task = task::spawn(async move {
1353 let (socket, _) = listener.accept().await.unwrap();
1354 let mut buf = [0u8; 1024];
1355 let _ = socket.try_read(&mut buf);
1356
1357 loop {
1358 sleep(Duration::from_secs(1)).await;
1359 }
1360 });
1361
1362 let config = SocketConfig {
1363 url: format!("127.0.0.1:{port}"),
1364 mode: Mode::Plain,
1365 suffix: b"\r\n".to_vec(),
1366 message_handler: None,
1367 heartbeat: None,
1368 reconnect_timeout_ms: None,
1369 reconnect_delay_initial_ms: None,
1370 reconnect_backoff_factor: None,
1371 reconnect_delay_max_ms: None,
1372 reconnect_jitter_ms: None,
1373 reconnect_max_attempts: None,
1374 connection_max_retries: None,
1375 idle_timeout_ms: None,
1376 certs_dir: None,
1377 };
1378
1379 let client = SocketClient::connect(config, None, None, None)
1380 .await
1381 .unwrap();
1382
1383 client.close().await;
1384 assert!(client.is_closed());
1385 server_task.abort();
1386 }
1387
1388 #[tokio::test]
1389 async fn test_heartbeat() {
1390 Python::initialize();
1391
1392 let (port, listener) = bind_test_server().await;
1393 let received = Arc::new(Mutex::new(Vec::new()));
1394 let received2 = received.clone();
1395
1396 let server_task = task::spawn(async move {
1397 let (socket, _) = listener.accept().await.unwrap();
1398
1399 let mut buf = Vec::new();
1400 loop {
1401 match socket.try_read_buf(&mut buf) {
1402 Ok(0) => break,
1403 Ok(_) => {
1404 while let Some(idx) = buf.array_windows().position(|w| w == b"\r\n") {
1405 let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1406 line.truncate(line.len() - 2);
1407 received2.lock().await.push(line);
1408 }
1409 }
1410 Err(_) => {
1411 tokio::time::sleep(Duration::from_millis(10)).await;
1412 }
1413 }
1414 }
1415 });
1416
1417 let heartbeat = Some((1, b"ping".to_vec()));
1419
1420 let config = SocketConfig {
1421 url: format!("127.0.0.1:{port}"),
1422 mode: Mode::Plain,
1423 suffix: b"\r\n".to_vec(),
1424 message_handler: None,
1425 heartbeat,
1426 reconnect_timeout_ms: None,
1427 reconnect_delay_initial_ms: None,
1428 reconnect_backoff_factor: None,
1429 reconnect_delay_max_ms: None,
1430 reconnect_jitter_ms: None,
1431 reconnect_max_attempts: None,
1432 connection_max_retries: None,
1433 idle_timeout_ms: None,
1434 certs_dir: None,
1435 };
1436
1437 let client = SocketClient::connect(config, None, None, None)
1438 .await
1439 .unwrap();
1440
1441 sleep(Duration::from_secs(3)).await;
1443
1444 {
1445 let lock = received.lock().await;
1446 let pings = lock
1447 .iter()
1448 .filter(|line| line == &&b"ping".to_vec())
1449 .count();
1450 assert!(
1451 pings >= 2,
1452 "Expected at least 2 heartbeat pings; got {pings}"
1453 );
1454 }
1455
1456 client.close().await;
1457 server_task.abort();
1458 }
1459
1460 #[tokio::test]
1461 async fn test_reconnect_success() {
1462 Python::initialize();
1463
1464 let (port, listener) = bind_test_server().await;
1465
1466 let server_task = task::spawn(async move {
1470 let (mut socket, _) = listener.accept().await.expect("First accept failed");
1472
1473 sleep(Duration::from_millis(500)).await;
1475 let _ = socket.shutdown().await;
1476
1477 sleep(Duration::from_millis(500)).await;
1479
1480 let (socket, _) = listener.accept().await.expect("Second accept failed");
1482 run_echo_server(socket).await;
1483 });
1484
1485 let config = SocketConfig {
1486 url: format!("127.0.0.1:{port}"),
1487 mode: Mode::Plain,
1488 suffix: b"\r\n".to_vec(),
1489 message_handler: None,
1490 heartbeat: None,
1491 reconnect_timeout_ms: Some(5_000),
1492 reconnect_delay_initial_ms: Some(500),
1493 reconnect_delay_max_ms: Some(5_000),
1494 reconnect_backoff_factor: Some(2.0),
1495 reconnect_jitter_ms: Some(50),
1496 reconnect_max_attempts: None,
1497 connection_max_retries: None,
1498 idle_timeout_ms: None,
1499 certs_dir: None,
1500 };
1501
1502 let client = SocketClient::connect(config, None, None, None)
1503 .await
1504 .expect("Client connect failed unexpectedly");
1505
1506 assert!(client.is_active(), "Client should start as active");
1508
1509 wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1512
1513 client
1514 .send_bytes(b"TestReconnect".into())
1515 .await
1516 .expect("Send failed");
1517
1518 client.close().await;
1519 server_task.abort();
1520 }
1521}
1522
1523#[cfg(test)]
1524#[cfg(not(feature = "turmoil"))]
1525#[cfg(not(all(feature = "simulation", madsim)))] mod rust_tests {
1527 use nautilus_common::testing::wait_until_async;
1528 use rstest::rstest;
1529 use tokio::{
1530 io::{AsyncReadExt, AsyncWriteExt},
1531 net::TcpListener,
1532 task,
1533 time::{Duration, sleep},
1534 };
1535
1536 use super::*;
1537
1538 #[rstest]
1539 #[tokio::test]
1540 async fn test_reconnect_then_close() {
1541 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1543 let port = listener.local_addr().unwrap().port();
1544
1545 let server = task::spawn(async move {
1547 if let Ok((mut sock, _)) = listener.accept().await {
1548 drop(sock.shutdown());
1549 }
1550 sleep(Duration::from_secs(1)).await;
1552 });
1553
1554 let config = SocketConfig {
1556 url: format!("127.0.0.1:{port}"),
1557 mode: Mode::Plain,
1558 suffix: b"\r\n".to_vec(),
1559 message_handler: None,
1560 heartbeat: None,
1561 reconnect_timeout_ms: Some(1_000),
1562 reconnect_delay_initial_ms: Some(50),
1563 reconnect_delay_max_ms: Some(100),
1564 reconnect_backoff_factor: Some(1.0),
1565 reconnect_jitter_ms: Some(0),
1566 connection_max_retries: Some(1),
1567 reconnect_max_attempts: None,
1568 idle_timeout_ms: None,
1569 certs_dir: None,
1570 };
1571
1572 let client = SocketClient::connect(config.clone(), None, None, None)
1574 .await
1575 .unwrap();
1576
1577 wait_until_async(
1579 || async { client.is_reconnecting() },
1580 Duration::from_secs(2),
1581 )
1582 .await;
1583
1584 client.close().await;
1586 assert!(client.is_closed());
1587 server.abort();
1588 }
1589
1590 #[rstest]
1591 #[tokio::test]
1592 async fn test_reconnect_state_flips_when_reader_stops() {
1593 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1595 let port = listener.local_addr().unwrap().port();
1596
1597 let server = task::spawn(async move {
1598 if let Ok((sock, _)) = listener.accept().await {
1599 drop(sock);
1600 }
1601 sleep(Duration::from_millis(50)).await;
1603 });
1604
1605 let config = SocketConfig {
1606 url: format!("127.0.0.1:{port}"),
1607 mode: Mode::Plain,
1608 suffix: b"\r\n".to_vec(),
1609 message_handler: None,
1610 heartbeat: None,
1611 reconnect_timeout_ms: Some(1_000),
1612 reconnect_delay_initial_ms: Some(50),
1613 reconnect_delay_max_ms: Some(100),
1614 reconnect_backoff_factor: Some(1.0),
1615 reconnect_jitter_ms: Some(0),
1616 connection_max_retries: Some(1),
1617 reconnect_max_attempts: None,
1618 idle_timeout_ms: None,
1619 certs_dir: None,
1620 };
1621
1622 let client = SocketClient::connect(config, None, None, None)
1623 .await
1624 .unwrap();
1625
1626 wait_until_async(
1627 || async { client.is_reconnecting() },
1628 Duration::from_secs(2),
1629 )
1630 .await;
1631
1632 client.close().await;
1633 server.abort();
1634 }
1635
1636 #[rstest]
1637 fn test_parse_socket_url_raw_address() {
1638 let (socket_addr, request_url) =
1640 SocketClientInner::parse_socket_url("example.com:6130", Mode::Tls).unwrap();
1641 assert_eq!(socket_addr, "example.com:6130");
1642 assert_eq!(request_url, "wss://example.com:6130");
1643
1644 let (socket_addr, request_url) =
1646 SocketClientInner::parse_socket_url("localhost:8080", Mode::Plain).unwrap();
1647 assert_eq!(socket_addr, "localhost:8080");
1648 assert_eq!(request_url, "ws://localhost:8080");
1649 }
1650
1651 #[rstest]
1652 fn test_parse_socket_url_with_scheme() {
1653 let (socket_addr, request_url) =
1655 SocketClientInner::parse_socket_url("wss://example.com:443/path", Mode::Tls).unwrap();
1656 assert_eq!(socket_addr, "example.com:443");
1657 assert_eq!(request_url, "wss://example.com:443/path");
1658
1659 let (socket_addr, request_url) =
1661 SocketClientInner::parse_socket_url("ws://localhost:8080", Mode::Plain).unwrap();
1662 assert_eq!(socket_addr, "localhost:8080");
1663 assert_eq!(request_url, "ws://localhost:8080");
1664 }
1665
1666 #[rstest]
1667 fn test_parse_socket_url_default_ports() {
1668 let (socket_addr, _) =
1670 SocketClientInner::parse_socket_url("wss://example.com", Mode::Tls).unwrap();
1671 assert_eq!(socket_addr, "example.com:443");
1672
1673 let (socket_addr, _) =
1675 SocketClientInner::parse_socket_url("ws://example.com", Mode::Plain).unwrap();
1676 assert_eq!(socket_addr, "example.com:80");
1677
1678 let (socket_addr, _) =
1680 SocketClientInner::parse_socket_url("https://example.com", Mode::Tls).unwrap();
1681 assert_eq!(socket_addr, "example.com:443");
1682
1683 let (socket_addr, _) =
1685 SocketClientInner::parse_socket_url("http://example.com", Mode::Plain).unwrap();
1686 assert_eq!(socket_addr, "example.com:80");
1687 }
1688
1689 #[rstest]
1690 fn test_parse_socket_url_unknown_scheme_uses_mode() {
1691 let (socket_addr, _) =
1693 SocketClientInner::parse_socket_url("custom://example.com", Mode::Tls).unwrap();
1694 assert_eq!(socket_addr, "example.com:443");
1695
1696 let (socket_addr, _) =
1697 SocketClientInner::parse_socket_url("custom://example.com", Mode::Plain).unwrap();
1698 assert_eq!(socket_addr, "example.com:80");
1699 }
1700
1701 #[rstest]
1702 fn test_parse_socket_url_ipv6() {
1703 let (socket_addr, request_url) =
1705 SocketClientInner::parse_socket_url("[::1]:8080", Mode::Plain).unwrap();
1706 assert_eq!(socket_addr, "[::1]:8080");
1707 assert_eq!(request_url, "ws://[::1]:8080");
1708
1709 let (socket_addr, _) =
1711 SocketClientInner::parse_socket_url("ws://[::1]:8080", Mode::Plain).unwrap();
1712 assert_eq!(socket_addr, "[::1]:8080");
1713 }
1714
1715 #[rstest]
1716 #[tokio::test]
1717 async fn test_url_parsing_raw_socket_address() {
1718 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1720 let port = listener.local_addr().unwrap().port();
1721
1722 let server = task::spawn(async move {
1723 if let Ok((sock, _)) = listener.accept().await {
1724 drop(sock);
1725 }
1726 sleep(Duration::from_millis(50)).await;
1727 });
1728
1729 let config = SocketConfig {
1730 url: format!("127.0.0.1:{port}"), mode: Mode::Plain,
1732 suffix: b"\r\n".to_vec(),
1733 message_handler: None,
1734 heartbeat: None,
1735 reconnect_timeout_ms: Some(1_000),
1736 reconnect_delay_initial_ms: Some(50),
1737 reconnect_delay_max_ms: Some(100),
1738 reconnect_backoff_factor: Some(1.0),
1739 reconnect_jitter_ms: Some(0),
1740 connection_max_retries: Some(1),
1741 reconnect_max_attempts: None,
1742 idle_timeout_ms: None,
1743 certs_dir: None,
1744 };
1745
1746 let client = SocketClient::connect(config, None, None, None).await;
1748 assert!(
1749 client.is_ok(),
1750 "Client should connect with raw socket address format"
1751 );
1752
1753 if let Ok(client) = client {
1754 client.close().await;
1755 }
1756 server.abort();
1757 }
1758
1759 #[rstest]
1760 #[tokio::test]
1761 async fn test_url_parsing_with_scheme() {
1762 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1764 let port = listener.local_addr().unwrap().port();
1765
1766 let server = task::spawn(async move {
1767 if let Ok((sock, _)) = listener.accept().await {
1768 drop(sock);
1769 }
1770 sleep(Duration::from_millis(50)).await;
1771 });
1772
1773 let config = SocketConfig {
1774 url: format!("ws://127.0.0.1:{port}"), mode: Mode::Plain,
1776 suffix: b"\r\n".to_vec(),
1777 message_handler: None,
1778 heartbeat: None,
1779 reconnect_timeout_ms: Some(1_000),
1780 reconnect_delay_initial_ms: Some(50),
1781 reconnect_delay_max_ms: Some(100),
1782 reconnect_backoff_factor: Some(1.0),
1783 reconnect_jitter_ms: Some(0),
1784 connection_max_retries: Some(1),
1785 reconnect_max_attempts: None,
1786 idle_timeout_ms: None,
1787 certs_dir: None,
1788 };
1789
1790 let client = SocketClient::connect(config, None, None, None).await;
1792 assert!(
1793 client.is_ok(),
1794 "Client should connect with URL scheme format"
1795 );
1796
1797 if let Ok(client) = client {
1798 client.close().await;
1799 }
1800 server.abort();
1801 }
1802
1803 #[rstest]
1804 fn test_parse_socket_url_ipv6_with_zone() {
1805 let (socket_addr, request_url) =
1807 SocketClientInner::parse_socket_url("[fe80::1%eth0]:8080", Mode::Plain).unwrap();
1808 assert_eq!(socket_addr, "[fe80::1%eth0]:8080");
1809 assert_eq!(request_url, "ws://[fe80::1%eth0]:8080");
1810
1811 let (socket_addr, request_url) =
1813 SocketClientInner::parse_socket_url("ws://[fe80::1%lo]:9090", Mode::Plain).unwrap();
1814 assert_eq!(socket_addr, "[fe80::1%lo]:9090");
1815 assert_eq!(request_url, "ws://[fe80::1%lo]:9090");
1816 }
1817
1818 #[rstest]
1819 #[tokio::test]
1820 async fn test_ipv6_loopback_connection() {
1821 if TcpListener::bind("[::1]:0").await.is_err() {
1824 eprintln!("IPv6 not available, skipping test");
1825 return;
1826 }
1827
1828 let listener = TcpListener::bind("[::1]:0").await.unwrap();
1829 let port = listener.local_addr().unwrap().port();
1830
1831 let server = task::spawn(async move {
1832 if let Ok((mut sock, _)) = listener.accept().await {
1833 let mut buf = vec![0u8; 1024];
1834 if let Ok(n) = sock.read(&mut buf).await {
1835 let _ = sock.write_all(&buf[..n]).await;
1837 }
1838 }
1839 sleep(Duration::from_millis(50)).await;
1840 });
1841
1842 let config = SocketConfig {
1843 url: format!("[::1]:{port}"), mode: Mode::Plain,
1845 suffix: b"\r\n".to_vec(),
1846 message_handler: None,
1847 heartbeat: None,
1848 reconnect_timeout_ms: Some(1_000),
1849 reconnect_delay_initial_ms: Some(50),
1850 reconnect_delay_max_ms: Some(100),
1851 reconnect_backoff_factor: Some(1.0),
1852 reconnect_jitter_ms: Some(0),
1853 connection_max_retries: Some(1),
1854 reconnect_max_attempts: None,
1855 idle_timeout_ms: None,
1856 certs_dir: None,
1857 };
1858
1859 let client = SocketClient::connect(config, None, None, None).await;
1860 assert!(
1861 client.is_ok(),
1862 "Client should connect to IPv6 loopback address"
1863 );
1864
1865 if let Ok(client) = client {
1866 client.close().await;
1867 }
1868 server.abort();
1869 }
1870
1871 #[rstest]
1872 #[tokio::test]
1873 async fn test_send_waits_during_reconnection() {
1874 use nautilus_common::testing::wait_until_async;
1876
1877 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1878 let port = listener.local_addr().unwrap().port();
1879
1880 let server = task::spawn(async move {
1881 if let Ok((sock, _)) = listener.accept().await {
1883 drop(sock);
1884 }
1885
1886 sleep(Duration::from_millis(500)).await;
1888
1889 if let Ok((mut sock, _)) = listener.accept().await {
1891 let mut buf = vec![0u8; 1024];
1893 while let Ok(n) = sock.read(&mut buf).await {
1894 if n == 0 {
1895 break;
1896 }
1897
1898 if sock.write_all(&buf[..n]).await.is_err() {
1899 break;
1900 }
1901 }
1902 }
1903 });
1904
1905 let config = SocketConfig {
1906 url: format!("127.0.0.1:{port}"),
1907 mode: Mode::Plain,
1908 suffix: b"\r\n".to_vec(),
1909 message_handler: None,
1910 heartbeat: None,
1911 reconnect_timeout_ms: Some(5_000), reconnect_delay_initial_ms: Some(100),
1913 reconnect_delay_max_ms: Some(200),
1914 reconnect_backoff_factor: Some(1.0),
1915 reconnect_jitter_ms: Some(0),
1916 connection_max_retries: Some(1),
1917 reconnect_max_attempts: None,
1918 idle_timeout_ms: None,
1919 certs_dir: None,
1920 };
1921
1922 let client = SocketClient::connect(config, None, None, None)
1923 .await
1924 .unwrap();
1925
1926 wait_until_async(
1928 || async { client.is_reconnecting() },
1929 Duration::from_secs(2),
1930 )
1931 .await;
1932
1933 let send_result = tokio::time::timeout(
1935 Duration::from_secs(3),
1936 client.send_bytes(b"test_message".to_vec()),
1937 )
1938 .await;
1939
1940 assert!(
1941 send_result.is_ok() && send_result.unwrap().is_ok(),
1942 "Send should succeed after waiting for reconnection"
1943 );
1944
1945 client.close().await;
1946 server.abort();
1947 }
1948
1949 #[rstest]
1950 #[tokio::test]
1951 async fn test_send_bytes_timeout_uses_configured_reconnect_timeout() {
1952 use nautilus_common::testing::wait_until_async;
1955
1956 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1957 let port = listener.local_addr().unwrap().port();
1958
1959 let server = task::spawn(async move {
1960 if let Ok((sock, _)) = listener.accept().await {
1962 drop(sock);
1963 }
1964 drop(listener);
1966 sleep(Duration::from_mins(1)).await;
1967 });
1968
1969 let config = SocketConfig {
1970 url: format!("127.0.0.1:{port}"),
1971 mode: Mode::Plain,
1972 suffix: b"\r\n".to_vec(),
1973 message_handler: None,
1974 heartbeat: None,
1975 reconnect_timeout_ms: Some(1_000), reconnect_delay_initial_ms: Some(200), reconnect_delay_max_ms: Some(200),
1978 reconnect_backoff_factor: Some(1.0),
1979 reconnect_jitter_ms: Some(0),
1980 connection_max_retries: Some(1),
1981 reconnect_max_attempts: None,
1982 idle_timeout_ms: None,
1983 certs_dir: None,
1984 };
1985
1986 let client = SocketClient::connect(config, None, None, None)
1987 .await
1988 .unwrap();
1989
1990 wait_until_async(
1992 || async { client.is_reconnecting() },
1993 Duration::from_secs(3),
1994 )
1995 .await;
1996
1997 let start = std::time::Instant::now();
2000 let send_result = client.send_bytes(b"test".to_vec()).await;
2001 let elapsed = start.elapsed();
2002
2003 assert!(
2004 send_result.is_err(),
2005 "Send should fail when client stuck in RECONNECT, was: {send_result:?}"
2006 );
2007 assert!(
2008 matches!(send_result, Err(crate::error::SendError::Timeout)),
2009 "Send should return Timeout error, was: {send_result:?}"
2010 );
2011 assert!(
2014 elapsed >= Duration::from_millis(900),
2015 "Send should timeout after at least 1s (configured timeout), took {elapsed:?}"
2016 );
2017
2018 client.close().await;
2019 server.abort();
2020 }
2021
2022 #[rstest]
2023 #[tokio::test]
2024 async fn test_idle_timeout_triggers_reconnect() {
2025 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2026 let port = listener.local_addr().unwrap().port();
2027
2028 let server = task::spawn(async move {
2030 let (_sock1, _) = listener.accept().await.unwrap();
2031 sleep(Duration::from_secs(5)).await;
2033 });
2034
2035 let config = SocketConfig {
2036 url: format!("127.0.0.1:{port}"),
2037 mode: Mode::Plain,
2038 suffix: b"\r\n".to_vec(),
2039 message_handler: None,
2040 heartbeat: None,
2041 reconnect_timeout_ms: Some(2_000),
2042 reconnect_delay_initial_ms: Some(50),
2043 reconnect_delay_max_ms: Some(100),
2044 reconnect_backoff_factor: Some(1.0),
2045 reconnect_jitter_ms: Some(0),
2046 connection_max_retries: Some(1),
2047 reconnect_max_attempts: Some(1),
2048 idle_timeout_ms: Some(500),
2049 certs_dir: None,
2050 };
2051
2052 let client = SocketClient::connect(config, None, None, None)
2053 .await
2054 .unwrap();
2055
2056 assert!(client.is_active());
2057
2058 wait_until_async(
2060 || async { client.is_reconnecting() || client.is_closed() },
2061 Duration::from_secs(3),
2062 )
2063 .await;
2064
2065 assert!(
2066 !client.is_active(),
2067 "Client should not be active after idle timeout"
2068 );
2069
2070 client.close().await;
2071 server.abort();
2072 }
2073
2074 #[rstest]
2075 #[tokio::test]
2076 async fn test_idle_timeout_resets_on_data() {
2077 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2078 let port = listener.local_addr().unwrap().port();
2079
2080 let server = task::spawn(async move {
2082 let (mut sock, _) = listener.accept().await.unwrap();
2083 for _ in 0..10 {
2084 sleep(Duration::from_millis(200)).await;
2085
2086 if sock.write_all(b"ping\r\n").await.is_err() {
2087 break;
2088 }
2089 }
2090 });
2091
2092 let config = SocketConfig {
2093 url: format!("127.0.0.1:{port}"),
2094 mode: Mode::Plain,
2095 suffix: b"\r\n".to_vec(),
2096 message_handler: None,
2097 heartbeat: None,
2098 reconnect_timeout_ms: Some(2_000),
2099 reconnect_delay_initial_ms: Some(50),
2100 reconnect_delay_max_ms: Some(100),
2101 reconnect_backoff_factor: Some(1.0),
2102 reconnect_jitter_ms: Some(0),
2103 connection_max_retries: Some(1),
2104 reconnect_max_attempts: Some(1),
2105 idle_timeout_ms: Some(1_000),
2106 certs_dir: None,
2107 };
2108
2109 let client = SocketClient::connect(config, None, None, None)
2110 .await
2111 .unwrap();
2112
2113 assert!(client.is_active());
2114
2115 sleep(Duration::from_millis(1_500)).await;
2117
2118 assert!(
2119 client.is_active(),
2120 "Client should remain active when data is flowing"
2121 );
2122
2123 client.close().await;
2124 server.abort();
2125 }
2126
2127 #[rstest]
2128 #[tokio::test]
2129 async fn test_close_during_backoff_exits_promptly() {
2130 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2134 let port = listener.local_addr().unwrap().port();
2135
2136 let server = task::spawn(async move {
2137 if let Ok((mut sock, _)) = listener.accept().await {
2139 drop(sock.shutdown());
2140 }
2141 sleep(Duration::from_mins(1)).await;
2143 });
2144
2145 let config = SocketConfig {
2146 url: format!("127.0.0.1:{port}"),
2147 mode: Mode::Plain,
2148 suffix: b"\r\n".to_vec(),
2149 message_handler: None,
2150 heartbeat: None,
2151 reconnect_timeout_ms: Some(1_000),
2152 reconnect_delay_initial_ms: Some(10_000), reconnect_delay_max_ms: Some(10_000),
2154 reconnect_backoff_factor: Some(1.0),
2155 reconnect_jitter_ms: Some(0),
2156 connection_max_retries: None,
2157 reconnect_max_attempts: None,
2158 idle_timeout_ms: None,
2159 certs_dir: None,
2160 };
2161
2162 let client = SocketClient::connect(config, None, None, None)
2163 .await
2164 .unwrap();
2165
2166 wait_until_async(
2168 || async { client.is_reconnecting() },
2169 Duration::from_secs(3),
2170 )
2171 .await;
2172
2173 sleep(Duration::from_millis(1_500)).await;
2175
2176 let start = std::time::Instant::now();
2178 client.close().await;
2179 let elapsed = start.elapsed();
2180
2181 assert!(client.is_closed(), "Client should be closed");
2182 assert!(
2184 elapsed < Duration::from_secs(2),
2185 "Close should interrupt backoff sleep, took {elapsed:?}"
2186 );
2187
2188 server.abort();
2189 }
2190
2191 #[rstest]
2192 #[tokio::test]
2193 async fn test_zero_idle_timeout_rejected() {
2194 let config = SocketConfig {
2195 url: "127.0.0.1:9999".to_string(),
2196 mode: Mode::Plain,
2197 suffix: b"\r\n".to_vec(),
2198 message_handler: None,
2199 heartbeat: None,
2200 reconnect_timeout_ms: None,
2201 reconnect_delay_initial_ms: None,
2202 reconnect_delay_max_ms: None,
2203 reconnect_backoff_factor: None,
2204 reconnect_jitter_ms: None,
2205 reconnect_max_attempts: None,
2206 connection_max_retries: Some(1),
2207 idle_timeout_ms: Some(0),
2208 certs_dir: None,
2209 };
2210
2211 let result = SocketClient::connect(config, None, None, None).await;
2212
2213 assert!(result.is_err(), "Zero idle timeout should be rejected");
2214 let err_msg = result.unwrap_err().to_string();
2215 assert!(
2216 err_msg.contains("Idle timeout cannot be zero"),
2217 "Error should mention zero idle timeout, was: {err_msg}"
2218 );
2219 }
2220
2221 #[rstest]
2222 #[tokio::test]
2223 async fn test_empty_suffix_rejected() {
2224 let config = SocketConfig {
2225 url: "127.0.0.1:9999".to_string(),
2226 mode: Mode::Plain,
2227 suffix: vec![],
2228 message_handler: None,
2229 heartbeat: None,
2230 reconnect_timeout_ms: None,
2231 reconnect_delay_initial_ms: None,
2232 reconnect_delay_max_ms: None,
2233 reconnect_backoff_factor: None,
2234 reconnect_jitter_ms: None,
2235 reconnect_max_attempts: None,
2236 connection_max_retries: Some(1),
2237 idle_timeout_ms: None,
2238 certs_dir: None,
2239 };
2240
2241 let result = SocketClient::connect(config, None, None, None).await;
2242
2243 assert!(
2244 result.is_err(),
2245 "Empty suffix should cause connection to fail"
2246 );
2247 let err_msg = result.unwrap_err().to_string();
2248 assert!(
2249 err_msg.contains("suffix cannot be empty"),
2250 "Error should mention empty suffix, was: {err_msg}"
2251 );
2252 }
2253}