1use std::{
19 collections::HashMap,
20 sync::{
21 Arc, RwLock,
22 atomic::{AtomicBool, AtomicU8, Ordering},
23 },
24};
25
26use arc_swap::ArcSwap;
27use nautilus_common::live::get_runtime;
28use nautilus_core::AtomicMap;
29use nautilus_model::{
30 data::BarType,
31 enums::BarAggregation,
32 identifiers::{AccountId, ClientOrderId, InstrumentId, StrategyId, TraderId, VenueOrderId},
33 instruments::{Instrument, InstrumentAny},
34};
35use nautilus_network::{
36 mode::ConnectionMode,
37 websocket::{
38 AuthTracker, SubscriptionState, TransportBackend, WebSocketClient, WebSocketConfig,
39 channel_message_handler,
40 },
41};
42use tokio_util::sync::CancellationToken;
43use ustr::Ustr;
44
45pub const KRAKEN_SPOT_WS_TOPIC_DELIMITER: char = ':';
49
50use super::{
51 enums::{KrakenWsChannel, KrakenWsMethod},
52 handler::{SpotFeedHandler, SpotHandlerCommand},
53 messages::{KrakenSpotWsMessage, KrakenWsParams, KrakenWsRequest},
54};
55use crate::{
56 common::parse::normalize_spot_symbol,
57 config::KrakenDataClientConfig,
58 http::{KrakenSpotHttpClient, spot::client::KRAKEN_SPOT_DEFAULT_RATE_LIMIT_PER_SECOND},
59 websocket::error::KrakenWsError,
60};
61
62const WS_PING_MSG: &str = r#"{"method":"ping"}"#;
63
64#[derive(Debug)]
66#[cfg_attr(
67 feature = "python",
68 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.kraken", from_py_object)
69)]
70#[cfg_attr(
71 feature = "python",
72 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.kraken")
73)]
74pub struct KrakenSpotWebSocketClient {
75 url: String,
76 config: KrakenDataClientConfig,
77 signal: Arc<AtomicBool>,
78 connection_mode: Arc<ArcSwap<AtomicU8>>,
79 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<SpotHandlerCommand>>>,
80 out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<KrakenSpotWsMessage>>>,
81 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
82 subscriptions: SubscriptionState,
83 subscription_payloads: Arc<tokio::sync::RwLock<HashMap<String, String>>>,
84 auth_tracker: AuthTracker,
85 cancellation_token: CancellationToken,
86 req_id_counter: Arc<tokio::sync::RwLock<u64>>,
87 auth_token: Arc<tokio::sync::RwLock<Option<String>>>,
88 account_id: Arc<RwLock<Option<AccountId>>>,
89 truncated_id_map: Arc<AtomicMap<String, ClientOrderId>>,
90 instruments: Arc<AtomicMap<InstrumentId, InstrumentAny>>,
91 transport_backend: TransportBackend,
92 proxy_url: Option<String>,
93}
94
95impl Clone for KrakenSpotWebSocketClient {
96 fn clone(&self) -> Self {
97 Self {
98 url: self.url.clone(),
99 config: self.config.clone(),
100 signal: Arc::clone(&self.signal),
101 connection_mode: Arc::clone(&self.connection_mode),
102 cmd_tx: Arc::clone(&self.cmd_tx),
103 out_rx: self.out_rx.clone(),
104 task_handle: self.task_handle.clone(),
105 subscriptions: self.subscriptions.clone(),
106 subscription_payloads: Arc::clone(&self.subscription_payloads),
107 auth_tracker: self.auth_tracker.clone(),
108 cancellation_token: self.cancellation_token.clone(),
109 req_id_counter: self.req_id_counter.clone(),
110 auth_token: self.auth_token.clone(),
111 account_id: Arc::clone(&self.account_id),
112 truncated_id_map: Arc::clone(&self.truncated_id_map),
113 instruments: Arc::clone(&self.instruments),
114 transport_backend: self.transport_backend,
115 proxy_url: self.proxy_url.clone(),
116 }
117 }
118}
119
120impl KrakenSpotWebSocketClient {
121 pub fn new(
123 mut config: KrakenDataClientConfig,
124 cancellation_token: CancellationToken,
125 proxy_url: Option<String>,
126 ) -> Self {
127 let url = if config.ws_private_url.is_some() {
129 config.ws_private_url()
130 } else {
131 config.ws_public_url()
132 };
133 let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<SpotHandlerCommand>();
134 let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
135 let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
136
137 let transport_backend = config.transport_backend;
138
139 config.proxy_url = proxy_url.clone();
143
144 Self {
145 url,
146 config,
147 signal: Arc::new(AtomicBool::new(false)),
148 connection_mode,
149 cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
150 out_rx: None,
151 task_handle: None,
152 subscriptions: SubscriptionState::new(KRAKEN_SPOT_WS_TOPIC_DELIMITER),
153 subscription_payloads: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
154 auth_tracker: AuthTracker::new(),
155 cancellation_token,
156 req_id_counter: Arc::new(tokio::sync::RwLock::new(0)),
157 auth_token: Arc::new(tokio::sync::RwLock::new(None)),
158 account_id: Arc::new(RwLock::new(None)),
159 truncated_id_map: Arc::new(AtomicMap::new()),
160 instruments: Arc::new(AtomicMap::new()),
161 transport_backend,
162 proxy_url,
163 }
164 }
165
166 async fn get_next_req_id(&self) -> u64 {
167 let mut counter = self.req_id_counter.write().await;
168 *counter += 1;
169 *counter
170 }
171
172 pub async fn connect(&mut self) -> Result<(), KrakenWsError> {
174 log::debug!("Connecting to {}", self.url);
175
176 self.signal.store(false, Ordering::Relaxed);
177
178 let (raw_handler, raw_rx) = channel_message_handler();
179
180 let ws_config = WebSocketConfig {
181 url: self.url.clone(),
182 headers: vec![],
183 heartbeat: Some(self.config.heartbeat_interval_secs),
184 heartbeat_msg: Some(WS_PING_MSG.to_string()),
185 reconnect_timeout_ms: Some(5_000),
186 reconnect_delay_initial_ms: Some(500),
187 reconnect_delay_max_ms: Some(5_000),
188 reconnect_backoff_factor: Some(1.5),
189 reconnect_jitter_ms: Some(250),
190 reconnect_max_attempts: None,
191 idle_timeout_ms: None,
192 backend: self.transport_backend,
193 proxy_url: self.proxy_url.clone(),
194 };
195
196 let ws_client = WebSocketClient::connect(
197 ws_config,
198 Some(raw_handler),
199 None, None, vec![], None, )
204 .await
205 .map_err(|e| KrakenWsError::ConnectionError(e.to_string()))?;
206
207 self.connection_mode
209 .store(ws_client.connection_mode_atomic());
210
211 let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<KrakenSpotWsMessage>();
212 self.out_rx = Some(Arc::new(out_rx));
213
214 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<SpotHandlerCommand>();
215 *self.cmd_tx.write().await = cmd_tx.clone();
216
217 if let Err(e) = cmd_tx.send(SpotHandlerCommand::SetClient(ws_client)) {
218 return Err(KrakenWsError::ConnectionError(format!(
219 "Failed to send WebSocketClient to handler: {e}"
220 )));
221 }
222
223 let signal = self.signal.clone();
224 let subscriptions = self.subscriptions.clone();
225 let subscription_payloads = self.subscription_payloads.clone();
226 let config_for_reconnect = self.config.clone();
227 let auth_token_for_reconnect = self.auth_token.clone();
228 let auth_tracker_for_reconnect = self.auth_tracker.clone();
229 let cmd_tx_for_reconnect = cmd_tx.clone();
230
231 let stream_handle = get_runtime().spawn(async move {
232 let mut handler =
233 SpotFeedHandler::new(signal.clone(), cmd_rx, raw_rx, subscriptions.clone());
234
235 loop {
236 match handler.next().await {
237 Some(KrakenSpotWsMessage::Reconnected) => {
238 if signal.load(Ordering::Relaxed) {
239 continue;
240 }
241 log::info!("WebSocket reconnected, resubscribing");
242
243 let confirmed_topics = subscriptions.all_topics();
244 for topic in &confirmed_topics {
245 subscriptions.mark_failure(topic);
246 }
247
248 let payloads = subscription_payloads.read().await;
249 if payloads.is_empty() {
250 log::debug!("No subscriptions to restore after reconnection");
251 } else {
252 let had_auth = auth_token_for_reconnect.read().await.is_some();
253
254 if had_auth && config_for_reconnect.has_api_credentials() {
255 log::debug!("Re-authenticating after reconnect");
256
257 auth_tracker_for_reconnect.invalidate();
258 let _rx = auth_tracker_for_reconnect.begin();
259
260 match refresh_auth_token(&config_for_reconnect).await {
261 Ok(new_token) => {
262 *auth_token_for_reconnect.write().await = Some(new_token);
263 auth_tracker_for_reconnect.succeed();
264 log::debug!("Re-authentication successful");
265 }
266 Err(e) => {
267 log::error!(
268 "Failed to re-authenticate after reconnect: {e}"
269 );
270 *auth_token_for_reconnect.write().await = None;
271 auth_tracker_for_reconnect.fail(e.to_string());
272 }
273 }
274 }
275
276 log::info!(
277 "Resubscribing after reconnection: count={}",
278 payloads.len()
279 );
280
281 for (topic, payload) in payloads.iter() {
282 let payload = if topic == "executions" {
283 let auth_token = auth_token_for_reconnect.read().await.clone();
284 match auth_token {
285 Some(token) => {
286 match update_auth_token_in_payload(payload, &token) {
287 Ok(p) => p,
288 Err(e) => {
289 log::error!("Failed to update auth token: {e}");
290 continue;
291 }
292 }
293 }
294 None => {
295 log::warn!(
296 "Cannot resubscribe to executions: no auth token"
297 );
298 continue;
299 }
300 }
301 } else {
302 payload.clone()
303 };
304
305 if let Err(e) = cmd_tx_for_reconnect
306 .send(SpotHandlerCommand::Subscribe { payload })
307 {
308 log::error!(
309 "Failed to send resubscribe command: error={e}, \
310 topic={topic}"
311 );
312 }
313
314 subscriptions.mark_subscribe(topic);
315 }
316 }
317
318 if out_tx.send(KrakenSpotWsMessage::Reconnected).is_err() {
319 log::error!("Failed to send message (receiver dropped)");
320 break;
321 }
322 }
323 Some(msg) => {
324 if out_tx.send(msg).is_err() {
325 log::error!("Failed to send message (receiver dropped)");
326 break;
327 }
328 }
329 None => {
330 if handler.is_stopped() {
331 log::debug!("Stop signal received, ending message processing");
332 break;
333 }
334 log::warn!("WebSocket stream ended unexpectedly");
335 break;
336 }
337 }
338 }
339
340 log::debug!("Handler task exiting");
341 });
342
343 self.task_handle = Some(Arc::new(stream_handle));
344
345 log::debug!("WebSocket connected successfully");
346 Ok(())
347 }
348
349 pub async fn disconnect(&mut self) -> Result<(), KrakenWsError> {
351 log::debug!("Disconnecting WebSocket");
352
353 self.signal.store(true, Ordering::Relaxed);
354
355 if let Err(e) = self
356 .cmd_tx
357 .read()
358 .await
359 .send(SpotHandlerCommand::Disconnect)
360 {
361 log::debug!(
362 "Failed to send disconnect command (handler may already be shut down): {e}"
363 );
364 }
365
366 if let Some(task_handle) = self.task_handle.take() {
367 match Arc::try_unwrap(task_handle) {
368 Ok(handle) => {
369 log::debug!("Waiting for task handle to complete");
370 match tokio::time::timeout(tokio::time::Duration::from_secs(2), handle).await {
371 Ok(Ok(())) => log::debug!("Task handle completed successfully"),
372 Ok(Err(e)) => log::error!("Task handle encountered an error: {e:?}"),
373 Err(_) => {
374 log::warn!(
375 "Timeout waiting for task handle, task may still be running"
376 );
377 }
378 }
379 }
380 Err(arc_handle) => {
381 log::debug!(
382 "Cannot take ownership of task handle - other references exist, aborting task"
383 );
384 arc_handle.abort();
385 }
386 }
387 } else {
388 log::debug!("No task handle to await");
389 }
390
391 self.subscriptions.clear();
392 self.subscription_payloads.write().await.clear();
393 self.auth_tracker.fail("Disconnected");
394
395 Ok(())
396 }
397
398 pub async fn close(&mut self) -> Result<(), KrakenWsError> {
400 self.disconnect().await
401 }
402
403 pub async fn wait_until_active(&self, timeout_secs: f64) -> Result<(), KrakenWsError> {
405 let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
406
407 tokio::time::timeout(timeout, async {
408 while !self.is_active() {
409 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
410 }
411 })
412 .await
413 .map_err(|_| {
414 KrakenWsError::ConnectionError(format!(
415 "WebSocket connection timeout after {timeout_secs} seconds"
416 ))
417 })?;
418
419 Ok(())
420 }
421
422 #[must_use]
424 pub fn is_authenticated(&self) -> bool {
425 self.auth_tracker.is_authenticated()
426 }
427
428 pub async fn wait_until_authenticated(&self, timeout_secs: f64) -> Result<(), KrakenWsError> {
432 let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
433
434 if self.auth_tracker.wait_for_authenticated(timeout).await {
435 Ok(())
436 } else {
437 Err(KrakenWsError::AuthenticationError(format!(
438 "Authentication not completed within {timeout_secs} seconds"
439 )))
440 }
441 }
442
443 pub async fn authenticate(&self) -> Result<(), KrakenWsError> {
445 if !self.config.has_api_credentials() {
446 return Err(KrakenWsError::AuthenticationError(
447 "API credentials required for authentication".to_string(),
448 ));
449 }
450
451 let _receiver = self.auth_tracker.begin();
452
453 match refresh_auth_token(&self.config).await {
454 Ok(token) => {
455 *self.auth_token.write().await = Some(token);
456 self.auth_tracker.succeed();
457 Ok(())
458 }
459 Err(e) => {
460 *self.auth_token.write().await = None;
461 self.auth_tracker.fail(e.to_string());
462 Err(e)
463 }
464 }
465 }
466
467 pub fn cancel_all_requests(&self) {
469 self.cancellation_token.cancel();
470 }
471
472 pub fn cancellation_token(&self) -> &CancellationToken {
474 &self.cancellation_token
475 }
476
477 pub async fn subscribe(
479 &self,
480 channel: KrakenWsChannel,
481 symbols: Vec<Ustr>,
482 depth: Option<u32>,
483 ) -> Result<(), KrakenWsError> {
484 let mut symbols_to_subscribe = Vec::new();
485 let channel_str = channel.as_ref();
486 for symbol in &symbols {
487 let key = format!("{channel_str}:{symbol}");
488 if self.subscriptions.add_reference(&key) {
489 self.subscriptions.mark_subscribe(&key);
490 symbols_to_subscribe.push(*symbol);
491 }
492 }
493
494 if symbols_to_subscribe.is_empty() {
495 return Ok(());
496 }
497
498 let is_private = matches!(
499 channel,
500 KrakenWsChannel::Executions | KrakenWsChannel::Balances
501 );
502 let token = if is_private {
503 Some(self.auth_token.read().await.clone().ok_or_else(|| {
504 KrakenWsError::AuthenticationError(
505 "Authentication token required for private channels. Call authenticate() first"
506 .to_string(),
507 )
508 })?)
509 } else {
510 None
511 };
512
513 let req_id = self.get_next_req_id().await;
514 let request = KrakenWsRequest {
515 method: KrakenWsMethod::Subscribe,
516 params: Some(KrakenWsParams {
517 channel,
518 symbol: Some(symbols_to_subscribe.clone()),
519 snapshot: None,
520 depth,
521 interval: None,
522 event_trigger: None,
523 token,
524 snap_orders: None,
525 snap_trades: None,
526 }),
527 req_id: Some(req_id),
528 };
529
530 let payload = self.send_command(&request).await?;
531
532 for symbol in &symbols_to_subscribe {
533 let key = format!("{channel_str}:{symbol}");
534 self.subscriptions.confirm_subscribe(&key);
535 self.subscription_payloads
536 .write()
537 .await
538 .insert(key, payload.clone());
539 }
540
541 Ok(())
542 }
543
544 async fn subscribe_with_interval(
546 &self,
547 channel: KrakenWsChannel,
548 symbols: Vec<Ustr>,
549 interval: u32,
550 ) -> Result<(), KrakenWsError> {
551 let mut symbols_to_subscribe = Vec::new();
552 let channel_str = channel.as_ref();
553 for symbol in &symbols {
554 let key = format!("{channel_str}:{symbol}:{interval}");
555 if self.subscriptions.add_reference(&key) {
556 self.subscriptions.mark_subscribe(&key);
557 symbols_to_subscribe.push(*symbol);
558 }
559 }
560
561 if symbols_to_subscribe.is_empty() {
562 return Ok(());
563 }
564
565 let req_id = self.get_next_req_id().await;
566 let request = KrakenWsRequest {
567 method: KrakenWsMethod::Subscribe,
568 params: Some(KrakenWsParams {
569 channel,
570 symbol: Some(symbols_to_subscribe.clone()),
571 snapshot: Some(false),
572 depth: None,
573 interval: Some(interval),
574 event_trigger: None,
575 token: None,
576 snap_orders: None,
577 snap_trades: None,
578 }),
579 req_id: Some(req_id),
580 };
581
582 let payload = self.send_command(&request).await?;
583
584 for symbol in &symbols_to_subscribe {
585 let key = format!("{channel_str}:{symbol}:{interval}");
586 self.subscriptions.confirm_subscribe(&key);
587 self.subscription_payloads
588 .write()
589 .await
590 .insert(key, payload.clone());
591 }
592
593 Ok(())
594 }
595
596 async fn unsubscribe_with_interval(
598 &self,
599 channel: KrakenWsChannel,
600 symbols: Vec<Ustr>,
601 interval: u32,
602 ) -> Result<(), KrakenWsError> {
603 let mut symbols_to_unsubscribe = Vec::new();
604 let channel_str = channel.as_ref();
605 for symbol in &symbols {
606 let key = format!("{channel_str}:{symbol}:{interval}");
607 if self.subscriptions.remove_reference(&key) {
608 self.subscriptions.mark_unsubscribe(&key);
609 symbols_to_unsubscribe.push(*symbol);
610 }
611 }
612
613 if symbols_to_unsubscribe.is_empty() {
614 return Ok(());
615 }
616
617 let req_id = self.get_next_req_id().await;
618 let request = KrakenWsRequest {
619 method: KrakenWsMethod::Unsubscribe,
620 params: Some(KrakenWsParams {
621 channel,
622 symbol: Some(symbols_to_unsubscribe.clone()),
623 snapshot: None,
624 depth: None,
625 interval: Some(interval),
626 event_trigger: None,
627 token: None,
628 snap_orders: None,
629 snap_trades: None,
630 }),
631 req_id: Some(req_id),
632 };
633
634 self.send_command(&request).await?;
635
636 for symbol in &symbols_to_unsubscribe {
637 let key = format!("{channel_str}:{symbol}:{interval}");
638 self.subscriptions.confirm_unsubscribe(&key);
639 self.subscription_payloads.write().await.remove(&key);
640 }
641
642 Ok(())
643 }
644
645 pub async fn unsubscribe(
647 &self,
648 channel: KrakenWsChannel,
649 symbols: Vec<Ustr>,
650 ) -> Result<(), KrakenWsError> {
651 let mut symbols_to_unsubscribe = Vec::new();
652 let channel_str = channel.as_ref();
653 for symbol in &symbols {
654 let key = format!("{channel_str}:{symbol}");
655 if self.subscriptions.remove_reference(&key) {
656 self.subscriptions.mark_unsubscribe(&key);
657 symbols_to_unsubscribe.push(*symbol);
658 } else {
659 log::debug!(
660 "Channel {channel_str} symbol {symbol} still has active subscriptions, not unsubscribing"
661 );
662 }
663 }
664
665 if symbols_to_unsubscribe.is_empty() {
666 return Ok(());
667 }
668
669 let is_private = matches!(
670 channel,
671 KrakenWsChannel::Executions | KrakenWsChannel::Balances
672 );
673 let token = if is_private {
674 Some(self.auth_token.read().await.clone().ok_or_else(|| {
675 KrakenWsError::AuthenticationError(
676 "Authentication token required for private channels. Call authenticate() first"
677 .to_string(),
678 )
679 })?)
680 } else {
681 None
682 };
683
684 let req_id = self.get_next_req_id().await;
685 let request = KrakenWsRequest {
686 method: KrakenWsMethod::Unsubscribe,
687 params: Some(KrakenWsParams {
688 channel,
689 symbol: Some(symbols_to_unsubscribe.clone()),
690 snapshot: None,
691 depth: None,
692 interval: None,
693 event_trigger: None,
694 token,
695 snap_orders: None,
696 snap_trades: None,
697 }),
698 req_id: Some(req_id),
699 };
700
701 self.send_command(&request).await?;
702
703 for symbol in &symbols_to_unsubscribe {
704 let key = format!("{channel_str}:{symbol}");
705 self.subscriptions.confirm_unsubscribe(&key);
706 self.subscription_payloads.write().await.remove(&key);
707 }
708
709 Ok(())
710 }
711
712 pub async fn send_ping(&self) -> Result<(), KrakenWsError> {
714 let req_id = self.get_next_req_id().await;
715
716 let request = KrakenWsRequest {
717 method: KrakenWsMethod::Ping,
718 params: None,
719 req_id: Some(req_id),
720 };
721
722 self.send_command(&request).await?;
723 Ok(())
724 }
725
726 async fn send_command(&self, request: &KrakenWsRequest) -> Result<String, KrakenWsError> {
727 let payload =
728 serde_json::to_string(request).map_err(|e| KrakenWsError::JsonError(e.to_string()))?;
729
730 log::trace!("Sending message: {payload}");
731
732 let cmd = match request.method {
733 KrakenWsMethod::Subscribe => SpotHandlerCommand::Subscribe {
734 payload: payload.clone(),
735 },
736 KrakenWsMethod::Unsubscribe => SpotHandlerCommand::Unsubscribe {
737 payload: payload.clone(),
738 },
739 KrakenWsMethod::Ping | KrakenWsMethod::Pong => SpotHandlerCommand::Ping {
740 payload: payload.clone(),
741 },
742 };
743
744 self.cmd_tx
745 .read()
746 .await
747 .send(cmd)
748 .map_err(|e| KrakenWsError::ConnectionError(format!("Failed to send request: {e}")))?;
749
750 Ok(payload)
751 }
752
753 pub fn is_connected(&self) -> bool {
755 let connection_mode_arc = self.connection_mode.load();
756 !ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
757 }
758
759 pub fn is_active(&self) -> bool {
761 let connection_mode_arc = self.connection_mode.load();
762 ConnectionMode::from_atomic(&connection_mode_arc).is_active()
763 && !self.signal.load(Ordering::Relaxed)
764 }
765
766 pub fn is_closed(&self) -> bool {
768 let connection_mode_arc = self.connection_mode.load();
769 ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
770 || self.signal.load(Ordering::Relaxed)
771 }
772
773 pub fn url(&self) -> &str {
775 &self.url
776 }
777
778 pub fn get_subscriptions(&self) -> Vec<String> {
780 self.subscriptions.all_topics()
781 }
782
783 pub fn set_account_id(&self, account_id: AccountId) {
785 if let Ok(mut guard) = self.account_id.write() {
786 *guard = Some(account_id);
787 }
788 }
789
790 #[must_use]
792 pub fn account_id(&self) -> Option<AccountId> {
793 self.account_id.read().ok().and_then(|g| *g)
794 }
795
796 pub fn cache_instrument(&self, instrument: InstrumentAny) {
798 self.instruments.insert(instrument.id(), instrument);
799 }
800
801 pub fn account_id_shared(&self) -> &Arc<RwLock<Option<AccountId>>> {
803 &self.account_id
804 }
805
806 pub fn truncated_id_map(&self) -> &Arc<AtomicMap<String, ClientOrderId>> {
808 &self.truncated_id_map
809 }
810
811 pub fn cache_client_order(
813 &self,
814 client_order_id: ClientOrderId,
815 _venue_order_id: Option<VenueOrderId>,
816 _instrument_id: InstrumentId,
817 _trader_id: TraderId,
818 _strategy_id: StrategyId,
819 ) {
820 let truncated = crate::common::parse::truncate_cl_ord_id(&client_order_id);
821
822 if truncated != client_order_id.as_str() {
823 self.truncated_id_map.insert(truncated, client_order_id);
824 }
825 }
826
827 pub fn stream(
835 &mut self,
836 ) -> Result<impl futures_util::Stream<Item = KrakenSpotWsMessage> + use<>, KrakenWsError> {
837 let rx = self.out_rx.take().ok_or_else(|| {
838 KrakenWsError::ChannelError(
839 "Stream receiver already taken or client not connected".to_string(),
840 )
841 })?;
842 let mut rx = Arc::try_unwrap(rx).map_err(|_| {
843 KrakenWsError::ChannelError(
844 "Cannot take ownership of stream - other client clones still hold references"
845 .to_string(),
846 )
847 })?;
848 Ok(async_stream::stream! {
849 while let Some(msg) = rx.recv().await {
850 yield msg;
851 }
852 })
853 }
854
855 pub async fn subscribe_book(
857 &self,
858 instrument_id: InstrumentId,
859 depth: Option<u32>,
860 ) -> Result<(), KrakenWsError> {
861 let symbol = to_ws_v2_symbol(instrument_id.symbol.inner());
862 self.subscribe(KrakenWsChannel::Book, vec![symbol], depth)
863 .await
864 }
865
866 pub async fn subscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), KrakenWsError> {
871 let symbol = to_ws_v2_symbol(instrument_id.symbol.inner());
872 let key = format!("quotes:{symbol}");
873
874 if !self.subscriptions.add_reference(&key) {
875 return Ok(());
876 }
877
878 self.subscriptions.mark_subscribe(&key);
879
880 let req_id = self.get_next_req_id().await;
881 let request = KrakenWsRequest {
882 method: KrakenWsMethod::Subscribe,
883 params: Some(KrakenWsParams {
884 channel: KrakenWsChannel::Ticker,
885 symbol: Some(vec![symbol]),
886 snapshot: None,
887 depth: None,
888 interval: None,
889 event_trigger: Some("bbo".to_string()),
890 token: None,
891 snap_orders: None,
892 snap_trades: None,
893 }),
894 req_id: Some(req_id),
895 };
896
897 let payload = self.send_command(&request).await?;
898 self.subscriptions.confirm_subscribe(&key);
899 self.subscription_payloads
900 .write()
901 .await
902 .insert(key, payload);
903 Ok(())
904 }
905
906 pub async fn subscribe_trades(&self, instrument_id: InstrumentId) -> Result<(), KrakenWsError> {
908 let symbol = to_ws_v2_symbol(instrument_id.symbol.inner());
909 self.subscribe(KrakenWsChannel::Trade, vec![symbol], None)
910 .await
911 }
912
913 pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), KrakenWsError> {
919 let symbol = to_ws_v2_symbol(bar_type.instrument_id().symbol.inner());
920 let interval = bar_type_to_ws_interval(bar_type)?;
921 self.subscribe_with_interval(KrakenWsChannel::Ohlc, vec![symbol], interval)
922 .await
923 }
924
925 pub async fn subscribe_executions(
929 &self,
930 snap_orders: bool,
931 snap_trades: bool,
932 ) -> Result<(), KrakenWsError> {
933 let req_id = self.get_next_req_id().await;
934
935 let token = self.auth_token.read().await.clone().ok_or_else(|| {
936 KrakenWsError::AuthenticationError(
937 "Authentication token required for executions channel. Call authenticate() first"
938 .to_string(),
939 )
940 })?;
941
942 let request = KrakenWsRequest {
943 method: KrakenWsMethod::Subscribe,
944 params: Some(KrakenWsParams {
945 channel: KrakenWsChannel::Executions,
946 symbol: None,
947 snapshot: None,
948 depth: None,
949 interval: None,
950 event_trigger: None,
951 token: Some(token),
952 snap_orders: Some(snap_orders),
953 snap_trades: Some(snap_trades),
954 }),
955 req_id: Some(req_id),
956 };
957
958 let payload = self.send_command(&request).await?;
959
960 let key = "executions";
961 if self.subscriptions.add_reference(key) {
962 self.subscriptions.mark_subscribe(key);
963 self.subscriptions.confirm_subscribe(key);
964 self.subscription_payloads
965 .write()
966 .await
967 .insert(key.to_string(), payload);
968 }
969
970 Ok(())
971 }
972
973 pub async fn unsubscribe_book(&self, instrument_id: InstrumentId) -> Result<(), KrakenWsError> {
975 let symbol = to_ws_v2_symbol(instrument_id.symbol.inner());
976 self.unsubscribe(KrakenWsChannel::Book, vec![symbol]).await
977 }
978
979 pub async fn unsubscribe_quotes(
981 &self,
982 instrument_id: InstrumentId,
983 ) -> Result<(), KrakenWsError> {
984 let symbol = to_ws_v2_symbol(instrument_id.symbol.inner());
985 let key = format!("quotes:{symbol}");
986
987 if !self.subscriptions.remove_reference(&key) {
988 return Ok(());
989 }
990
991 self.subscriptions.mark_unsubscribe(&key);
992
993 let req_id = self.get_next_req_id().await;
994 let request = KrakenWsRequest {
995 method: KrakenWsMethod::Unsubscribe,
996 params: Some(KrakenWsParams {
997 channel: KrakenWsChannel::Ticker,
998 symbol: Some(vec![symbol]),
999 snapshot: None,
1000 depth: None,
1001 interval: None,
1002 event_trigger: Some("bbo".to_string()),
1003 token: None,
1004 snap_orders: None,
1005 snap_trades: None,
1006 }),
1007 req_id: Some(req_id),
1008 };
1009
1010 self.send_command(&request).await?;
1011 self.subscriptions.confirm_unsubscribe(&key);
1012 self.subscription_payloads.write().await.remove(&key);
1013 Ok(())
1014 }
1015
1016 pub async fn unsubscribe_trades(
1018 &self,
1019 instrument_id: InstrumentId,
1020 ) -> Result<(), KrakenWsError> {
1021 let symbol = to_ws_v2_symbol(instrument_id.symbol.inner());
1022 self.unsubscribe(KrakenWsChannel::Trade, vec![symbol]).await
1023 }
1024
1025 pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), KrakenWsError> {
1031 let symbol = to_ws_v2_symbol(bar_type.instrument_id().symbol.inner());
1032 let interval = bar_type_to_ws_interval(bar_type)?;
1033 self.unsubscribe_with_interval(KrakenWsChannel::Ohlc, vec![symbol], interval)
1034 .await
1035 }
1036}
1037
1038async fn refresh_auth_token(config: &KrakenDataClientConfig) -> Result<String, KrakenWsError> {
1040 let api_key = config
1041 .api_key
1042 .clone()
1043 .ok_or_else(|| KrakenWsError::AuthenticationError("Missing API key".to_string()))?;
1044 let api_secret = config
1045 .api_secret
1046 .clone()
1047 .ok_or_else(|| KrakenWsError::AuthenticationError("Missing API secret".to_string()))?;
1048
1049 let http_client = KrakenSpotHttpClient::with_credentials(
1050 api_key,
1051 api_secret,
1052 config.environment,
1053 Some(config.http_base_url()),
1054 config.timeout_secs,
1055 None,
1056 None,
1057 None,
1058 config.proxy_url.clone(),
1059 config
1060 .max_requests_per_second
1061 .unwrap_or(KRAKEN_SPOT_DEFAULT_RATE_LIMIT_PER_SECOND),
1062 )
1063 .map_err(|e| {
1064 KrakenWsError::AuthenticationError(format!("Failed to create HTTP client: {e}"))
1065 })?;
1066
1067 let ws_token = http_client.get_websockets_token().await.map_err(|e| {
1068 KrakenWsError::AuthenticationError(format!("Failed to get WebSocket token: {e}"))
1069 })?;
1070
1071 log::debug!(
1072 "WebSocket authentication token refreshed: token_length={}, expires={}",
1073 ws_token.token.len(),
1074 ws_token.expires
1075 );
1076
1077 Ok(ws_token.token)
1078}
1079
1080fn update_auth_token_in_payload(payload: &str, new_token: &str) -> Result<String, KrakenWsError> {
1081 let mut value: serde_json::Value =
1082 serde_json::from_str(payload).map_err(|e| KrakenWsError::JsonError(e.to_string()))?;
1083
1084 if let Some(params) = value.get_mut("params") {
1085 params["token"] = serde_json::Value::String(new_token.to_string());
1086 }
1087
1088 serde_json::to_string(&value).map_err(|e| KrakenWsError::JsonError(e.to_string()))
1089}
1090
1091#[inline]
1092fn to_ws_v2_symbol(symbol: Ustr) -> Ustr {
1093 Ustr::from(&normalize_spot_symbol(symbol.as_str()))
1094}
1095
1096fn bar_type_to_ws_interval(bar_type: BarType) -> Result<u32, KrakenWsError> {
1097 const VALID_INTERVALS: [u32; 9] = [1, 5, 15, 30, 60, 240, 1440, 10080, 21600];
1098
1099 let spec = bar_type.spec();
1100 let step = spec.step.get() as u32;
1101
1102 let base_minutes = match spec.aggregation {
1103 BarAggregation::Minute => 1,
1104 BarAggregation::Hour => 60,
1105 BarAggregation::Day => 1440,
1106 BarAggregation::Week => 10080,
1107 other => {
1108 return Err(KrakenWsError::SubscriptionError(format!(
1109 "Unsupported bar aggregation for Kraken OHLC streaming: {other:?}"
1110 )));
1111 }
1112 };
1113
1114 let interval = base_minutes * step;
1115
1116 if !VALID_INTERVALS.contains(&interval) {
1117 return Err(KrakenWsError::SubscriptionError(format!(
1118 "Invalid bar interval {interval} minutes for Kraken OHLC streaming. \
1119 Supported intervals: 1, 5, 15, 30, 60, 240, 1440, 10080, 21600"
1120 )));
1121 }
1122
1123 Ok(interval)
1124}
1125
1126#[cfg(test)]
1127mod tests {
1128 use rstest::rstest;
1129
1130 use super::*;
1131
1132 #[rstest]
1133 #[case("XBT/EUR", "BTC/EUR")]
1134 #[case("XBT/USD", "BTC/USD")]
1135 #[case("XBT/USDT", "BTC/USDT")]
1136 #[case("ETH/USD", "ETH/USD")]
1137 #[case("ETH/XBT", "ETH/BTC")]
1138 #[case("SOL/XBT", "SOL/BTC")]
1139 #[case("SOL/USD", "SOL/USD")]
1140 #[case("BTC/USD", "BTC/USD")]
1141 #[case("ETH/BTC", "ETH/BTC")]
1142 fn test_to_kraken_ws_v2_symbol(#[case] input: &str, #[case] expected: &str) {
1143 let symbol = Ustr::from(input);
1144 let result = to_ws_v2_symbol(symbol);
1145 assert_eq!(result.as_str(), expected);
1146 }
1147
1148 fn test_client_without_credentials() -> KrakenSpotWebSocketClient {
1149 KrakenSpotWebSocketClient::new(
1150 KrakenDataClientConfig::default(),
1151 CancellationToken::new(),
1152 None,
1153 )
1154 }
1155
1156 #[rstest]
1157 #[tokio::test]
1158 async fn test_authenticate_without_credentials_errors() {
1159 let client = test_client_without_credentials();
1160
1161 let err = client.authenticate().await.expect_err("should fail");
1162 assert!(
1163 matches!(err, KrakenWsError::AuthenticationError(ref msg) if msg.contains("API credentials required")),
1164 "unexpected error: {err:?}"
1165 );
1166 assert!(!client.is_authenticated());
1167 }
1168
1169 #[rstest]
1170 #[tokio::test]
1171 async fn test_wait_until_authenticated_times_out() {
1172 let client = test_client_without_credentials();
1173
1174 let err = client
1175 .wait_until_authenticated(0.05)
1176 .await
1177 .expect_err("should time out");
1178 assert!(matches!(err, KrakenWsError::AuthenticationError(_)));
1179 }
1180
1181 #[rstest]
1182 #[tokio::test]
1183 async fn test_wait_until_authenticated_resolves_after_succeed() {
1184 let client = test_client_without_credentials();
1185
1186 let tracker = client.auth_tracker.clone();
1187 let _rx = tracker.begin();
1188
1189 tokio::spawn(async move {
1190 tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
1191 tracker.succeed();
1192 });
1193
1194 client
1195 .wait_until_authenticated(1.0)
1196 .await
1197 .expect("should resolve once tracker succeeds");
1198 assert!(client.is_authenticated());
1199 }
1200
1201 #[rstest]
1202 #[tokio::test]
1203 async fn test_is_authenticated_flips_on_fail() {
1204 let client = test_client_without_credentials();
1205
1206 let _rx = client.auth_tracker.begin();
1207 client.auth_tracker.succeed();
1208 assert!(client.is_authenticated());
1209
1210 client.auth_tracker.fail("test failure");
1211 assert!(!client.is_authenticated());
1212 }
1213}