Skip to main content

nautilus_kraken/websocket/spot_v2/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! WebSocket client for the Kraken v2 streaming API.
17
18use std::{
19    collections::HashMap,
20    sync::{
21        Arc, RwLock,
22        atomic::{AtomicBool, AtomicU8, Ordering},
23    },
24};
25
26use arc_swap::ArcSwap;
27use nautilus_common::live::get_runtime;
28use nautilus_core::AtomicMap;
29use nautilus_model::{
30    data::BarType,
31    enums::BarAggregation,
32    identifiers::{AccountId, ClientOrderId, InstrumentId, StrategyId, TraderId, VenueOrderId},
33    instruments::{Instrument, InstrumentAny},
34};
35use nautilus_network::{
36    mode::ConnectionMode,
37    websocket::{
38        AuthTracker, SubscriptionState, TransportBackend, WebSocketClient, WebSocketConfig,
39        channel_message_handler,
40    },
41};
42use tokio_util::sync::CancellationToken;
43use ustr::Ustr;
44
45/// Topic delimiter for Kraken Spot v2 WebSocket subscriptions.
46///
47/// Topics use colon format: `channel:symbol` (e.g., `Trade:ETH/USD`).
48pub const KRAKEN_SPOT_WS_TOPIC_DELIMITER: char = ':';
49
50use super::{
51    enums::{KrakenWsChannel, KrakenWsMethod},
52    handler::{SpotFeedHandler, SpotHandlerCommand},
53    messages::{KrakenSpotWsMessage, KrakenWsParams, KrakenWsRequest},
54};
55use crate::{
56    common::parse::normalize_spot_symbol,
57    config::KrakenDataClientConfig,
58    http::{KrakenSpotHttpClient, spot::client::KRAKEN_SPOT_DEFAULT_RATE_LIMIT_PER_SECOND},
59    websocket::error::KrakenWsError,
60};
61
62const WS_PING_MSG: &str = r#"{"method":"ping"}"#;
63
64/// WebSocket client for the Kraken Spot v2 streaming API.
65#[derive(Debug)]
66#[cfg_attr(
67    feature = "python",
68    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.kraken", from_py_object)
69)]
70#[cfg_attr(
71    feature = "python",
72    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.kraken")
73)]
74pub struct KrakenSpotWebSocketClient {
75    url: String,
76    config: KrakenDataClientConfig,
77    signal: Arc<AtomicBool>,
78    connection_mode: Arc<ArcSwap<AtomicU8>>,
79    cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<SpotHandlerCommand>>>,
80    out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<KrakenSpotWsMessage>>>,
81    task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
82    subscriptions: SubscriptionState,
83    subscription_payloads: Arc<tokio::sync::RwLock<HashMap<String, String>>>,
84    auth_tracker: AuthTracker,
85    cancellation_token: CancellationToken,
86    req_id_counter: Arc<tokio::sync::RwLock<u64>>,
87    auth_token: Arc<tokio::sync::RwLock<Option<String>>>,
88    account_id: Arc<RwLock<Option<AccountId>>>,
89    truncated_id_map: Arc<AtomicMap<String, ClientOrderId>>,
90    instruments: Arc<AtomicMap<InstrumentId, InstrumentAny>>,
91    transport_backend: TransportBackend,
92    proxy_url: Option<String>,
93}
94
95impl Clone for KrakenSpotWebSocketClient {
96    fn clone(&self) -> Self {
97        Self {
98            url: self.url.clone(),
99            config: self.config.clone(),
100            signal: Arc::clone(&self.signal),
101            connection_mode: Arc::clone(&self.connection_mode),
102            cmd_tx: Arc::clone(&self.cmd_tx),
103            out_rx: self.out_rx.clone(),
104            task_handle: self.task_handle.clone(),
105            subscriptions: self.subscriptions.clone(),
106            subscription_payloads: Arc::clone(&self.subscription_payloads),
107            auth_tracker: self.auth_tracker.clone(),
108            cancellation_token: self.cancellation_token.clone(),
109            req_id_counter: self.req_id_counter.clone(),
110            auth_token: self.auth_token.clone(),
111            account_id: Arc::clone(&self.account_id),
112            truncated_id_map: Arc::clone(&self.truncated_id_map),
113            instruments: Arc::clone(&self.instruments),
114            transport_backend: self.transport_backend,
115            proxy_url: self.proxy_url.clone(),
116        }
117    }
118}
119
120impl KrakenSpotWebSocketClient {
121    /// Creates a new client with the given configuration.
122    pub fn new(
123        mut config: KrakenDataClientConfig,
124        cancellation_token: CancellationToken,
125        proxy_url: Option<String>,
126    ) -> Self {
127        // Prefer private URL if explicitly set (for authenticated endpoints)
128        let url = if config.ws_private_url.is_some() {
129            config.ws_private_url()
130        } else {
131            config.ws_public_url()
132        };
133        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<SpotHandlerCommand>();
134        let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
135        let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
136
137        let transport_backend = config.transport_backend;
138
139        // Keep the config's proxy_url in sync with the constructor argument so
140        // refresh_auth_token() (which reads config.proxy_url) goes through the
141        // same proxy as the WebSocket connection.
142        config.proxy_url = proxy_url.clone();
143
144        Self {
145            url,
146            config,
147            signal: Arc::new(AtomicBool::new(false)),
148            connection_mode,
149            cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
150            out_rx: None,
151            task_handle: None,
152            subscriptions: SubscriptionState::new(KRAKEN_SPOT_WS_TOPIC_DELIMITER),
153            subscription_payloads: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
154            auth_tracker: AuthTracker::new(),
155            cancellation_token,
156            req_id_counter: Arc::new(tokio::sync::RwLock::new(0)),
157            auth_token: Arc::new(tokio::sync::RwLock::new(None)),
158            account_id: Arc::new(RwLock::new(None)),
159            truncated_id_map: Arc::new(AtomicMap::new()),
160            instruments: Arc::new(AtomicMap::new()),
161            transport_backend,
162            proxy_url,
163        }
164    }
165
166    async fn get_next_req_id(&self) -> u64 {
167        let mut counter = self.req_id_counter.write().await;
168        *counter += 1;
169        *counter
170    }
171
172    /// Connects to the WebSocket server.
173    pub async fn connect(&mut self) -> Result<(), KrakenWsError> {
174        log::debug!("Connecting to {}", self.url);
175
176        self.signal.store(false, Ordering::Relaxed);
177
178        let (raw_handler, raw_rx) = channel_message_handler();
179
180        let ws_config = WebSocketConfig {
181            url: self.url.clone(),
182            headers: vec![],
183            heartbeat: Some(self.config.heartbeat_interval_secs),
184            heartbeat_msg: Some(WS_PING_MSG.to_string()),
185            reconnect_timeout_ms: Some(5_000),
186            reconnect_delay_initial_ms: Some(500),
187            reconnect_delay_max_ms: Some(5_000),
188            reconnect_backoff_factor: Some(1.5),
189            reconnect_jitter_ms: Some(250),
190            reconnect_max_attempts: None,
191            idle_timeout_ms: None,
192            backend: self.transport_backend,
193            proxy_url: self.proxy_url.clone(),
194        };
195
196        let ws_client = WebSocketClient::connect(
197            ws_config,
198            Some(raw_handler),
199            None,   // ping_handler
200            None,   // post_reconnection
201            vec![], // keyed_quotas
202            None,   // default_quota
203        )
204        .await
205        .map_err(|e| KrakenWsError::ConnectionError(e.to_string()))?;
206
207        // Share connection state across clones via ArcSwap
208        self.connection_mode
209            .store(ws_client.connection_mode_atomic());
210
211        let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<KrakenSpotWsMessage>();
212        self.out_rx = Some(Arc::new(out_rx));
213
214        let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<SpotHandlerCommand>();
215        *self.cmd_tx.write().await = cmd_tx.clone();
216
217        if let Err(e) = cmd_tx.send(SpotHandlerCommand::SetClient(ws_client)) {
218            return Err(KrakenWsError::ConnectionError(format!(
219                "Failed to send WebSocketClient to handler: {e}"
220            )));
221        }
222
223        let signal = self.signal.clone();
224        let subscriptions = self.subscriptions.clone();
225        let subscription_payloads = self.subscription_payloads.clone();
226        let config_for_reconnect = self.config.clone();
227        let auth_token_for_reconnect = self.auth_token.clone();
228        let auth_tracker_for_reconnect = self.auth_tracker.clone();
229        let cmd_tx_for_reconnect = cmd_tx.clone();
230
231        let stream_handle = get_runtime().spawn(async move {
232            let mut handler =
233                SpotFeedHandler::new(signal.clone(), cmd_rx, raw_rx, subscriptions.clone());
234
235            loop {
236                match handler.next().await {
237                    Some(KrakenSpotWsMessage::Reconnected) => {
238                        if signal.load(Ordering::Relaxed) {
239                            continue;
240                        }
241                        log::info!("WebSocket reconnected, resubscribing");
242
243                        let confirmed_topics = subscriptions.all_topics();
244                        for topic in &confirmed_topics {
245                            subscriptions.mark_failure(topic);
246                        }
247
248                        let payloads = subscription_payloads.read().await;
249                        if payloads.is_empty() {
250                            log::debug!("No subscriptions to restore after reconnection");
251                        } else {
252                            let had_auth = auth_token_for_reconnect.read().await.is_some();
253
254                            if had_auth && config_for_reconnect.has_api_credentials() {
255                                log::debug!("Re-authenticating after reconnect");
256
257                                auth_tracker_for_reconnect.invalidate();
258                                let _rx = auth_tracker_for_reconnect.begin();
259
260                                match refresh_auth_token(&config_for_reconnect).await {
261                                    Ok(new_token) => {
262                                        *auth_token_for_reconnect.write().await = Some(new_token);
263                                        auth_tracker_for_reconnect.succeed();
264                                        log::debug!("Re-authentication successful");
265                                    }
266                                    Err(e) => {
267                                        log::error!(
268                                            "Failed to re-authenticate after reconnect: {e}"
269                                        );
270                                        *auth_token_for_reconnect.write().await = None;
271                                        auth_tracker_for_reconnect.fail(e.to_string());
272                                    }
273                                }
274                            }
275
276                            log::info!(
277                                "Resubscribing after reconnection: count={}",
278                                payloads.len()
279                            );
280
281                            for (topic, payload) in payloads.iter() {
282                                let payload = if topic == "executions" {
283                                    let auth_token = auth_token_for_reconnect.read().await.clone();
284                                    match auth_token {
285                                        Some(token) => {
286                                            match update_auth_token_in_payload(payload, &token) {
287                                                Ok(p) => p,
288                                                Err(e) => {
289                                                    log::error!("Failed to update auth token: {e}");
290                                                    continue;
291                                                }
292                                            }
293                                        }
294                                        None => {
295                                            log::warn!(
296                                                "Cannot resubscribe to executions: no auth token"
297                                            );
298                                            continue;
299                                        }
300                                    }
301                                } else {
302                                    payload.clone()
303                                };
304
305                                if let Err(e) = cmd_tx_for_reconnect
306                                    .send(SpotHandlerCommand::Subscribe { payload })
307                                {
308                                    log::error!(
309                                        "Failed to send resubscribe command: error={e}, \
310                                        topic={topic}"
311                                    );
312                                }
313
314                                subscriptions.mark_subscribe(topic);
315                            }
316                        }
317
318                        if out_tx.send(KrakenSpotWsMessage::Reconnected).is_err() {
319                            log::error!("Failed to send message (receiver dropped)");
320                            break;
321                        }
322                    }
323                    Some(msg) => {
324                        if out_tx.send(msg).is_err() {
325                            log::error!("Failed to send message (receiver dropped)");
326                            break;
327                        }
328                    }
329                    None => {
330                        if handler.is_stopped() {
331                            log::debug!("Stop signal received, ending message processing");
332                            break;
333                        }
334                        log::warn!("WebSocket stream ended unexpectedly");
335                        break;
336                    }
337                }
338            }
339
340            log::debug!("Handler task exiting");
341        });
342
343        self.task_handle = Some(Arc::new(stream_handle));
344
345        log::debug!("WebSocket connected successfully");
346        Ok(())
347    }
348
349    /// Disconnects from the WebSocket server.
350    pub async fn disconnect(&mut self) -> Result<(), KrakenWsError> {
351        log::debug!("Disconnecting WebSocket");
352
353        self.signal.store(true, Ordering::Relaxed);
354
355        if let Err(e) = self
356            .cmd_tx
357            .read()
358            .await
359            .send(SpotHandlerCommand::Disconnect)
360        {
361            log::debug!(
362                "Failed to send disconnect command (handler may already be shut down): {e}"
363            );
364        }
365
366        if let Some(task_handle) = self.task_handle.take() {
367            match Arc::try_unwrap(task_handle) {
368                Ok(handle) => {
369                    log::debug!("Waiting for task handle to complete");
370                    match tokio::time::timeout(tokio::time::Duration::from_secs(2), handle).await {
371                        Ok(Ok(())) => log::debug!("Task handle completed successfully"),
372                        Ok(Err(e)) => log::error!("Task handle encountered an error: {e:?}"),
373                        Err(_) => {
374                            log::warn!(
375                                "Timeout waiting for task handle, task may still be running"
376                            );
377                        }
378                    }
379                }
380                Err(arc_handle) => {
381                    log::debug!(
382                        "Cannot take ownership of task handle - other references exist, aborting task"
383                    );
384                    arc_handle.abort();
385                }
386            }
387        } else {
388            log::debug!("No task handle to await");
389        }
390
391        self.subscriptions.clear();
392        self.subscription_payloads.write().await.clear();
393        self.auth_tracker.fail("Disconnected");
394
395        Ok(())
396    }
397
398    /// Closes the WebSocket connection.
399    pub async fn close(&mut self) -> Result<(), KrakenWsError> {
400        self.disconnect().await
401    }
402
403    /// Waits until the connection is active or timeout.
404    pub async fn wait_until_active(&self, timeout_secs: f64) -> Result<(), KrakenWsError> {
405        let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
406
407        tokio::time::timeout(timeout, async {
408            while !self.is_active() {
409                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
410            }
411        })
412        .await
413        .map_err(|_| {
414            KrakenWsError::ConnectionError(format!(
415                "WebSocket connection timeout after {timeout_secs} seconds"
416            ))
417        })?;
418
419        Ok(())
420    }
421
422    /// Returns true if the WebSocket is authenticated for private subscriptions.
423    #[must_use]
424    pub fn is_authenticated(&self) -> bool {
425        self.auth_tracker.is_authenticated()
426    }
427
428    /// Waits until the WebSocket is authenticated or the timeout elapses.
429    ///
430    /// Returns an error on timeout or explicit auth failure.
431    pub async fn wait_until_authenticated(&self, timeout_secs: f64) -> Result<(), KrakenWsError> {
432        let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
433
434        if self.auth_tracker.wait_for_authenticated(timeout).await {
435            Ok(())
436        } else {
437            Err(KrakenWsError::AuthenticationError(format!(
438                "Authentication not completed within {timeout_secs} seconds"
439            )))
440        }
441    }
442
443    /// Authenticates with the Kraken API to enable private subscriptions.
444    pub async fn authenticate(&self) -> Result<(), KrakenWsError> {
445        if !self.config.has_api_credentials() {
446            return Err(KrakenWsError::AuthenticationError(
447                "API credentials required for authentication".to_string(),
448            ));
449        }
450
451        let _receiver = self.auth_tracker.begin();
452
453        match refresh_auth_token(&self.config).await {
454            Ok(token) => {
455                *self.auth_token.write().await = Some(token);
456                self.auth_tracker.succeed();
457                Ok(())
458            }
459            Err(e) => {
460                *self.auth_token.write().await = None;
461                self.auth_tracker.fail(e.to_string());
462                Err(e)
463            }
464        }
465    }
466
467    /// Cancels all pending requests.
468    pub fn cancel_all_requests(&self) {
469        self.cancellation_token.cancel();
470    }
471
472    /// Returns the cancellation token for this client.
473    pub fn cancellation_token(&self) -> &CancellationToken {
474        &self.cancellation_token
475    }
476
477    /// Subscribes to a channel for the given symbols.
478    pub async fn subscribe(
479        &self,
480        channel: KrakenWsChannel,
481        symbols: Vec<Ustr>,
482        depth: Option<u32>,
483    ) -> Result<(), KrakenWsError> {
484        let mut symbols_to_subscribe = Vec::new();
485        let channel_str = channel.as_ref();
486        for symbol in &symbols {
487            let key = format!("{channel_str}:{symbol}");
488            if self.subscriptions.add_reference(&key) {
489                self.subscriptions.mark_subscribe(&key);
490                symbols_to_subscribe.push(*symbol);
491            }
492        }
493
494        if symbols_to_subscribe.is_empty() {
495            return Ok(());
496        }
497
498        let is_private = matches!(
499            channel,
500            KrakenWsChannel::Executions | KrakenWsChannel::Balances
501        );
502        let token = if is_private {
503            Some(self.auth_token.read().await.clone().ok_or_else(|| {
504                KrakenWsError::AuthenticationError(
505                    "Authentication token required for private channels. Call authenticate() first"
506                        .to_string(),
507                )
508            })?)
509        } else {
510            None
511        };
512
513        let req_id = self.get_next_req_id().await;
514        let request = KrakenWsRequest {
515            method: KrakenWsMethod::Subscribe,
516            params: Some(KrakenWsParams {
517                channel,
518                symbol: Some(symbols_to_subscribe.clone()),
519                snapshot: None,
520                depth,
521                interval: None,
522                event_trigger: None,
523                token,
524                snap_orders: None,
525                snap_trades: None,
526            }),
527            req_id: Some(req_id),
528        };
529
530        let payload = self.send_command(&request).await?;
531
532        for symbol in &symbols_to_subscribe {
533            let key = format!("{channel_str}:{symbol}");
534            self.subscriptions.confirm_subscribe(&key);
535            self.subscription_payloads
536                .write()
537                .await
538                .insert(key, payload.clone());
539        }
540
541        Ok(())
542    }
543
544    /// Subscribes to a channel with a specific interval (for OHLC).
545    async fn subscribe_with_interval(
546        &self,
547        channel: KrakenWsChannel,
548        symbols: Vec<Ustr>,
549        interval: u32,
550    ) -> Result<(), KrakenWsError> {
551        let mut symbols_to_subscribe = Vec::new();
552        let channel_str = channel.as_ref();
553        for symbol in &symbols {
554            let key = format!("{channel_str}:{symbol}:{interval}");
555            if self.subscriptions.add_reference(&key) {
556                self.subscriptions.mark_subscribe(&key);
557                symbols_to_subscribe.push(*symbol);
558            }
559        }
560
561        if symbols_to_subscribe.is_empty() {
562            return Ok(());
563        }
564
565        let req_id = self.get_next_req_id().await;
566        let request = KrakenWsRequest {
567            method: KrakenWsMethod::Subscribe,
568            params: Some(KrakenWsParams {
569                channel,
570                symbol: Some(symbols_to_subscribe.clone()),
571                snapshot: Some(false),
572                depth: None,
573                interval: Some(interval),
574                event_trigger: None,
575                token: None,
576                snap_orders: None,
577                snap_trades: None,
578            }),
579            req_id: Some(req_id),
580        };
581
582        let payload = self.send_command(&request).await?;
583
584        for symbol in &symbols_to_subscribe {
585            let key = format!("{channel_str}:{symbol}:{interval}");
586            self.subscriptions.confirm_subscribe(&key);
587            self.subscription_payloads
588                .write()
589                .await
590                .insert(key, payload.clone());
591        }
592
593        Ok(())
594    }
595
596    /// Unsubscribes from a channel with a specific interval (for OHLC).
597    async fn unsubscribe_with_interval(
598        &self,
599        channel: KrakenWsChannel,
600        symbols: Vec<Ustr>,
601        interval: u32,
602    ) -> Result<(), KrakenWsError> {
603        let mut symbols_to_unsubscribe = Vec::new();
604        let channel_str = channel.as_ref();
605        for symbol in &symbols {
606            let key = format!("{channel_str}:{symbol}:{interval}");
607            if self.subscriptions.remove_reference(&key) {
608                self.subscriptions.mark_unsubscribe(&key);
609                symbols_to_unsubscribe.push(*symbol);
610            }
611        }
612
613        if symbols_to_unsubscribe.is_empty() {
614            return Ok(());
615        }
616
617        let req_id = self.get_next_req_id().await;
618        let request = KrakenWsRequest {
619            method: KrakenWsMethod::Unsubscribe,
620            params: Some(KrakenWsParams {
621                channel,
622                symbol: Some(symbols_to_unsubscribe.clone()),
623                snapshot: None,
624                depth: None,
625                interval: Some(interval),
626                event_trigger: None,
627                token: None,
628                snap_orders: None,
629                snap_trades: None,
630            }),
631            req_id: Some(req_id),
632        };
633
634        self.send_command(&request).await?;
635
636        for symbol in &symbols_to_unsubscribe {
637            let key = format!("{channel_str}:{symbol}:{interval}");
638            self.subscriptions.confirm_unsubscribe(&key);
639            self.subscription_payloads.write().await.remove(&key);
640        }
641
642        Ok(())
643    }
644
645    /// Unsubscribes from a channel for the given symbols.
646    pub async fn unsubscribe(
647        &self,
648        channel: KrakenWsChannel,
649        symbols: Vec<Ustr>,
650    ) -> Result<(), KrakenWsError> {
651        let mut symbols_to_unsubscribe = Vec::new();
652        let channel_str = channel.as_ref();
653        for symbol in &symbols {
654            let key = format!("{channel_str}:{symbol}");
655            if self.subscriptions.remove_reference(&key) {
656                self.subscriptions.mark_unsubscribe(&key);
657                symbols_to_unsubscribe.push(*symbol);
658            } else {
659                log::debug!(
660                    "Channel {channel_str} symbol {symbol} still has active subscriptions, not unsubscribing"
661                );
662            }
663        }
664
665        if symbols_to_unsubscribe.is_empty() {
666            return Ok(());
667        }
668
669        let is_private = matches!(
670            channel,
671            KrakenWsChannel::Executions | KrakenWsChannel::Balances
672        );
673        let token = if is_private {
674            Some(self.auth_token.read().await.clone().ok_or_else(|| {
675                KrakenWsError::AuthenticationError(
676                    "Authentication token required for private channels. Call authenticate() first"
677                        .to_string(),
678                )
679            })?)
680        } else {
681            None
682        };
683
684        let req_id = self.get_next_req_id().await;
685        let request = KrakenWsRequest {
686            method: KrakenWsMethod::Unsubscribe,
687            params: Some(KrakenWsParams {
688                channel,
689                symbol: Some(symbols_to_unsubscribe.clone()),
690                snapshot: None,
691                depth: None,
692                interval: None,
693                event_trigger: None,
694                token,
695                snap_orders: None,
696                snap_trades: None,
697            }),
698            req_id: Some(req_id),
699        };
700
701        self.send_command(&request).await?;
702
703        for symbol in &symbols_to_unsubscribe {
704            let key = format!("{channel_str}:{symbol}");
705            self.subscriptions.confirm_unsubscribe(&key);
706            self.subscription_payloads.write().await.remove(&key);
707        }
708
709        Ok(())
710    }
711
712    /// Sends a ping message to keep the connection alive.
713    pub async fn send_ping(&self) -> Result<(), KrakenWsError> {
714        let req_id = self.get_next_req_id().await;
715
716        let request = KrakenWsRequest {
717            method: KrakenWsMethod::Ping,
718            params: None,
719            req_id: Some(req_id),
720        };
721
722        self.send_command(&request).await?;
723        Ok(())
724    }
725
726    async fn send_command(&self, request: &KrakenWsRequest) -> Result<String, KrakenWsError> {
727        let payload =
728            serde_json::to_string(request).map_err(|e| KrakenWsError::JsonError(e.to_string()))?;
729
730        log::trace!("Sending message: {payload}");
731
732        let cmd = match request.method {
733            KrakenWsMethod::Subscribe => SpotHandlerCommand::Subscribe {
734                payload: payload.clone(),
735            },
736            KrakenWsMethod::Unsubscribe => SpotHandlerCommand::Unsubscribe {
737                payload: payload.clone(),
738            },
739            KrakenWsMethod::Ping | KrakenWsMethod::Pong => SpotHandlerCommand::Ping {
740                payload: payload.clone(),
741            },
742        };
743
744        self.cmd_tx
745            .read()
746            .await
747            .send(cmd)
748            .map_err(|e| KrakenWsError::ConnectionError(format!("Failed to send request: {e}")))?;
749
750        Ok(payload)
751    }
752
753    /// Returns true if connected (not closed).
754    pub fn is_connected(&self) -> bool {
755        let connection_mode_arc = self.connection_mode.load();
756        !ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
757    }
758
759    /// Returns true if the connection is active.
760    pub fn is_active(&self) -> bool {
761        let connection_mode_arc = self.connection_mode.load();
762        ConnectionMode::from_atomic(&connection_mode_arc).is_active()
763            && !self.signal.load(Ordering::Relaxed)
764    }
765
766    /// Returns true if the connection is closed.
767    pub fn is_closed(&self) -> bool {
768        let connection_mode_arc = self.connection_mode.load();
769        ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
770            || self.signal.load(Ordering::Relaxed)
771    }
772
773    /// Returns the WebSocket URL.
774    pub fn url(&self) -> &str {
775        &self.url
776    }
777
778    /// Returns all active subscriptions.
779    pub fn get_subscriptions(&self) -> Vec<String> {
780        self.subscriptions.all_topics()
781    }
782
783    /// Sets the account ID for execution report parsing.
784    pub fn set_account_id(&self, account_id: AccountId) {
785        if let Ok(mut guard) = self.account_id.write() {
786            *guard = Some(account_id);
787        }
788    }
789
790    /// Returns the account ID if set.
791    #[must_use]
792    pub fn account_id(&self) -> Option<AccountId> {
793        self.account_id.read().ok().and_then(|g| *g)
794    }
795
796    /// Caches an instrument for execution report parsing.
797    pub fn cache_instrument(&self, instrument: InstrumentAny) {
798        self.instruments.insert(instrument.id(), instrument);
799    }
800
801    /// Returns a shared reference to the account ID.
802    pub fn account_id_shared(&self) -> &Arc<RwLock<Option<AccountId>>> {
803        &self.account_id
804    }
805
806    /// Returns a shared reference to the truncated ID map.
807    pub fn truncated_id_map(&self) -> &Arc<AtomicMap<String, ClientOrderId>> {
808        &self.truncated_id_map
809    }
810
811    /// Caches a client order for truncated ID resolution.
812    pub fn cache_client_order(
813        &self,
814        client_order_id: ClientOrderId,
815        _venue_order_id: Option<VenueOrderId>,
816        _instrument_id: InstrumentId,
817        _trader_id: TraderId,
818        _strategy_id: StrategyId,
819    ) {
820        let truncated = crate::common::parse::truncate_cl_ord_id(&client_order_id);
821
822        if truncated != client_order_id.as_str() {
823            self.truncated_id_map.insert(truncated, client_order_id);
824        }
825    }
826
827    /// Returns a stream of WebSocket messages.
828    ///
829    /// # Errors
830    ///
831    /// Returns an error if:
832    /// - The stream receiver has already been taken
833    /// - Other clones of this client still hold references to the receiver
834    pub fn stream(
835        &mut self,
836    ) -> Result<impl futures_util::Stream<Item = KrakenSpotWsMessage> + use<>, KrakenWsError> {
837        let rx = self.out_rx.take().ok_or_else(|| {
838            KrakenWsError::ChannelError(
839                "Stream receiver already taken or client not connected".to_string(),
840            )
841        })?;
842        let mut rx = Arc::try_unwrap(rx).map_err(|_| {
843            KrakenWsError::ChannelError(
844                "Cannot take ownership of stream - other client clones still hold references"
845                    .to_string(),
846            )
847        })?;
848        Ok(async_stream::stream! {
849            while let Some(msg) = rx.recv().await {
850                yield msg;
851            }
852        })
853    }
854
855    /// Subscribes to order book updates for the given instrument.
856    pub async fn subscribe_book(
857        &self,
858        instrument_id: InstrumentId,
859        depth: Option<u32>,
860    ) -> Result<(), KrakenWsError> {
861        let symbol = to_ws_v2_symbol(instrument_id.symbol.inner());
862        self.subscribe(KrakenWsChannel::Book, vec![symbol], depth)
863            .await
864    }
865
866    /// Subscribes to quote updates for the given instrument.
867    ///
868    /// Uses the Ticker channel with `event_trigger: "bbo"` for updates only on
869    /// best bid/offer changes.
870    pub async fn subscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), KrakenWsError> {
871        let symbol = to_ws_v2_symbol(instrument_id.symbol.inner());
872        let key = format!("quotes:{symbol}");
873
874        if !self.subscriptions.add_reference(&key) {
875            return Ok(());
876        }
877
878        self.subscriptions.mark_subscribe(&key);
879
880        let req_id = self.get_next_req_id().await;
881        let request = KrakenWsRequest {
882            method: KrakenWsMethod::Subscribe,
883            params: Some(KrakenWsParams {
884                channel: KrakenWsChannel::Ticker,
885                symbol: Some(vec![symbol]),
886                snapshot: None,
887                depth: None,
888                interval: None,
889                event_trigger: Some("bbo".to_string()),
890                token: None,
891                snap_orders: None,
892                snap_trades: None,
893            }),
894            req_id: Some(req_id),
895        };
896
897        let payload = self.send_command(&request).await?;
898        self.subscriptions.confirm_subscribe(&key);
899        self.subscription_payloads
900            .write()
901            .await
902            .insert(key, payload);
903        Ok(())
904    }
905
906    /// Subscribes to trade updates for the given instrument.
907    pub async fn subscribe_trades(&self, instrument_id: InstrumentId) -> Result<(), KrakenWsError> {
908        let symbol = to_ws_v2_symbol(instrument_id.symbol.inner());
909        self.subscribe(KrakenWsChannel::Trade, vec![symbol], None)
910            .await
911    }
912
913    /// Subscribes to bar/OHLC updates for the given bar type.
914    ///
915    /// # Errors
916    ///
917    /// Returns an error if the bar aggregation is not supported by Kraken.
918    pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), KrakenWsError> {
919        let symbol = to_ws_v2_symbol(bar_type.instrument_id().symbol.inner());
920        let interval = bar_type_to_ws_interval(bar_type)?;
921        self.subscribe_with_interval(KrakenWsChannel::Ohlc, vec![symbol], interval)
922            .await
923    }
924
925    /// Subscribes to execution updates (order and fill events).
926    ///
927    /// Requires authentication - call `authenticate()` first.
928    pub async fn subscribe_executions(
929        &self,
930        snap_orders: bool,
931        snap_trades: bool,
932    ) -> Result<(), KrakenWsError> {
933        let req_id = self.get_next_req_id().await;
934
935        let token = self.auth_token.read().await.clone().ok_or_else(|| {
936            KrakenWsError::AuthenticationError(
937                "Authentication token required for executions channel. Call authenticate() first"
938                    .to_string(),
939            )
940        })?;
941
942        let request = KrakenWsRequest {
943            method: KrakenWsMethod::Subscribe,
944            params: Some(KrakenWsParams {
945                channel: KrakenWsChannel::Executions,
946                symbol: None,
947                snapshot: None,
948                depth: None,
949                interval: None,
950                event_trigger: None,
951                token: Some(token),
952                snap_orders: Some(snap_orders),
953                snap_trades: Some(snap_trades),
954            }),
955            req_id: Some(req_id),
956        };
957
958        let payload = self.send_command(&request).await?;
959
960        let key = "executions";
961        if self.subscriptions.add_reference(key) {
962            self.subscriptions.mark_subscribe(key);
963            self.subscriptions.confirm_subscribe(key);
964            self.subscription_payloads
965                .write()
966                .await
967                .insert(key.to_string(), payload);
968        }
969
970        Ok(())
971    }
972
973    /// Unsubscribes from order book updates for the given instrument.
974    pub async fn unsubscribe_book(&self, instrument_id: InstrumentId) -> Result<(), KrakenWsError> {
975        let symbol = to_ws_v2_symbol(instrument_id.symbol.inner());
976        self.unsubscribe(KrakenWsChannel::Book, vec![symbol]).await
977    }
978
979    /// Unsubscribes from quote updates for the given instrument.
980    pub async fn unsubscribe_quotes(
981        &self,
982        instrument_id: InstrumentId,
983    ) -> Result<(), KrakenWsError> {
984        let symbol = to_ws_v2_symbol(instrument_id.symbol.inner());
985        let key = format!("quotes:{symbol}");
986
987        if !self.subscriptions.remove_reference(&key) {
988            return Ok(());
989        }
990
991        self.subscriptions.mark_unsubscribe(&key);
992
993        let req_id = self.get_next_req_id().await;
994        let request = KrakenWsRequest {
995            method: KrakenWsMethod::Unsubscribe,
996            params: Some(KrakenWsParams {
997                channel: KrakenWsChannel::Ticker,
998                symbol: Some(vec![symbol]),
999                snapshot: None,
1000                depth: None,
1001                interval: None,
1002                event_trigger: Some("bbo".to_string()),
1003                token: None,
1004                snap_orders: None,
1005                snap_trades: None,
1006            }),
1007            req_id: Some(req_id),
1008        };
1009
1010        self.send_command(&request).await?;
1011        self.subscriptions.confirm_unsubscribe(&key);
1012        self.subscription_payloads.write().await.remove(&key);
1013        Ok(())
1014    }
1015
1016    /// Unsubscribes from trade updates for the given instrument.
1017    pub async fn unsubscribe_trades(
1018        &self,
1019        instrument_id: InstrumentId,
1020    ) -> Result<(), KrakenWsError> {
1021        let symbol = to_ws_v2_symbol(instrument_id.symbol.inner());
1022        self.unsubscribe(KrakenWsChannel::Trade, vec![symbol]).await
1023    }
1024
1025    /// Unsubscribes from bar/OHLC updates for the given bar type.
1026    ///
1027    /// # Errors
1028    ///
1029    /// Returns an error if the bar aggregation is not supported by Kraken.
1030    pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), KrakenWsError> {
1031        let symbol = to_ws_v2_symbol(bar_type.instrument_id().symbol.inner());
1032        let interval = bar_type_to_ws_interval(bar_type)?;
1033        self.unsubscribe_with_interval(KrakenWsChannel::Ohlc, vec![symbol], interval)
1034            .await
1035    }
1036}
1037
1038/// Helper function to refresh authentication token via HTTP API.
1039async fn refresh_auth_token(config: &KrakenDataClientConfig) -> Result<String, KrakenWsError> {
1040    let api_key = config
1041        .api_key
1042        .clone()
1043        .ok_or_else(|| KrakenWsError::AuthenticationError("Missing API key".to_string()))?;
1044    let api_secret = config
1045        .api_secret
1046        .clone()
1047        .ok_or_else(|| KrakenWsError::AuthenticationError("Missing API secret".to_string()))?;
1048
1049    let http_client = KrakenSpotHttpClient::with_credentials(
1050        api_key,
1051        api_secret,
1052        config.environment,
1053        Some(config.http_base_url()),
1054        config.timeout_secs,
1055        None,
1056        None,
1057        None,
1058        config.proxy_url.clone(),
1059        config
1060            .max_requests_per_second
1061            .unwrap_or(KRAKEN_SPOT_DEFAULT_RATE_LIMIT_PER_SECOND),
1062    )
1063    .map_err(|e| {
1064        KrakenWsError::AuthenticationError(format!("Failed to create HTTP client: {e}"))
1065    })?;
1066
1067    let ws_token = http_client.get_websockets_token().await.map_err(|e| {
1068        KrakenWsError::AuthenticationError(format!("Failed to get WebSocket token: {e}"))
1069    })?;
1070
1071    log::debug!(
1072        "WebSocket authentication token refreshed: token_length={}, expires={}",
1073        ws_token.token.len(),
1074        ws_token.expires
1075    );
1076
1077    Ok(ws_token.token)
1078}
1079
1080fn update_auth_token_in_payload(payload: &str, new_token: &str) -> Result<String, KrakenWsError> {
1081    let mut value: serde_json::Value =
1082        serde_json::from_str(payload).map_err(|e| KrakenWsError::JsonError(e.to_string()))?;
1083
1084    if let Some(params) = value.get_mut("params") {
1085        params["token"] = serde_json::Value::String(new_token.to_string());
1086    }
1087
1088    serde_json::to_string(&value).map_err(|e| KrakenWsError::JsonError(e.to_string()))
1089}
1090
1091#[inline]
1092fn to_ws_v2_symbol(symbol: Ustr) -> Ustr {
1093    Ustr::from(&normalize_spot_symbol(symbol.as_str()))
1094}
1095
1096fn bar_type_to_ws_interval(bar_type: BarType) -> Result<u32, KrakenWsError> {
1097    const VALID_INTERVALS: [u32; 9] = [1, 5, 15, 30, 60, 240, 1440, 10080, 21600];
1098
1099    let spec = bar_type.spec();
1100    let step = spec.step.get() as u32;
1101
1102    let base_minutes = match spec.aggregation {
1103        BarAggregation::Minute => 1,
1104        BarAggregation::Hour => 60,
1105        BarAggregation::Day => 1440,
1106        BarAggregation::Week => 10080,
1107        other => {
1108            return Err(KrakenWsError::SubscriptionError(format!(
1109                "Unsupported bar aggregation for Kraken OHLC streaming: {other:?}"
1110            )));
1111        }
1112    };
1113
1114    let interval = base_minutes * step;
1115
1116    if !VALID_INTERVALS.contains(&interval) {
1117        return Err(KrakenWsError::SubscriptionError(format!(
1118            "Invalid bar interval {interval} minutes for Kraken OHLC streaming. \
1119             Supported intervals: 1, 5, 15, 30, 60, 240, 1440, 10080, 21600"
1120        )));
1121    }
1122
1123    Ok(interval)
1124}
1125
1126#[cfg(test)]
1127mod tests {
1128    use rstest::rstest;
1129
1130    use super::*;
1131
1132    #[rstest]
1133    #[case("XBT/EUR", "BTC/EUR")]
1134    #[case("XBT/USD", "BTC/USD")]
1135    #[case("XBT/USDT", "BTC/USDT")]
1136    #[case("ETH/USD", "ETH/USD")]
1137    #[case("ETH/XBT", "ETH/BTC")]
1138    #[case("SOL/XBT", "SOL/BTC")]
1139    #[case("SOL/USD", "SOL/USD")]
1140    #[case("BTC/USD", "BTC/USD")]
1141    #[case("ETH/BTC", "ETH/BTC")]
1142    fn test_to_kraken_ws_v2_symbol(#[case] input: &str, #[case] expected: &str) {
1143        let symbol = Ustr::from(input);
1144        let result = to_ws_v2_symbol(symbol);
1145        assert_eq!(result.as_str(), expected);
1146    }
1147
1148    fn test_client_without_credentials() -> KrakenSpotWebSocketClient {
1149        KrakenSpotWebSocketClient::new(
1150            KrakenDataClientConfig::default(),
1151            CancellationToken::new(),
1152            None,
1153        )
1154    }
1155
1156    #[rstest]
1157    #[tokio::test]
1158    async fn test_authenticate_without_credentials_errors() {
1159        let client = test_client_without_credentials();
1160
1161        let err = client.authenticate().await.expect_err("should fail");
1162        assert!(
1163            matches!(err, KrakenWsError::AuthenticationError(ref msg) if msg.contains("API credentials required")),
1164            "unexpected error: {err:?}"
1165        );
1166        assert!(!client.is_authenticated());
1167    }
1168
1169    #[rstest]
1170    #[tokio::test]
1171    async fn test_wait_until_authenticated_times_out() {
1172        let client = test_client_without_credentials();
1173
1174        let err = client
1175            .wait_until_authenticated(0.05)
1176            .await
1177            .expect_err("should time out");
1178        assert!(matches!(err, KrakenWsError::AuthenticationError(_)));
1179    }
1180
1181    #[rstest]
1182    #[tokio::test]
1183    async fn test_wait_until_authenticated_resolves_after_succeed() {
1184        let client = test_client_without_credentials();
1185
1186        let tracker = client.auth_tracker.clone();
1187        let _rx = tracker.begin();
1188
1189        tokio::spawn(async move {
1190            tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
1191            tracker.succeed();
1192        });
1193
1194        client
1195            .wait_until_authenticated(1.0)
1196            .await
1197            .expect("should resolve once tracker succeeds");
1198        assert!(client.is_authenticated());
1199    }
1200
1201    #[rstest]
1202    #[tokio::test]
1203    async fn test_is_authenticated_flips_on_fail() {
1204        let client = test_client_without_credentials();
1205
1206        let _rx = client.auth_tracker.begin();
1207        client.auth_tracker.succeed();
1208        assert!(client.is_authenticated());
1209
1210        client.auth_tracker.fail("test failure");
1211        assert!(!client.is_authenticated());
1212    }
1213}