Skip to main content

nautilus_network/socket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! High-performance raw TCP client implementation with TLS capability, automatic reconnection
17//! with exponential backoff and state management.
18//!
19//! **Key features**:
20//! - Connection state tracking (ACTIVE/RECONNECTING/DISCONNECTING/CLOSED).
21//! - Synchronized reconnection with backoff.
22//! - Split read/write architecture.
23//! - Python callback integration.
24//!
25//! **Design**:
26//! - Single reader, multiple writer model.
27//! - Read half runs in dedicated task.
28//! - Write half runs in dedicated task connected with channel.
29//! - Controller task manages lifecycle.
30//! - Event-driven state notification via `Notify` for immediate wakeup on transitions.
31
32use std::{
33    collections::VecDeque,
34    fmt::Debug,
35    path::Path,
36    sync::{
37        Arc,
38        atomic::{AtomicU8, Ordering},
39    },
40    time::Duration,
41};
42
43use bytes::Bytes;
44use nautilus_core::CleanDrop;
45use nautilus_cryptography::providers::install_cryptographic_provider;
46use tokio::io::{AsyncReadExt, AsyncWriteExt};
47use tokio_tungstenite::tungstenite::{Error, client::IntoClientRequest, stream::Mode};
48
49use super::{SocketConfig, TcpMessageHandler, TcpReader, TcpWriter, WriterCommand};
50use crate::{
51    backoff::ExponentialBackoff,
52    dst,
53    error::SendError,
54    logging::{log_task_aborted, log_task_started, log_task_stopped},
55    mode::ConnectionMode,
56    net::TcpStream,
57    tls::{Connector, create_tls_config_from_certs_dir, tcp_tls},
58};
59
60// Connection timing constants
61const CONNECTION_STATE_CHECK_INTERVAL_MS: u64 = 10;
62const GRACEFUL_SHUTDOWN_DELAY_MS: u64 = 100;
63const GRACEFUL_SHUTDOWN_TIMEOUT_SECS: u64 = 5;
64
65// Maximum buffer size for read operations (10 MB)
66const MAX_READ_BUFFER_BYTES: usize = 10 * 1024 * 1024;
67
68/// Creates a `TcpStream` with the server.
69///
70/// The stream can be encrypted with TLS or Plain. The stream is split into
71/// read and write ends:
72/// - The read end is passed to the task that keeps receiving
73///   messages from the server and passing them to a handler.
74/// - The write end is passed to a task which receives messages over a channel
75///   to send to the server.
76///
77/// The heartbeat is optional and can be configured with an interval and data to
78/// send.
79///
80/// The client uses a suffix to separate messages on the byte stream. It is
81/// appended to all sent messages and heartbeats. It is also used to split
82/// the received byte stream.
83#[cfg_attr(
84    feature = "python",
85    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
86)]
87struct SocketClientInner {
88    config: SocketConfig,
89    connector: Option<Connector>,
90    read_task: Arc<tokio::task::JoinHandle<()>>,
91    write_task: tokio::task::JoinHandle<()>,
92    writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
93    heartbeat_task: Option<tokio::task::JoinHandle<()>>,
94    connection_mode: Arc<AtomicU8>,
95    state_notify: Arc<tokio::sync::Notify>,
96    reconnect_timeout: Duration,
97    backoff: ExponentialBackoff,
98    handler: Option<TcpMessageHandler>,
99    reconnect_max_attempts: Option<u32>,
100    reconnect_attempt_count: u32,
101}
102
103impl SocketClientInner {
104    /// Connect to a URL with the specified configuration.
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if connection fails or configuration is invalid.
109    pub async fn connect_url(config: SocketConfig) -> anyhow::Result<Self> {
110        const CONNECTION_TIMEOUT_SECS: u64 = 10;
111
112        install_cryptographic_provider();
113
114        // Validate suffix is non-empty to prevent panic in read loop (windows(0) panics)
115        if config.suffix.is_empty() {
116            anyhow::bail!("Socket suffix cannot be empty: suffix is required for message framing");
117        }
118
119        if let Some((interval_secs, _)) = &config.heartbeat
120            && *interval_secs == 0
121        {
122            anyhow::bail!("Heartbeat interval cannot be zero");
123        }
124
125        if config.idle_timeout_ms == Some(0) {
126            anyhow::bail!("Idle timeout cannot be zero");
127        }
128
129        let SocketConfig {
130            url,
131            mode,
132            heartbeat,
133            suffix,
134            message_handler,
135            reconnect_timeout_ms,
136            reconnect_delay_initial_ms,
137            reconnect_delay_max_ms,
138            reconnect_backoff_factor,
139            reconnect_jitter_ms,
140            connection_max_retries,
141            reconnect_max_attempts,
142            idle_timeout_ms,
143            certs_dir,
144        } = &config.clone();
145        let connector = if let Some(dir) = certs_dir {
146            let config = create_tls_config_from_certs_dir(Path::new(dir), false)?;
147            Some(Connector::Rustls(Arc::new(config)))
148        } else {
149            None
150        };
151
152        // Retry initial connection with exponential backoff to handle transient DNS/network issues
153        let max_retries = connection_max_retries.unwrap_or(5);
154
155        let mut backoff = ExponentialBackoff::new(
156            Duration::from_millis(500),
157            Duration::from_secs(5),
158            2.0,
159            250,
160            false,
161        )?;
162
163        #[allow(unused_assignments)]
164        let mut last_error = String::new();
165        let mut attempt = 0;
166        let (reader, writer) = loop {
167            attempt += 1;
168
169            match dst::time::timeout(
170                Duration::from_secs(CONNECTION_TIMEOUT_SECS),
171                Self::tls_connect_with_server(url, *mode, connector.clone()),
172            )
173            .await
174            {
175                Ok(Ok(result)) => {
176                    if attempt > 1 {
177                        log::info!("Socket connection established after {attempt} attempts");
178                    }
179                    break result;
180                }
181                Ok(Err(e)) => {
182                    last_error = e.to_string();
183                    log::warn!(
184                        "Socket connection attempt {attempt}/{max_retries} to {url} failed: {last_error}"
185                    );
186                }
187                Err(_) => {
188                    last_error = format!(
189                        "Connection timeout after {CONNECTION_TIMEOUT_SECS}s (possible DNS resolution failure)"
190                    );
191                    log::warn!(
192                        "Socket connection attempt {attempt}/{max_retries} to {url} timed out"
193                    );
194                }
195            }
196
197            if attempt >= max_retries {
198                anyhow::bail!(
199                    "Failed to connect to {} after {} attempts: {}. \
200                    If this is a DNS error, check your network configuration and DNS settings.",
201                    url,
202                    max_retries,
203                    if last_error.is_empty() {
204                        "unknown error"
205                    } else {
206                        &last_error
207                    }
208                );
209            }
210
211            let delay = backoff.next_duration();
212            log::debug!(
213                "Retrying in {delay:?} (attempt {}/{})",
214                attempt + 1,
215                max_retries
216            );
217            dst::time::sleep(delay).await;
218        };
219
220        log::debug!("Connected");
221
222        let connection_mode = Arc::new(AtomicU8::new(ConnectionMode::Active.as_u8()));
223        let state_notify = Arc::new(tokio::sync::Notify::new());
224
225        let read_task = Arc::new(Self::spawn_read_task(
226            connection_mode.clone(),
227            reader,
228            message_handler.clone(),
229            suffix.clone(),
230            *idle_timeout_ms,
231        ));
232
233        let (writer_tx, writer_rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
234
235        let write_task = Self::spawn_write_task(
236            connection_mode.clone(),
237            state_notify.clone(),
238            writer,
239            writer_rx,
240            suffix.clone(),
241        );
242
243        // Optionally spawn a heartbeat task to periodically ping server
244        let heartbeat_task = heartbeat.as_ref().map(|heartbeat| {
245            Self::spawn_heartbeat_task(
246                connection_mode.clone(),
247                heartbeat.clone(),
248                writer_tx.clone(),
249            )
250        });
251
252        let reconnect_timeout = Duration::from_millis(reconnect_timeout_ms.unwrap_or(10_000));
253        let backoff = ExponentialBackoff::new(
254            Duration::from_millis(reconnect_delay_initial_ms.unwrap_or(2_000)),
255            Duration::from_millis(reconnect_delay_max_ms.unwrap_or(30_000)),
256            reconnect_backoff_factor.unwrap_or(1.5),
257            reconnect_jitter_ms.unwrap_or(100),
258            true, // immediate-first
259        )?;
260
261        Ok(Self {
262            config,
263            connector,
264            read_task,
265            write_task,
266            writer_tx,
267            heartbeat_task,
268            connection_mode,
269            state_notify,
270            reconnect_timeout,
271            backoff,
272            handler: message_handler.clone(),
273            reconnect_max_attempts: *reconnect_max_attempts,
274            reconnect_attempt_count: 0,
275        })
276    }
277
278    /// Parse URL and extract socket address and request URL.
279    ///
280    /// Accepts either:
281    /// - Raw socket address: "host:port" → returns ("host:port", "scheme://host:port")
282    /// - Full URL: "scheme://host:port/path" → returns ("host:port", original URL)
283    ///
284    /// # Errors
285    ///
286    /// Returns an error if the URL is invalid or missing required components.
287    fn parse_socket_url(url: &str, mode: Mode) -> Result<(String, String), Error> {
288        if url.contains("://") {
289            // URL with scheme (e.g., "wss://host:port/path")
290            let parsed = url.parse::<http::Uri>().map_err(|e| {
291                Error::Io(std::io::Error::new(
292                    std::io::ErrorKind::InvalidInput,
293                    format!("Invalid URL: {e}"),
294                ))
295            })?;
296
297            let host = parsed.host().ok_or_else(|| {
298                Error::Io(std::io::Error::new(
299                    std::io::ErrorKind::InvalidInput,
300                    "URL missing host",
301                ))
302            })?;
303
304            let port = parsed
305                .port_u16()
306                .unwrap_or_else(|| match parsed.scheme_str() {
307                    Some("wss" | "https") => 443,
308                    Some("ws" | "http") => 80,
309                    _ => match mode {
310                        Mode::Tls => 443,
311                        Mode::Plain => 80,
312                    },
313                });
314
315            Ok((format!("{host}:{port}"), url.to_string()))
316        } else {
317            // Raw socket address (e.g., "host:port")
318            // Construct a proper URL for the request based on mode
319            let scheme = match mode {
320                Mode::Tls => "wss",
321                Mode::Plain => "ws",
322            };
323            Ok((url.to_string(), format!("{scheme}://{url}")))
324        }
325    }
326
327    /// Establish a TLS or plain TCP connection with the server.
328    ///
329    /// Accepts either a raw socket address (e.g., "host:port") or a full URL with scheme
330    /// (e.g., "wss://host:port"). For FIX/raw socket connections, use the host:port format.
331    /// For WebSocket-style connections, include the scheme.
332    ///
333    /// # Errors
334    ///
335    /// Returns an error if the connection cannot be established.
336    pub async fn tls_connect_with_server(
337        url: &str,
338        mode: Mode,
339        connector: Option<Connector>,
340    ) -> Result<(TcpReader, TcpWriter), Error> {
341        log::debug!("Connecting to {url}");
342
343        let (socket_addr, request_url) = Self::parse_socket_url(url, mode)?;
344        let tcp_result = TcpStream::connect(&socket_addr).await;
345
346        match tcp_result {
347            Ok(stream) => {
348                log::debug!("TCP connection established to {socket_addr}, proceeding with TLS");
349
350                if let Err(e) = stream.set_nodelay(true) {
351                    log::warn!("Failed to enable TCP_NODELAY for socket client: {e:?}");
352                }
353                let request = request_url.into_client_request()?;
354                tcp_tls(&request, mode, stream, connector)
355                    .await
356                    .map(tokio::io::split)
357            }
358            Err(e) => {
359                log::error!("TCP connection failed to {socket_addr}: {e:?}");
360                Err(Error::Io(e))
361            }
362        }
363    }
364
365    /// Reconnect with server.
366    ///
367    /// Makes a new connection with server, uses the new read and write halves
368    /// to update the reader and writer.
369    async fn reconnect(&mut self) -> Result<(), Error> {
370        log::debug!("Reconnecting");
371
372        if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
373            log::debug!("Reconnect aborted due to disconnect state");
374            return Ok(());
375        }
376
377        dst::time::timeout(self.reconnect_timeout, async {
378            let SocketConfig {
379                url,
380                mode,
381                heartbeat: _,
382                suffix,
383                message_handler: _,
384                reconnect_timeout_ms: _,
385                reconnect_delay_initial_ms: _,
386                reconnect_backoff_factor: _,
387                reconnect_delay_max_ms: _,
388                reconnect_jitter_ms: _,
389                connection_max_retries: _,
390                reconnect_max_attempts: _,
391                idle_timeout_ms,
392                certs_dir: _,
393            } = &self.config;
394            // Create a fresh connection
395            let connector = self.connector.clone();
396            // Attempt to connect; abort early if a disconnect was requested
397            let (reader, new_writer) = Self::tls_connect_with_server(url, *mode, connector).await?;
398
399            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
400                log::debug!("Reconnect aborted mid-flight (after connect)");
401                return Ok(());
402            }
403            log::debug!("Connected");
404
405            // Use a oneshot channel to synchronize with the writer task.
406            // We must verify that the buffer was successfully drained before transitioning to ACTIVE
407            // to prevent silent message loss if the new connection drops immediately.
408            let (tx, rx) = tokio::sync::oneshot::channel();
409            if let Err(e) = self.writer_tx.send(WriterCommand::Update(new_writer, tx)) {
410                log::error!("{e}");
411                return Err(Error::Io(std::io::Error::new(
412                    std::io::ErrorKind::BrokenPipe,
413                    format!("Failed to send update command: {e}"),
414                )));
415            }
416
417            // Wait for writer to confirm it has drained the buffer
418            match rx.await {
419                Ok(true) => log::debug!("Writer confirmed buffer drain success"),
420                Ok(false) => {
421                    log::warn!("Writer failed to drain buffer, aborting reconnect");
422                    // Return error to trigger retry logic in controller
423                    return Err(Error::Io(std::io::Error::other(
424                        "Failed to drain reconnection buffer",
425                    )));
426                }
427                Err(e) => {
428                    log::error!("Writer dropped update channel: {e}");
429                    return Err(Error::Io(std::io::Error::new(
430                        std::io::ErrorKind::BrokenPipe,
431                        "Writer task dropped response channel",
432                    )));
433                }
434            }
435
436            // Delay before closing connection
437            dst::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
438
439            if ConnectionMode::from_atomic(&self.connection_mode).is_disconnect() {
440                log::debug!("Reconnect aborted mid-flight (after delay)");
441                return Ok(());
442            }
443
444            if !self.read_task.is_finished() {
445                self.read_task.abort();
446                log_task_aborted("read");
447            }
448
449            // Atomically transition from Reconnect to Active
450            // This prevents race condition where disconnect could be requested between check and store
451            if self
452                .connection_mode
453                .compare_exchange(
454                    ConnectionMode::Reconnect.as_u8(),
455                    ConnectionMode::Active.as_u8(),
456                    Ordering::SeqCst,
457                    Ordering::SeqCst,
458                )
459                .is_err()
460            {
461                log::debug!("Reconnect aborted (state changed during reconnect)");
462                return Ok(());
463            }
464
465            // Spawn new read task
466            self.read_task = Arc::new(Self::spawn_read_task(
467                self.connection_mode.clone(),
468                reader,
469                self.handler.clone(),
470                suffix.clone(),
471                *idle_timeout_ms,
472            ));
473
474            log::debug!("Reconnect succeeded");
475            Ok(())
476        })
477        .await
478        .map_err(|_| {
479            Error::Io(std::io::Error::new(
480                std::io::ErrorKind::TimedOut,
481                format!(
482                    "reconnection timed out after {}s",
483                    self.reconnect_timeout.as_secs_f64()
484                ),
485            ))
486        })?
487    }
488
489    /// Check if the client is still alive.
490    ///
491    /// Returns `true` if both the read and write tasks are still running.
492    /// There may be some delay between the connection closing and the
493    /// client detecting it.
494    #[inline]
495    #[must_use]
496    pub fn is_alive(&self) -> bool {
497        !self.read_task.is_finished() && !self.write_task.is_finished()
498    }
499
500    #[must_use]
501    fn spawn_read_task(
502        connection_state: Arc<AtomicU8>,
503        mut reader: TcpReader,
504        handler: Option<TcpMessageHandler>,
505        suffix: Vec<u8>,
506        idle_timeout_ms: Option<u64>,
507    ) -> tokio::task::JoinHandle<()> {
508        log_task_started("read");
509
510        // Interval between checking the connection mode
511        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
512        let idle_timeout = idle_timeout_ms.map(Duration::from_millis);
513
514        tokio::task::spawn(async move {
515            let mut buf = Vec::new();
516            let mut last_data_time = dst::time::Instant::now();
517
518            loop {
519                if !ConnectionMode::from_atomic(&connection_state).is_active() {
520                    break;
521                }
522
523                match dst::time::timeout(check_interval, reader.read_buf(&mut buf)).await {
524                    // Connection has been terminated or vector buffer is complete
525                    Ok(Ok(0)) => {
526                        log::debug!("Connection closed by server");
527                        break;
528                    }
529                    Ok(Err(e)) => {
530                        log::debug!("Connection ended: {e}");
531                        break;
532                    }
533                    // Received bytes of data
534                    Ok(Ok(bytes)) => {
535                        log::trace!("Received <binary> {bytes} bytes");
536                        last_data_time = dst::time::Instant::now();
537
538                        while let Some((i, _)) = &buf
539                            .windows(suffix.len())
540                            .enumerate()
541                            .find(|(_, pair)| pair.eq(&suffix))
542                        {
543                            let mut data: Vec<u8> = buf.drain(0..i + suffix.len()).collect();
544                            data.truncate(data.len() - suffix.len());
545
546                            if let Some(ref handler) = handler {
547                                handler(&data);
548                            }
549                        }
550
551                        if buf.len() > MAX_READ_BUFFER_BYTES {
552                            log::error!(
553                                "Read buffer exceeded maximum size ({MAX_READ_BUFFER_BYTES} bytes), closing connection"
554                            );
555                            break;
556                        }
557                    }
558                    Err(_) => {
559                        if let Some(timeout) = idle_timeout {
560                            let idle_duration = last_data_time.elapsed();
561                            if idle_duration >= timeout {
562                                log::warn!(
563                                    "Read idle timeout: no data received for {:.1}s",
564                                    idle_duration.as_secs_f64()
565                                );
566                                break;
567                            }
568                        }
569                    }
570                }
571            }
572
573            log_task_stopped("read");
574        })
575    }
576
577    /// Drains buffered messages after reconnection completes.
578    ///
579    /// Attempts to send all buffered messages that were queued during reconnection.
580    /// Uses a peek-and-pop pattern to preserve messages if sending fails midway through the buffer.
581    ///
582    /// # Returns
583    ///
584    /// Returns `true` if a send error occurred (buffer may still contain unsent messages),
585    /// `false` if all messages were sent successfully (buffer is empty).
586    async fn drain_reconnect_buffer(
587        buffer: &mut VecDeque<Bytes>,
588        writer: &mut TcpWriter,
589        suffix: &[u8],
590    ) -> bool {
591        if buffer.is_empty() {
592            return false;
593        }
594
595        let initial_buffer_len = buffer.len();
596        log::info!("Sending {initial_buffer_len} buffered messages after reconnection");
597
598        let mut send_error_occurred = false;
599
600        while let Some(buffered_msg) = buffer.front() {
601            let mut combined_msg = Vec::with_capacity(buffered_msg.len() + suffix.len());
602            combined_msg.extend_from_slice(buffered_msg);
603            combined_msg.extend_from_slice(suffix);
604
605            if let Err(e) = writer.write_all(&combined_msg).await {
606                log::error!(
607                    "Failed to send buffered message with suffix after reconnection: {e}, {} messages remain in buffer",
608                    buffer.len()
609                );
610                send_error_occurred = true;
611                break;
612            }
613
614            buffer.pop_front();
615        }
616
617        if buffer.is_empty() {
618            log::info!("Successfully sent all {initial_buffer_len} buffered messages");
619        }
620
621        send_error_occurred
622    }
623
624    fn spawn_write_task(
625        connection_state: Arc<AtomicU8>,
626        state_notify: Arc<tokio::sync::Notify>,
627        writer: TcpWriter,
628        mut writer_rx: tokio::sync::mpsc::UnboundedReceiver<WriterCommand>,
629        suffix: Vec<u8>,
630    ) -> tokio::task::JoinHandle<()> {
631        log_task_started("write");
632
633        // Interval between checking the connection mode
634        let check_interval = Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS);
635
636        tokio::task::spawn(async move {
637            let mut active_writer = writer;
638            let mut reconnect_buffer: VecDeque<Bytes> = VecDeque::new();
639            let mut write_buf: Vec<u8> = Vec::new();
640
641            loop {
642                if matches!(
643                    ConnectionMode::from_atomic(&connection_state),
644                    ConnectionMode::Disconnect | ConnectionMode::Closed
645                ) {
646                    break;
647                }
648
649                match dst::time::timeout(check_interval, writer_rx.recv()).await {
650                    Ok(Some(msg)) => {
651                        // Re-check connection mode after receiving a message
652                        let mode = ConnectionMode::from_atomic(&connection_state);
653                        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
654                            break;
655                        }
656
657                        match msg {
658                            WriterCommand::Update(new_writer, tx) => {
659                                log::debug!("Received new writer");
660
661                                // Delay before closing connection
662                                dst::time::sleep(Duration::from_millis(100)).await;
663
664                                // Attempt to shutdown the writer gracefully before updating,
665                                // we ignore any error as the writer may already be closed.
666                                _ = dst::time::timeout(
667                                    Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
668                                    active_writer.shutdown(),
669                                )
670                                .await;
671
672                                active_writer = new_writer;
673                                log::debug!("Updated writer");
674
675                                let send_error = Self::drain_reconnect_buffer(
676                                    &mut reconnect_buffer,
677                                    &mut active_writer,
678                                    &suffix,
679                                )
680                                .await;
681
682                                if let Err(e) = tx.send(!send_error) {
683                                    log::error!(
684                                        "Failed to report drain status to controller: {e:?}"
685                                    );
686                                }
687                            }
688                            _ if mode.is_reconnect() => {
689                                if let WriterCommand::Send(data) = msg {
690                                    log::debug!(
691                                        "Buffering message while reconnecting ({} bytes)",
692                                        data.len()
693                                    );
694                                    reconnect_buffer.push_back(data);
695                                }
696                            }
697                            WriterCommand::Send(msg) => {
698                                write_buf.clear();
699                                write_buf.extend_from_slice(&msg);
700                                write_buf.extend_from_slice(&suffix);
701
702                                if let Err(e) = active_writer.write_all(&write_buf).await {
703                                    log::error!("Failed to send message: {e}");
704                                    log::warn!("Writer triggering reconnect");
705
706                                    reconnect_buffer.push_back(msg);
707                                    connection_state
708                                        .store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
709                                    state_notify.notify_one();
710                                }
711                            }
712                        }
713                    }
714                    Ok(None) => {
715                        // Channel closed - writer task should terminate
716                        log::debug!("Writer channel closed, terminating writer task");
717                        break;
718                    }
719                    Err(_) => {
720                        // Timeout - just continue the loop
721                    }
722                }
723            }
724
725            // Attempt to shutdown the writer gracefully before exiting,
726            // we ignore any error as the writer may already be closed.
727            _ = dst::time::timeout(
728                Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS),
729                active_writer.shutdown(),
730            )
731            .await;
732
733            log_task_stopped("write");
734        })
735    }
736
737    fn spawn_heartbeat_task(
738        connection_state: Arc<AtomicU8>,
739        heartbeat: (u64, Vec<u8>),
740        writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
741    ) -> tokio::task::JoinHandle<()> {
742        log_task_started("heartbeat");
743        let (interval_secs, message) = heartbeat;
744
745        tokio::task::spawn(async move {
746            let interval = Duration::from_secs(interval_secs);
747
748            loop {
749                dst::time::sleep(interval).await;
750
751                match ConnectionMode::from_u8(connection_state.load(Ordering::SeqCst)) {
752                    ConnectionMode::Active => {
753                        let msg = WriterCommand::Send(message.clone().into());
754
755                        match writer_tx.send(msg) {
756                            Ok(()) => log::trace!("Sent heartbeat to writer task"),
757                            Err(e) => {
758                                log::error!("Failed to send heartbeat to writer task: {e}");
759                            }
760                        }
761                    }
762                    ConnectionMode::Reconnect => {}
763                    ConnectionMode::Disconnect | ConnectionMode::Closed => break,
764                }
765            }
766
767            log_task_stopped("heartbeat");
768        })
769    }
770}
771
772impl Drop for SocketClientInner {
773    fn drop(&mut self) {
774        // Delegate to explicit cleanup handler
775        self.clean_drop();
776    }
777}
778
779/// Cleanup on drop: aborts background tasks and clears handlers to break reference cycles.
780impl CleanDrop for SocketClientInner {
781    fn clean_drop(&mut self) {
782        if !self.read_task.is_finished() {
783            self.read_task.abort();
784            log_task_aborted("read");
785        }
786
787        if !self.write_task.is_finished() {
788            self.write_task.abort();
789            log_task_aborted("write");
790        }
791
792        if let Some(ref handle) = self.heartbeat_task.take()
793            && !handle.is_finished()
794        {
795            handle.abort();
796            log_task_aborted("heartbeat");
797        }
798
799        #[cfg(feature = "python")]
800        {
801            // Remove stored handler to break ref cycle
802            self.config.message_handler = None;
803        }
804    }
805}
806
807#[cfg_attr(
808    feature = "python",
809    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.network")
810)]
811#[cfg_attr(
812    feature = "python",
813    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.network")
814)]
815pub struct SocketClient {
816    pub(crate) controller_task: tokio::task::JoinHandle<()>,
817    pub(crate) connection_mode: Arc<AtomicU8>,
818    pub(crate) state_notify: Arc<tokio::sync::Notify>,
819    pub(crate) reconnect_timeout: Duration,
820    pub writer_tx: tokio::sync::mpsc::UnboundedSender<WriterCommand>,
821}
822
823impl Debug for SocketClient {
824    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
825        f.debug_struct(stringify!(SocketClient)).finish()
826    }
827}
828
829impl SocketClient {
830    /// Connect to the server.
831    ///
832    /// # Errors
833    ///
834    /// Returns any error connecting to the server.
835    pub async fn connect(
836        config: SocketConfig,
837        post_connection: Option<Arc<dyn Fn() + Send + Sync>>,
838        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
839        post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
840    ) -> anyhow::Result<Self> {
841        let inner = SocketClientInner::connect_url(config).await?;
842        let writer_tx = inner.writer_tx.clone();
843        let connection_mode = inner.connection_mode.clone();
844        let state_notify = inner.state_notify.clone();
845        let reconnect_timeout = inner.reconnect_timeout;
846
847        let controller_task = Self::spawn_controller_task(
848            inner,
849            connection_mode.clone(),
850            state_notify.clone(),
851            post_reconnection,
852            post_disconnection,
853        );
854
855        if let Some(handler) = post_connection {
856            handler();
857            log::debug!("Called `post_connection` handler");
858        }
859
860        Ok(Self {
861            controller_task,
862            connection_mode,
863            state_notify,
864            reconnect_timeout,
865            writer_tx,
866        })
867    }
868
869    /// Returns the current connection mode.
870    #[must_use]
871    pub fn connection_mode(&self) -> ConnectionMode {
872        ConnectionMode::from_atomic(&self.connection_mode)
873    }
874
875    /// Check if the client connection is active.
876    ///
877    /// Returns `true` if the client is connected and has not been signalled to disconnect.
878    /// The client will automatically retry connection based on its configuration.
879    #[inline]
880    #[must_use]
881    pub fn is_active(&self) -> bool {
882        self.connection_mode().is_active()
883    }
884
885    /// Check if the client is reconnecting.
886    ///
887    /// Returns `true` if the client lost connection and is attempting to reestablish it.
888    /// The client will automatically retry connection based on its configuration.
889    #[inline]
890    #[must_use]
891    pub fn is_reconnecting(&self) -> bool {
892        self.connection_mode().is_reconnect()
893    }
894
895    /// Check if the client is disconnecting.
896    ///
897    /// Returns `true` if the client is in disconnect mode.
898    #[inline]
899    #[must_use]
900    pub fn is_disconnecting(&self) -> bool {
901        self.connection_mode().is_disconnect()
902    }
903
904    /// Check if the client is closed.
905    ///
906    /// Returns `true` if the client has been explicitly disconnected or reached
907    /// maximum reconnection attempts. In this state, the client cannot be reused
908    /// and a new client must be created for further connections.
909    #[inline]
910    #[must_use]
911    pub fn is_closed(&self) -> bool {
912        self.connection_mode().is_closed()
913    }
914
915    /// Close the client.
916    ///
917    /// Controller task will periodically check the disconnect mode
918    /// and shutdown the client if it is not alive.
919    pub async fn close(&self) {
920        self.connection_mode
921            .store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
922        self.state_notify.notify_waiters();
923
924        if dst::time::timeout(Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS), async {
925            while !self.is_closed() {
926                dst::time::sleep(Duration::from_millis(CONNECTION_STATE_CHECK_INTERVAL_MS)).await;
927            }
928
929            if !self.controller_task.is_finished() {
930                self.controller_task.abort();
931                log_task_aborted("controller");
932            }
933        })
934        .await
935            == Ok(())
936        {
937            log_task_stopped("controller");
938        } else {
939            log::error!("Timeout waiting for controller task to finish");
940
941            if !self.controller_task.is_finished() {
942                self.controller_task.abort();
943                log_task_aborted("controller");
944            }
945            self.connection_mode
946                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
947        }
948    }
949
950    /// Checks whether the connection is in a terminal state (disconnecting or closed).
951    ///
952    /// Single atomic load to fail fast before waiting.
953    #[inline]
954    fn check_not_terminal(&self) -> Result<(), SendError> {
955        match self.connection_mode() {
956            ConnectionMode::Disconnect | ConnectionMode::Closed => Err(SendError::Closed),
957            _ => Ok(()),
958        }
959    }
960
961    /// Waits for the client to become active before sending.
962    ///
963    /// Uses `state_notify` for event-driven wakeup so sends resume immediately
964    /// after reconnection completes. A fallback interval guards against missed
965    /// notifications.
966    async fn wait_for_active(&self) -> Result<(), SendError> {
967        const FALLBACK_INTERVAL_MS: u64 = 100;
968
969        let mode = self.connection_mode();
970        if mode.is_active() {
971            return Ok(());
972        }
973
974        if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
975            return Err(SendError::Closed);
976        }
977
978        log::debug!("Waiting for client to become ACTIVE before sending...");
979
980        let fallback_interval = Duration::from_millis(FALLBACK_INTERVAL_MS);
981
982        dst::time::timeout(self.reconnect_timeout, async {
983            loop {
984                let notified = self.state_notify.notified();
985
986                let mode = self.connection_mode();
987                if mode.is_active() {
988                    return Ok(());
989                }
990
991                if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
992                    return Err(());
993                }
994
995                tokio::select! {
996                    biased;
997                    () = notified => {}
998                    () = dst::time::sleep(fallback_interval) => {}
999                }
1000            }
1001        })
1002        .await
1003        .map_err(|_| SendError::Timeout)?
1004        .map_err(|()| SendError::Closed)
1005    }
1006
1007    /// Sends a message of the given `data`.
1008    ///
1009    /// Returns `Ok(())` when the message is enqueued to the writer channel. This does NOT
1010    /// guarantee delivery: if a disconnect occurs concurrently, the writer task may drop the
1011    /// message. During reconnection, messages are buffered and replayed on the new connection.
1012    ///
1013    /// # Errors
1014    ///
1015    /// Returns an error if sending fails.
1016    pub async fn send_bytes(&self, data: Vec<u8>) -> Result<(), SendError> {
1017        self.check_not_terminal()?;
1018        self.wait_for_active().await?;
1019
1020        let msg = WriterCommand::Send(data.into());
1021        self.writer_tx
1022            .send(msg)
1023            .map_err(|e| SendError::BrokenPipe(e.to_string()))
1024    }
1025
1026    fn spawn_controller_task(
1027        mut inner: SocketClientInner,
1028        connection_mode: Arc<AtomicU8>,
1029        state_notify: Arc<tokio::sync::Notify>,
1030        post_reconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1031        post_disconnection: Option<Arc<dyn Fn() + Send + Sync>>,
1032    ) -> tokio::task::JoinHandle<()> {
1033        const CONTROLLER_FALLBACK_INTERVAL_MS: u64 = 100;
1034
1035        tokio::task::spawn(async move {
1036            log_task_started("controller");
1037
1038            let fallback_interval = Duration::from_millis(CONTROLLER_FALLBACK_INTERVAL_MS);
1039
1040            loop {
1041                tokio::select! {
1042                    biased;
1043                    () = state_notify.notified() => {}
1044                    () = dst::time::sleep(fallback_interval) => {}
1045                }
1046
1047                let mut mode = ConnectionMode::from_atomic(&connection_mode);
1048
1049                if mode.is_disconnect() {
1050                    log::debug!("Disconnecting");
1051
1052                    let timeout = Duration::from_secs(GRACEFUL_SHUTDOWN_TIMEOUT_SECS);
1053                    if dst::time::timeout(timeout, async {
1054                        // Delay awaiting graceful shutdown
1055                        dst::time::sleep(Duration::from_millis(GRACEFUL_SHUTDOWN_DELAY_MS)).await;
1056
1057                        if !inner.read_task.is_finished() {
1058                            inner.read_task.abort();
1059                            log_task_aborted("read");
1060                        }
1061
1062                        if let Some(task) = &inner.heartbeat_task
1063                            && !task.is_finished()
1064                        {
1065                            task.abort();
1066                            log_task_aborted("heartbeat");
1067                        }
1068                    })
1069                    .await
1070                    .is_err()
1071                    {
1072                        log::error!("Shutdown timed out after {}s", timeout.as_secs());
1073                    }
1074
1075                    log::debug!("Closed");
1076
1077                    if let Some(ref handler) = post_disconnection {
1078                        handler();
1079                        log::debug!("Called `post_disconnection` handler");
1080                    }
1081                    break; // Controller finished
1082                }
1083
1084                if mode.is_closed() {
1085                    log::debug!("Connection closed");
1086                    break;
1087                }
1088
1089                if mode.is_active() && !inner.is_alive() {
1090                    if connection_mode
1091                        .compare_exchange(
1092                            ConnectionMode::Active.as_u8(),
1093                            ConnectionMode::Reconnect.as_u8(),
1094                            Ordering::SeqCst,
1095                            Ordering::SeqCst,
1096                        )
1097                        .is_ok()
1098                    {
1099                        log::debug!("Detected dead read task, transitioning to RECONNECT");
1100                    }
1101                    mode = ConnectionMode::from_atomic(&connection_mode);
1102                }
1103
1104                if mode.is_reconnect() {
1105                    // Check max reconnection attempts before attempting reconnect
1106                    if let Some(max_attempts) = inner.reconnect_max_attempts
1107                        && inner.reconnect_attempt_count >= max_attempts
1108                    {
1109                        log::error!(
1110                            "Max reconnection attempts ({max_attempts}) exceeded, transitioning to CLOSED"
1111                        );
1112                        connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1113                        state_notify.notify_waiters();
1114                        break;
1115                    }
1116
1117                    inner.reconnect_attempt_count += 1;
1118
1119                    // Race reconnect against disconnect notification
1120                    let reconnect_result = tokio::select! {
1121                        biased;
1122                        result = inner.reconnect() => Some(result),
1123                        () = async {
1124                            loop {
1125                                state_notify.notified().await;
1126
1127                                if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
1128                                    break;
1129                                }
1130                            }
1131                        } => None,
1132                    };
1133
1134                    match reconnect_result {
1135                        None => {
1136                            log::debug!("Reconnect interrupted by disconnect");
1137                        }
1138                        Some(Ok(())) => {
1139                            log::debug!("Reconnected successfully");
1140                            inner.backoff.reset();
1141                            inner.reconnect_attempt_count = 0;
1142
1143                            state_notify.notify_waiters();
1144
1145                            if ConnectionMode::from_atomic(&connection_mode).is_active() {
1146                                if let Some(ref handler) = post_reconnection {
1147                                    handler();
1148                                    log::debug!("Called `post_reconnection` handler");
1149                                }
1150                            } else {
1151                                log::debug!(
1152                                    "Skipping post_reconnection handlers due to disconnect state"
1153                                );
1154                            }
1155                        }
1156                        Some(Err(e)) => {
1157                            let duration = inner.backoff.next_duration();
1158                            log::warn!(
1159                                "Reconnect attempt {} failed: {e}",
1160                                inner.reconnect_attempt_count
1161                            );
1162
1163                            if !duration.is_zero() {
1164                                log::warn!("Backing off for {}s...", duration.as_secs_f64());
1165                                // Race backoff sleep against disconnect
1166                                tokio::select! {
1167                                    biased;
1168                                    () = dst::time::sleep(duration) => {}
1169                                    () = async {
1170                                        loop {
1171                                            state_notify.notified().await;
1172
1173                                            if ConnectionMode::from_atomic(&connection_mode).is_disconnect() {
1174                                                break;
1175                                            }
1176                                        }
1177                                    } => {
1178                                        log::debug!("Backoff interrupted by disconnect");
1179                                    }
1180                                }
1181                            }
1182                        }
1183                    }
1184                }
1185            }
1186            inner
1187                .connection_mode
1188                .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
1189
1190            log_task_stopped("controller");
1191        })
1192    }
1193}
1194
1195// Abort controller task on drop to clean up background tasks
1196impl Drop for SocketClient {
1197    fn drop(&mut self) {
1198        if !self.controller_task.is_finished() {
1199            self.controller_task.abort();
1200            log_task_aborted("controller");
1201        }
1202    }
1203}
1204
1205#[cfg(test)]
1206#[cfg(feature = "python")]
1207#[cfg(not(all(feature = "simulation", madsim)))] // transport-layer I/O not simulated
1208#[cfg(target_os = "linux")] // Only run network tests on Linux (CI stability)
1209mod tests {
1210    use nautilus_common::testing::wait_until_async;
1211    use pyo3::Python;
1212    use tokio::{
1213        io::{AsyncReadExt, AsyncWriteExt},
1214        net::{TcpListener, TcpStream},
1215        sync::Mutex,
1216        task,
1217        time::{Duration, sleep},
1218    };
1219
1220    use super::*;
1221
1222    async fn bind_test_server() -> (u16, TcpListener) {
1223        let listener = TcpListener::bind("127.0.0.1:0")
1224            .await
1225            .expect("Failed to bind ephemeral port");
1226        let port = listener.local_addr().unwrap().port();
1227        (port, listener)
1228    }
1229
1230    async fn run_echo_server(mut socket: TcpStream) {
1231        let mut buf = Vec::new();
1232        loop {
1233            match socket.read_buf(&mut buf).await {
1234                Ok(0) => {
1235                    break;
1236                }
1237                Ok(_n) => {
1238                    while let Some(idx) = buf.array_windows().position(|w| w == b"\r\n") {
1239                        let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1240                        // Remove trailing \r\n
1241                        line.truncate(line.len() - 2);
1242
1243                        if line == b"close" {
1244                            let _ = socket.shutdown().await;
1245                            return;
1246                        }
1247
1248                        let mut echo_data = line;
1249                        echo_data.extend_from_slice(b"\r\n");
1250                        if socket.write_all(&echo_data).await.is_err() {
1251                            break;
1252                        }
1253                    }
1254                }
1255                Err(e) => {
1256                    eprintln!("Server read error: {e}");
1257                    break;
1258                }
1259            }
1260        }
1261    }
1262
1263    #[tokio::test]
1264    async fn test_basic_send_receive() {
1265        Python::initialize();
1266
1267        let (port, listener) = bind_test_server().await;
1268        let server_task = task::spawn(async move {
1269            let (socket, _) = listener.accept().await.unwrap();
1270            run_echo_server(socket).await;
1271        });
1272
1273        let config = SocketConfig {
1274            url: format!("127.0.0.1:{port}"),
1275            mode: Mode::Plain,
1276            suffix: b"\r\n".to_vec(),
1277            message_handler: None,
1278            heartbeat: None,
1279            reconnect_timeout_ms: None,
1280            reconnect_delay_initial_ms: None,
1281            reconnect_backoff_factor: None,
1282            reconnect_delay_max_ms: None,
1283            reconnect_jitter_ms: None,
1284            reconnect_max_attempts: None,
1285            connection_max_retries: None,
1286            idle_timeout_ms: None,
1287            certs_dir: None,
1288        };
1289
1290        let client = SocketClient::connect(config, None, None, None)
1291            .await
1292            .expect("Client connect failed unexpectedly");
1293
1294        client.send_bytes(b"Hello".into()).await.unwrap();
1295        client.send_bytes(b"World".into()).await.unwrap();
1296
1297        // Wait a bit for the server to echo them back
1298        sleep(Duration::from_millis(100)).await;
1299
1300        client.send_bytes(b"close".into()).await.unwrap();
1301        server_task.await.unwrap();
1302        assert!(!client.is_closed());
1303    }
1304
1305    #[tokio::test]
1306    async fn test_reconnect_fail_exhausted() {
1307        Python::initialize();
1308
1309        let (port, listener) = bind_test_server().await;
1310        drop(listener); // We drop it immediately -> no server is listening
1311
1312        // Wait until port is truly unavailable (OS has released it)
1313        wait_until_async(
1314            || async {
1315                TcpStream::connect(format!("127.0.0.1:{port}"))
1316                    .await
1317                    .is_err()
1318            },
1319            Duration::from_secs(2),
1320        )
1321        .await;
1322
1323        let config = SocketConfig {
1324            url: format!("127.0.0.1:{port}"),
1325            mode: Mode::Plain,
1326            suffix: b"\r\n".to_vec(),
1327            message_handler: None,
1328            heartbeat: None,
1329            reconnect_timeout_ms: Some(100),
1330            reconnect_delay_initial_ms: Some(50),
1331            reconnect_backoff_factor: Some(1.0),
1332            reconnect_delay_max_ms: Some(50),
1333            reconnect_jitter_ms: Some(0),
1334            connection_max_retries: Some(1),
1335            reconnect_max_attempts: None,
1336            idle_timeout_ms: None,
1337            certs_dir: None,
1338        };
1339
1340        let client_res = SocketClient::connect(config, None, None, None).await;
1341        assert!(
1342            client_res.is_err(),
1343            "Should fail quickly with no server listening"
1344        );
1345    }
1346
1347    #[tokio::test]
1348    async fn test_user_disconnect() {
1349        Python::initialize();
1350
1351        let (port, listener) = bind_test_server().await;
1352        let server_task = task::spawn(async move {
1353            let (socket, _) = listener.accept().await.unwrap();
1354            let mut buf = [0u8; 1024];
1355            let _ = socket.try_read(&mut buf);
1356
1357            loop {
1358                sleep(Duration::from_secs(1)).await;
1359            }
1360        });
1361
1362        let config = SocketConfig {
1363            url: format!("127.0.0.1:{port}"),
1364            mode: Mode::Plain,
1365            suffix: b"\r\n".to_vec(),
1366            message_handler: None,
1367            heartbeat: None,
1368            reconnect_timeout_ms: None,
1369            reconnect_delay_initial_ms: None,
1370            reconnect_backoff_factor: None,
1371            reconnect_delay_max_ms: None,
1372            reconnect_jitter_ms: None,
1373            reconnect_max_attempts: None,
1374            connection_max_retries: None,
1375            idle_timeout_ms: None,
1376            certs_dir: None,
1377        };
1378
1379        let client = SocketClient::connect(config, None, None, None)
1380            .await
1381            .unwrap();
1382
1383        client.close().await;
1384        assert!(client.is_closed());
1385        server_task.abort();
1386    }
1387
1388    #[tokio::test]
1389    async fn test_heartbeat() {
1390        Python::initialize();
1391
1392        let (port, listener) = bind_test_server().await;
1393        let received = Arc::new(Mutex::new(Vec::new()));
1394        let received2 = received.clone();
1395
1396        let server_task = task::spawn(async move {
1397            let (socket, _) = listener.accept().await.unwrap();
1398
1399            let mut buf = Vec::new();
1400            loop {
1401                match socket.try_read_buf(&mut buf) {
1402                    Ok(0) => break,
1403                    Ok(_) => {
1404                        while let Some(idx) = buf.array_windows().position(|w| w == b"\r\n") {
1405                            let mut line = buf.drain(..idx + 2).collect::<Vec<u8>>();
1406                            line.truncate(line.len() - 2);
1407                            received2.lock().await.push(line);
1408                        }
1409                    }
1410                    Err(_) => {
1411                        tokio::time::sleep(Duration::from_millis(10)).await;
1412                    }
1413                }
1414            }
1415        });
1416
1417        // Heartbeat every 1 second
1418        let heartbeat = Some((1, b"ping".to_vec()));
1419
1420        let config = SocketConfig {
1421            url: format!("127.0.0.1:{port}"),
1422            mode: Mode::Plain,
1423            suffix: b"\r\n".to_vec(),
1424            message_handler: None,
1425            heartbeat,
1426            reconnect_timeout_ms: None,
1427            reconnect_delay_initial_ms: None,
1428            reconnect_backoff_factor: None,
1429            reconnect_delay_max_ms: None,
1430            reconnect_jitter_ms: None,
1431            reconnect_max_attempts: None,
1432            connection_max_retries: None,
1433            idle_timeout_ms: None,
1434            certs_dir: None,
1435        };
1436
1437        let client = SocketClient::connect(config, None, None, None)
1438            .await
1439            .unwrap();
1440
1441        // Wait ~3 seconds to collect some heartbeats
1442        sleep(Duration::from_secs(3)).await;
1443
1444        {
1445            let lock = received.lock().await;
1446            let pings = lock
1447                .iter()
1448                .filter(|line| line == &&b"ping".to_vec())
1449                .count();
1450            assert!(
1451                pings >= 2,
1452                "Expected at least 2 heartbeat pings; got {pings}"
1453            );
1454        }
1455
1456        client.close().await;
1457        server_task.abort();
1458    }
1459
1460    #[tokio::test]
1461    async fn test_reconnect_success() {
1462        Python::initialize();
1463
1464        let (port, listener) = bind_test_server().await;
1465
1466        // Spawn a server task that:
1467        // 1. Accepts the first connection and then drops it after a short delay (simulate disconnect)
1468        // 2. Waits a bit and then accepts a new connection and runs the echo server
1469        let server_task = task::spawn(async move {
1470            // Accept first connection
1471            let (mut socket, _) = listener.accept().await.expect("First accept failed");
1472
1473            // Wait briefly and then force-close the connection
1474            sleep(Duration::from_millis(500)).await;
1475            let _ = socket.shutdown().await;
1476
1477            // Wait for the client's reconnect attempt
1478            sleep(Duration::from_millis(500)).await;
1479
1480            // Run the echo server on the new connection
1481            let (socket, _) = listener.accept().await.expect("Second accept failed");
1482            run_echo_server(socket).await;
1483        });
1484
1485        let config = SocketConfig {
1486            url: format!("127.0.0.1:{port}"),
1487            mode: Mode::Plain,
1488            suffix: b"\r\n".to_vec(),
1489            message_handler: None,
1490            heartbeat: None,
1491            reconnect_timeout_ms: Some(5_000),
1492            reconnect_delay_initial_ms: Some(500),
1493            reconnect_delay_max_ms: Some(5_000),
1494            reconnect_backoff_factor: Some(2.0),
1495            reconnect_jitter_ms: Some(50),
1496            reconnect_max_attempts: None,
1497            connection_max_retries: None,
1498            idle_timeout_ms: None,
1499            certs_dir: None,
1500        };
1501
1502        let client = SocketClient::connect(config, None, None, None)
1503            .await
1504            .expect("Client connect failed unexpectedly");
1505
1506        // Initially, the client should be active
1507        assert!(client.is_active(), "Client should start as active");
1508
1509        // Wait until the client loses connection (i.e. not active),
1510        // then wait until it reconnects (active again).
1511        wait_until_async(|| async { client.is_active() }, Duration::from_secs(10)).await;
1512
1513        client
1514            .send_bytes(b"TestReconnect".into())
1515            .await
1516            .expect("Send failed");
1517
1518        client.close().await;
1519        server_task.abort();
1520    }
1521}
1522
1523#[cfg(test)]
1524#[cfg(not(feature = "turmoil"))]
1525#[cfg(not(all(feature = "simulation", madsim)))] // transport-layer I/O not simulated
1526mod rust_tests {
1527    use nautilus_common::testing::wait_until_async;
1528    use rstest::rstest;
1529    use tokio::{
1530        io::{AsyncReadExt, AsyncWriteExt},
1531        net::TcpListener,
1532        task,
1533        time::{Duration, sleep},
1534    };
1535
1536    use super::*;
1537
1538    #[rstest]
1539    #[tokio::test]
1540    async fn test_reconnect_then_close() {
1541        // Bind an ephemeral port
1542        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1543        let port = listener.local_addr().unwrap().port();
1544
1545        // Server task: accept one connection and then drop it
1546        let server = task::spawn(async move {
1547            if let Ok((mut sock, _)) = listener.accept().await {
1548                drop(sock.shutdown());
1549            }
1550            // Keep listener alive briefly to avoid premature exit
1551            sleep(Duration::from_secs(1)).await;
1552        });
1553
1554        // Configure client with a short reconnect backoff
1555        let config = SocketConfig {
1556            url: format!("127.0.0.1:{port}"),
1557            mode: Mode::Plain,
1558            suffix: b"\r\n".to_vec(),
1559            message_handler: None,
1560            heartbeat: None,
1561            reconnect_timeout_ms: Some(1_000),
1562            reconnect_delay_initial_ms: Some(50),
1563            reconnect_delay_max_ms: Some(100),
1564            reconnect_backoff_factor: Some(1.0),
1565            reconnect_jitter_ms: Some(0),
1566            connection_max_retries: Some(1),
1567            reconnect_max_attempts: None,
1568            idle_timeout_ms: None,
1569            certs_dir: None,
1570        };
1571
1572        // Connect client (handler=None)
1573        let client = SocketClient::connect(config.clone(), None, None, None)
1574            .await
1575            .unwrap();
1576
1577        // Wait for client to detect dropped connection and enter reconnect state
1578        wait_until_async(
1579            || async { client.is_reconnecting() },
1580            Duration::from_secs(2),
1581        )
1582        .await;
1583
1584        // Now close the client
1585        client.close().await;
1586        assert!(client.is_closed());
1587        server.abort();
1588    }
1589
1590    #[rstest]
1591    #[tokio::test]
1592    async fn test_reconnect_state_flips_when_reader_stops() {
1593        // Bind an ephemeral port and accept a single connection which we immediately close.
1594        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1595        let port = listener.local_addr().unwrap().port();
1596
1597        let server = task::spawn(async move {
1598            if let Ok((sock, _)) = listener.accept().await {
1599                drop(sock);
1600            }
1601            // Give the client a moment to observe the closed connection.
1602            sleep(Duration::from_millis(50)).await;
1603        });
1604
1605        let config = SocketConfig {
1606            url: format!("127.0.0.1:{port}"),
1607            mode: Mode::Plain,
1608            suffix: b"\r\n".to_vec(),
1609            message_handler: None,
1610            heartbeat: None,
1611            reconnect_timeout_ms: Some(1_000),
1612            reconnect_delay_initial_ms: Some(50),
1613            reconnect_delay_max_ms: Some(100),
1614            reconnect_backoff_factor: Some(1.0),
1615            reconnect_jitter_ms: Some(0),
1616            connection_max_retries: Some(1),
1617            reconnect_max_attempts: None,
1618            idle_timeout_ms: None,
1619            certs_dir: None,
1620        };
1621
1622        let client = SocketClient::connect(config, None, None, None)
1623            .await
1624            .unwrap();
1625
1626        wait_until_async(
1627            || async { client.is_reconnecting() },
1628            Duration::from_secs(2),
1629        )
1630        .await;
1631
1632        client.close().await;
1633        server.abort();
1634    }
1635
1636    #[rstest]
1637    fn test_parse_socket_url_raw_address() {
1638        // Raw socket address with TLS mode
1639        let (socket_addr, request_url) =
1640            SocketClientInner::parse_socket_url("example.com:6130", Mode::Tls).unwrap();
1641        assert_eq!(socket_addr, "example.com:6130");
1642        assert_eq!(request_url, "wss://example.com:6130");
1643
1644        // Raw socket address with Plain mode
1645        let (socket_addr, request_url) =
1646            SocketClientInner::parse_socket_url("localhost:8080", Mode::Plain).unwrap();
1647        assert_eq!(socket_addr, "localhost:8080");
1648        assert_eq!(request_url, "ws://localhost:8080");
1649    }
1650
1651    #[rstest]
1652    fn test_parse_socket_url_with_scheme() {
1653        // Full URL with wss scheme
1654        let (socket_addr, request_url) =
1655            SocketClientInner::parse_socket_url("wss://example.com:443/path", Mode::Tls).unwrap();
1656        assert_eq!(socket_addr, "example.com:443");
1657        assert_eq!(request_url, "wss://example.com:443/path");
1658
1659        // Full URL with ws scheme
1660        let (socket_addr, request_url) =
1661            SocketClientInner::parse_socket_url("ws://localhost:8080", Mode::Plain).unwrap();
1662        assert_eq!(socket_addr, "localhost:8080");
1663        assert_eq!(request_url, "ws://localhost:8080");
1664    }
1665
1666    #[rstest]
1667    fn test_parse_socket_url_default_ports() {
1668        // wss without explicit port defaults to 443
1669        let (socket_addr, _) =
1670            SocketClientInner::parse_socket_url("wss://example.com", Mode::Tls).unwrap();
1671        assert_eq!(socket_addr, "example.com:443");
1672
1673        // ws without explicit port defaults to 80
1674        let (socket_addr, _) =
1675            SocketClientInner::parse_socket_url("ws://example.com", Mode::Plain).unwrap();
1676        assert_eq!(socket_addr, "example.com:80");
1677
1678        // https defaults to 443
1679        let (socket_addr, _) =
1680            SocketClientInner::parse_socket_url("https://example.com", Mode::Tls).unwrap();
1681        assert_eq!(socket_addr, "example.com:443");
1682
1683        // http defaults to 80
1684        let (socket_addr, _) =
1685            SocketClientInner::parse_socket_url("http://example.com", Mode::Plain).unwrap();
1686        assert_eq!(socket_addr, "example.com:80");
1687    }
1688
1689    #[rstest]
1690    fn test_parse_socket_url_unknown_scheme_uses_mode() {
1691        // Unknown scheme defaults to mode-based port
1692        let (socket_addr, _) =
1693            SocketClientInner::parse_socket_url("custom://example.com", Mode::Tls).unwrap();
1694        assert_eq!(socket_addr, "example.com:443");
1695
1696        let (socket_addr, _) =
1697            SocketClientInner::parse_socket_url("custom://example.com", Mode::Plain).unwrap();
1698        assert_eq!(socket_addr, "example.com:80");
1699    }
1700
1701    #[rstest]
1702    fn test_parse_socket_url_ipv6() {
1703        // IPv6 address with port
1704        let (socket_addr, request_url) =
1705            SocketClientInner::parse_socket_url("[::1]:8080", Mode::Plain).unwrap();
1706        assert_eq!(socket_addr, "[::1]:8080");
1707        assert_eq!(request_url, "ws://[::1]:8080");
1708
1709        // IPv6 in URL
1710        let (socket_addr, _) =
1711            SocketClientInner::parse_socket_url("ws://[::1]:8080", Mode::Plain).unwrap();
1712        assert_eq!(socket_addr, "[::1]:8080");
1713    }
1714
1715    #[rstest]
1716    #[tokio::test]
1717    async fn test_url_parsing_raw_socket_address() {
1718        // Test that raw socket addresses (host:port) work correctly
1719        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1720        let port = listener.local_addr().unwrap().port();
1721
1722        let server = task::spawn(async move {
1723            if let Ok((sock, _)) = listener.accept().await {
1724                drop(sock);
1725            }
1726            sleep(Duration::from_millis(50)).await;
1727        });
1728
1729        let config = SocketConfig {
1730            url: format!("127.0.0.1:{port}"), // Raw socket address format
1731            mode: Mode::Plain,
1732            suffix: b"\r\n".to_vec(),
1733            message_handler: None,
1734            heartbeat: None,
1735            reconnect_timeout_ms: Some(1_000),
1736            reconnect_delay_initial_ms: Some(50),
1737            reconnect_delay_max_ms: Some(100),
1738            reconnect_backoff_factor: Some(1.0),
1739            reconnect_jitter_ms: Some(0),
1740            connection_max_retries: Some(1),
1741            reconnect_max_attempts: None,
1742            idle_timeout_ms: None,
1743            certs_dir: None,
1744        };
1745
1746        // Should successfully connect with raw socket address
1747        let client = SocketClient::connect(config, None, None, None).await;
1748        assert!(
1749            client.is_ok(),
1750            "Client should connect with raw socket address format"
1751        );
1752
1753        if let Ok(client) = client {
1754            client.close().await;
1755        }
1756        server.abort();
1757    }
1758
1759    #[rstest]
1760    #[tokio::test]
1761    async fn test_url_parsing_with_scheme() {
1762        // Test that URLs with schemes also work
1763        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1764        let port = listener.local_addr().unwrap().port();
1765
1766        let server = task::spawn(async move {
1767            if let Ok((sock, _)) = listener.accept().await {
1768                drop(sock);
1769            }
1770            sleep(Duration::from_millis(50)).await;
1771        });
1772
1773        let config = SocketConfig {
1774            url: format!("ws://127.0.0.1:{port}"), // URL with scheme
1775            mode: Mode::Plain,
1776            suffix: b"\r\n".to_vec(),
1777            message_handler: None,
1778            heartbeat: None,
1779            reconnect_timeout_ms: Some(1_000),
1780            reconnect_delay_initial_ms: Some(50),
1781            reconnect_delay_max_ms: Some(100),
1782            reconnect_backoff_factor: Some(1.0),
1783            reconnect_jitter_ms: Some(0),
1784            connection_max_retries: Some(1),
1785            reconnect_max_attempts: None,
1786            idle_timeout_ms: None,
1787            certs_dir: None,
1788        };
1789
1790        // Should successfully connect with URL format
1791        let client = SocketClient::connect(config, None, None, None).await;
1792        assert!(
1793            client.is_ok(),
1794            "Client should connect with URL scheme format"
1795        );
1796
1797        if let Ok(client) = client {
1798            client.close().await;
1799        }
1800        server.abort();
1801    }
1802
1803    #[rstest]
1804    fn test_parse_socket_url_ipv6_with_zone() {
1805        // IPv6 with zone ID (link-local address)
1806        let (socket_addr, request_url) =
1807            SocketClientInner::parse_socket_url("[fe80::1%eth0]:8080", Mode::Plain).unwrap();
1808        assert_eq!(socket_addr, "[fe80::1%eth0]:8080");
1809        assert_eq!(request_url, "ws://[fe80::1%eth0]:8080");
1810
1811        // Verify zone is preserved in URL format too
1812        let (socket_addr, request_url) =
1813            SocketClientInner::parse_socket_url("ws://[fe80::1%lo]:9090", Mode::Plain).unwrap();
1814        assert_eq!(socket_addr, "[fe80::1%lo]:9090");
1815        assert_eq!(request_url, "ws://[fe80::1%lo]:9090");
1816    }
1817
1818    #[rstest]
1819    #[tokio::test]
1820    async fn test_ipv6_loopback_connection() {
1821        // Test IPv6 loopback address connection
1822        // Skip if IPv6 is not available on the system
1823        if TcpListener::bind("[::1]:0").await.is_err() {
1824            eprintln!("IPv6 not available, skipping test");
1825            return;
1826        }
1827
1828        let listener = TcpListener::bind("[::1]:0").await.unwrap();
1829        let port = listener.local_addr().unwrap().port();
1830
1831        let server = task::spawn(async move {
1832            if let Ok((mut sock, _)) = listener.accept().await {
1833                let mut buf = vec![0u8; 1024];
1834                if let Ok(n) = sock.read(&mut buf).await {
1835                    // Echo back
1836                    let _ = sock.write_all(&buf[..n]).await;
1837                }
1838            }
1839            sleep(Duration::from_millis(50)).await;
1840        });
1841
1842        let config = SocketConfig {
1843            url: format!("[::1]:{port}"), // IPv6 loopback
1844            mode: Mode::Plain,
1845            suffix: b"\r\n".to_vec(),
1846            message_handler: None,
1847            heartbeat: None,
1848            reconnect_timeout_ms: Some(1_000),
1849            reconnect_delay_initial_ms: Some(50),
1850            reconnect_delay_max_ms: Some(100),
1851            reconnect_backoff_factor: Some(1.0),
1852            reconnect_jitter_ms: Some(0),
1853            connection_max_retries: Some(1),
1854            reconnect_max_attempts: None,
1855            idle_timeout_ms: None,
1856            certs_dir: None,
1857        };
1858
1859        let client = SocketClient::connect(config, None, None, None).await;
1860        assert!(
1861            client.is_ok(),
1862            "Client should connect to IPv6 loopback address"
1863        );
1864
1865        if let Ok(client) = client {
1866            client.close().await;
1867        }
1868        server.abort();
1869    }
1870
1871    #[rstest]
1872    #[tokio::test]
1873    async fn test_send_waits_during_reconnection() {
1874        // Test that send operations wait for reconnection to complete (up to configured timeout)
1875        use nautilus_common::testing::wait_until_async;
1876
1877        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1878        let port = listener.local_addr().unwrap().port();
1879
1880        let server = task::spawn(async move {
1881            // First connection - accept and immediately close
1882            if let Ok((sock, _)) = listener.accept().await {
1883                drop(sock);
1884            }
1885
1886            // Wait before accepting second connection
1887            sleep(Duration::from_millis(500)).await;
1888
1889            // Second connection - accept and keep alive
1890            if let Ok((mut sock, _)) = listener.accept().await {
1891                // Echo messages
1892                let mut buf = vec![0u8; 1024];
1893                while let Ok(n) = sock.read(&mut buf).await {
1894                    if n == 0 {
1895                        break;
1896                    }
1897
1898                    if sock.write_all(&buf[..n]).await.is_err() {
1899                        break;
1900                    }
1901                }
1902            }
1903        });
1904
1905        let config = SocketConfig {
1906            url: format!("127.0.0.1:{port}"),
1907            mode: Mode::Plain,
1908            suffix: b"\r\n".to_vec(),
1909            message_handler: None,
1910            heartbeat: None,
1911            reconnect_timeout_ms: Some(5_000), // 5s timeout - enough for reconnect
1912            reconnect_delay_initial_ms: Some(100),
1913            reconnect_delay_max_ms: Some(200),
1914            reconnect_backoff_factor: Some(1.0),
1915            reconnect_jitter_ms: Some(0),
1916            connection_max_retries: Some(1),
1917            reconnect_max_attempts: None,
1918            idle_timeout_ms: None,
1919            certs_dir: None,
1920        };
1921
1922        let client = SocketClient::connect(config, None, None, None)
1923            .await
1924            .unwrap();
1925
1926        // Wait for reconnection to trigger
1927        wait_until_async(
1928            || async { client.is_reconnecting() },
1929            Duration::from_secs(2),
1930        )
1931        .await;
1932
1933        // Try to send while reconnecting - should wait and succeed after reconnect
1934        let send_result = tokio::time::timeout(
1935            Duration::from_secs(3),
1936            client.send_bytes(b"test_message".to_vec()),
1937        )
1938        .await;
1939
1940        assert!(
1941            send_result.is_ok() && send_result.unwrap().is_ok(),
1942            "Send should succeed after waiting for reconnection"
1943        );
1944
1945        client.close().await;
1946        server.abort();
1947    }
1948
1949    #[rstest]
1950    #[tokio::test]
1951    async fn test_send_bytes_timeout_uses_configured_reconnect_timeout() {
1952        // Test that send_bytes operations respect the configured reconnect_timeout.
1953        // When a client is stuck in RECONNECT longer than the timeout, sends should fail with Timeout.
1954        use nautilus_common::testing::wait_until_async;
1955
1956        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
1957        let port = listener.local_addr().unwrap().port();
1958
1959        let server = task::spawn(async move {
1960            // Accept first connection and immediately close it
1961            if let Ok((sock, _)) = listener.accept().await {
1962                drop(sock);
1963            }
1964            // Drop listener entirely so reconnection fails completely
1965            drop(listener);
1966            sleep(Duration::from_mins(1)).await;
1967        });
1968
1969        let config = SocketConfig {
1970            url: format!("127.0.0.1:{port}"),
1971            mode: Mode::Plain,
1972            suffix: b"\r\n".to_vec(),
1973            message_handler: None,
1974            heartbeat: None,
1975            reconnect_timeout_ms: Some(1_000), // 1s timeout for faster test
1976            reconnect_delay_initial_ms: Some(200), // Short backoff (but > timeout) to keep client in RECONNECT
1977            reconnect_delay_max_ms: Some(200),
1978            reconnect_backoff_factor: Some(1.0),
1979            reconnect_jitter_ms: Some(0),
1980            connection_max_retries: Some(1),
1981            reconnect_max_attempts: None,
1982            idle_timeout_ms: None,
1983            certs_dir: None,
1984        };
1985
1986        let client = SocketClient::connect(config, None, None, None)
1987            .await
1988            .unwrap();
1989
1990        // Wait for client to enter RECONNECT state
1991        wait_until_async(
1992            || async { client.is_reconnecting() },
1993            Duration::from_secs(3),
1994        )
1995        .await;
1996
1997        // Attempt send while stuck in RECONNECT - should timeout after 1s (configured timeout)
1998        // The client will try to reconnect for 1s, fail, then wait 5s backoff before next attempt
1999        let start = std::time::Instant::now();
2000        let send_result = client.send_bytes(b"test".to_vec()).await;
2001        let elapsed = start.elapsed();
2002
2003        assert!(
2004            send_result.is_err(),
2005            "Send should fail when client stuck in RECONNECT, was: {send_result:?}"
2006        );
2007        assert!(
2008            matches!(send_result, Err(crate::error::SendError::Timeout)),
2009            "Send should return Timeout error, was: {send_result:?}"
2010        );
2011        // Verify timeout respects configured value (1s), but don't check upper bound
2012        // as CI scheduler jitter can cause legitimate delays beyond the timeout
2013        assert!(
2014            elapsed >= Duration::from_millis(900),
2015            "Send should timeout after at least 1s (configured timeout), took {elapsed:?}"
2016        );
2017
2018        client.close().await;
2019        server.abort();
2020    }
2021
2022    #[rstest]
2023    #[tokio::test]
2024    async fn test_idle_timeout_triggers_reconnect() {
2025        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2026        let port = listener.local_addr().unwrap().port();
2027
2028        // Server accepts connection but sends nothing (simulates silent death)
2029        let server = task::spawn(async move {
2030            let (_sock1, _) = listener.accept().await.unwrap();
2031            // Hold connection open but send nothing, wait for reconnect attempt
2032            sleep(Duration::from_secs(5)).await;
2033        });
2034
2035        let config = SocketConfig {
2036            url: format!("127.0.0.1:{port}"),
2037            mode: Mode::Plain,
2038            suffix: b"\r\n".to_vec(),
2039            message_handler: None,
2040            heartbeat: None,
2041            reconnect_timeout_ms: Some(2_000),
2042            reconnect_delay_initial_ms: Some(50),
2043            reconnect_delay_max_ms: Some(100),
2044            reconnect_backoff_factor: Some(1.0),
2045            reconnect_jitter_ms: Some(0),
2046            connection_max_retries: Some(1),
2047            reconnect_max_attempts: Some(1),
2048            idle_timeout_ms: Some(500),
2049            certs_dir: None,
2050        };
2051
2052        let client = SocketClient::connect(config, None, None, None)
2053            .await
2054            .unwrap();
2055
2056        assert!(client.is_active());
2057
2058        // Wait for idle timeout to fire and client to enter reconnect
2059        wait_until_async(
2060            || async { client.is_reconnecting() || client.is_closed() },
2061            Duration::from_secs(3),
2062        )
2063        .await;
2064
2065        assert!(
2066            !client.is_active(),
2067            "Client should not be active after idle timeout"
2068        );
2069
2070        client.close().await;
2071        server.abort();
2072    }
2073
2074    #[rstest]
2075    #[tokio::test]
2076    async fn test_idle_timeout_resets_on_data() {
2077        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2078        let port = listener.local_addr().unwrap().port();
2079
2080        // Server sends data every 200ms (well within the 1s idle timeout)
2081        let server = task::spawn(async move {
2082            let (mut sock, _) = listener.accept().await.unwrap();
2083            for _ in 0..10 {
2084                sleep(Duration::from_millis(200)).await;
2085
2086                if sock.write_all(b"ping\r\n").await.is_err() {
2087                    break;
2088                }
2089            }
2090        });
2091
2092        let config = SocketConfig {
2093            url: format!("127.0.0.1:{port}"),
2094            mode: Mode::Plain,
2095            suffix: b"\r\n".to_vec(),
2096            message_handler: None,
2097            heartbeat: None,
2098            reconnect_timeout_ms: Some(2_000),
2099            reconnect_delay_initial_ms: Some(50),
2100            reconnect_delay_max_ms: Some(100),
2101            reconnect_backoff_factor: Some(1.0),
2102            reconnect_jitter_ms: Some(0),
2103            connection_max_retries: Some(1),
2104            reconnect_max_attempts: Some(1),
2105            idle_timeout_ms: Some(1_000),
2106            certs_dir: None,
2107        };
2108
2109        let client = SocketClient::connect(config, None, None, None)
2110            .await
2111            .unwrap();
2112
2113        assert!(client.is_active());
2114
2115        // Wait 1.5s - data arrives every 200ms so idle timeout (1s) should NOT fire
2116        sleep(Duration::from_millis(1_500)).await;
2117
2118        assert!(
2119            client.is_active(),
2120            "Client should remain active when data is flowing"
2121        );
2122
2123        client.close().await;
2124        server.abort();
2125    }
2126
2127    #[rstest]
2128    #[tokio::test]
2129    async fn test_close_during_backoff_exits_promptly() {
2130        // Verify that close() interrupts backoff sleep (Finding 1).
2131        // Server accepts then drops, no second listener -> reconnect fails -> enters backoff.
2132        // We close while backing off and assert the client shuts down quickly.
2133        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
2134        let port = listener.local_addr().unwrap().port();
2135
2136        let server = task::spawn(async move {
2137            // Accept first connection, close immediately
2138            if let Ok((mut sock, _)) = listener.accept().await {
2139                drop(sock.shutdown());
2140            }
2141            // Don't accept again so reconnect fails and enters backoff
2142            sleep(Duration::from_mins(1)).await;
2143        });
2144
2145        let config = SocketConfig {
2146            url: format!("127.0.0.1:{port}"),
2147            mode: Mode::Plain,
2148            suffix: b"\r\n".to_vec(),
2149            message_handler: None,
2150            heartbeat: None,
2151            reconnect_timeout_ms: Some(1_000),
2152            reconnect_delay_initial_ms: Some(10_000), // 10s backoff to ensure we're sleeping
2153            reconnect_delay_max_ms: Some(10_000),
2154            reconnect_backoff_factor: Some(1.0),
2155            reconnect_jitter_ms: Some(0),
2156            connection_max_retries: None,
2157            reconnect_max_attempts: None,
2158            idle_timeout_ms: None,
2159            certs_dir: None,
2160        };
2161
2162        let client = SocketClient::connect(config, None, None, None)
2163            .await
2164            .unwrap();
2165
2166        // Wait for client to enter reconnect
2167        wait_until_async(
2168            || async { client.is_reconnecting() },
2169            Duration::from_secs(3),
2170        )
2171        .await;
2172
2173        // Wait for the reconnect attempt to fail and enter backoff sleep
2174        sleep(Duration::from_millis(1_500)).await;
2175
2176        // Close while backing off
2177        let start = std::time::Instant::now();
2178        client.close().await;
2179        let elapsed = start.elapsed();
2180
2181        assert!(client.is_closed(), "Client should be closed");
2182        // Should exit well before the 10s backoff sleep completes
2183        assert!(
2184            elapsed < Duration::from_secs(2),
2185            "Close should interrupt backoff sleep, took {elapsed:?}"
2186        );
2187
2188        server.abort();
2189    }
2190
2191    #[rstest]
2192    #[tokio::test]
2193    async fn test_zero_idle_timeout_rejected() {
2194        let config = SocketConfig {
2195            url: "127.0.0.1:9999".to_string(),
2196            mode: Mode::Plain,
2197            suffix: b"\r\n".to_vec(),
2198            message_handler: None,
2199            heartbeat: None,
2200            reconnect_timeout_ms: None,
2201            reconnect_delay_initial_ms: None,
2202            reconnect_delay_max_ms: None,
2203            reconnect_backoff_factor: None,
2204            reconnect_jitter_ms: None,
2205            reconnect_max_attempts: None,
2206            connection_max_retries: Some(1),
2207            idle_timeout_ms: Some(0),
2208            certs_dir: None,
2209        };
2210
2211        let result = SocketClient::connect(config, None, None, None).await;
2212
2213        assert!(result.is_err(), "Zero idle timeout should be rejected");
2214        let err_msg = result.unwrap_err().to_string();
2215        assert!(
2216            err_msg.contains("Idle timeout cannot be zero"),
2217            "Error should mention zero idle timeout, was: {err_msg}"
2218        );
2219    }
2220
2221    #[rstest]
2222    #[tokio::test]
2223    async fn test_empty_suffix_rejected() {
2224        let config = SocketConfig {
2225            url: "127.0.0.1:9999".to_string(),
2226            mode: Mode::Plain,
2227            suffix: vec![],
2228            message_handler: None,
2229            heartbeat: None,
2230            reconnect_timeout_ms: None,
2231            reconnect_delay_initial_ms: None,
2232            reconnect_delay_max_ms: None,
2233            reconnect_backoff_factor: None,
2234            reconnect_jitter_ms: None,
2235            reconnect_max_attempts: None,
2236            connection_max_retries: Some(1),
2237            idle_timeout_ms: None,
2238            certs_dir: None,
2239        };
2240
2241        let result = SocketClient::connect(config, None, None, None).await;
2242
2243        assert!(
2244            result.is_err(),
2245            "Empty suffix should cause connection to fail"
2246        );
2247        let err_msg = result.unwrap_err().to_string();
2248        assert!(
2249            err_msg.contains("suffix cannot be empty"),
2250            "Error should mention empty suffix, was: {err_msg}"
2251        );
2252    }
2253}