Skip to main content

nautilus_network/python/
socket.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{sync::atomic::Ordering, time::Duration};
17
18use nautilus_core::python::{clone_py_object, to_pyruntime_err};
19use pyo3::{Py, prelude::*};
20use tokio_tungstenite::tungstenite::stream::Mode;
21
22use crate::{
23    mode::ConnectionMode,
24    socket::{SocketClient, SocketConfig, TcpMessageHandler, WriterCommand},
25};
26
27#[pymethods]
28#[pyo3_stub_gen::derive::gen_stub_pymethods]
29impl SocketConfig {
30    /// Configuration for TCP socket connection.
31    #[new]
32    #[expect(clippy::too_many_arguments, clippy::needless_pass_by_value)]
33    #[pyo3(signature = (url, ssl, suffix, handler, heartbeat=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100, connection_max_retries=5, reconnect_max_attempts=None, idle_timeout_ms=None, certs_dir=None))]
34    fn py_new(
35        url: String,
36        ssl: bool,
37        suffix: Vec<u8>,
38        handler: Py<PyAny>,
39        heartbeat: Option<(u64, Vec<u8>)>,
40        reconnect_timeout_ms: Option<u64>,
41        reconnect_delay_initial_ms: Option<u64>,
42        reconnect_delay_max_ms: Option<u64>,
43        reconnect_backoff_factor: Option<f64>,
44        reconnect_jitter_ms: Option<u64>,
45        connection_max_retries: Option<u32>,
46        reconnect_max_attempts: Option<u32>,
47        idle_timeout_ms: Option<u64>,
48        certs_dir: Option<String>,
49    ) -> Self {
50        let mode = if ssl { Mode::Tls } else { Mode::Plain };
51
52        // Create function pointer that calls Python handler
53        let handler_clone = clone_py_object(&handler);
54        let message_handler: TcpMessageHandler = std::sync::Arc::new(move |data: &[u8]| {
55            Python::attach(|py| {
56                if let Err(e) = handler_clone.call1(py, (data,)) {
57                    log::error!("Error calling Python message handler: {e}");
58                }
59            });
60        });
61
62        Self {
63            url,
64            mode,
65            suffix,
66            message_handler: Some(message_handler),
67            heartbeat,
68            reconnect_timeout_ms,
69            reconnect_delay_initial_ms,
70            reconnect_delay_max_ms,
71            reconnect_backoff_factor,
72            reconnect_jitter_ms,
73            connection_max_retries,
74            reconnect_max_attempts,
75            idle_timeout_ms,
76            certs_dir,
77        }
78    }
79}
80
81#[pymethods]
82#[pyo3_stub_gen::derive::gen_stub_pymethods]
83impl SocketClient {
84    /// Connect to the server.
85    #[staticmethod]
86    #[pyo3(name = "connect")]
87    #[pyo3(signature = (config, post_connection=None, post_reconnection=None, post_disconnection=None))]
88    fn py_connect(
89        config: SocketConfig,
90        post_connection: Option<Py<PyAny>>,
91        post_reconnection: Option<Py<PyAny>>,
92        post_disconnection: Option<Py<PyAny>>,
93        py: Python<'_>,
94    ) -> PyResult<Bound<'_, PyAny>> {
95        // Convert Python callbacks to function pointers
96        let post_connection_fn = post_connection.map(|callback| {
97            let callback_clone = clone_py_object(&callback);
98            std::sync::Arc::new(move || {
99                Python::attach(|py| {
100                    if let Err(e) = callback_clone.call0(py) {
101                        log::error!("Error calling post_connection handler: {e}");
102                    }
103                });
104            }) as std::sync::Arc<dyn Fn() + Send + Sync>
105        });
106
107        let post_reconnection_fn = post_reconnection.map(|callback| {
108            let callback_clone = clone_py_object(&callback);
109            std::sync::Arc::new(move || {
110                Python::attach(|py| {
111                    if let Err(e) = callback_clone.call0(py) {
112                        log::error!("Error calling post_reconnection handler: {e}");
113                    }
114                });
115            }) as std::sync::Arc<dyn Fn() + Send + Sync>
116        });
117
118        let post_disconnection_fn = post_disconnection.map(|callback| {
119            let callback_clone = clone_py_object(&callback);
120            std::sync::Arc::new(move || {
121                Python::attach(|py| {
122                    if let Err(e) = callback_clone.call0(py) {
123                        log::error!("Error calling post_disconnection handler: {e}");
124                    }
125                });
126            }) as std::sync::Arc<dyn Fn() + Send + Sync>
127        });
128
129        pyo3_async_runtimes::tokio::future_into_py(py, async move {
130            Self::connect(
131                config,
132                post_connection_fn,
133                post_reconnection_fn,
134                post_disconnection_fn,
135            )
136            .await
137            .map_err(to_pyruntime_err)
138        })
139    }
140
141    /// Check if the client connection is active.
142    ///
143    /// Returns `true` if the client is connected and has not been signalled to disconnect.
144    /// The client will automatically retry connection based on its configuration.
145    #[pyo3(name = "is_active")]
146    #[expect(clippy::needless_pass_by_value)]
147    fn py_is_active(slf: PyRef<'_, Self>) -> bool {
148        slf.is_active()
149    }
150
151    /// Check if the client is reconnecting.
152    ///
153    /// Returns `true` if the client lost connection and is attempting to reestablish it.
154    /// The client will automatically retry connection based on its configuration.
155    #[pyo3(name = "is_reconnecting")]
156    #[expect(clippy::needless_pass_by_value)]
157    fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
158        slf.is_reconnecting()
159    }
160
161    /// Check if the client is disconnecting.
162    ///
163    /// Returns `true` if the client is in disconnect mode.
164    #[pyo3(name = "is_disconnecting")]
165    #[expect(clippy::needless_pass_by_value)]
166    fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
167        slf.is_disconnecting()
168    }
169
170    /// Check if the client is closed.
171    ///
172    /// Returns `true` if the client has been explicitly disconnected or reached
173    /// maximum reconnection attempts. In this state, the client cannot be reused
174    /// and a new client must be created for further connections.
175    #[pyo3(name = "is_closed")]
176    #[expect(clippy::needless_pass_by_value)]
177    fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
178        slf.is_closed()
179    }
180
181    #[pyo3(name = "mode")]
182    #[expect(clippy::needless_pass_by_value)]
183    fn py_mode(slf: PyRef<'_, Self>) -> String {
184        slf.connection_mode().to_string()
185    }
186
187    /// Reconnect the client.
188    #[pyo3(name = "reconnect")]
189    #[expect(clippy::needless_pass_by_value)]
190    fn py_reconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
191        let connection_mode = slf.connection_mode.clone();
192        let state_notify = slf.state_notify.clone();
193        let mode_str = ConnectionMode::from_atomic(&connection_mode).to_string();
194        log::debug!("Reconnect from mode {mode_str}");
195
196        pyo3_async_runtimes::tokio::future_into_py(py, async move {
197            match ConnectionMode::from_atomic(&connection_mode) {
198                ConnectionMode::Reconnect => {
199                    log::warn!("Cannot reconnect - socket already reconnecting");
200                }
201                ConnectionMode::Disconnect => {
202                    log::warn!("Cannot reconnect - socket disconnecting");
203                }
204                ConnectionMode::Closed => {
205                    log::warn!("Cannot reconnect - socket closed");
206                }
207                ConnectionMode::Active => {
208                    connection_mode.store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
209                    state_notify.notify_one();
210
211                    let fallback_interval = Duration::from_millis(100);
212                    let timeout = tokio::time::timeout(Duration::from_secs(30), async {
213                        loop {
214                            let notified = state_notify.notified();
215
216                            let current = ConnectionMode::from_atomic(&connection_mode);
217                            if current.is_active() {
218                                return Ok(());
219                            }
220
221                            if current.is_closed() || current.is_disconnect() {
222                                return Err("Connection closed during reconnect");
223                            }
224
225                            tokio::select! {
226                                biased;
227                                () = notified => {}
228                                () = tokio::time::sleep(fallback_interval) => {}
229                            }
230                        }
231                    })
232                    .await;
233
234                    match timeout {
235                        Ok(Ok(())) => log::debug!("Reconnected successfully"),
236                        Ok(Err(e)) => log::warn!("Reconnect aborted: {e}"),
237                        Err(_) => log::error!("Reconnect timed out after 30s"),
238                    }
239                }
240            }
241
242            Ok(())
243        })
244    }
245
246    /// Close the client.
247    ///
248    /// Controller task will periodically check the disconnect mode
249    /// and shutdown the client if it is not alive.
250    #[pyo3(name = "close")]
251    #[expect(clippy::needless_pass_by_value)]
252    fn py_close<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
253        let connection_mode = slf.connection_mode.clone();
254        let state_notify = slf.state_notify.clone();
255        let mode_str = ConnectionMode::from_atomic(&connection_mode).to_string();
256        log::debug!("Close from mode {mode_str}");
257
258        pyo3_async_runtimes::tokio::future_into_py(py, async move {
259            match ConnectionMode::from_atomic(&connection_mode) {
260                ConnectionMode::Closed => {
261                    log::debug!("Socket already closed");
262                }
263                ConnectionMode::Disconnect => {
264                    log::debug!("Socket already disconnecting");
265                }
266                _ => {
267                    connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
268                    state_notify.notify_one();
269
270                    let timeout = tokio::time::timeout(Duration::from_secs(5), async {
271                        while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
272                            tokio::time::sleep(Duration::from_millis(10)).await;
273                        }
274                    })
275                    .await;
276
277                    if timeout.is_err() {
278                        log::error!("Timeout waiting for socket to close, forcing closed state");
279                        connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
280                    }
281                }
282            }
283
284            Ok(())
285        })
286    }
287
288    /// Send bytes data to the connection.
289    ///
290    /// # Errors
291    ///
292    /// - Throws an Exception if it is not able to send data.
293    #[pyo3(name = "send")]
294    #[expect(clippy::needless_pass_by_value)]
295    fn py_send<'py>(
296        slf: PyRef<'_, Self>,
297        data: Vec<u8>,
298        py: Python<'py>,
299    ) -> PyResult<Bound<'py, PyAny>> {
300        log::trace!("Sending {}", String::from_utf8_lossy(&data));
301
302        let connection_mode = slf.connection_mode.clone();
303        let state_notify = slf.state_notify.clone();
304        let writer_tx = slf.writer_tx.clone();
305
306        pyo3_async_runtimes::tokio::future_into_py(py, async move {
307            match ConnectionMode::from_atomic(&connection_mode) {
308                ConnectionMode::Disconnect | ConnectionMode::Closed => {
309                    let msg = format!(
310                        "Cannot send data ({}): socket closed",
311                        String::from_utf8_lossy(&data)
312                    );
313
314                    let io_err = std::io::Error::new(std::io::ErrorKind::NotConnected, msg);
315                    return Err(to_pyruntime_err(io_err));
316                }
317                mode if !mode.is_active() => {
318                    let timeout = Duration::from_secs(2);
319                    let fallback_interval = Duration::from_millis(100);
320
321                    log::debug!("Waiting for client to become ACTIVE before sending (2s)...");
322
323                    match tokio::time::timeout(timeout, async {
324                        loop {
325                            let notified = state_notify.notified();
326
327                            let mode = ConnectionMode::from_atomic(&connection_mode);
328                            if mode.is_active() {
329                                return Ok(());
330                            }
331
332                            if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
333                                return Err("Client disconnected waiting to send");
334                            }
335
336                            tokio::select! {
337                                biased;
338                                () = notified => {}
339                                () = tokio::time::sleep(fallback_interval) => {}
340                            }
341                        }
342                    })
343                    .await
344                    {
345                        Ok(Ok(())) => log::debug!("Client now active"),
346                        Ok(Err(e)) => {
347                            let err_msg = format!(
348                                "Failed sending data ({}): {e}",
349                                String::from_utf8_lossy(&data)
350                            );
351
352                            let io_err =
353                                std::io::Error::new(std::io::ErrorKind::NotConnected, err_msg);
354                            return Err(to_pyruntime_err(io_err));
355                        }
356                        Err(_) => {
357                            let err_msg = format!(
358                                "Failed sending data ({}): timeout waiting to become ACTIVE",
359                                String::from_utf8_lossy(&data)
360                            );
361
362                            let io_err = std::io::Error::new(std::io::ErrorKind::TimedOut, err_msg);
363                            return Err(to_pyruntime_err(io_err));
364                        }
365                    }
366                }
367                _ => {}
368            }
369
370            let msg = WriterCommand::Send(data.into());
371            writer_tx.send(msg).map_err(to_pyruntime_err)
372        })
373    }
374}