1use std::{
17 sync::{
18 Arc,
19 atomic::{AtomicU8, Ordering},
20 },
21 time::Duration,
22};
23
24use nautilus_core::{
25 collections::into_ustr_vec,
26 python::{clone_py_object, to_pyruntime_err, to_pyvalue_err},
27};
28use pyo3::{Py, create_exception, exceptions::PyException, prelude::*, types::PyBytes};
29
30use crate::{
31 RECONNECTED,
32 mode::ConnectionMode,
33 ratelimiter::quota::Quota,
34 transport::{Message, TransportError},
35 websocket::{
36 TransportBackend, WebSocketClient, WebSocketConfig,
37 types::{MessageHandler, PingHandler, WriterCommand},
38 },
39};
40
41create_exception!(network, WebSocketClientError, PyException);
42
43#[expect(clippy::needless_pass_by_value)]
44fn to_websocket_pyerr(e: TransportError) -> PyErr {
45 PyErr::new::<WebSocketClientError, _>(e.to_string())
46}
47
48#[pymethods]
49#[pyo3_stub_gen::derive::gen_stub_pymethods]
50impl WebSocketConfig {
51 #[new]
75 #[expect(clippy::too_many_arguments)]
76 #[pyo3(signature = (
77 url,
78 headers,
79 heartbeat=None,
80 heartbeat_msg=None,
81 reconnect_timeout_ms=10_000,
82 reconnect_delay_initial_ms=2_000,
83 reconnect_delay_max_ms=30_000,
84 reconnect_backoff_factor=1.5,
85 reconnect_jitter_ms=100,
86 reconnect_max_attempts=None,
87 idle_timeout_ms=None,
88 proxy_url=None,
89 ))]
90 fn py_new(
91 url: String,
92 headers: Vec<(String, String)>,
93 heartbeat: Option<u64>,
94 heartbeat_msg: Option<String>,
95 reconnect_timeout_ms: Option<u64>,
96 reconnect_delay_initial_ms: Option<u64>,
97 reconnect_delay_max_ms: Option<u64>,
98 reconnect_backoff_factor: Option<f64>,
99 reconnect_jitter_ms: Option<u64>,
100 reconnect_max_attempts: Option<u32>,
101 idle_timeout_ms: Option<u64>,
102 proxy_url: Option<String>,
103 ) -> Self {
104 Self {
105 url,
106 headers,
107 heartbeat,
108 heartbeat_msg,
109 reconnect_timeout_ms,
110 reconnect_delay_initial_ms,
111 reconnect_delay_max_ms,
112 reconnect_backoff_factor,
113 reconnect_jitter_ms,
114 reconnect_max_attempts,
115 idle_timeout_ms,
116 backend: TransportBackend::default(),
117 proxy_url,
118 }
119 }
120}
121
122#[pymethods]
123#[pyo3_stub_gen::derive::gen_stub_pymethods]
124impl WebSocketClient {
125 #[staticmethod]
137 #[pyo3(name = "connect", signature = (loop_, config, handler, ping_handler = None, post_reconnection = None, keyed_quotas = Vec::new(), default_quota = None))]
138 #[expect(clippy::too_many_arguments, clippy::needless_pass_by_value)]
139 fn py_connect(
140 loop_: Py<PyAny>,
141 config: WebSocketConfig,
142 handler: Py<PyAny>,
143 ping_handler: Option<Py<PyAny>>,
144 post_reconnection: Option<Py<PyAny>>,
145 keyed_quotas: Vec<(String, Quota)>,
146 default_quota: Option<Quota>,
147 py: Python<'_>,
148 ) -> PyResult<Bound<'_, PyAny>> {
149 let call_soon_threadsafe: Py<PyAny> = loop_.getattr(py, "call_soon_threadsafe")?;
150 let call_soon_clone = clone_py_object(&call_soon_threadsafe);
151 let handler_clone = clone_py_object(&handler);
152
153 let message_handler: MessageHandler = Arc::new(move |msg: Message| {
154 if matches!(msg, Message::Text(ref text) if text.as_ref() == RECONNECTED.as_bytes()) {
155 return;
156 }
157
158 Python::attach(|py| {
159 let py_bytes = match &msg {
160 Message::Binary(data) | Message::Text(data) => PyBytes::new(py, data.as_ref()),
161 _ => return,
162 };
163
164 if let Err(e) = call_soon_clone.call1(py, (&handler_clone, py_bytes)) {
165 log::error!("Error scheduling message handler on event loop: {e}");
166 }
167 });
168 });
169
170 let ping_handler_fn = ping_handler.map(|ping_handler| {
171 let ping_handler_clone = clone_py_object(&ping_handler);
172 let call_soon_clone = clone_py_object(&call_soon_threadsafe);
173
174 let ping_handler_fn: PingHandler = Arc::new(move |data: Vec<u8>| {
175 Python::attach(|py| {
176 let py_bytes = PyBytes::new(py, &data);
177 if let Err(e) = call_soon_clone.call1(py, (&ping_handler_clone, py_bytes)) {
178 log::error!("Error scheduling ping handler on event loop: {e}");
179 }
180 });
181 });
182 ping_handler_fn
183 });
184
185 let post_reconnection_fn = post_reconnection.map(|callback| {
186 let callback_clone = clone_py_object(&callback);
187 Arc::new(move || {
188 Python::attach(|py| {
189 if let Err(e) = callback_clone.call0(py) {
190 log::error!("Error calling post_reconnection handler: {e}");
191 }
192 });
193 }) as std::sync::Arc<dyn Fn() + Send + Sync>
194 });
195
196 pyo3_async_runtimes::tokio::future_into_py(py, async move {
197 Box::pin(Self::connect(
198 config,
199 Some(message_handler),
200 ping_handler_fn,
201 post_reconnection_fn,
202 keyed_quotas,
203 default_quota,
204 ))
205 .await
206 .map_err(to_websocket_pyerr)
207 })
208 }
209
210 #[pyo3(name = "disconnect")]
215 #[expect(clippy::needless_pass_by_value)]
216 fn py_disconnect<'py>(slf: PyRef<'_, Self>, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
217 let connection_mode = slf.connection_mode.clone();
218 let state_notify = slf.state_notify.clone();
219 let mode = ConnectionMode::from_atomic(&connection_mode);
220 log::debug!("Close from mode {mode}");
221
222 pyo3_async_runtimes::tokio::future_into_py(py, async move {
223 match ConnectionMode::from_atomic(&connection_mode) {
224 ConnectionMode::Closed => {
225 log::debug!("WebSocket already closed");
226 }
227 ConnectionMode::Disconnect => {
228 log::debug!("WebSocket already disconnecting");
229 }
230 _ => {
231 connection_mode.store(ConnectionMode::Disconnect.as_u8(), Ordering::SeqCst);
232 state_notify.notify_one();
233
234 let timeout = tokio::time::timeout(Duration::from_secs(5), async {
235 while !ConnectionMode::from_atomic(&connection_mode).is_closed() {
236 tokio::time::sleep(Duration::from_millis(10)).await;
237 }
238 })
239 .await;
240
241 if timeout.is_err() {
242 log::error!("Timeout waiting for WebSocket to close, forcing closed state");
243 connection_mode.store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
244 }
245 }
246 }
247
248 Ok(())
249 })
250 }
251
252 #[pyo3(name = "is_active")]
257 #[expect(clippy::needless_pass_by_value)]
258 fn py_is_active(slf: PyRef<'_, Self>) -> bool {
259 !slf.controller_task.is_finished()
260 }
261
262 #[pyo3(name = "is_reconnecting")]
267 #[expect(clippy::needless_pass_by_value)]
268 fn py_is_reconnecting(slf: PyRef<'_, Self>) -> bool {
269 slf.is_reconnecting()
270 }
271
272 #[pyo3(name = "is_disconnecting")]
276 #[expect(clippy::needless_pass_by_value)]
277 fn py_is_disconnecting(slf: PyRef<'_, Self>) -> bool {
278 slf.is_disconnecting()
279 }
280
281 #[pyo3(name = "is_closed")]
287 #[expect(clippy::needless_pass_by_value)]
288 fn py_is_closed(slf: PyRef<'_, Self>) -> bool {
289 slf.is_closed()
290 }
291
292 #[pyo3(name = "send")]
298 #[pyo3(signature = (data, keys=None))]
299 #[expect(clippy::needless_pass_by_value)]
300 fn py_send<'py>(
301 slf: PyRef<'_, Self>,
302 data: Vec<u8>,
303 py: Python<'py>,
304 keys: Option<Vec<String>>,
305 ) -> PyResult<Bound<'py, PyAny>> {
306 let rate_limiter = slf.rate_limiter.clone();
307 let writer_tx = slf.writer_tx.clone();
308 let mode = slf.connection_mode.clone();
309 let keys = keys.map(into_ustr_vec);
310
311 pyo3_async_runtimes::tokio::future_into_py(py, async move {
312 if !ConnectionMode::from_atomic(&mode).is_active() {
313 let msg = "Cannot send data: connection not active".to_string();
314 log::error!("{msg}");
315 return Err(to_pyruntime_err(std::io::Error::new(
316 std::io::ErrorKind::NotConnected,
317 msg,
318 )));
319 }
320
321 tokio::select! {
322 biased;
323 () = rate_limiter.await_keys_ready(keys.as_deref()) => {}
324 () = poll_until_closed(&mode) => {
325 return Err(to_pyruntime_err(std::io::Error::new(
326 std::io::ErrorKind::ConnectionAborted,
327 "Connection closed while waiting for rate limit",
328 )));
329 }
330 }
331
332 log::trace!("Sending binary: {data:?}");
333
334 let msg = Message::Binary(data.into());
335 writer_tx
336 .send(WriterCommand::Send(msg))
337 .map_err(to_pyruntime_err)
338 })
339 }
340
341 #[pyo3(name = "send_text")]
347 #[pyo3(signature = (data, keys=None))]
348 #[expect(clippy::needless_pass_by_value)]
349 fn py_send_text<'py>(
350 slf: PyRef<'_, Self>,
351 data: Vec<u8>,
352 py: Python<'py>,
353 keys: Option<Vec<String>>,
354 ) -> PyResult<Bound<'py, PyAny>> {
355 let data_str = String::from_utf8(data).map_err(to_pyvalue_err)?;
356 let rate_limiter = slf.rate_limiter.clone();
357 let writer_tx = slf.writer_tx.clone();
358 let mode = slf.connection_mode.clone();
359 let keys = keys.map(into_ustr_vec);
360
361 pyo3_async_runtimes::tokio::future_into_py(py, async move {
362 if !ConnectionMode::from_atomic(&mode).is_active() {
363 let e = std::io::Error::new(
364 std::io::ErrorKind::NotConnected,
365 "Cannot send text: connection not active",
366 );
367 return Err(to_pyruntime_err(e));
368 }
369
370 tokio::select! {
371 biased;
372 () = rate_limiter.await_keys_ready(keys.as_deref()) => {}
373 () = poll_until_closed(&mode) => {
374 return Err(to_pyruntime_err(std::io::Error::new(
375 std::io::ErrorKind::ConnectionAborted,
376 "Connection closed while waiting for rate limit",
377 )));
378 }
379 }
380
381 log::trace!("Sending text: {data_str}");
382
383 let msg = Message::Text(data_str.into());
384 writer_tx
385 .send(WriterCommand::Send(msg))
386 .map_err(to_pyruntime_err)
387 })
388 }
389
390 #[pyo3(name = "send_pong")]
392 #[expect(clippy::needless_pass_by_value)]
393 fn py_send_pong<'py>(
394 slf: PyRef<'_, Self>,
395 data: Vec<u8>,
396 py: Python<'py>,
397 ) -> PyResult<Bound<'py, PyAny>> {
398 let writer_tx = slf.writer_tx.clone();
399 let mode = slf.connection_mode.clone();
400 let data_len = data.len();
401
402 pyo3_async_runtimes::tokio::future_into_py(py, async move {
403 if !ConnectionMode::from_atomic(&mode).is_active() {
404 let e = std::io::Error::new(
405 std::io::ErrorKind::NotConnected,
406 "Cannot send pong: connection not active",
407 );
408 return Err(to_pyruntime_err(e));
409 }
410 log::trace!("Sending pong frame ({data_len} bytes)");
411
412 let msg = Message::Pong(data.into());
413 writer_tx
414 .send(WriterCommand::Send(msg))
415 .map_err(to_pyruntime_err)
416 })
417 }
418}
419
420async fn poll_until_closed(mode: &Arc<AtomicU8>) {
421 loop {
422 if matches!(
423 ConnectionMode::from_atomic(mode),
424 ConnectionMode::Disconnect | ConnectionMode::Closed
425 ) {
426 break;
427 }
428
429 tokio::time::sleep(Duration::from_millis(100)).await;
430 }
431}
432
433#[cfg(test)]
434#[cfg(target_os = "linux")] mod tests {
436 use std::ffi::CString;
437
438 use futures_util::{SinkExt, StreamExt};
439 use nautilus_core::python::IntoPyObjectNautilusExt;
440 use pyo3::{prelude::*, types::PyBytes};
441 use tokio::{
442 net::TcpListener,
443 task::{self, JoinHandle},
444 time::{Duration, sleep},
445 };
446 use tokio_tungstenite::{
447 accept_hdr_async,
448 tungstenite::{
449 handshake::server::{self, Callback},
450 http::HeaderValue,
451 },
452 };
453
454 use crate::{
455 transport::Message,
456 websocket::{MessageHandler, WebSocketClient, WebSocketConfig},
457 };
458
459 struct TestServer {
460 task: JoinHandle<()>,
461 port: u16,
462 }
463
464 #[derive(Debug, Clone)]
465 struct TestCallback {
466 key: String,
467 value: HeaderValue,
468 }
469
470 impl Callback for TestCallback {
471 #[expect(clippy::panic_in_result_fn)]
472 fn on_request(
473 self,
474 request: &server::Request,
475 response: server::Response,
476 ) -> Result<server::Response, server::ErrorResponse> {
477 let _ = response;
478 let value = request.headers().get(&self.key);
479 assert!(value.is_some());
480
481 if let Some(value) = request.headers().get(&self.key) {
482 assert_eq!(value, self.value);
483 }
484
485 Ok(response)
486 }
487 }
488
489 impl TestServer {
490 async fn setup(key: String, value: String) -> Self {
491 let server = TcpListener::bind("127.0.0.1:0").await.unwrap();
492 let port = TcpListener::local_addr(&server).unwrap().port();
493
494 let test_call_back = TestCallback {
495 key,
496 value: HeaderValue::from_str(&value).unwrap(),
497 };
498
499 let task = task::spawn(async move {
501 loop {
503 let (conn, _) = server.accept().await.unwrap();
504 let mut websocket = accept_hdr_async(conn, test_call_back.clone())
505 .await
506 .unwrap();
507
508 task::spawn(async move {
509 #[expect(clippy::collapsible_match)]
511 while let Some(Ok(msg)) = websocket.next().await {
512 match msg {
513 tokio_tungstenite::tungstenite::protocol::Message::Text(txt)
514 if txt == "close-now" =>
515 {
516 log::debug!("Forcibly closing from server side");
517 let _ = websocket.close(None).await;
519 break;
520 }
521 tokio_tungstenite::tungstenite::protocol::Message::Text(_)
523 | tokio_tungstenite::tungstenite::protocol::Message::Binary(_) => {
524 if websocket.send(msg).await.is_err() {
525 break;
526 }
527 }
528 tokio_tungstenite::tungstenite::protocol::Message::Close(
530 _frame,
531 ) => {
532 let _ = websocket.close(None).await;
533 break;
534 }
535 _ => {}
537 }
538 }
539 });
540 }
541 });
542
543 Self { task, port }
544 }
545 }
546
547 impl Drop for TestServer {
548 fn drop(&mut self) {
549 self.task.abort();
550 }
551 }
552
553 fn create_test_handler() -> (Py<PyAny>, Py<PyAny>) {
554 let code_raw = "
555class Counter:
556 def __init__(self):
557 self.count = 0
558 self.check = False
559
560 def handler(self, bytes):
561 msg = bytes.decode()
562 if msg == 'ping':
563 self.count += 1
564 elif msg == 'heartbeat message':
565 self.check = True
566
567 def get_check(self):
568 return self.check
569
570 def get_count(self):
571 return self.count
572
573counter = Counter()
574";
575
576 let code = CString::new(code_raw).unwrap();
577 let filename = CString::new("test".to_string()).unwrap();
578 let module = CString::new("test".to_string()).unwrap();
579 Python::attach(|py| {
580 let pymod = PyModule::from_code(py, &code, &filename, &module).unwrap();
581
582 let counter = pymod.getattr("counter").unwrap().into_py_any_unwrap(py);
583 let handler = counter
584 .getattr(py, "handler")
585 .unwrap()
586 .into_py_any_unwrap(py);
587
588 (counter, handler)
589 })
590 }
591
592 #[tokio::test]
593 async fn basic_client_test() {
594 const N: usize = 10;
595
596 Python::initialize();
597
598 let mut success_count = 0;
599 let header_key = "hello-custom-key".to_string();
600 let header_value = "hello-custom-value".to_string();
601
602 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
603 let (counter, handler) = create_test_handler();
604
605 let config = WebSocketConfig::py_new(
606 format!("ws://127.0.0.1:{}", server.port),
607 vec![(header_key, header_value)],
608 None,
609 None,
610 None,
611 None,
612 None,
613 None,
614 None,
615 None,
616 None,
617 None,
618 );
619
620 let handler_clone = Python::attach(|py| handler.clone_ref(py));
621
622 let message_handler: MessageHandler = std::sync::Arc::new(move |msg: Message| {
623 Python::attach(|py| {
624 let data = match msg {
625 Message::Binary(data) | Message::Text(data) => data.to_vec(),
626 _ => return,
627 };
628 let py_bytes = PyBytes::new(py, &data);
629 if let Err(e) = handler_clone.call1(py, (py_bytes,)) {
630 log::error!("Error calling handler: {e}");
631 }
632 });
633 });
634
635 let client =
636 WebSocketClient::connect(config, Some(message_handler), None, None, vec![], None)
637 .await
638 .unwrap();
639
640 for _ in 0..N {
641 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
642 success_count += 1;
643 }
644
645 sleep(Duration::from_secs(1)).await;
646 let count_value: usize = Python::attach(|py| {
647 counter
648 .getattr(py, "get_count")
649 .unwrap()
650 .call0(py)
651 .unwrap()
652 .extract(py)
653 .unwrap()
654 });
655 assert_eq!(count_value, success_count);
656
657 client.send_close_message().await.unwrap();
659
660 sleep(Duration::from_secs(2)).await;
662
663 for _ in 0..N {
664 client.send_bytes(b"ping".to_vec(), None).await.unwrap();
665 success_count += 1;
666 }
667
668 sleep(Duration::from_secs(1)).await;
669 let count_value: usize = Python::attach(|py| {
670 counter
671 .getattr(py, "get_count")
672 .unwrap()
673 .call0(py)
674 .unwrap()
675 .extract(py)
676 .unwrap()
677 });
678 assert_eq!(count_value, success_count);
679 assert_eq!(success_count, N + N);
680
681 client.disconnect().await;
682 assert!(client.is_disconnected());
683 }
684
685 #[tokio::test]
686 async fn message_ping_test() {
687 Python::initialize();
688
689 let header_key = "hello-custom-key".to_string();
690 let header_value = "hello-custom-value".to_string();
691
692 let (checker, handler) = create_test_handler();
693
694 let server = TestServer::setup(header_key.clone(), header_value.clone()).await;
695 let config = WebSocketConfig::py_new(
696 format!("ws://127.0.0.1:{}", server.port),
697 vec![(header_key, header_value)],
698 Some(1),
699 Some("heartbeat message".to_string()),
700 None,
701 None,
702 None,
703 None,
704 None,
705 None,
706 None,
707 None,
708 );
709
710 let handler_clone = Python::attach(|py| handler.clone_ref(py));
711
712 let message_handler: MessageHandler = std::sync::Arc::new(move |msg: Message| {
713 Python::attach(|py| {
714 let data = match msg {
715 Message::Binary(data) | Message::Text(data) => data.to_vec(),
716 _ => return,
717 };
718 let py_bytes = PyBytes::new(py, &data);
719 if let Err(e) = handler_clone.call1(py, (py_bytes,)) {
720 log::error!("Error calling handler: {e}");
721 }
722 });
723 });
724
725 let client =
726 WebSocketClient::connect(config, Some(message_handler), None, None, vec![], None)
727 .await
728 .unwrap();
729
730 sleep(Duration::from_secs(2)).await;
731 let check_value: bool = Python::attach(|py| {
732 checker
733 .getattr(py, "get_check")
734 .unwrap()
735 .call0(py)
736 .unwrap()
737 .extract(py)
738 .unwrap()
739 });
740 assert!(check_value);
741
742 client.disconnect().await;
743 assert!(client.is_disconnected());
744 }
745}