1use std::{
22 num::NonZeroU32,
23 str::FromStr,
24 sync::{
25 Arc, LazyLock,
26 atomic::{AtomicBool, AtomicU8, Ordering},
27 },
28 time::Duration,
29};
30
31use arc_swap::ArcSwap;
32use nautilus_common::live::get_runtime;
33use nautilus_core::AtomicMap;
34use nautilus_model::{
35 data::BarType,
36 identifiers::{AccountId, InstrumentId},
37 instruments::{Instrument, InstrumentAny},
38};
39use nautilus_network::{
40 mode::ConnectionMode,
41 ratelimiter::quota::Quota,
42 websocket::{
43 SubscriptionState, TransportBackend, WebSocketClient, WebSocketConfig,
44 channel_message_handler,
45 },
46};
47use ustr::Ustr;
48
49use crate::{
50 common::{
51 consts::{
52 RECONNECT_BACKOFF_FACTOR, RECONNECT_BASE_BACKOFF, RECONNECT_JITTER_MS,
53 RECONNECT_MAX_BACKOFF, RECONNECT_TIMEOUT, WS_DISCONNECT_TIMEOUT, WS_HEARTBEAT_SECS,
54 },
55 credential::CoinbaseCredential,
56 enums::CoinbaseWsChannel,
57 },
58 websocket::{
59 handler::{FeedHandler, HandlerCommand, NautilusWsMessage},
60 messages::{CoinbaseWsAction, CoinbaseWsSubscription},
61 },
62};
63
64pub static COINBASE_WS_CONNECTION_QUOTA: LazyLock<Quota> = LazyLock::new(|| {
66 Quota::per_second(NonZeroU32::new(8).expect("non-zero")).expect("valid constant")
67});
68
69pub static COINBASE_WS_SUBSCRIPTION_QUOTA: LazyLock<Quota> = LazyLock::new(|| {
71 Quota::per_second(NonZeroU32::new(8).expect("non-zero")).expect("valid constant")
72});
73
74pub const COINBASE_RATE_LIMIT_KEY_SUBSCRIPTION: &str = "subscription";
76
77pub static COINBASE_WS_SUBSCRIPTION_KEYS: LazyLock<[Ustr; 1]> =
79 LazyLock::new(|| [Ustr::from(COINBASE_RATE_LIMIT_KEY_SUBSCRIPTION)]);
80
81#[derive(Debug)]
86#[cfg_attr(
87 feature = "python",
88 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.coinbase", from_py_object)
89)]
90#[cfg_attr(
91 feature = "python",
92 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.coinbase")
93)]
94pub struct CoinbaseWebSocketClient {
95 url: String,
96 connection_mode: Arc<ArcSwap<AtomicU8>>,
97 signal: Arc<AtomicBool>,
98 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
99 out_rx: Option<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>,
100 instruments: Arc<AtomicMap<InstrumentId, InstrumentAny>>,
101 subscription_aliases: Arc<AtomicMap<Ustr, Ustr>>,
106 bar_types: ahash::AHashMap<String, BarType>,
107 subscriptions: SubscriptionState,
108 credential: Option<CoinbaseCredential>,
109 account_id: Option<AccountId>,
110 task_handle: Option<tokio::task::JoinHandle<()>>,
111 transport_backend: TransportBackend,
112 proxy_url: Option<String>,
113}
114
115impl Clone for CoinbaseWebSocketClient {
116 fn clone(&self) -> Self {
117 Self {
118 url: self.url.clone(),
119 connection_mode: Arc::clone(&self.connection_mode),
120 signal: Arc::clone(&self.signal),
121 cmd_tx: Arc::clone(&self.cmd_tx),
122 out_rx: None,
123 instruments: Arc::clone(&self.instruments),
124 subscription_aliases: Arc::clone(&self.subscription_aliases),
125 bar_types: self.bar_types.clone(),
126 subscriptions: self.subscriptions.clone(),
127 credential: self.credential.clone(),
128 account_id: self.account_id,
129 task_handle: None,
130 transport_backend: self.transport_backend,
131 proxy_url: self.proxy_url.clone(),
132 }
133 }
134}
135
136impl CoinbaseWebSocketClient {
137 pub fn new(url: &str, transport_backend: TransportBackend, proxy_url: Option<String>) -> Self {
139 let (placeholder_tx, _) = tokio::sync::mpsc::unbounded_channel();
140
141 Self {
142 url: url.to_string(),
143 connection_mode: Arc::new(ArcSwap::from_pointee(AtomicU8::new(
144 ConnectionMode::Closed.as_u8(),
145 ))),
146 signal: Arc::new(AtomicBool::new(false)),
147 cmd_tx: Arc::new(tokio::sync::RwLock::new(placeholder_tx)),
148 out_rx: None,
149 instruments: Arc::new(AtomicMap::new()),
150 subscription_aliases: Arc::new(AtomicMap::new()),
151 bar_types: ahash::AHashMap::new(),
152 subscriptions: SubscriptionState::new('|'),
153 credential: None,
154 account_id: None,
155 task_handle: None,
156 transport_backend,
157 proxy_url,
158 }
159 }
160
161 pub fn with_credential(
163 url: &str,
164 credential: CoinbaseCredential,
165 transport_backend: TransportBackend,
166 proxy_url: Option<String>,
167 ) -> Self {
168 let mut client = Self::new(url, transport_backend, proxy_url);
169 client.credential = Some(credential);
170 client
171 }
172
173 pub async fn set_account_id(&mut self, account_id: AccountId) {
178 self.account_id = Some(account_id);
179
180 let cmd_tx = self.cmd_tx.read().await;
181 if let Err(e) = cmd_tx.send(HandlerCommand::SetAccountId(account_id)) {
182 log::debug!("Failed to send SetAccountId: {e}");
183 }
184 }
185
186 pub async fn initialize_instruments(&self, instruments: Vec<InstrumentAny>) {
193 for instrument in &instruments {
194 self.instruments.insert(instrument.id(), instrument.clone());
195 }
196
197 let cmd_tx = self.cmd_tx.read().await;
198 if let Err(e) = cmd_tx.send(HandlerCommand::InitializeInstruments(instruments)) {
199 log::debug!("Failed to send InitializeInstruments: {e}");
200 }
201 }
202
203 fn prime_default_subscriptions(&self) {
207 self.subscriptions
208 .mark_subscribe(CoinbaseWsChannel::Heartbeats.as_ref());
209 }
210
211 pub async fn connect(&mut self) -> anyhow::Result<()> {
213 if self.is_active() || self.is_reconnecting() {
214 log::warn!("WebSocket already connected or reconnecting");
215 return Ok(());
216 }
217
218 self.signal.store(false, Ordering::Relaxed);
220
221 let (message_handler, raw_rx) = channel_message_handler();
222 let cfg = WebSocketConfig {
223 url: self.url.clone(),
224 headers: vec![],
225 heartbeat: Some(WS_HEARTBEAT_SECS),
228 heartbeat_msg: None,
229 reconnect_timeout_ms: Some(RECONNECT_TIMEOUT.as_millis() as u64),
230 reconnect_delay_initial_ms: Some(RECONNECT_BASE_BACKOFF.as_millis() as u64),
231 reconnect_delay_max_ms: Some(RECONNECT_MAX_BACKOFF.as_millis() as u64),
232 reconnect_backoff_factor: Some(RECONNECT_BACKOFF_FACTOR),
233 reconnect_jitter_ms: Some(RECONNECT_JITTER_MS),
234 reconnect_max_attempts: None,
235 idle_timeout_ms: None,
236 backend: self.transport_backend,
237 proxy_url: self.proxy_url.clone(),
238 };
239
240 let keyed_quotas = vec![(
241 COINBASE_RATE_LIMIT_KEY_SUBSCRIPTION.to_string(),
242 *COINBASE_WS_SUBSCRIPTION_QUOTA,
243 )];
244
245 let client = WebSocketClient::connect(
246 cfg,
247 Some(message_handler),
248 None,
249 None,
250 keyed_quotas,
251 Some(*COINBASE_WS_CONNECTION_QUOTA),
252 )
253 .await?;
254
255 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
256 let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
257
258 *self.cmd_tx.write().await = cmd_tx.clone();
259 self.out_rx = Some(out_rx);
260 self.connection_mode.store(client.connection_mode_atomic());
261 log::info!("Coinbase WebSocket connected: {}", self.url);
262
263 if let Err(e) = cmd_tx.send(HandlerCommand::SetClient(client)) {
264 anyhow::bail!("Failed to send SetClient command: {e}");
265 }
266
267 let instruments_vec: Vec<InstrumentAny> =
268 self.instruments.load().values().cloned().collect();
269
270 if !instruments_vec.is_empty()
271 && let Err(e) = cmd_tx.send(HandlerCommand::InitializeInstruments(instruments_vec))
272 {
273 log::error!("Failed to send InitializeInstruments: {e}");
274 }
275
276 for (key, bar_type) in &self.bar_types {
278 if let Err(e) = cmd_tx.send(HandlerCommand::AddBarType {
279 key: key.clone(),
280 bar_type: *bar_type,
281 }) {
282 log::error!("Failed to restore bar type {key}: {e}");
283 }
284 }
285
286 if let Some(account_id) = self.account_id
287 && let Err(e) = cmd_tx.send(HandlerCommand::SetAccountId(account_id))
288 {
289 log::error!("Failed to restore account_id: {e}");
290 }
291
292 self.prime_default_subscriptions();
293
294 resubscribe_all(
296 &self.subscriptions,
297 &self.credential,
298 &cmd_tx,
299 Some(&out_tx),
300 );
301
302 let signal = Arc::clone(&self.signal);
303 let subscriptions = self.subscriptions.clone();
304 let credential = self.credential.clone();
305 let cmd_tx_reconnect = cmd_tx.clone();
306 let aliases_for_handler = Arc::clone(&self.subscription_aliases);
307
308 let stream_handle = get_runtime().spawn(async move {
309 let mut handler = FeedHandler::new(signal, cmd_rx, raw_rx, aliases_for_handler);
310
311 loop {
312 match handler.next().await {
313 Some(NautilusWsMessage::Reconnected) => {
314 resubscribe_all(
315 &subscriptions,
316 &credential,
317 &cmd_tx_reconnect,
318 Some(&out_tx),
319 );
320
321 if let Err(e) = out_tx.send(NautilusWsMessage::Reconnected) {
322 log::debug!("Output channel closed: {e}");
323 break;
324 }
325 }
326 Some(msg) => {
327 if let Err(e) = out_tx.send(msg) {
328 log::debug!("Output channel closed: {e}");
329 break;
330 }
331 }
332 None => {
333 log::info!("Feed handler stopped");
334 break;
335 }
336 }
337 }
338 });
339
340 self.task_handle = Some(stream_handle);
341 Ok(())
342 }
343
344 pub async fn subscribe(
346 &self,
347 channel: CoinbaseWsChannel,
348 product_ids: &[Ustr],
349 ) -> anyhow::Result<()> {
350 let jwt = if channel.requires_auth() {
351 let credential = self
352 .credential
353 .as_ref()
354 .ok_or_else(|| anyhow::anyhow!("Credentials required for {channel}"))?;
355 Some(credential.build_ws_jwt()?)
356 } else {
357 self.credential.as_ref().and_then(|c| c.build_ws_jwt().ok())
358 };
359
360 let sub = CoinbaseWsSubscription {
361 msg_type: CoinbaseWsAction::Subscribe,
362 product_ids: product_ids.to_vec(),
363 channel,
364 jwt,
365 };
366
367 let channel_str = channel.as_ref();
368
369 if product_ids.is_empty() {
370 self.subscriptions.mark_subscribe(channel_str);
371 } else {
372 for product_id in product_ids {
373 let topic = format!("{channel_str}|{product_id}");
374 self.subscriptions.mark_subscribe(&topic);
375 }
376 }
377
378 let cmd_tx = self.cmd_tx.read().await;
379 cmd_tx
380 .send(HandlerCommand::Subscribe(sub))
381 .map_err(|e| anyhow::anyhow!("Failed to send Subscribe command: {e}"))
382 }
383
384 pub async fn unsubscribe(
386 &self,
387 channel: CoinbaseWsChannel,
388 product_ids: &[Ustr],
389 ) -> anyhow::Result<()> {
390 let jwt = self.credential.as_ref().and_then(|c| c.build_ws_jwt().ok());
391
392 let unsub = CoinbaseWsSubscription {
393 msg_type: CoinbaseWsAction::Unsubscribe,
394 product_ids: product_ids.to_vec(),
395 channel,
396 jwt,
397 };
398
399 let channel_str = channel.as_ref();
400
401 if product_ids.is_empty() {
402 self.subscriptions.mark_unsubscribe(channel_str);
403 } else {
404 for product_id in product_ids {
405 let topic = format!("{channel_str}|{product_id}");
406 self.subscriptions.mark_unsubscribe(&topic);
407 }
408 }
409
410 let cmd_tx = self.cmd_tx.read().await;
411 cmd_tx
412 .send(HandlerCommand::Unsubscribe(unsub))
413 .map_err(|e| anyhow::anyhow!("Failed to send Unsubscribe command: {e}"))
414 }
415
416 pub async fn next_message(&mut self) -> Option<NautilusWsMessage> {
418 self.out_rx.as_mut()?.recv().await
419 }
420
421 pub async fn disconnect(&mut self) {
423 let cmd_tx = self.cmd_tx.read().await;
426
427 if let Err(e) = cmd_tx.send(HandlerCommand::Disconnect) {
428 log::debug!("Failed to send Disconnect command: {e}");
429 }
430 drop(cmd_tx);
431
432 self.signal.store(true, Ordering::Release);
435
436 if let Some(handle) = self.task_handle.take() {
437 let abort_handle = handle.abort_handle();
440 match tokio::time::timeout(WS_DISCONNECT_TIMEOUT, handle).await {
441 Ok(_) => log::debug!("Feed handler task completed"),
442 Err(_) => {
443 log::warn!("Feed handler task did not complete within timeout, aborting");
444 abort_handle.abort();
445 }
446 }
447 }
448
449 let deadline = tokio::time::Instant::now() + Duration::from_secs(5);
454
455 loop {
456 let mode_ptr = self.connection_mode.load();
457
458 if ConnectionMode::from_u8(mode_ptr.load(Ordering::Relaxed)).is_closed() {
459 break;
460 }
461
462 if tokio::time::Instant::now() >= deadline {
463 log::warn!("Timed out waiting for WebSocket to reach Closed state");
464 break;
465 }
466
467 tokio::time::sleep(Duration::from_millis(20)).await;
468 }
469 }
470
471 #[must_use]
473 pub fn is_active(&self) -> bool {
474 let mode_ptr = self.connection_mode.load();
475 let mode_val = mode_ptr.load(Ordering::Relaxed);
476 ConnectionMode::from_u8(mode_val).is_active()
477 }
478
479 #[must_use]
481 pub fn is_reconnecting(&self) -> bool {
482 let mode_ptr = self.connection_mode.load();
483 let mode_val = mode_ptr.load(Ordering::Relaxed);
484 ConnectionMode::from_u8(mode_val).is_reconnect()
485 }
486
487 #[must_use]
489 pub fn instruments(&self) -> &Arc<AtomicMap<InstrumentId, InstrumentAny>> {
490 &self.instruments
491 }
492
493 #[must_use]
495 pub fn subscription_aliases(&self) -> &Arc<AtomicMap<Ustr, Ustr>> {
496 &self.subscription_aliases
497 }
498
499 pub fn register_subscription_alias(&self, canonical: Ustr, subscribed: Ustr) {
503 self.subscription_aliases.insert(canonical, subscribed);
504 }
505
506 pub fn unregister_subscription_alias(&self, canonical: &Ustr) {
508 self.subscription_aliases.remove(canonical);
509 }
510
511 #[must_use]
513 pub fn subscriptions(&self) -> &SubscriptionState {
514 &self.subscriptions
515 }
516
517 pub async fn update_instrument(&self, instrument: InstrumentAny) {
519 let id = instrument.id();
520 self.instruments.insert(id, instrument.clone());
521
522 let cmd_tx = self.cmd_tx.read().await;
523
524 if let Err(e) = cmd_tx.send(HandlerCommand::UpdateInstrument(Box::new(instrument))) {
525 log::debug!("Failed to send UpdateInstrument: {e}");
526 }
527 }
528
529 pub fn take_out_rx(
533 &mut self,
534 ) -> Option<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>> {
535 self.out_rx.take()
536 }
537
538 pub fn register_bar_type(&mut self, key: String, bar_type: BarType) {
543 self.bar_types.insert(key, bar_type);
544 }
545
546 pub async fn add_bar_type(&mut self, key: String, bar_type: BarType) {
548 self.bar_types.insert(key.clone(), bar_type);
549
550 let cmd_tx = self.cmd_tx.read().await;
551
552 if let Err(e) = cmd_tx.send(HandlerCommand::AddBarType { key, bar_type }) {
553 log::debug!("Failed to send AddBarType: {e}");
554 }
555 }
556}
557
558fn resubscribe_all(
559 subscriptions: &SubscriptionState,
560 credential: &Option<CoinbaseCredential>,
561 cmd_tx: &tokio::sync::mpsc::UnboundedSender<HandlerCommand>,
562 out_tx: Option<&tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>>,
563) {
564 let topics = subscriptions.all_topics();
565
566 if topics.is_empty() {
567 log::debug!("No active subscriptions to restore");
568 return;
569 }
570
571 log::info!(
572 "Resubscribing to {} topics after reconnection",
573 topics.len()
574 );
575
576 for topic in topics {
577 let (channel, product_id) = match topic.split_once('|') {
578 Some((ch, pid)) => (ch, Some(pid)),
579 None => (topic.as_str(), None),
580 };
581
582 let channel_enum = match CoinbaseWsChannel::from_str(channel) {
583 Ok(ch) => ch,
584 Err(_) => {
585 log::warn!("Unknown channel in topic: {topic}");
586 continue;
587 }
588 };
589
590 let jwt = match credential.as_ref() {
591 Some(c) => match c.build_ws_jwt() {
592 Ok(token) => Some(token),
593 Err(e) => {
594 if channel_enum.requires_auth() {
595 let msg = format!(
596 "JWT required for {channel} but build failed: {e}; topic {topic} not restored"
597 );
598 log::error!("{msg}");
599 if let Some(tx) = out_tx {
600 let _ = tx.send(NautilusWsMessage::Error(msg));
601 }
602 continue;
603 }
604 None
605 }
606 },
607 None => {
608 if channel_enum.requires_auth() {
609 let msg = format!(
610 "JWT required for {channel} but no credentials configured; topic {topic} not restored"
611 );
612 log::error!("{msg}");
613 if let Some(tx) = out_tx {
614 let _ = tx.send(NautilusWsMessage::Error(msg));
615 }
616 continue;
617 }
618 None
619 }
620 };
621
622 let product_ids = match product_id {
623 Some(pid) => vec![Ustr::from(pid)],
624 None => vec![],
625 };
626
627 let sub = CoinbaseWsSubscription {
628 msg_type: CoinbaseWsAction::Subscribe,
629 product_ids,
630 channel: channel_enum,
631 jwt,
632 };
633
634 if let Err(e) = cmd_tx.send(HandlerCommand::Subscribe(sub)) {
635 log::error!("Failed to resubscribe {topic}: {e}");
636 }
637 }
638}
639
640#[cfg(test)]
641mod tests {
642 use nautilus_network::websocket::SubscriptionState;
643 use rstest::rstest;
644
645 use super::*;
646
647 #[rstest]
648 fn test_resubscribe_all_product_level_topic() {
649 let subs = SubscriptionState::new('|');
650 subs.mark_subscribe("level2|BTC-USD");
651
652 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
653 resubscribe_all(&subs, &None, &tx, None);
654
655 let cmd = rx.try_recv().unwrap();
656
657 match cmd {
658 HandlerCommand::Subscribe(sub) => {
659 assert_eq!(sub.channel, CoinbaseWsChannel::Level2);
660 assert_eq!(sub.product_ids.len(), 1);
661 assert_eq!(sub.product_ids[0], "BTC-USD");
662 assert!(sub.jwt.is_none());
663 }
664 other => panic!("Expected Subscribe, was {other:?}"),
665 }
666 }
667
668 #[rstest]
669 fn test_resubscribe_all_channel_level_topic() {
670 let subs = SubscriptionState::new('|');
671 subs.mark_subscribe("heartbeats");
672
673 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
674 resubscribe_all(&subs, &None, &tx, None);
675
676 let cmd = rx.try_recv().unwrap();
677
678 match cmd {
679 HandlerCommand::Subscribe(sub) => {
680 assert_eq!(sub.channel, CoinbaseWsChannel::Heartbeats);
681 assert!(sub.product_ids.is_empty());
682 }
683 other => panic!("Expected Subscribe, was {other:?}"),
684 }
685 }
686
687 #[rstest]
688 fn test_resubscribe_all_multiple_topics() {
689 let subs = SubscriptionState::new('|');
690 subs.mark_subscribe("market_trades|BTC-USD");
691 subs.mark_subscribe("ticker|ETH-USD");
692
693 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
694 resubscribe_all(&subs, &None, &tx, None);
695
696 let cmd1 = rx.try_recv().unwrap();
697 let cmd2 = rx.try_recv().unwrap();
698
699 assert!(matches!(cmd1, HandlerCommand::Subscribe(_)));
700 assert!(matches!(cmd2, HandlerCommand::Subscribe(_)));
701 assert!(rx.try_recv().is_err());
702 }
703
704 #[rstest]
705 fn test_resubscribe_all_empty_subscriptions() {
706 let subs = SubscriptionState::new('|');
707
708 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
709 resubscribe_all(&subs, &None, &tx, None);
710
711 assert!(rx.try_recv().is_err());
712 }
713
714 #[rstest]
715 fn test_resubscribe_all_unknown_channel_skipped() {
716 let subs = SubscriptionState::new('|');
717 subs.mark_subscribe("nonexistent_channel|BTC-USD");
718
719 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
720 resubscribe_all(&subs, &None, &tx, None);
721
722 assert!(rx.try_recv().is_err());
723 }
724
725 #[rstest]
726 #[case("level2|BTC-USD", CoinbaseWsChannel::Level2)]
727 #[case("market_trades|ETH-USD", CoinbaseWsChannel::MarketTrades)]
728 #[case("ticker|BTC-USD", CoinbaseWsChannel::Ticker)]
729 #[case("ticker_batch|BTC-USD", CoinbaseWsChannel::TickerBatch)]
730 #[case("candles|BTC-USD", CoinbaseWsChannel::Candles)]
731 #[case("heartbeats", CoinbaseWsChannel::Heartbeats)]
732 #[case("status", CoinbaseWsChannel::Status)]
733 fn test_resubscribe_all_channel_mapping(
734 #[case] topic: &str,
735 #[case] expected_channel: CoinbaseWsChannel,
736 ) {
737 let subs = SubscriptionState::new('|');
738 subs.mark_subscribe(topic);
739
740 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
741 resubscribe_all(&subs, &None, &tx, None);
742
743 let cmd = rx.try_recv().unwrap();
744
745 match cmd {
746 HandlerCommand::Subscribe(sub) => {
747 assert_eq!(sub.channel, expected_channel);
748 }
749 other => panic!("Expected Subscribe, was {other:?}"),
750 }
751 }
752
753 #[rstest]
754 #[case("user|BTC-USD")]
755 #[case("futures_balance_summary")]
756 fn test_resubscribe_all_auth_channel_skipped_without_credentials(#[case] topic: &str) {
757 let subs = SubscriptionState::new('|');
758 subs.mark_subscribe(topic);
759
760 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
761 resubscribe_all(&subs, &None, &tx, None);
762
763 assert!(rx.try_recv().is_err());
765 }
766
767 #[rstest]
768 #[case("user|BTC-USD", "user")]
769 #[case("futures_balance_summary", "futures_balance_summary")]
770 fn test_resubscribe_all_emits_error_for_auth_channel_without_credentials(
771 #[case] topic: &str,
772 #[case] channel: &str,
773 ) {
774 let subs = SubscriptionState::new('|');
775 subs.mark_subscribe(topic);
776
777 let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel();
778 let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
779 resubscribe_all(&subs, &None, &cmd_tx, Some(&out_tx));
780
781 assert!(cmd_rx.try_recv().is_err());
783
784 let msg = out_rx
785 .try_recv()
786 .expect("Error event must be emitted when auth channel cannot resubscribe");
787 match msg {
788 NautilusWsMessage::Error(text) => {
789 assert!(
790 text.contains(channel),
791 "error must mention the channel, was: {text}"
792 );
793 assert!(
794 text.contains(topic),
795 "error must mention the topic, was: {text}"
796 );
797 }
798 other => panic!("expected Error variant, was {other:?}"),
799 }
800 }
801
802 #[rstest]
803 fn test_resubscribe_all_emits_error_when_jwt_build_fails() {
804 let subs = SubscriptionState::new('|');
805 let topic = "user|BTC-USD";
806 subs.mark_subscribe(topic);
807
808 let bad_credential = Some(CoinbaseCredential::new(
811 "organizations/test/apiKeys/test".to_string(),
812 "not-a-pem-key".to_string(),
813 ));
814
815 let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel();
816 let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
817 resubscribe_all(&subs, &bad_credential, &cmd_tx, Some(&out_tx));
818
819 assert!(cmd_rx.try_recv().is_err(), "no subscribe should be sent");
820 let msg = out_rx
821 .try_recv()
822 .expect("Error event must be emitted when JWT build fails for an auth channel");
823 match msg {
824 NautilusWsMessage::Error(text) => {
825 assert!(text.contains("user"), "error must mention channel: {text}");
826 assert!(text.contains(topic), "error must mention topic: {text}");
827 }
828 other => panic!("expected Error variant, was {other:?}"),
829 }
830 }
831
832 #[rstest]
833 fn test_prime_default_subscriptions_marks_heartbeats() {
834 let client = CoinbaseWebSocketClient::new("wss://test", TransportBackend::default(), None);
835 assert!(client.subscriptions.all_topics().is_empty());
836
837 client.prime_default_subscriptions();
838
839 let topics = client.subscriptions.all_topics();
840 assert!(topics.iter().any(|t| t == "heartbeats"), "{topics:?}");
841 }
842
843 #[rstest]
844 fn test_ws_quotas_match_documented_limits() {
845 assert_eq!(COINBASE_WS_CONNECTION_QUOTA.burst_size().get(), 8);
846 assert_eq!(COINBASE_WS_SUBSCRIPTION_QUOTA.burst_size().get(), 8);
847 }
848
849 #[rstest]
850 fn test_ws_subscription_rate_limit_key_is_stable() {
851 assert_eq!(COINBASE_RATE_LIMIT_KEY_SUBSCRIPTION, "subscription");
852 assert_eq!(
853 COINBASE_WS_SUBSCRIPTION_KEYS[0].as_str(),
854 COINBASE_RATE_LIMIT_KEY_SUBSCRIPTION,
855 );
856 }
857}