1use std::sync::{
23 Arc, OnceLock,
24 atomic::{AtomicBool, AtomicU64, Ordering},
25};
26
27use bytes::Bytes;
28use nautilus_network::socket::{SocketClient, SocketConfig, TcpMessageHandler, WriterCommand};
29use tokio::sync::watch; use tokio_tungstenite::tungstenite::stream::Mode;
31
32use super::{
33 config::BetfairStreamConfig,
34 error::BetfairStreamError,
35 messages::{
36 Authentication, MarketDataFilter, MarketSubscription, OrderFilter, OrderSubscription,
37 RaceSubscription, StreamMarketFilter, StreamMessage, stream_decode,
38 },
39};
40use crate::common::{
41 consts::{STREAM_OP_MARKET_SUBSCRIPTION, STREAM_OP_ORDER_SUBSCRIPTION},
42 credential::BetfairCredential,
43 enums::StatusErrorCode,
44};
45
46#[derive(Debug)]
56pub struct BetfairStreamClient {
57 socket: SocketClient,
58 market_sub_tx: watch::Sender<Option<MarketSubscription>>,
59 market_clk_tx: watch::Sender<Option<String>>,
60 market_initial_clk_tx: watch::Sender<Option<String>>,
61 order_sub_tx: watch::Sender<Option<OrderSubscription>>,
62 order_clk_tx: watch::Sender<Option<String>>,
63 order_initial_clk_tx: watch::Sender<Option<String>>,
64 market_active_sub_id: Arc<AtomicU64>,
65 order_active_sub_id: Arc<AtomicU64>,
66 request_id: AtomicU64,
67 auth_bytes_tx: watch::Sender<Bytes>,
68 closed: AtomicBool,
69}
70
71impl BetfairStreamClient {
72 pub async fn connect(
78 credential: &BetfairCredential,
79 session_token: String,
80 handler: TcpMessageHandler,
81 config: BetfairStreamConfig,
82 ) -> Result<Self, BetfairStreamError> {
83 let auth = Authentication::new(credential.app_key().to_string(), session_token);
84 let auth_bytes_vec = serde_json::to_vec(&auth)?;
85 let auth_bytes = Bytes::from(auth_bytes_vec.clone());
86 let (auth_bytes_tx, auth_bytes_rx) = watch::channel(auth_bytes);
87 let mode = if config.use_tls {
88 Mode::Tls
89 } else {
90 Mode::Plain
91 };
92
93 let (market_clk_tx, market_clk_rx) = watch::channel(None::<String>);
94 let (market_initial_clk_tx, market_initial_clk_rx) = watch::channel(None::<String>);
95 let (order_clk_tx, order_clk_rx) = watch::channel(None::<String>);
96 let (order_initial_clk_tx, order_initial_clk_rx) = watch::channel(None::<String>);
97 let (market_sub_tx, market_sub_rx) = watch::channel(None::<MarketSubscription>);
98 let (order_sub_tx, order_sub_rx) = watch::channel(None::<OrderSubscription>);
99
100 let shared_tx: Arc<OnceLock<tokio::sync::mpsc::UnboundedSender<WriterCommand>>> =
102 Arc::new(OnceLock::new());
103
104 let (market_clk_tx_h, market_initial_clk_tx_h) =
106 (market_clk_tx.clone(), market_initial_clk_tx.clone());
107 let (order_clk_tx_h, order_initial_clk_tx_h) =
108 (order_clk_tx.clone(), order_initial_clk_tx.clone());
109
110 let market_active_sub_id = Arc::new(AtomicU64::new(0));
111 let order_active_sub_id = Arc::new(AtomicU64::new(0));
112 let market_active_sub_id_h = Arc::clone(&market_active_sub_id);
113 let order_active_sub_id_h = Arc::clone(&order_active_sub_id);
114
115 let message_handler: TcpMessageHandler = Arc::new(move |data: &[u8]| {
116 if let Ok(msg) = stream_decode(data) {
117 match &msg {
118 StreamMessage::MarketChange(mcm) => {
119 let active = market_active_sub_id_h.load(Ordering::SeqCst);
120 if active > 0 && mcm.id.is_none_or(|id| id == active) {
125 if mcm.clk.is_some() {
126 let _ = market_clk_tx_h.send(mcm.clk.clone());
127 }
128
129 if mcm.initial_clk.is_some() {
130 let _ = market_initial_clk_tx_h.send(mcm.initial_clk.clone());
131 }
132 }
133 }
134 StreamMessage::OrderChange(ocm) => {
135 let active = order_active_sub_id_h.load(Ordering::SeqCst);
136 if active > 0 && ocm.id.is_none_or(|id| id == active) {
137 if ocm.clk.is_some() {
138 let _ = order_clk_tx_h.send(ocm.clk.clone());
139 }
140
141 if ocm.initial_clk.is_some() {
142 let _ = order_initial_clk_tx_h.send(ocm.initial_clk.clone());
143 }
144 }
145 }
146 StreamMessage::Status(status) => {
147 if status.error_code == Some(StatusErrorCode::InvalidClock) {
152 let _ = market_clk_tx_h.send(None);
153 let _ = market_initial_clk_tx_h.send(None);
154 let _ = order_clk_tx_h.send(None);
155 let _ = order_initial_clk_tx_h.send(None);
156 log::warn!(
157 "Betfair stream INVALID_CLOCK: clocks cleared, \
158 next reconnect will request a full image",
159 );
160 } else if status.connection_closed {
161 log::error!(
162 "Betfair stream connection closed by server: {:?} - {:?}",
163 status.error_code,
164 status.error_message,
165 );
166 } else if status.error_code.is_some() {
167 log::warn!(
168 "Betfair stream status error: {:?} - {:?}",
169 status.error_code,
170 status.error_message,
171 );
172 }
173 }
174 _ => {}
175 }
176 }
177 handler(data);
178 });
179
180 let auth_bytes_reconnect = auth_bytes_rx;
181 let shared_tx_reconnect = Arc::clone(&shared_tx);
182 let post_reconnection: Arc<dyn Fn() + Send + Sync> = Arc::new(move || {
183 let Some(tx) = shared_tx_reconnect.get() else {
184 return;
185 };
186
187 let auth = auth_bytes_reconnect.borrow().clone();
188 let market_sub = market_sub_rx.borrow().clone();
189 let order_sub = order_sub_rx.borrow().clone();
190
191 let _ = tx.send(WriterCommand::Send(auth));
192
193 if let Some(mut sub) = market_sub {
194 sub.clk = market_clk_rx.borrow().clone();
195 sub.initial_clk = market_initial_clk_rx.borrow().clone();
196 if let Ok(sub_bytes) = serde_json::to_vec(&sub) {
197 let _ = tx.send(WriterCommand::Send(Bytes::from(sub_bytes)));
198 }
199 }
200
201 if let Some(mut sub) = order_sub {
202 sub.clk = order_clk_rx.borrow().clone();
203 sub.initial_clk = order_initial_clk_rx.borrow().clone();
204 if let Ok(sub_bytes) = serde_json::to_vec(&sub) {
205 let _ = tx.send(WriterCommand::Send(Bytes::from(sub_bytes)));
206 }
207 }
208 });
209
210 let url = format!("{}:{}", config.host, config.port);
211 let socket_config = SocketConfig {
212 url,
213 mode,
214 suffix: b"\r\n".to_vec(),
215 message_handler: Some(message_handler),
216 heartbeat: Some((
218 config.heartbeat_ms.div_ceil(1_000),
219 b"{\"op\":\"heartbeat\"}".to_vec(),
220 )),
221 reconnect_timeout_ms: None,
222 reconnect_delay_initial_ms: Some(config.reconnect_delay_initial_ms),
223 reconnect_delay_max_ms: Some(config.reconnect_delay_max_ms),
224 reconnect_backoff_factor: None,
225 reconnect_jitter_ms: None,
226 connection_max_retries: None,
227 reconnect_max_attempts: None,
228 idle_timeout_ms: Some(config.idle_timeout_ms),
229 certs_dir: None,
230 };
231
232 let socket = SocketClient::connect(socket_config, None, Some(post_reconnection), None)
233 .await
234 .map_err(|e| BetfairStreamError::ConnectionFailed(e.to_string()))?;
235
236 let _ = shared_tx.set(socket.writer_tx.clone());
238
239 socket
240 .send_bytes(auth_bytes_vec)
241 .await
242 .map_err(|e| BetfairStreamError::ConnectionFailed(e.to_string()))?;
243
244 Ok(Self {
245 socket,
246 market_sub_tx,
247 market_clk_tx,
248 market_initial_clk_tx,
249 order_sub_tx,
250 order_clk_tx,
251 order_initial_clk_tx,
252 market_active_sub_id,
253 order_active_sub_id,
254 request_id: AtomicU64::new(1),
255 auth_bytes_tx,
256 closed: AtomicBool::new(false),
257 })
258 }
259
260 pub async fn subscribe_markets(
268 &self,
269 market_filter: StreamMarketFilter,
270 data_filter: MarketDataFilter,
271 heartbeat_ms: Option<u64>,
272 conflate_ms: Option<u64>,
273 ) -> Result<(), BetfairStreamError> {
274 if self.closed.load(Ordering::SeqCst) || self.socket.is_closed() {
275 return Err(BetfairStreamError::Disconnected(
276 "stream client is closed".to_string(),
277 ));
278 }
279 let id = self.request_id.fetch_add(1, Ordering::Relaxed);
280 self.market_active_sub_id.store(id, Ordering::SeqCst);
283 let sub = MarketSubscription {
284 op: STREAM_OP_MARKET_SUBSCRIPTION.to_string(),
285 id: Some(id),
286 market_filter,
287 market_data_filter: data_filter,
288 clk: None,
289 conflate_ms,
290 heartbeat_ms,
291 initial_clk: None,
292 segmentation_enabled: None,
293 };
294
295 let _ = self.market_clk_tx.send(None);
298 let _ = self.market_initial_clk_tx.send(None);
299 let _ = self.market_sub_tx.send(Some(sub.clone()));
300
301 let sub_bytes = serde_json::to_vec(&sub)?;
302 self.socket
303 .send_bytes(sub_bytes)
304 .await
305 .map_err(|e| BetfairStreamError::ConnectionFailed(e.to_string()))?;
306 Ok(())
307 }
308
309 pub async fn subscribe_orders(
317 &self,
318 order_filter: Option<OrderFilter>,
319 heartbeat_ms: Option<u64>,
320 ) -> Result<(), BetfairStreamError> {
321 if self.closed.load(Ordering::SeqCst) || self.socket.is_closed() {
322 return Err(BetfairStreamError::Disconnected(
323 "stream client is closed".to_string(),
324 ));
325 }
326 let id = self.request_id.fetch_add(1, Ordering::Relaxed);
327 self.order_active_sub_id.store(id, Ordering::SeqCst);
328 let sub = OrderSubscription {
329 op: STREAM_OP_ORDER_SUBSCRIPTION.to_string(),
330 id: Some(id),
331 order_filter,
332 clk: None,
333 conflate_ms: None,
334 heartbeat_ms,
335 initial_clk: None,
336 segmentation_enabled: None,
337 };
338
339 let _ = self.order_clk_tx.send(None);
342 let _ = self.order_initial_clk_tx.send(None);
343 let _ = self.order_sub_tx.send(Some(sub.clone()));
344
345 let sub_bytes = serde_json::to_vec(&sub)?;
346 self.socket
347 .send_bytes(sub_bytes)
348 .await
349 .map_err(|e| BetfairStreamError::ConnectionFailed(e.to_string()))?;
350 Ok(())
351 }
352
353 #[must_use]
355 pub fn is_active(&self) -> bool {
356 self.socket.is_active()
357 }
358
359 pub fn update_auth(&self, app_key: &str, session_token: String) {
362 let auth = Authentication::new(app_key.to_string(), session_token);
363 if let Ok(bytes) = serde_json::to_vec(&auth) {
364 let _ = self.auth_bytes_tx.send(Bytes::from(bytes));
365 }
366 }
367
368 pub async fn close(&self) {
370 self.closed.store(true, Ordering::SeqCst);
371 self.socket.close().await;
372 }
373}
374
375#[derive(Debug)]
381pub struct BetfairRaceStreamClient {
382 socket: SocketClient,
383 auth_bytes_tx: watch::Sender<Bytes>,
384 closed: AtomicBool,
385}
386
387impl BetfairRaceStreamClient {
388 pub async fn connect(
398 credential: &BetfairCredential,
399 session_token: String,
400 handler: TcpMessageHandler,
401 config: BetfairStreamConfig,
402 race_fatal_tx: tokio::sync::mpsc::UnboundedSender<()>,
403 ) -> Result<Self, BetfairStreamError> {
404 let auth = Authentication::new(credential.app_key().to_string(), session_token);
405 let auth_bytes_vec = serde_json::to_vec(&auth)?;
406 let auth_bytes = Bytes::from(auth_bytes_vec.clone());
407 let (auth_bytes_tx, auth_bytes_rx) = watch::channel(auth_bytes.clone());
408
409 let race_sub = RaceSubscription::new(1);
410 let race_sub_bytes = Bytes::from(serde_json::to_vec(&race_sub)?);
411
412 let mode = if config.use_tls {
413 Mode::Tls
414 } else {
415 Mode::Plain
416 };
417
418 let shared_tx: Arc<OnceLock<tokio::sync::mpsc::UnboundedSender<WriterCommand>>> =
419 Arc::new(OnceLock::new());
420
421 let message_handler: TcpMessageHandler = Arc::new(move |data: &[u8]| {
422 if let Ok(StreamMessage::Status(status)) = stream_decode(data) {
423 if let Some(ref code) = status.error_code
424 && code.is_race_stream_fatal()
425 {
426 log::error!(
427 "Betfair race stream fatal error: {:?} - {:?} \
428 (check TPD entitlement on your Betfair app key)",
429 status.error_code,
430 status.error_message,
431 );
432 let _ = race_fatal_tx.send(());
433 return;
434 }
435
436 if status.connection_closed {
437 log::error!(
438 "Betfair race stream closed: {:?} - {:?}",
439 status.error_code,
440 status.error_message,
441 );
442 } else if status.error_code.is_some() {
443 log::warn!(
444 "Betfair race stream status: {:?} - {:?}",
445 status.error_code,
446 status.error_message,
447 );
448 }
449 }
450 handler(data);
451 });
452
453 let auth_bytes_reconnect = auth_bytes_rx;
454 let sub_reconnect = race_sub_bytes.clone();
455 let shared_tx_reconnect = Arc::clone(&shared_tx);
456 let post_reconnection: Arc<dyn Fn() + Send + Sync> = Arc::new(move || {
457 let Some(tx) = shared_tx_reconnect.get() else {
458 return;
459 };
460 let auth = auth_bytes_reconnect.borrow().clone();
461 let mut combined = Vec::with_capacity(auth.len() + 2 + sub_reconnect.len());
462 combined.extend_from_slice(&auth);
463 combined.extend_from_slice(b"\r\n");
464 combined.extend_from_slice(&sub_reconnect);
465 let _ = tx.send(WriterCommand::Send(Bytes::from(combined)));
466 });
467
468 let url = format!("{}:{}", config.host, config.port);
469 let socket_config = SocketConfig {
470 url,
471 mode,
472 suffix: b"\r\n".to_vec(),
473 message_handler: Some(message_handler),
474 heartbeat: Some((
475 config.heartbeat_ms.div_ceil(1_000),
476 b"{\"op\":\"heartbeat\"}".to_vec(),
477 )),
478 reconnect_timeout_ms: None,
479 reconnect_delay_initial_ms: Some(config.reconnect_delay_initial_ms),
480 reconnect_delay_max_ms: Some(config.reconnect_delay_max_ms),
481 reconnect_backoff_factor: None,
482 reconnect_jitter_ms: None,
483 connection_max_retries: None,
484 reconnect_max_attempts: None,
485 idle_timeout_ms: Some(config.idle_timeout_ms),
486 certs_dir: None,
487 };
488
489 let socket = SocketClient::connect(socket_config, None, Some(post_reconnection), None)
490 .await
491 .map_err(|e| BetfairStreamError::ConnectionFailed(e.to_string()))?;
492
493 let _ = shared_tx.set(socket.writer_tx.clone());
494
495 let mut combined = Vec::with_capacity(auth_bytes_vec.len() + 2 + race_sub_bytes.len());
496 combined.extend_from_slice(&auth_bytes_vec);
497 combined.extend_from_slice(b"\r\n");
498 combined.extend_from_slice(&race_sub_bytes);
499 socket
500 .send_bytes(combined)
501 .await
502 .map_err(|e| BetfairStreamError::ConnectionFailed(e.to_string()))?;
503
504 Ok(Self {
505 socket,
506 auth_bytes_tx,
507 closed: AtomicBool::new(false),
508 })
509 }
510
511 #[must_use]
513 pub fn is_active(&self) -> bool {
514 self.socket.is_active()
515 }
516
517 pub fn update_auth(&self, app_key: &str, session_token: String) {
520 let auth = Authentication::new(app_key.to_string(), session_token);
521 if let Ok(bytes) = serde_json::to_vec(&auth) {
522 let _ = self.auth_bytes_tx.send(Bytes::from(bytes));
523 }
524 }
525
526 pub async fn close(&self) {
528 self.closed.store(true, Ordering::SeqCst);
529 self.socket.close().await;
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use rstest::rstest;
536
537 use super::*;
538 use crate::stream::messages::{
539 Authentication, MarketDataFilter, RaceSubscription, StreamMarketFilter,
540 };
541
542 #[rstest]
543 fn test_invalid_clock_status_resets_clocks() {
544 let (market_clk_tx, market_clk_rx) = watch::channel(Some("old-market-clk".to_string()));
545 let (market_initial_clk_tx, market_initial_clk_rx) =
546 watch::channel(Some("old-market-iclk".to_string()));
547 let (order_clk_tx, order_clk_rx) = watch::channel(Some("old-order-clk".to_string()));
548 let (order_initial_clk_tx, order_initial_clk_rx) =
549 watch::channel(Some("old-order-iclk".to_string()));
550
551 let handler: TcpMessageHandler = Arc::new(move |data: &[u8]| {
552 if let Ok(msg) = stream_decode(data)
553 && let StreamMessage::Status(status) = &msg
554 && status.error_code == Some(StatusErrorCode::InvalidClock)
555 {
556 let _ = market_clk_tx.send(None);
557 let _ = market_initial_clk_tx.send(None);
558 let _ = order_clk_tx.send(None);
559 let _ = order_initial_clk_tx.send(None);
560 }
561 });
562
563 handler(
564 br#"{"op":"status","statusCode":"503","errorCode":"INVALID_CLOCK","connectionClosed":true}"#,
565 );
566
567 assert!(
568 market_clk_rx.borrow().is_none(),
569 "market clk must be cleared"
570 );
571 assert!(
572 market_initial_clk_rx.borrow().is_none(),
573 "market initialClk must be cleared"
574 );
575 assert!(order_clk_rx.borrow().is_none(), "order clk must be cleared");
576 assert!(
577 order_initial_clk_rx.borrow().is_none(),
578 "order initialClk must be cleared"
579 );
580 }
581
582 #[rstest]
583 fn test_auth_message_serialization() {
584 let auth = Authentication::new("my-app-key".to_string(), "my-session".to_string());
585 let json = serde_json::to_string(&auth).unwrap();
586 assert!(json.contains("\"op\":\"authentication\""));
587 assert!(json.contains("\"appKey\":\"my-app-key\""));
588 assert!(json.contains("\"session\":\"my-session\""));
589 }
590
591 #[rstest]
592 fn test_clk_is_updated_from_mcm() {
593 let (market_clk_tx, market_clk_rx) = watch::channel(None::<String>);
594 let (market_initial_clk_tx, market_initial_clk_rx) = watch::channel(None::<String>);
595 let (order_clk_tx, order_clk_rx) = watch::channel(None::<String>);
596 let (order_initial_clk_tx, order_initial_clk_rx) = watch::channel(None::<String>);
597 let market_active_sub_id = Arc::new(AtomicU64::new(5));
598 let order_active_sub_id = Arc::new(AtomicU64::new(6));
599
600 let handler: TcpMessageHandler = Arc::new(move |data: &[u8]| {
601 if let Ok(msg) = stream_decode(data) {
602 match &msg {
603 StreamMessage::MarketChange(mcm) => {
604 let active = market_active_sub_id.load(Ordering::SeqCst);
605 if active > 0 && mcm.id.is_none_or(|id| id == active) {
606 if mcm.clk.is_some() {
607 let _ = market_clk_tx.send(mcm.clk.clone());
608 }
609
610 if mcm.initial_clk.is_some() {
611 let _ = market_initial_clk_tx.send(mcm.initial_clk.clone());
612 }
613 }
614 }
615 StreamMessage::OrderChange(ocm) => {
616 let active = order_active_sub_id.load(Ordering::SeqCst);
617 if active > 0 && ocm.id.is_none_or(|id| id == active) {
618 if ocm.clk.is_some() {
619 let _ = order_clk_tx.send(ocm.clk.clone());
620 }
621
622 if ocm.initial_clk.is_some() {
623 let _ = order_initial_clk_tx.send(ocm.initial_clk.clone());
624 }
625 }
626 }
627 _ => {}
628 }
629 }
630 });
631
632 handler(br#"{"op":"mcm","id":5,"pt":1000,"initialClk":"mcm-iclk","clk":"mcm-clk"}"#);
634 handler(br#"{"op":"ocm","id":6,"pt":2000,"initialClk":"ocm-iclk","clk":"ocm-clk"}"#);
635
636 assert_eq!(market_clk_rx.borrow().as_deref(), Some("mcm-clk"));
637 assert_eq!(market_initial_clk_rx.borrow().as_deref(), Some("mcm-iclk"));
638 assert_eq!(order_clk_rx.borrow().as_deref(), Some("ocm-clk"));
639 assert_eq!(order_initial_clk_rx.borrow().as_deref(), Some("ocm-iclk"));
640
641 handler(br#"{"op":"mcm","pt":1001,"clk":"hb-clk"}"#);
643 assert_eq!(market_clk_rx.borrow().as_deref(), Some("hb-clk"));
644
645 handler(br#"{"op":"mcm","id":4,"pt":1002,"clk":"stale-clk"}"#);
647 assert_eq!(market_clk_rx.borrow().as_deref(), Some("hb-clk"));
648 }
649
650 #[rstest]
651 fn test_reconnect_callback_sends_auth_and_subscription() {
652 let (market_clk_tx, market_clk_rx) = watch::channel(Some("mcm-clk1".to_string()));
653 let (market_initial_clk_tx, market_initial_clk_rx) =
654 watch::channel(Some("mcm-iclk1".to_string()));
655 let (order_clk_tx, order_clk_rx) = watch::channel(Some("ocm-clk1".to_string()));
656 let (order_initial_clk_tx, order_initial_clk_rx) =
657 watch::channel(Some("ocm-iclk1".to_string()));
658 let (market_sub_tx, market_sub_rx) = watch::channel(None::<MarketSubscription>);
659 let (order_sub_tx, order_sub_rx) = watch::channel(None::<OrderSubscription>);
660 let shared_tx: Arc<OnceLock<tokio::sync::mpsc::UnboundedSender<WriterCommand>>> =
661 Arc::new(OnceLock::new());
662
663 let auth = Authentication::new("key".to_string(), "token".to_string());
664 let auth_bytes = Bytes::from(serde_json::to_vec(&auth).unwrap());
665
666 let _ = market_sub_tx.send(Some(MarketSubscription {
667 op: STREAM_OP_MARKET_SUBSCRIPTION.to_string(),
668 id: Some(1),
669 market_filter: StreamMarketFilter::default(),
670 market_data_filter: MarketDataFilter::default(),
671 clk: None,
672 conflate_ms: None,
673 heartbeat_ms: None,
674 initial_clk: None,
675 segmentation_enabled: None,
676 }));
677 let _ = order_sub_tx.send(Some(OrderSubscription {
678 op: STREAM_OP_ORDER_SUBSCRIPTION.to_string(),
679 id: Some(2),
680 order_filter: None,
681 clk: None,
682 conflate_ms: None,
683 heartbeat_ms: None,
684 initial_clk: None,
685 segmentation_enabled: None,
686 }));
687
688 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
689 let _ = shared_tx.set(tx);
690
691 let auth_bytes_reconnect = auth_bytes;
693 let shared_tx_reconnect = Arc::clone(&shared_tx);
694 let post_reconnection: Arc<dyn Fn() + Send + Sync> = Arc::new(move || {
695 let Some(tx) = shared_tx_reconnect.get() else {
696 return;
697 };
698
699 let market_sub = market_sub_rx.borrow().clone();
700 let order_sub = order_sub_rx.borrow().clone();
701
702 let _ = tx.send(WriterCommand::Send(auth_bytes_reconnect.clone()));
703
704 if let Some(mut sub) = market_sub {
705 sub.clk = market_clk_rx.borrow().clone();
706 sub.initial_clk = market_initial_clk_rx.borrow().clone();
707 if let Ok(sub_bytes) = serde_json::to_vec(&sub) {
708 let _ = tx.send(WriterCommand::Send(Bytes::from(sub_bytes)));
709 }
710 }
711
712 if let Some(mut sub) = order_sub {
713 sub.clk = order_clk_rx.borrow().clone();
714 sub.initial_clk = order_initial_clk_rx.borrow().clone();
715 if let Ok(sub_bytes) = serde_json::to_vec(&sub) {
716 let _ = tx.send(WriterCommand::Send(Bytes::from(sub_bytes)));
717 }
718 }
719 });
720
721 drop(market_clk_tx);
722 drop(market_initial_clk_tx);
723 drop(order_clk_tx);
724 drop(order_initial_clk_tx);
725
726 post_reconnection();
727
728 let auth_cmd = rx.try_recv().expect("auth replay message");
729 let market_cmd = rx.try_recv().expect("market subscription message");
730 let order_cmd = rx.try_recv().expect("order subscription message");
731 assert!(rx.try_recv().is_err(), "no further messages expected");
732
733 let WriterCommand::Send(auth_bytes) = auth_cmd else {
734 panic!("expected Send");
735 };
736 let WriterCommand::Send(market_bytes) = market_cmd else {
737 panic!("expected Send");
738 };
739 let WriterCommand::Send(order_bytes) = order_cmd else {
740 panic!("expected Send");
741 };
742
743 let auth_str = std::str::from_utf8(&auth_bytes).unwrap();
744 let market_str = std::str::from_utf8(&market_bytes).unwrap();
745 let order_str = std::str::from_utf8(&order_bytes).unwrap();
746
747 assert!(auth_str.contains("\"op\":\"authentication\""));
748 assert!(market_str.contains("\"op\":\"marketSubscription\""));
749 assert!(market_str.contains("\"clk\":\"mcm-clk1\""));
751 assert!(market_str.contains("\"initialClk\":\"mcm-iclk1\""));
752
753 assert!(order_str.contains("\"op\":\"orderSubscription\""));
754 assert!(order_str.contains("\"clk\":\"ocm-clk1\""));
755 assert!(order_str.contains("\"initialClk\":\"ocm-iclk1\""));
756 }
757
758 #[rstest]
759 fn test_race_subscription_serialization() {
760 let sub = RaceSubscription::new(42);
761 let json = serde_json::to_string(&sub).unwrap();
762 assert!(json.contains("\"op\":\"raceSubscription\""));
763 assert!(json.contains("\"id\":42"));
764 }
765
766 #[rstest]
767 fn test_race_stream_reconnect_replays_auth_and_subscription() {
768 let auth = Authentication::new("key".to_string(), "token".to_string());
769 let auth_bytes = Bytes::from(serde_json::to_vec(&auth).unwrap());
770 let race_sub = RaceSubscription::new(1);
771 let race_sub_bytes = Bytes::from(serde_json::to_vec(&race_sub).unwrap());
772
773 let shared_tx: Arc<OnceLock<tokio::sync::mpsc::UnboundedSender<WriterCommand>>> =
774 Arc::new(OnceLock::new());
775
776 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<WriterCommand>();
777 let _ = shared_tx.set(tx);
778
779 let auth_reconnect = auth_bytes;
780 let sub_reconnect = race_sub_bytes;
781 let shared_tx_reconnect = Arc::clone(&shared_tx);
782 let post_reconnection: Arc<dyn Fn() + Send + Sync> = Arc::new(move || {
783 let Some(tx) = shared_tx_reconnect.get() else {
784 return;
785 };
786 let mut combined = Vec::with_capacity(auth_reconnect.len() + 2 + sub_reconnect.len());
787 combined.extend_from_slice(&auth_reconnect);
788 combined.extend_from_slice(b"\r\n");
789 combined.extend_from_slice(&sub_reconnect);
790 let _ = tx.send(WriterCommand::Send(Bytes::from(combined)));
791 });
792
793 post_reconnection();
794
795 let cmd = rx.try_recv().expect("auth+race subscription message");
796 assert!(rx.try_recv().is_err(), "no further messages expected");
797
798 let WriterCommand::Send(bytes) = cmd else {
799 panic!("expected Send");
800 };
801
802 let text = std::str::from_utf8(&bytes).unwrap();
803 let (auth_part, sub_part) = text
804 .split_once("\r\n")
805 .expect("CRLF separator in combined message");
806
807 assert!(auth_part.contains("\"op\":\"authentication\""));
808 assert!(sub_part.contains("\"op\":\"raceSubscription\""));
809 }
810
811 #[rstest]
812 fn test_race_stream_handler_fatal_status_sends_kill_signal() {
813 let (race_fatal_tx, mut race_fatal_rx) = tokio::sync::mpsc::unbounded_channel::<()>();
814 let inner_handler: TcpMessageHandler = Arc::new(|_data: &[u8]| {});
815
816 let handler: TcpMessageHandler = Arc::new(move |data: &[u8]| {
817 if let Ok(StreamMessage::Status(status)) = stream_decode(data)
818 && let Some(ref code) = status.error_code
819 && code.is_race_stream_fatal()
820 {
821 let _ = race_fatal_tx.send(());
822 return;
823 }
824 inner_handler(data);
825 });
826
827 handler(
829 br#"{"op":"status","statusCode":"503","errorCode":"NOT_AUTHORIZED","connectionClosed":true}"#,
830 );
831 assert!(
832 race_fatal_rx.try_recv().is_ok(),
833 "fatal error must send kill signal"
834 );
835
836 handler(
838 br#"{"op":"status","statusCode":"503","errorCode":"INVALID_CLOCK","connectionClosed":true}"#,
839 );
840 assert!(
841 race_fatal_rx.try_recv().is_err(),
842 "non-fatal error must not send kill signal"
843 );
844 }
845}