1use std::{sync::atomic::Ordering, time::Duration};
17
18use nautilus_core::python::{clone_py_object, to_pyruntime_err};
19use pyo3::{Py, prelude::*};
20use tokio_tungstenite::tungstenite::stream::Mode;
21
22use crate::{
23 mode::ConnectionMode,
24 socket::{SocketClient, SocketConfig, TcpMessageHandler, WriterCommand},
25};
26
27#[pymethods]
28#[pyo3_stub_gen::derive::gen_stub_pymethods]
29impl SocketConfig {
30 #[new]
32 #[expect(clippy::too_many_arguments, clippy::needless_pass_by_value)]
33 #[pyo3(signature = (url, ssl, suffix, handler, heartbeat=None, reconnect_timeout_ms=10_000, reconnect_delay_initial_ms=2_000, reconnect_delay_max_ms=30_000, reconnect_backoff_factor=1.5, reconnect_jitter_ms=100, connection_max_retries=5, reconnect_max_attempts=None, idle_timeout_ms=None, certs_dir=None))]
34 fn py_new(
35 url: String,
36 ssl: bool,
37 suffix: Vec<u8>,
38 handler: Py<PyAny>,
39 heartbeat: Option<(u64, Vec<u8>)>,
40 reconnect_timeout_ms: Option<u64>,
41 reconnect_delay_initial_ms: Option<u64>,
42 reconnect_delay_max_ms: Option<u64>,
43 reconnect_backoff_factor: Option<f64>,
44 reconnect_jitter_ms: Option<u64>,
45 connection_max_retries: Option<u32>,
46 reconnect_max_attempts: Option<u32>,
47 idle_timeout_ms: Option<u64>,
48 certs_dir: Option<String>,
49 ) -> Self {
50 let mode = if ssl { Mode::Tls } else { Mode::Plain };
51
52 let handler_clone = clone_py_object(&handler);
54 let message_handler: TcpMessageHandler = std::sync::Arc::new(move |data: &[u8]| {
55 Python::attach(|py| {
56 if let Err(e) = handler_clone.call1(py, (data,)) {
57 log::error!("Error calling Python message handler: {e}");
58 }
59 });
60 });
61
62 Self {
63 url,
64 mode,
65 suffix,
66 message_handler: Some(message_handler),
67 heartbeat,
68 reconnect_timeout_ms,
69 reconnect_delay_initial_ms,
70 reconnect_delay_max_ms,
71 reconnect_backoff_factor,
72 reconnect_jitter_ms,
73 connection_max_retries,
74 reconnect_max_attempts,
75 idle_timeout_ms,
76 certs_dir,
77 }
78 }
79}
80
81#[pymethods]
82#[pyo3_stub_gen::derive::gen_stub_pymethods]
83impl SocketClient {
84 #[staticmethod]
86 #[pyo3(name = "connect")]
87 #[pyo3(signature = (config, post_connection=None, post_reconnection=None, post_disconnection=None))]
88 fn py_connect(
89 config: SocketConfig,
90 post_connection: Option<Py<PyAny>>,
91 post_reconnection: Option<Py<PyAny>>,
92 post_disconnection: Option<Py<PyAny>>,
93 py: Python<'_>,
94 ) -> PyResult<Bound<'_, PyAny>> {
95 let post_connection_fn = post_connection.map(|callback| {
97 let callback_clone = clone_py_object(&callback);
98 std::sync::Arc::new(move || {
99 Python::attach(|py| {
100 if let Err(e) = callback_clone.call0(py) {
101 log::error!("Error calling post_connection handler: {e}");
102 }
103 });
104 }) as std::sync::Arc<dyn Fn() + Send + Sync>
105 });
106
107 let post_reconnection_fn = post_reconnection.map(|callback| {
108 let callback_clone = clone_py_object(&callback);
109 std::sync::Arc::new(move || {
110 Python::attach(|py| {
111 if let Err(e) = callback_clone.call0(py) {
112 log::error!("Error calling post_reconnection handler: {e}");
113 }
114 });
115 }) as std::sync::Arc<dyn Fn() + Send + Sync>
116 });
117
118 let post_disconnection_fn = post_disconnection.map(|callback| {
119 let callback_clone = clone_py_object(&callback);
120 std::sync::Arc::new(move || {
121 Python::attach(|py| {
122 if let Err(e) = callback_clone.call0(py) {
123 log::error!("Error calling post_disconnection handler: {e}");
124 }
125 });
126 }) as std::sync::Arc<dyn Fn() + Send + Sync>
127 });
128
129 pyo3_async_runtimes::tokio::future_into_py(py, async move {
130 Self::connect(
131 config,
132 post_connection_fn,
133 post_reconnection_fn,
134 post_disconnection_fn,
135 )
136 .await
137 .map_err(to_pyruntime_err)
138 })
139 }
140
141 #[pyo3(name = "is_active")]
146 #[expect(clippy::needless_pass_by_value)]
147 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
148 slf.is_active()
149 }
150
151 #[pyo3(name = "is_reconnecting")]
156 #[expect(clippy::needless_pass_by_value)]
157 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
158 slf.is_reconnecting()
159 }
160
161 #[pyo3(name = "is_disconnecting")]
165 #[expect(clippy::needless_pass_by_value)]
166 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
167 slf.is_disconnecting()
168 }
169
170 #[pyo3(name = "is_closed")]
176 #[expect(clippy::needless_pass_by_value)]
177 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
178 slf.is_closed()
179 }
180
181 #[pyo3(name = "mode")]
182 #[expect(clippy::needless_pass_by_value)]
183 fn py_mode(slf: PyRef<'_, Self>) -> String {
184 slf.connection_mode().to_string()
185 }
186
187 #[pyo3(name = "reconnect")]
189 #[expect(clippy::needless_pass_by_value)]
190 fn py_reconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
191 let connection_mode = slf.connection_mode.clone();
192 let state_notify = slf.state_notify.clone();
193 let mode_str = ConnectionMode::from_atomic(&connection_mode).to_string();
194 log::debug!("Reconnect from mode {mode_str}");
195
196 pyo3_async_runtimes::tokio::future_into_py(py, async move {
197 match ConnectionMode::from_atomic(&connection_mode) {
198 ConnectionMode::Reconnect => {
199 log::warn!("Cannot reconnect - socket already reconnecting");
200 }
201 ConnectionMode::Disconnect => {
202 log::warn!("Cannot reconnect - socket disconnecting");
203 }
204 ConnectionMode::Closed => {
205 log::warn!("Cannot reconnect - socket closed");
206 }
207 ConnectionMode::Active => {
208 connection_mode.store(ConnectionMode::Reconnect.as_u8(), Ordering::SeqCst);
209 state_notify.notify_one();
210
211 let fallback_interval = Duration::from_millis(100);
212 let timeout = tokio::time::timeout(Duration::from_secs(30), async {
213 loop {
214 let notified = state_notify.notified();
215
216 let current = ConnectionMode::from_atomic(&connection_mode);
217 if current.is_active() {
218 return Ok(());
219 }
220
221 if current.is_closed() || current.is_disconnect() {
222 return Err("Connection closed during reconnect");
223 }
224
225 tokio::select! {
226 biased;
227 () = notified => {}
228 () = tokio::time::sleep(fallback_interval) => {}
229 }
230 }
231 })
232 .await;
233
234 match timeout {
235 Ok(Ok(())) => log::debug!("Reconnected successfully"),
236 Ok(Err(e)) => log::warn!("Reconnect aborted: {e}"),
237 Err(_) => log::error!("Reconnect timed out after 30s"),
238 }
239 }
240 }
241
242 Ok(())
243 })
244 }
245
246 #[pyo3(name = "close")]
251 #[expect(clippy::needless_pass_by_value)]
252 fn py_close<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
253 let connection_mode = slf.connection_mode.clone();
254 let state_notify = slf.state_notify.clone();
255 let mode_str = ConnectionMode::from_atomic(&connection_mode).to_string();
256 log::debug!("Close from mode {mode_str}");
257
258 pyo3_async_runtimes::tokio::future_into_py(py, async move {
259 match ConnectionMode::from_atomic(&connection_mode) {
260 ConnectionMode::Closed => {
261 log::debug!("Socket already closed");
262 }
263 ConnectionMode::Disconnect => {
264 log::debug!("Socket already disconnecting");
265 }
266 _ => {
267 connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
268 state_notify.notify_one();
269
270 let timeout = tokio::time::timeout(Duration::from_secs(5), async {
271 while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
272 tokio::time::sleep(Duration::from_millis(10)).await;
273 }
274 })
275 .await;
276
277 if timeout.is_err() {
278 log::error!("Timeout waiting for socket to close, forcing closed state");
279 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
280 }
281 }
282 }
283
284 Ok(())
285 })
286 }
287
288 #[pyo3(name = "send")]
294 #[expect(clippy::needless_pass_by_value)]
295 fn py_send<'py>(
296 slf: PyRef<'_, Self>,
297 data: Vec<u8>,
298 py: Python<'py>,
299 ) -> PyResult<Bound<'py, PyAny>> {
300 log::trace!("Sending {}", String::from_utf8_lossy(&data));
301
302 let connection_mode = slf.connection_mode.clone();
303 let state_notify = slf.state_notify.clone();
304 let writer_tx = slf.writer_tx.clone();
305
306 pyo3_async_runtimes::tokio::future_into_py(py, async move {
307 match ConnectionMode::from_atomic(&connection_mode) {
308 ConnectionMode::Disconnect | ConnectionMode::Closed => {
309 let msg = format!(
310 "Cannot send data ({}): socket closed",
311 String::from_utf8_lossy(&data)
312 );
313
314 let io_err = std::io::Error::new(std::io::ErrorKind::NotConnected, msg);
315 return Err(to_pyruntime_err(io_err));
316 }
317 mode if !mode.is_active() => {
318 let timeout = Duration::from_secs(2);
319 let fallback_interval = Duration::from_millis(100);
320
321 log::debug!("Waiting for client to become ACTIVE before sending (2s)...");
322
323 match tokio::time::timeout(timeout, async {
324 loop {
325 let notified = state_notify.notified();
326
327 let mode = ConnectionMode::from_atomic(&connection_mode);
328 if mode.is_active() {
329 return Ok(());
330 }
331
332 if matches!(mode, ConnectionMode::Disconnect | ConnectionMode::Closed) {
333 return Err("Client disconnected waiting to send");
334 }
335
336 tokio::select! {
337 biased;
338 () = notified => {}
339 () = tokio::time::sleep(fallback_interval) => {}
340 }
341 }
342 })
343 .await
344 {
345 Ok(Ok(())) => log::debug!("Client now active"),
346 Ok(Err(e)) => {
347 let err_msg = format!(
348 "Failed sending data ({}): {e}",
349 String::from_utf8_lossy(&data)
350 );
351
352 let io_err =
353 std::io::Error::new(std::io::ErrorKind::NotConnected, err_msg);
354 return Err(to_pyruntime_err(io_err));
355 }
356 Err(_) => {
357 let err_msg = format!(
358 "Failed sending data ({}): timeout waiting to become ACTIVE",
359 String::from_utf8_lossy(&data)
360 );
361
362 let io_err = std::io::Error::new(std::io::ErrorKind::TimedOut, err_msg);
363 return Err(to_pyruntime_err(io_err));
364 }
365 }
366 }
367 _ => {}
368 }
369
370 let msg = WriterCommand::Send(data.into());
371 writer_tx.send(msg).map_err(to_pyruntime_err)
372 })
373 }
374}