Skip to main content

nautilus_coinbase/websocket/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! WebSocket client for the Coinbase Advanced Trade API.
17//!
18//! Manages connection lifecycle, JWT-authenticated subscriptions, and dispatches
19//! parsed Nautilus messages through the [`FeedHandler`].
20
21use std::{
22    num::NonZeroU32,
23    str::FromStr,
24    sync::{
25        Arc, LazyLock,
26        atomic::{AtomicBool, AtomicU8, Ordering},
27    },
28    time::Duration,
29};
30
31use arc_swap::ArcSwap;
32use nautilus_common::live::get_runtime;
33use nautilus_core::AtomicMap;
34use nautilus_model::{
35    data::BarType,
36    identifiers::{AccountId, InstrumentId},
37    instruments::{Instrument, InstrumentAny},
38};
39use nautilus_network::{
40    mode::ConnectionMode,
41    ratelimiter::quota::Quota,
42    websocket::{
43        SubscriptionState, TransportBackend, WebSocketClient, WebSocketConfig,
44        channel_message_handler,
45    },
46};
47use ustr::Ustr;
48
49use crate::{
50    common::{
51        consts::{
52            RECONNECT_BACKOFF_FACTOR, RECONNECT_BASE_BACKOFF, RECONNECT_JITTER_MS,
53            RECONNECT_MAX_BACKOFF, RECONNECT_TIMEOUT, WS_DISCONNECT_TIMEOUT, WS_HEARTBEAT_SECS,
54        },
55        credential::CoinbaseCredential,
56        enums::CoinbaseWsChannel,
57    },
58    websocket::{
59        handler::{FeedHandler, HandlerCommand, NautilusWsMessage},
60        messages::{CoinbaseWsAction, CoinbaseWsSubscription},
61    },
62};
63
64/// Coinbase WebSocket connection rate limit (8 per second per IP).
65pub static COINBASE_WS_CONNECTION_QUOTA: LazyLock<Quota> = LazyLock::new(|| {
66    Quota::per_second(NonZeroU32::new(8).expect("non-zero")).expect("valid constant")
67});
68
69/// Coinbase WebSocket subscribe/unsubscribe rate limit (8 per second per IP).
70pub static COINBASE_WS_SUBSCRIPTION_QUOTA: LazyLock<Quota> = LazyLock::new(|| {
71    Quota::per_second(NonZeroU32::new(8).expect("non-zero")).expect("valid constant")
72});
73
74/// Rate-limit key for subscribe/unsubscribe operations.
75pub const COINBASE_RATE_LIMIT_KEY_SUBSCRIPTION: &str = "subscription";
76
77/// Pre-interned [`COINBASE_RATE_LIMIT_KEY_SUBSCRIPTION`] slice.
78pub static COINBASE_WS_SUBSCRIPTION_KEYS: LazyLock<[Ustr; 1]> =
79    LazyLock::new(|| [Ustr::from(COINBASE_RATE_LIMIT_KEY_SUBSCRIPTION)]);
80
81/// WebSocket client for Coinbase Advanced Trade market data and user streams.
82///
83/// Manages connection lifecycle, subscription state, and JWT authentication.
84/// Spawns a [`FeedHandler`] task that parses raw messages into Nautilus types.
85#[derive(Debug)]
86#[cfg_attr(
87    feature = "python",
88    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.coinbase", from_py_object)
89)]
90#[cfg_attr(
91    feature = "python",
92    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.coinbase")
93)]
94pub struct CoinbaseWebSocketClient {
95    url: String,
96    connection_mode: Arc<ArcSwap<AtomicU8>>,
97    signal: Arc<AtomicBool>,
98    cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
99    out_rx: Option<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>>,
100    instruments: Arc<AtomicMap<InstrumentId, InstrumentAny>>,
101    /// Maps a canonical wire `product_id` to the `product_id` the caller
102    /// subscribed or submitted with. Coinbase rewrites aliased products to
103    /// their canonical form on the wire (e.g. `BTC-USDC -> BTC-USD`), so
104    /// inbound messages must be re-keyed to the caller's id before parsing.
105    subscription_aliases: Arc<AtomicMap<Ustr, Ustr>>,
106    bar_types: ahash::AHashMap<String, BarType>,
107    subscriptions: SubscriptionState,
108    credential: Option<CoinbaseCredential>,
109    account_id: Option<AccountId>,
110    task_handle: Option<tokio::task::JoinHandle<()>>,
111    transport_backend: TransportBackend,
112    proxy_url: Option<String>,
113}
114
115impl Clone for CoinbaseWebSocketClient {
116    fn clone(&self) -> Self {
117        Self {
118            url: self.url.clone(),
119            connection_mode: Arc::clone(&self.connection_mode),
120            signal: Arc::clone(&self.signal),
121            cmd_tx: Arc::clone(&self.cmd_tx),
122            out_rx: None,
123            instruments: Arc::clone(&self.instruments),
124            subscription_aliases: Arc::clone(&self.subscription_aliases),
125            bar_types: self.bar_types.clone(),
126            subscriptions: self.subscriptions.clone(),
127            credential: self.credential.clone(),
128            account_id: self.account_id,
129            task_handle: None,
130            transport_backend: self.transport_backend,
131            proxy_url: self.proxy_url.clone(),
132        }
133    }
134}
135
136impl CoinbaseWebSocketClient {
137    /// Creates a new [`CoinbaseWebSocketClient`] for public market data.
138    pub fn new(url: &str, transport_backend: TransportBackend, proxy_url: Option<String>) -> Self {
139        let (placeholder_tx, _) = tokio::sync::mpsc::unbounded_channel();
140
141        Self {
142            url: url.to_string(),
143            connection_mode: Arc::new(ArcSwap::from_pointee(AtomicU8::new(
144                ConnectionMode::Closed.as_u8(),
145            ))),
146            signal: Arc::new(AtomicBool::new(false)),
147            cmd_tx: Arc::new(tokio::sync::RwLock::new(placeholder_tx)),
148            out_rx: None,
149            instruments: Arc::new(AtomicMap::new()),
150            subscription_aliases: Arc::new(AtomicMap::new()),
151            bar_types: ahash::AHashMap::new(),
152            subscriptions: SubscriptionState::new('|'),
153            credential: None,
154            account_id: None,
155            task_handle: None,
156            transport_backend,
157            proxy_url,
158        }
159    }
160
161    /// Creates a new [`CoinbaseWebSocketClient`] with credentials for authenticated channels.
162    pub fn with_credential(
163        url: &str,
164        credential: CoinbaseCredential,
165        transport_backend: TransportBackend,
166        proxy_url: Option<String>,
167    ) -> Self {
168        let mut client = Self::new(url, transport_backend, proxy_url);
169        client.credential = Some(credential);
170        client
171    }
172
173    /// Sets the account ID used when emitting user-channel execution reports.
174    ///
175    /// Propagates to the feed handler when the connection is active so that
176    /// subsequent user events carry the correct account identifier.
177    pub async fn set_account_id(&mut self, account_id: AccountId) {
178        self.account_id = Some(account_id);
179
180        let cmd_tx = self.cmd_tx.read().await;
181        if let Err(e) = cmd_tx.send(HandlerCommand::SetAccountId(account_id)) {
182            log::debug!("Failed to send SetAccountId: {e}");
183        }
184    }
185
186    /// Bulk-populates the instrument cache.
187    ///
188    /// Safe to call before or after [`Self::connect`]. When called before
189    /// connect, instruments are picked up by the initial `InitializeInstruments`
190    /// command the client sends to the handler; when called after, a fresh
191    /// `InitializeInstruments` command is sent to refresh the handler's cache.
192    pub async fn initialize_instruments(&self, instruments: Vec<InstrumentAny>) {
193        for instrument in &instruments {
194            self.instruments.insert(instrument.id(), instrument.clone());
195        }
196
197        let cmd_tx = self.cmd_tx.read().await;
198        if let Err(e) = cmd_tx.send(HandlerCommand::InitializeInstruments(instruments)) {
199            log::debug!("Failed to send InitializeInstruments: {e}");
200        }
201    }
202
203    // Coinbase closes clients that idle without a subscribe inside 5s, and
204    // heartbeats keeps the connection alive when product topics are quiet.
205    // Marking before `resubscribe_all` replays it on every reconnect.
206    fn prime_default_subscriptions(&self) {
207        self.subscriptions
208            .mark_subscribe(CoinbaseWsChannel::Heartbeats.as_ref());
209    }
210
211    /// Establishes the WebSocket connection and spawns the feed handler.
212    pub async fn connect(&mut self) -> anyhow::Result<()> {
213        if self.is_active() || self.is_reconnecting() {
214            log::warn!("WebSocket already connected or reconnecting");
215            return Ok(());
216        }
217
218        // Clear stop signal from any previous disconnect
219        self.signal.store(false, Ordering::Relaxed);
220
221        let (message_handler, raw_rx) = channel_message_handler();
222        let cfg = WebSocketConfig {
223            url: self.url.clone(),
224            headers: vec![],
225            // Coinbase uses TCP control-frame pings for transport keep-alive;
226            // application-layer liveness comes from the heartbeats channel.
227            heartbeat: Some(WS_HEARTBEAT_SECS),
228            heartbeat_msg: None,
229            reconnect_timeout_ms: Some(RECONNECT_TIMEOUT.as_millis() as u64),
230            reconnect_delay_initial_ms: Some(RECONNECT_BASE_BACKOFF.as_millis() as u64),
231            reconnect_delay_max_ms: Some(RECONNECT_MAX_BACKOFF.as_millis() as u64),
232            reconnect_backoff_factor: Some(RECONNECT_BACKOFF_FACTOR),
233            reconnect_jitter_ms: Some(RECONNECT_JITTER_MS),
234            reconnect_max_attempts: None,
235            idle_timeout_ms: None,
236            backend: self.transport_backend,
237            proxy_url: self.proxy_url.clone(),
238        };
239
240        let keyed_quotas = vec![(
241            COINBASE_RATE_LIMIT_KEY_SUBSCRIPTION.to_string(),
242            *COINBASE_WS_SUBSCRIPTION_QUOTA,
243        )];
244
245        let client = WebSocketClient::connect(
246            cfg,
247            Some(message_handler),
248            None,
249            None,
250            keyed_quotas,
251            Some(*COINBASE_WS_CONNECTION_QUOTA),
252        )
253        .await?;
254
255        let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
256        let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<NautilusWsMessage>();
257
258        *self.cmd_tx.write().await = cmd_tx.clone();
259        self.out_rx = Some(out_rx);
260        self.connection_mode.store(client.connection_mode_atomic());
261        log::info!("Coinbase WebSocket connected: {}", self.url);
262
263        if let Err(e) = cmd_tx.send(HandlerCommand::SetClient(client)) {
264            anyhow::bail!("Failed to send SetClient command: {e}");
265        }
266
267        let instruments_vec: Vec<InstrumentAny> =
268            self.instruments.load().values().cloned().collect();
269
270        if !instruments_vec.is_empty()
271            && let Err(e) = cmd_tx.send(HandlerCommand::InitializeInstruments(instruments_vec))
272        {
273            log::error!("Failed to send InitializeInstruments: {e}");
274        }
275
276        // Restore bar type registrations from previous session
277        for (key, bar_type) in &self.bar_types {
278            if let Err(e) = cmd_tx.send(HandlerCommand::AddBarType {
279                key: key.clone(),
280                bar_type: *bar_type,
281            }) {
282                log::error!("Failed to restore bar type {key}: {e}");
283            }
284        }
285
286        if let Some(account_id) = self.account_id
287            && let Err(e) = cmd_tx.send(HandlerCommand::SetAccountId(account_id))
288        {
289            log::error!("Failed to restore account_id: {e}");
290        }
291
292        self.prime_default_subscriptions();
293
294        // Replay retained subscriptions from previous session
295        resubscribe_all(
296            &self.subscriptions,
297            &self.credential,
298            &cmd_tx,
299            Some(&out_tx),
300        );
301
302        let signal = Arc::clone(&self.signal);
303        let subscriptions = self.subscriptions.clone();
304        let credential = self.credential.clone();
305        let cmd_tx_reconnect = cmd_tx.clone();
306        let aliases_for_handler = Arc::clone(&self.subscription_aliases);
307
308        let stream_handle = get_runtime().spawn(async move {
309            let mut handler = FeedHandler::new(signal, cmd_rx, raw_rx, aliases_for_handler);
310
311            loop {
312                match handler.next().await {
313                    Some(NautilusWsMessage::Reconnected) => {
314                        resubscribe_all(
315                            &subscriptions,
316                            &credential,
317                            &cmd_tx_reconnect,
318                            Some(&out_tx),
319                        );
320
321                        if let Err(e) = out_tx.send(NautilusWsMessage::Reconnected) {
322                            log::debug!("Output channel closed: {e}");
323                            break;
324                        }
325                    }
326                    Some(msg) => {
327                        if let Err(e) = out_tx.send(msg) {
328                            log::debug!("Output channel closed: {e}");
329                            break;
330                        }
331                    }
332                    None => {
333                        log::info!("Feed handler stopped");
334                        break;
335                    }
336                }
337            }
338        });
339
340        self.task_handle = Some(stream_handle);
341        Ok(())
342    }
343
344    /// Subscribes to a channel for the given product IDs.
345    pub async fn subscribe(
346        &self,
347        channel: CoinbaseWsChannel,
348        product_ids: &[Ustr],
349    ) -> anyhow::Result<()> {
350        let jwt = if channel.requires_auth() {
351            let credential = self
352                .credential
353                .as_ref()
354                .ok_or_else(|| anyhow::anyhow!("Credentials required for {channel}"))?;
355            Some(credential.build_ws_jwt()?)
356        } else {
357            self.credential.as_ref().and_then(|c| c.build_ws_jwt().ok())
358        };
359
360        let sub = CoinbaseWsSubscription {
361            msg_type: CoinbaseWsAction::Subscribe,
362            product_ids: product_ids.to_vec(),
363            channel,
364            jwt,
365        };
366
367        let channel_str = channel.as_ref();
368
369        if product_ids.is_empty() {
370            self.subscriptions.mark_subscribe(channel_str);
371        } else {
372            for product_id in product_ids {
373                let topic = format!("{channel_str}|{product_id}");
374                self.subscriptions.mark_subscribe(&topic);
375            }
376        }
377
378        let cmd_tx = self.cmd_tx.read().await;
379        cmd_tx
380            .send(HandlerCommand::Subscribe(sub))
381            .map_err(|e| anyhow::anyhow!("Failed to send Subscribe command: {e}"))
382    }
383
384    /// Unsubscribes from a channel for the given product IDs.
385    pub async fn unsubscribe(
386        &self,
387        channel: CoinbaseWsChannel,
388        product_ids: &[Ustr],
389    ) -> anyhow::Result<()> {
390        let jwt = self.credential.as_ref().and_then(|c| c.build_ws_jwt().ok());
391
392        let unsub = CoinbaseWsSubscription {
393            msg_type: CoinbaseWsAction::Unsubscribe,
394            product_ids: product_ids.to_vec(),
395            channel,
396            jwt,
397        };
398
399        let channel_str = channel.as_ref();
400
401        if product_ids.is_empty() {
402            self.subscriptions.mark_unsubscribe(channel_str);
403        } else {
404            for product_id in product_ids {
405                let topic = format!("{channel_str}|{product_id}");
406                self.subscriptions.mark_unsubscribe(&topic);
407            }
408        }
409
410        let cmd_tx = self.cmd_tx.read().await;
411        cmd_tx
412            .send(HandlerCommand::Unsubscribe(unsub))
413            .map_err(|e| anyhow::anyhow!("Failed to send Unsubscribe command: {e}"))
414    }
415
416    /// Returns the next parsed message from the feed handler.
417    pub async fn next_message(&mut self) -> Option<NautilusWsMessage> {
418        self.out_rx.as_mut()?.recv().await
419    }
420
421    /// Disconnects the WebSocket and stops the feed handler.
422    pub async fn disconnect(&mut self) {
423        // Send Disconnect command before setting the signal so the handler
424        // processes it and calls notify_closed() on the inner WebSocket client
425        let cmd_tx = self.cmd_tx.read().await;
426
427        if let Err(e) = cmd_tx.send(HandlerCommand::Disconnect) {
428            log::debug!("Failed to send Disconnect command: {e}");
429        }
430        drop(cmd_tx);
431
432        // Release pairs with the handler's Acquire load; fallback for when
433        // the command channel is full or closed.
434        self.signal.store(true, Ordering::Release);
435
436        if let Some(handle) = self.task_handle.take() {
437            // Capture an abort handle before awaiting so a stuck task can be
438            // forcibly stopped on timeout instead of leaking.
439            let abort_handle = handle.abort_handle();
440            match tokio::time::timeout(WS_DISCONNECT_TIMEOUT, handle).await {
441                Ok(_) => log::debug!("Feed handler task completed"),
442                Err(_) => {
443                    log::warn!("Feed handler task did not complete within timeout, aborting");
444                    abort_handle.abort();
445                }
446            }
447        }
448
449        // Wait for the inner WebSocket's connection_mode atomic to reach Closed
450        // before returning. Without this, a subsequent connect() can observe a
451        // stale Active/Reconnect state and early-return, leaving out_rx unset
452        // and causing "WebSocket output receiver not available" on take.
453        let deadline = tokio::time::Instant::now() + Duration::from_secs(5);
454
455        loop {
456            let mode_ptr = self.connection_mode.load();
457
458            if ConnectionMode::from_u8(mode_ptr.load(Ordering::Relaxed)).is_closed() {
459                break;
460            }
461
462            if tokio::time::Instant::now() >= deadline {
463                log::warn!("Timed out waiting for WebSocket to reach Closed state");
464                break;
465            }
466
467            tokio::time::sleep(Duration::from_millis(20)).await;
468        }
469    }
470
471    /// Returns true if the WebSocket connection is active.
472    #[must_use]
473    pub fn is_active(&self) -> bool {
474        let mode_ptr = self.connection_mode.load();
475        let mode_val = mode_ptr.load(Ordering::Relaxed);
476        ConnectionMode::from_u8(mode_val).is_active()
477    }
478
479    /// Returns true if the WebSocket is reconnecting after a transport drop.
480    #[must_use]
481    pub fn is_reconnecting(&self) -> bool {
482        let mode_ptr = self.connection_mode.load();
483        let mode_val = mode_ptr.load(Ordering::Relaxed);
484        ConnectionMode::from_u8(mode_val).is_reconnect()
485    }
486
487    /// Returns a reference to the instrument cache.
488    #[must_use]
489    pub fn instruments(&self) -> &Arc<AtomicMap<InstrumentId, InstrumentAny>> {
490        &self.instruments
491    }
492
493    /// Returns a reference to the canonical-to-subscribed alias map.
494    #[must_use]
495    pub fn subscription_aliases(&self) -> &Arc<AtomicMap<Ustr, Ustr>> {
496        &self.subscription_aliases
497    }
498
499    /// Records that inbound messages carrying `canonical` should be re-keyed to
500    /// `subscribed`. Caller is the data/exec client at subscribe or submit time
501    /// when the local product id differs from Coinbase's canonical alias.
502    pub fn register_subscription_alias(&self, canonical: Ustr, subscribed: Ustr) {
503        self.subscription_aliases.insert(canonical, subscribed);
504    }
505
506    /// Removes an alias registration. Safe to call if no entry exists.
507    pub fn unregister_subscription_alias(&self, canonical: &Ustr) {
508        self.subscription_aliases.remove(canonical);
509    }
510
511    /// Returns the subscription state.
512    #[must_use]
513    pub fn subscriptions(&self) -> &SubscriptionState {
514        &self.subscriptions
515    }
516
517    /// Updates an instrument in the cache and notifies the handler.
518    pub async fn update_instrument(&self, instrument: InstrumentAny) {
519        let id = instrument.id();
520        self.instruments.insert(id, instrument.clone());
521
522        let cmd_tx = self.cmd_tx.read().await;
523
524        if let Err(e) = cmd_tx.send(HandlerCommand::UpdateInstrument(Box::new(instrument))) {
525            log::debug!("Failed to send UpdateInstrument: {e}");
526        }
527    }
528
529    /// Takes the output message receiver, leaving `None` in its place.
530    ///
531    /// Used by the data client to move the receiver into a background consumption task.
532    pub fn take_out_rx(
533        &mut self,
534    ) -> Option<tokio::sync::mpsc::UnboundedReceiver<NautilusWsMessage>> {
535        self.out_rx.take()
536    }
537
538    /// Registers a bar type locally without notifying the handler.
539    ///
540    /// Used by the data client to persist registrations on the original client
541    /// before cloning for async command dispatch.
542    pub fn register_bar_type(&mut self, key: String, bar_type: BarType) {
543        self.bar_types.insert(key, bar_type);
544    }
545
546    /// Registers a bar type for candle parsing.
547    pub async fn add_bar_type(&mut self, key: String, bar_type: BarType) {
548        self.bar_types.insert(key.clone(), bar_type);
549
550        let cmd_tx = self.cmd_tx.read().await;
551
552        if let Err(e) = cmd_tx.send(HandlerCommand::AddBarType { key, bar_type }) {
553            log::debug!("Failed to send AddBarType: {e}");
554        }
555    }
556}
557
558fn resubscribe_all(
559    subscriptions: &SubscriptionState,
560    credential: &Option<CoinbaseCredential>,
561    cmd_tx: &tokio::sync::mpsc::UnboundedSender<HandlerCommand>,
562    out_tx: Option<&tokio::sync::mpsc::UnboundedSender<NautilusWsMessage>>,
563) {
564    let topics = subscriptions.all_topics();
565
566    if topics.is_empty() {
567        log::debug!("No active subscriptions to restore");
568        return;
569    }
570
571    log::info!(
572        "Resubscribing to {} topics after reconnection",
573        topics.len()
574    );
575
576    for topic in topics {
577        let (channel, product_id) = match topic.split_once('|') {
578            Some((ch, pid)) => (ch, Some(pid)),
579            None => (topic.as_str(), None),
580        };
581
582        let channel_enum = match CoinbaseWsChannel::from_str(channel) {
583            Ok(ch) => ch,
584            Err(_) => {
585                log::warn!("Unknown channel in topic: {topic}");
586                continue;
587            }
588        };
589
590        let jwt = match credential.as_ref() {
591            Some(c) => match c.build_ws_jwt() {
592                Ok(token) => Some(token),
593                Err(e) => {
594                    if channel_enum.requires_auth() {
595                        let msg = format!(
596                            "JWT required for {channel} but build failed: {e}; topic {topic} not restored"
597                        );
598                        log::error!("{msg}");
599                        if let Some(tx) = out_tx {
600                            let _ = tx.send(NautilusWsMessage::Error(msg));
601                        }
602                        continue;
603                    }
604                    None
605                }
606            },
607            None => {
608                if channel_enum.requires_auth() {
609                    let msg = format!(
610                        "JWT required for {channel} but no credentials configured; topic {topic} not restored"
611                    );
612                    log::error!("{msg}");
613                    if let Some(tx) = out_tx {
614                        let _ = tx.send(NautilusWsMessage::Error(msg));
615                    }
616                    continue;
617                }
618                None
619            }
620        };
621
622        let product_ids = match product_id {
623            Some(pid) => vec![Ustr::from(pid)],
624            None => vec![],
625        };
626
627        let sub = CoinbaseWsSubscription {
628            msg_type: CoinbaseWsAction::Subscribe,
629            product_ids,
630            channel: channel_enum,
631            jwt,
632        };
633
634        if let Err(e) = cmd_tx.send(HandlerCommand::Subscribe(sub)) {
635            log::error!("Failed to resubscribe {topic}: {e}");
636        }
637    }
638}
639
640#[cfg(test)]
641mod tests {
642    use nautilus_network::websocket::SubscriptionState;
643    use rstest::rstest;
644
645    use super::*;
646
647    #[rstest]
648    fn test_resubscribe_all_product_level_topic() {
649        let subs = SubscriptionState::new('|');
650        subs.mark_subscribe("level2|BTC-USD");
651
652        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
653        resubscribe_all(&subs, &None, &tx, None);
654
655        let cmd = rx.try_recv().unwrap();
656
657        match cmd {
658            HandlerCommand::Subscribe(sub) => {
659                assert_eq!(sub.channel, CoinbaseWsChannel::Level2);
660                assert_eq!(sub.product_ids.len(), 1);
661                assert_eq!(sub.product_ids[0], "BTC-USD");
662                assert!(sub.jwt.is_none());
663            }
664            other => panic!("Expected Subscribe, was {other:?}"),
665        }
666    }
667
668    #[rstest]
669    fn test_resubscribe_all_channel_level_topic() {
670        let subs = SubscriptionState::new('|');
671        subs.mark_subscribe("heartbeats");
672
673        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
674        resubscribe_all(&subs, &None, &tx, None);
675
676        let cmd = rx.try_recv().unwrap();
677
678        match cmd {
679            HandlerCommand::Subscribe(sub) => {
680                assert_eq!(sub.channel, CoinbaseWsChannel::Heartbeats);
681                assert!(sub.product_ids.is_empty());
682            }
683            other => panic!("Expected Subscribe, was {other:?}"),
684        }
685    }
686
687    #[rstest]
688    fn test_resubscribe_all_multiple_topics() {
689        let subs = SubscriptionState::new('|');
690        subs.mark_subscribe("market_trades|BTC-USD");
691        subs.mark_subscribe("ticker|ETH-USD");
692
693        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
694        resubscribe_all(&subs, &None, &tx, None);
695
696        let cmd1 = rx.try_recv().unwrap();
697        let cmd2 = rx.try_recv().unwrap();
698
699        assert!(matches!(cmd1, HandlerCommand::Subscribe(_)));
700        assert!(matches!(cmd2, HandlerCommand::Subscribe(_)));
701        assert!(rx.try_recv().is_err());
702    }
703
704    #[rstest]
705    fn test_resubscribe_all_empty_subscriptions() {
706        let subs = SubscriptionState::new('|');
707
708        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
709        resubscribe_all(&subs, &None, &tx, None);
710
711        assert!(rx.try_recv().is_err());
712    }
713
714    #[rstest]
715    fn test_resubscribe_all_unknown_channel_skipped() {
716        let subs = SubscriptionState::new('|');
717        subs.mark_subscribe("nonexistent_channel|BTC-USD");
718
719        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
720        resubscribe_all(&subs, &None, &tx, None);
721
722        assert!(rx.try_recv().is_err());
723    }
724
725    #[rstest]
726    #[case("level2|BTC-USD", CoinbaseWsChannel::Level2)]
727    #[case("market_trades|ETH-USD", CoinbaseWsChannel::MarketTrades)]
728    #[case("ticker|BTC-USD", CoinbaseWsChannel::Ticker)]
729    #[case("ticker_batch|BTC-USD", CoinbaseWsChannel::TickerBatch)]
730    #[case("candles|BTC-USD", CoinbaseWsChannel::Candles)]
731    #[case("heartbeats", CoinbaseWsChannel::Heartbeats)]
732    #[case("status", CoinbaseWsChannel::Status)]
733    fn test_resubscribe_all_channel_mapping(
734        #[case] topic: &str,
735        #[case] expected_channel: CoinbaseWsChannel,
736    ) {
737        let subs = SubscriptionState::new('|');
738        subs.mark_subscribe(topic);
739
740        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
741        resubscribe_all(&subs, &None, &tx, None);
742
743        let cmd = rx.try_recv().unwrap();
744
745        match cmd {
746            HandlerCommand::Subscribe(sub) => {
747                assert_eq!(sub.channel, expected_channel);
748            }
749            other => panic!("Expected Subscribe, was {other:?}"),
750        }
751    }
752
753    #[rstest]
754    #[case("user|BTC-USD")]
755    #[case("futures_balance_summary")]
756    fn test_resubscribe_all_auth_channel_skipped_without_credentials(#[case] topic: &str) {
757        let subs = SubscriptionState::new('|');
758        subs.mark_subscribe(topic);
759
760        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
761        resubscribe_all(&subs, &None, &tx, None);
762
763        // Auth channels should be skipped when no credentials are provided
764        assert!(rx.try_recv().is_err());
765    }
766
767    #[rstest]
768    #[case("user|BTC-USD", "user")]
769    #[case("futures_balance_summary", "futures_balance_summary")]
770    fn test_resubscribe_all_emits_error_for_auth_channel_without_credentials(
771        #[case] topic: &str,
772        #[case] channel: &str,
773    ) {
774        let subs = SubscriptionState::new('|');
775        subs.mark_subscribe(topic);
776
777        let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel();
778        let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
779        resubscribe_all(&subs, &None, &cmd_tx, Some(&out_tx));
780
781        // No subscribe command should be sent for an unauthenticated auth channel.
782        assert!(cmd_rx.try_recv().is_err());
783
784        let msg = out_rx
785            .try_recv()
786            .expect("Error event must be emitted when auth channel cannot resubscribe");
787        match msg {
788            NautilusWsMessage::Error(text) => {
789                assert!(
790                    text.contains(channel),
791                    "error must mention the channel, was: {text}"
792                );
793                assert!(
794                    text.contains(topic),
795                    "error must mention the topic, was: {text}"
796                );
797            }
798            other => panic!("expected Error variant, was {other:?}"),
799        }
800    }
801
802    #[rstest]
803    fn test_resubscribe_all_emits_error_when_jwt_build_fails() {
804        let subs = SubscriptionState::new('|');
805        let topic = "user|BTC-USD";
806        subs.mark_subscribe(topic);
807
808        // A credential with a malformed PEM secret causes build_ws_jwt() to fail
809        // every time, exercising the JWT-build error branch.
810        let bad_credential = Some(CoinbaseCredential::new(
811            "organizations/test/apiKeys/test".to_string(),
812            "not-a-pem-key".to_string(),
813        ));
814
815        let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel();
816        let (out_tx, mut out_rx) = tokio::sync::mpsc::unbounded_channel();
817        resubscribe_all(&subs, &bad_credential, &cmd_tx, Some(&out_tx));
818
819        assert!(cmd_rx.try_recv().is_err(), "no subscribe should be sent");
820        let msg = out_rx
821            .try_recv()
822            .expect("Error event must be emitted when JWT build fails for an auth channel");
823        match msg {
824            NautilusWsMessage::Error(text) => {
825                assert!(text.contains("user"), "error must mention channel: {text}");
826                assert!(text.contains(topic), "error must mention topic: {text}");
827            }
828            other => panic!("expected Error variant, was {other:?}"),
829        }
830    }
831
832    #[rstest]
833    fn test_prime_default_subscriptions_marks_heartbeats() {
834        let client = CoinbaseWebSocketClient::new("wss://test", TransportBackend::default(), None);
835        assert!(client.subscriptions.all_topics().is_empty());
836
837        client.prime_default_subscriptions();
838
839        let topics = client.subscriptions.all_topics();
840        assert!(topics.iter().any(|t| t == "heartbeats"), "{topics:?}");
841    }
842
843    #[rstest]
844    fn test_ws_quotas_match_documented_limits() {
845        assert_eq!(COINBASE_WS_CONNECTION_QUOTA.burst_size().get(), 8);
846        assert_eq!(COINBASE_WS_SUBSCRIPTION_QUOTA.burst_size().get(), 8);
847    }
848
849    #[rstest]
850    fn test_ws_subscription_rate_limit_key_is_stable() {
851        assert_eq!(COINBASE_RATE_LIMIT_KEY_SUBSCRIPTION, "subscription");
852        assert_eq!(
853            COINBASE_WS_SUBSCRIPTION_KEYS[0].as_str(),
854            COINBASE_RATE_LIMIT_KEY_SUBSCRIPTION,
855        );
856    }
857}