1use std::{
26 fmt::Debug,
27 num::NonZeroU32,
28 sync::{
29 Arc, LazyLock,
30 atomic::{AtomicBool, AtomicU8, AtomicU64, Ordering},
31 },
32 time::{Duration, SystemTime},
33};
34
35use ahash::{AHashMap, AHashSet};
36use arc_swap::ArcSwap;
37use dashmap::DashMap;
38use futures_util::Stream;
39use nautilus_common::live::get_runtime;
40use nautilus_core::{
41 AtomicMap,
42 consts::NAUTILUS_USER_AGENT,
43 env::{get_env_var, get_or_env_var},
44 string::secret::REDACTED,
45};
46use nautilus_model::{
47 data::BarType,
48 enums::{OrderSide, OrderType, PositionSide, TimeInForce, TriggerType},
49 identifiers::{AccountId, ClientOrderId, InstrumentId, StrategyId, TraderId, VenueOrderId},
50 instruments::{Instrument, InstrumentAny},
51 types::{Price, Quantity},
52};
53use nautilus_network::{
54 http::USER_AGENT,
55 mode::ConnectionMode,
56 ratelimiter::quota::Quota,
57 websocket::{
58 AUTHENTICATION_TIMEOUT_SECS, AuthTracker, PingHandler, SubscriptionState, TEXT_PING,
59 TransportBackend, WebSocketClient, WebSocketConfig, channel_message_handler,
60 },
61};
62use serde_json::Value;
63use tokio_tungstenite::tungstenite::Error;
64use tokio_util::sync::CancellationToken;
65use ustr::Ustr;
66
67use super::{
68 enums::OKXWsChannel,
69 error::OKXWsError,
70 handler::{HandlerCommand, OKXWsFeedHandler},
71 messages::{
72 OKXAuthentication, OKXAuthenticationArg, OKXSubscriptionArg, OKXWsMessage, OKXWsRequest,
73 WsAmendOrderParamsBuilder, WsAttachAlgoOrdParams, WsCancelOrderParamsBuilder,
74 WsMassCancelParams, WsPostAlgoOrderParamsBuilder, WsPostOrderParamsBuilder,
75 },
76 subscription::topic_from_subscription_arg,
77};
78use crate::common::{
79 consts::{
80 OKX_NAUTILUS_BROKER_ID, OKX_SUPPORTED_ORDER_TYPES, OKX_SUPPORTED_TIME_IN_FORCE,
81 OKX_WS_PUBLIC_URL, OKX_WS_TOPIC_DELIMITER,
82 },
83 credential::Credential,
84 enums::{
85 OKXGreeksType, OKXInstrumentType, OKXOrderType, OKXPositionSide, OKXTargetCurrency,
86 OKXTradeMode, OKXTriggerType, OKXVipLevel, conditional_order_to_algo_type,
87 is_conditional_order,
88 },
89 parse::{
90 bar_spec_as_okx_channel, okx_instrument_type, okx_instrument_type_from_symbol,
91 parse_base_quote_from_symbol,
92 },
93};
94
95pub static OKX_WS_CONNECTION_QUOTA: LazyLock<Quota> = LazyLock::new(|| {
99 Quota::per_second(NonZeroU32::new(3).expect("non-zero")).expect("valid constant")
100});
101
102pub static OKX_WS_SUBSCRIPTION_QUOTA: LazyLock<Quota> =
107 LazyLock::new(|| Quota::per_hour(NonZeroU32::new(480).expect("non-zero")));
108
109pub static OKX_WS_ORDER_QUOTA: LazyLock<Quota> = LazyLock::new(|| {
114 Quota::per_second(NonZeroU32::new(250).expect("non-zero")).expect("valid constant")
115});
116
117pub static OKX_RATE_LIMIT_KEY_SUBSCRIPTION: LazyLock<[Ustr; 1]> =
122 LazyLock::new(|| [Ustr::from("subscription")]);
123
124pub static OKX_RATE_LIMIT_KEY_ORDER: LazyLock<[Ustr; 1]> = LazyLock::new(|| [Ustr::from("order")]);
129
130pub static OKX_RATE_LIMIT_KEY_CANCEL: LazyLock<[Ustr; 1]> =
136 LazyLock::new(|| [Ustr::from("cancel")]);
137
138pub static OKX_RATE_LIMIT_KEY_AMEND: LazyLock<[Ustr; 1]> = LazyLock::new(|| [Ustr::from("amend")]);
142
143#[derive(Debug, Clone)]
147#[allow(dead_code)]
148pub(crate) struct PendingOrderInfo {
149 pub trader_id: TraderId,
150 pub strategy_id: StrategyId,
151 pub instrument_id: InstrumentId,
152}
153
154#[derive(Clone)]
156#[cfg_attr(
157 feature = "python",
158 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.okx", from_py_object)
159)]
160#[cfg_attr(
161 feature = "python",
162 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.okx")
163)]
164pub struct OKXWebSocketClient {
165 url: String,
166 #[allow(dead_code)] pub(crate) account_id: AccountId,
168 vip_level: Arc<AtomicU8>,
169 credential: Option<Credential>,
170 heartbeat: Option<u64>,
171 auth_timeout_secs: u64,
172 auth_tracker: AuthTracker,
173 signal: Arc<AtomicBool>,
174 connection_mode: Arc<ArcSwap<AtomicU8>>,
175 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
176 out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<OKXWsMessage>>>,
177 task_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
178 subscriptions_inst_type: Arc<DashMap<OKXWsChannel, AHashSet<OKXInstrumentType>>>,
179 subscriptions_inst_family: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
180 subscriptions_inst_id: Arc<DashMap<OKXWsChannel, AHashSet<Ustr>>>,
181 subscriptions_bare: Arc<DashMap<OKXWsChannel, bool>>,
182 subscriptions_state: SubscriptionState,
183 request_id_counter: Arc<AtomicU64>,
184 instruments_cache: Arc<AtomicMap<Ustr, InstrumentAny>>,
185 inst_id_code_cache: Arc<AtomicMap<Ustr, u64>>,
186 pub(crate) pending_orders: Arc<DashMap<String, PendingOrderInfo>>,
187 pub(crate) pending_cancels: Arc<DashMap<String, PendingOrderInfo>>,
188 pub(crate) pending_amends: Arc<DashMap<String, PendingOrderInfo>>,
189 option_greeks_subs: Arc<AtomicMap<InstrumentId, AHashSet<OKXGreeksType>>>,
190 index_pair_subscribers: Arc<DashMap<Ustr, usize>>,
197 index_pair_transition: Arc<tokio::sync::Mutex<()>>,
202 transport_backend: TransportBackend,
204 proxy_url: Option<String>,
206 cancellation_token: CancellationToken,
207}
208
209impl Default for OKXWebSocketClient {
210 fn default() -> Self {
211 Self::new(
212 None,
213 None,
214 None,
215 None,
216 None,
217 None,
218 None,
219 TransportBackend::default(),
220 None,
221 )
222 .unwrap()
223 }
224}
225
226impl Debug for OKXWebSocketClient {
227 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
228 f.debug_struct(stringify!(OKXWebSocketClient))
229 .field("url", &self.url)
230 .field("credential", &self.credential.as_ref().map(|_| REDACTED))
231 .field("heartbeat", &self.heartbeat)
232 .finish_non_exhaustive()
233 }
234}
235
236impl OKXWebSocketClient {
237 #[allow(clippy::too_many_arguments)]
243 pub fn new(
244 url: Option<String>,
245 api_key: Option<String>,
246 api_secret: Option<String>,
247 api_passphrase: Option<String>,
248 account_id: Option<AccountId>,
249 heartbeat: Option<u64>,
250 auth_timeout_secs: Option<u64>,
251 transport_backend: TransportBackend,
252 proxy_url: Option<String>,
253 ) -> anyhow::Result<Self> {
254 let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
255 let account_id = account_id.unwrap_or(AccountId::from("OKX-master"));
256
257 let credential = match (api_key, api_secret, api_passphrase) {
258 (Some(key), Some(secret), Some(passphrase)) => {
259 Some(Credential::new(key, secret, passphrase))
260 }
261 (None, None, None) => None,
262 _ => anyhow::bail!(
263 "`api_key`, `api_secret`, `api_passphrase` credentials must be provided together"
264 ),
265 };
266
267 let signal = Arc::new(AtomicBool::new(false));
268 let subscriptions_inst_type = Arc::new(DashMap::new());
269 let subscriptions_inst_family = Arc::new(DashMap::new());
270 let subscriptions_inst_id = Arc::new(DashMap::new());
271 let subscriptions_bare = Arc::new(DashMap::new());
272 let subscriptions_state = SubscriptionState::new(OKX_WS_TOPIC_DELIMITER);
273
274 Ok(Self {
275 url,
276 account_id,
277 vip_level: Arc::new(AtomicU8::new(0)),
278 credential,
279 heartbeat,
280 auth_timeout_secs: auth_timeout_secs.unwrap_or(AUTHENTICATION_TIMEOUT_SECS),
281 auth_tracker: AuthTracker::new(),
282 signal,
283 connection_mode: Arc::new(ArcSwap::from_pointee(AtomicU8::new(
284 ConnectionMode::Closed.as_u8(),
285 ))),
286 cmd_tx: {
287 let (tx, _) = tokio::sync::mpsc::unbounded_channel();
289 Arc::new(tokio::sync::RwLock::new(tx))
290 },
291 out_rx: None,
292 task_handle: None,
293 subscriptions_inst_type,
294 subscriptions_inst_family,
295 subscriptions_inst_id,
296 subscriptions_bare,
297 subscriptions_state,
298 request_id_counter: Arc::new(AtomicU64::new(1)),
299 instruments_cache: Arc::new(AtomicMap::new()),
300 inst_id_code_cache: Arc::new(AtomicMap::new()),
301 pending_orders: Arc::new(DashMap::new()),
302 pending_cancels: Arc::new(DashMap::new()),
303 pending_amends: Arc::new(DashMap::new()),
304 option_greeks_subs: Arc::new(AtomicMap::new()),
305 index_pair_subscribers: Arc::new(DashMap::new()),
306 index_pair_transition: Arc::new(tokio::sync::Mutex::new(())),
307 transport_backend,
308 proxy_url,
309 cancellation_token: CancellationToken::new(),
310 })
311 }
312
313 #[allow(clippy::too_many_arguments)]
320 pub fn with_credentials(
321 url: Option<String>,
322 api_key: Option<String>,
323 api_secret: Option<String>,
324 api_passphrase: Option<String>,
325 account_id: Option<AccountId>,
326 heartbeat: Option<u64>,
327 auth_timeout_secs: Option<u64>,
328 transport_backend: TransportBackend,
329 proxy_url: Option<String>,
330 ) -> anyhow::Result<Self> {
331 let url = url.unwrap_or(OKX_WS_PUBLIC_URL.to_string());
332 let api_key = get_or_env_var(api_key, "OKX_API_KEY")?;
333 let api_secret = get_or_env_var(api_secret, "OKX_API_SECRET")?;
334 let api_passphrase = get_or_env_var(api_passphrase, "OKX_API_PASSPHRASE")?;
335
336 Self::new(
337 Some(url),
338 Some(api_key),
339 Some(api_secret),
340 Some(api_passphrase),
341 account_id,
342 heartbeat,
343 auth_timeout_secs,
344 transport_backend,
345 proxy_url,
346 )
347 }
348
349 pub fn from_env() -> anyhow::Result<Self> {
356 let url = get_env_var("OKX_WS_URL")?;
357 let api_key = get_env_var("OKX_API_KEY")?;
358 let api_secret = get_env_var("OKX_API_SECRET")?;
359 let api_passphrase = get_env_var("OKX_API_PASSPHRASE")?;
360
361 Self::new(
362 Some(url),
363 Some(api_key),
364 Some(api_secret),
365 Some(api_passphrase),
366 None,
367 None,
368 None,
369 TransportBackend::default(),
370 None,
371 )
372 }
373
374 pub fn cancel_all_requests(&self) {
376 self.cancellation_token.cancel();
377 }
378
379 pub fn cancellation_token(&self) -> &CancellationToken {
381 &self.cancellation_token
382 }
383
384 pub fn url(&self) -> &str {
386 self.url.as_str()
387 }
388
389 pub fn api_key(&self) -> Option<&str> {
391 self.credential.as_ref().map(|c| c.api_key())
392 }
393
394 #[must_use]
396 pub fn api_key_masked(&self) -> Option<String> {
397 self.credential.as_ref().map(|c| c.api_key_masked())
398 }
399
400 pub fn is_active(&self) -> bool {
402 let connection_mode_arc = self.connection_mode.load();
403 ConnectionMode::from_atomic(&connection_mode_arc).is_active()
404 && !self.signal.load(Ordering::Acquire)
405 }
406
407 pub fn is_closed(&self) -> bool {
409 let connection_mode_arc = self.connection_mode.load();
410 ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
411 || self.signal.load(Ordering::Acquire)
412 }
413
414 pub fn cache_instruments(&self, instruments: &[InstrumentAny]) {
418 self.instruments_cache.rcu(|m| {
419 for inst in instruments {
420 m.insert(inst.symbol().inner(), inst.clone());
421 }
422 });
423 }
424
425 pub fn cache_instrument(&self, instrument: InstrumentAny) {
429 self.instruments_cache
430 .insert(instrument.symbol().inner(), instrument);
431 }
432
433 pub fn instruments_snapshot(&self) -> AHashMap<Ustr, InstrumentAny> {
435 (**self.instruments_cache.load()).clone()
436 }
437
438 pub fn cache_inst_id_code(&self, inst_id: Ustr, inst_id_code: u64) {
442 self.inst_id_code_cache.insert(inst_id, inst_id_code);
443 }
444
445 pub fn cache_inst_id_codes(&self, mappings: impl IntoIterator<Item = (Ustr, u64)>) {
449 let entries: Vec<_> = mappings.into_iter().collect();
450 self.inst_id_code_cache.rcu(|m| {
451 for (inst_id, inst_id_code) in &entries {
452 m.insert(*inst_id, *inst_id_code);
453 }
454 });
455 }
456
457 #[must_use]
461 pub fn get_inst_id_code(&self, inst_id: &Ustr) -> Option<u64> {
462 self.inst_id_code_cache.load().get(inst_id).copied()
463 }
464
465 pub fn set_vip_level(&self, vip_level: OKXVipLevel) {
469 self.vip_level.store(vip_level as u8, Ordering::Relaxed);
470 }
471
472 pub fn vip_level(&self) -> OKXVipLevel {
474 let level = self.vip_level.load(Ordering::Relaxed);
475 OKXVipLevel::from(level)
476 }
477
478 pub async fn connect(&mut self) -> anyhow::Result<()> {
488 self.signal.store(false, Ordering::Release);
490
491 let (message_handler, raw_rx) = channel_message_handler();
492
493 let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {
496 });
498
499 let headers = vec![(USER_AGENT.to_string(), NAUTILUS_USER_AGENT.to_string())];
500
501 let config = WebSocketConfig {
502 url: self.url.clone(),
503 headers,
504 heartbeat: self.heartbeat,
505 heartbeat_msg: Some(TEXT_PING.to_string()),
506 reconnect_timeout_ms: Some(5_000),
507 reconnect_delay_initial_ms: None,
508 reconnect_delay_max_ms: None,
509 reconnect_backoff_factor: None,
510 reconnect_jitter_ms: None,
511 reconnect_max_attempts: None,
512 idle_timeout_ms: None,
513 backend: self.transport_backend,
514 proxy_url: self.proxy_url.clone(),
515 };
516
517 let keyed_quotas = vec![
518 (
519 OKX_RATE_LIMIT_KEY_SUBSCRIPTION[0].as_str().to_string(),
520 *OKX_WS_SUBSCRIPTION_QUOTA,
521 ),
522 (
523 OKX_RATE_LIMIT_KEY_ORDER[0].as_str().to_string(),
524 *OKX_WS_ORDER_QUOTA,
525 ),
526 (
527 OKX_RATE_LIMIT_KEY_CANCEL[0].as_str().to_string(),
528 *OKX_WS_ORDER_QUOTA,
529 ),
530 (
531 OKX_RATE_LIMIT_KEY_AMEND[0].as_str().to_string(),
532 *OKX_WS_ORDER_QUOTA,
533 ),
534 ];
535
536 let client = WebSocketClient::connect(
537 config,
538 Some(message_handler),
539 Some(ping_handler),
540 None, keyed_quotas,
542 Some(*OKX_WS_CONNECTION_QUOTA), )
544 .await?;
545
546 self.connection_mode.store(client.connection_mode_atomic());
548
549 let (msg_tx, rx) = tokio::sync::mpsc::unbounded_channel::<OKXWsMessage>();
550
551 self.out_rx = Some(Arc::new(rx));
552
553 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
554 *self.cmd_tx.write().await = cmd_tx.clone();
555
556 let signal = self.signal.clone();
557 let auth_tracker = self.auth_tracker.clone();
558 let subscriptions_state = self.subscriptions_state.clone();
559
560 let stream_handle = get_runtime().spawn({
561 let auth_tracker = auth_tracker.clone();
562 let signal = signal.clone();
563 let credential = self.credential.clone();
564 let cmd_tx_for_reconnect = cmd_tx.clone();
565 let subscriptions_bare = self.subscriptions_bare.clone();
566 let subscriptions_inst_type = self.subscriptions_inst_type.clone();
567 let subscriptions_inst_family = self.subscriptions_inst_family.clone();
568 let subscriptions_inst_id = self.subscriptions_inst_id.clone();
569 let mut has_reconnected = false;
570
571 async move {
572 let mut handler = OKXWsFeedHandler::new(
573 signal.clone(),
574 cmd_rx,
575 raw_rx,
576 msg_tx,
577 auth_tracker.clone(),
578 subscriptions_state.clone(),
579 );
580
581 let resubscribe_all = || {
583 for entry in subscriptions_inst_id.iter() {
584 let (channel, inst_ids) = entry.pair();
585 for inst_id in inst_ids {
586 let arg = OKXSubscriptionArg {
587 channel: channel.clone(),
588 inst_type: None,
589 inst_family: None,
590 inst_id: Some(*inst_id),
591 };
592
593 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
594 log::error!("Failed to send resubscribe command: error={e}");
595 }
596 }
597 }
598
599 for entry in subscriptions_bare.iter() {
600 let channel = entry.key();
601 let arg = OKXSubscriptionArg {
602 channel: channel.clone(),
603 inst_type: None,
604 inst_family: None,
605 inst_id: None,
606 };
607
608 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
609 log::error!("Failed to send resubscribe command: error={e}");
610 }
611 }
612
613 for entry in subscriptions_inst_type.iter() {
614 let (channel, inst_types) = entry.pair();
615 for inst_type in inst_types {
616 let arg = OKXSubscriptionArg {
617 channel: channel.clone(),
618 inst_type: Some(*inst_type),
619 inst_family: None,
620 inst_id: None,
621 };
622
623 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
624 log::error!("Failed to send resubscribe command: error={e}");
625 }
626 }
627 }
628
629 for entry in subscriptions_inst_family.iter() {
630 let (channel, inst_families) = entry.pair();
631 for inst_family in inst_families {
632 let arg = OKXSubscriptionArg {
633 channel: channel.clone(),
634 inst_type: None,
635 inst_family: Some(*inst_family),
636 inst_id: None,
637 };
638
639 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Subscribe { args: vec![arg] }) {
640 log::error!("Failed to send resubscribe command: error={e}");
641 }
642 }
643 }
644 };
645
646 loop {
647 match handler.next().await {
648 Some(OKXWsMessage::Reconnected) => {
649 if signal.load(Ordering::Acquire) {
650 continue;
651 }
652
653 has_reconnected = true;
654
655 let confirmed_topics_vec: Vec<String> = {
657 let confirmed = subscriptions_state.confirmed();
658 let mut topics = Vec::new();
659
660 for entry in confirmed.iter() {
661 let channel = entry.key();
662 for symbol in entry.value() {
663 if symbol.as_str() == "#" {
664 topics.push(channel.to_string());
665 } else {
666 topics.push(format!("{channel}{OKX_WS_TOPIC_DELIMITER}{symbol}"));
667 }
668 }
669 }
670 topics
671 };
672
673 if !confirmed_topics_vec.is_empty() {
674 log::debug!("Marking confirmed subscriptions as pending for replay: count={}", confirmed_topics_vec.len());
675 for topic in confirmed_topics_vec {
676 subscriptions_state.mark_failure(&topic);
677 }
678 }
679
680 if let Some(cred) = &credential {
681 log::debug!("Re-authenticating after reconnection");
682 let timestamp = std::time::SystemTime::now()
683 .duration_since(std::time::SystemTime::UNIX_EPOCH)
684 .expect("System time should be after UNIX epoch")
685 .as_secs()
686 .to_string();
687 let signature = cred.sign(×tamp, "GET", "/users/self/verify", "");
688
689 let auth_message = super::messages::OKXAuthentication {
690 op: "login",
691 args: vec![super::messages::OKXAuthenticationArg {
692 api_key: cred.api_key().to_string(),
693 passphrase: cred.api_passphrase().to_string(),
694 timestamp,
695 sign: signature,
696 }],
697 };
698
699 if let Ok(payload) = serde_json::to_string(&auth_message) {
700 if let Err(e) = cmd_tx_for_reconnect.send(HandlerCommand::Authenticate { payload }) {
701 log::error!("Failed to send reconnection auth command: error={e}");
702 }
703 } else {
704 log::error!("Failed to serialize reconnection auth message");
705 }
706 }
707
708 if credential.is_none() {
711 log::debug!("No authentication required, resubscribing immediately");
712 resubscribe_all();
713 }
714
715 if handler.send(OKXWsMessage::Reconnected).is_err() {
717 log::error!("Failed to send Reconnected through channel: receiver dropped");
718 break;
719 }
720 }
721 Some(OKXWsMessage::Authenticated) => {
722 if has_reconnected {
723 resubscribe_all();
724 }
725 }
726 Some(msg) => {
727 if handler.send(msg).is_err() {
728 log::error!(
729 "Failed to send message through channel: receiver dropped",
730 );
731 break;
732 }
733 }
734 None => {
735 if handler.is_stopped() {
736 log::debug!(
737 "Stop signal received, ending message processing",
738 );
739 break;
740 }
741 log::debug!("WebSocket stream closed");
742 break;
743 }
744 }
745 }
746
747 log::debug!("Handler task exiting");
748 }
749 });
750
751 self.task_handle = Some(Arc::new(stream_handle));
752
753 self.cmd_tx
754 .read()
755 .await
756 .send(HandlerCommand::SetClient(client))
757 .map_err(|e| {
758 OKXWsError::ClientError(format!("Failed to send WebSocket client to handler: {e}"))
759 })?;
760 log::debug!("Sent WebSocket client to handler");
761
762 if self.credential.is_some()
763 && let Err(e) = self.authenticate().await
764 {
765 anyhow::bail!("Authentication failed: {e}");
766 }
767
768 Ok(())
769 }
770
771 async fn authenticate(&self) -> Result<(), Error> {
773 let credential = self.credential.as_ref().ok_or_else(|| {
774 Error::Io(std::io::Error::other(
775 "API credentials not available to authenticate",
776 ))
777 })?;
778
779 let rx = self.auth_tracker.begin();
780
781 let timestamp = SystemTime::now()
782 .duration_since(SystemTime::UNIX_EPOCH)
783 .expect("System time should be after UNIX epoch")
784 .as_secs()
785 .to_string();
786 let signature = credential.sign(×tamp, "GET", "/users/self/verify", "");
787
788 let auth_message = OKXAuthentication {
789 op: "login",
790 args: vec![OKXAuthenticationArg {
791 api_key: credential.api_key().to_string(),
792 passphrase: credential.api_passphrase().to_string(),
793 timestamp,
794 sign: signature,
795 }],
796 };
797
798 let payload = serde_json::to_string(&auth_message).map_err(|e| {
799 Error::Io(std::io::Error::other(format!(
800 "Failed to serialize auth message: {e}"
801 )))
802 })?;
803
804 self.cmd_tx
805 .read()
806 .await
807 .send(HandlerCommand::Authenticate { payload })
808 .map_err(|e| {
809 Error::Io(std::io::Error::other(format!(
810 "Failed to send authenticate command: {e}"
811 )))
812 })?;
813
814 match self
815 .auth_tracker
816 .wait_for_result::<OKXWsError>(Duration::from_secs(self.auth_timeout_secs), rx)
817 .await
818 {
819 Ok(()) => {
820 log::info!("WebSocket authenticated");
821 Ok(())
822 }
823 Err(e) => {
824 log::error!("WebSocket authentication failed: error={e}");
825 Err(Error::Io(std::io::Error::other(e.to_string())))
826 }
827 }
828 }
829
830 pub fn stream(&mut self) -> impl Stream<Item = OKXWsMessage> + 'static {
838 let rx = self
839 .out_rx
840 .take()
841 .expect("Data stream receiver already taken or not connected");
842 let mut rx = Arc::try_unwrap(rx).expect("Cannot take ownership - other references exist");
843 async_stream::stream! {
844 while let Some(data) = rx.recv().await {
845 yield data;
846 }
847 }
848 }
849
850 pub async fn wait_until_active(&self, timeout_secs: f64) -> Result<(), OKXWsError> {
856 let timeout = tokio::time::Duration::from_secs_f64(timeout_secs);
857
858 tokio::time::timeout(timeout, async {
859 while !self.is_active() {
860 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
861 }
862 })
863 .await
864 .map_err(|_| {
865 OKXWsError::ClientError(format!(
866 "WebSocket connection timeout after {timeout_secs} seconds"
867 ))
868 })?;
869
870 Ok(())
871 }
872
873 pub async fn close(&mut self) -> Result<(), Error> {
880 log::debug!("Starting close process");
881
882 self.signal.store(true, Ordering::Release);
883
884 if let Err(e) = self.cmd_tx.read().await.send(HandlerCommand::Disconnect) {
885 log::warn!("Failed to send disconnect command to handler: {e}");
886 } else {
887 log::debug!("Sent disconnect command to handler");
888 }
889
890 if let Some(stream_handle) = self.task_handle.take() {
891 match Arc::try_unwrap(stream_handle) {
892 Ok(handle) => {
893 log::debug!("Waiting for stream handle to complete");
894 let abort_handle = handle.abort_handle();
895 match tokio::time::timeout(Duration::from_secs(2), handle).await {
896 Ok(Ok(())) => log::debug!("Stream handle completed successfully"),
897 Ok(Err(e)) => log::error!("Stream handle encountered an error: {e:?}"),
898 Err(_) => {
899 log::warn!("Timeout waiting for stream handle, aborting task");
900 abort_handle.abort();
901 }
902 }
903 }
904 Err(arc_handle) => {
905 log::debug!(
906 "Cannot take ownership of stream handle - other references exist, aborting task"
907 );
908 arc_handle.abort();
909 }
910 }
911 } else {
912 log::debug!("No stream handle to await");
913 }
914
915 self.index_pair_subscribers.clear();
919
920 log::debug!("Close process completed");
921
922 Ok(())
923 }
924
925 pub fn get_subscriptions(&self, instrument_id: InstrumentId) -> Vec<OKXWsChannel> {
927 let symbol = instrument_id.symbol.inner();
928 let mut channels = Vec::new();
929
930 for entry in self.subscriptions_inst_id.iter() {
931 let (channel, instruments) = entry.pair();
932 if instruments.contains(&symbol) {
933 channels.push(channel.clone());
934 }
935 }
936
937 channels
938 }
939
940 fn generate_unique_request_id(&self) -> String {
941 self.request_id_counter
942 .fetch_add(1, Ordering::SeqCst)
943 .to_string()
944 }
945
946 async fn subscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
947 self.cmd_tx
949 .read()
950 .await
951 .send(HandlerCommand::Subscribe { args: args.clone() })
952 .map_err(|e| {
953 OKXWsError::ClientError(format!("Failed to send subscribe command: {e}"))
954 })?;
955
956 for arg in &args {
957 let topic = topic_from_subscription_arg(arg);
958 self.subscriptions_state.mark_subscribe(&topic);
959
960 if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
962 self.subscriptions_bare.insert(arg.channel.clone(), true);
963 } else {
964 if let Some(inst_type) = &arg.inst_type {
965 self.subscriptions_inst_type
966 .entry(arg.channel.clone())
967 .or_default()
968 .insert(*inst_type);
969 }
970
971 if let Some(inst_family) = &arg.inst_family {
972 self.subscriptions_inst_family
973 .entry(arg.channel.clone())
974 .or_default()
975 .insert(*inst_family);
976 }
977
978 if let Some(inst_id) = &arg.inst_id {
979 self.subscriptions_inst_id
980 .entry(arg.channel.clone())
981 .or_default()
982 .insert(*inst_id);
983 }
984 }
985 }
986
987 Ok(())
988 }
989
990 #[expect(clippy::collapsible_if)]
991 async fn unsubscribe(&self, args: Vec<OKXSubscriptionArg>) -> Result<(), OKXWsError> {
992 self.cmd_tx
994 .read()
995 .await
996 .send(HandlerCommand::Unsubscribe { args: args.clone() })
997 .map_err(|e| {
998 OKXWsError::ClientError(format!("Failed to send unsubscribe command: {e}"))
999 })?;
1000
1001 for arg in &args {
1002 let topic = topic_from_subscription_arg(arg);
1003 self.subscriptions_state.mark_unsubscribe(&topic);
1004
1005 if arg.inst_type.is_none() && arg.inst_family.is_none() && arg.inst_id.is_none() {
1006 self.subscriptions_bare.remove(&arg.channel);
1007 } else {
1008 if let Some(inst_type) = &arg.inst_type {
1009 if let Some(mut entry) = self.subscriptions_inst_type.get_mut(&arg.channel) {
1010 entry.remove(inst_type);
1011 if entry.is_empty() {
1012 drop(entry);
1013 self.subscriptions_inst_type.remove(&arg.channel);
1014 }
1015 }
1016 }
1017
1018 if let Some(inst_family) = &arg.inst_family {
1019 if let Some(mut entry) = self.subscriptions_inst_family.get_mut(&arg.channel) {
1020 entry.remove(inst_family);
1021 if entry.is_empty() {
1022 drop(entry);
1023 self.subscriptions_inst_family.remove(&arg.channel);
1024 }
1025 }
1026 }
1027
1028 if let Some(inst_id) = &arg.inst_id {
1029 if let Some(mut entry) = self.subscriptions_inst_id.get_mut(&arg.channel) {
1030 entry.remove(inst_id);
1031 if entry.is_empty() {
1032 drop(entry);
1033 self.subscriptions_inst_id.remove(&arg.channel);
1034 }
1035 }
1036 }
1037 }
1038 }
1039
1040 Ok(())
1041 }
1042
1043 async fn subscribe_inst_id(
1044 &self,
1045 channel: OKXWsChannel,
1046 inst_id: Ustr,
1047 ) -> Result<(), OKXWsError> {
1048 self.subscribe(vec![OKXSubscriptionArg {
1049 channel,
1050 inst_type: None,
1051 inst_family: None,
1052 inst_id: Some(inst_id),
1053 }])
1054 .await
1055 }
1056
1057 async fn unsubscribe_inst_id(
1058 &self,
1059 channel: OKXWsChannel,
1060 inst_id: Ustr,
1061 ) -> Result<(), OKXWsError> {
1062 self.unsubscribe(vec![OKXSubscriptionArg {
1063 channel,
1064 inst_type: None,
1065 inst_family: None,
1066 inst_id: Some(inst_id),
1067 }])
1068 .await
1069 }
1070
1071 pub async fn unsubscribe_all(&self) -> Result<(), OKXWsError> {
1080 const BATCH_SIZE: usize = 256;
1081
1082 let mut all_args = Vec::new();
1083
1084 for entry in self.subscriptions_inst_type.iter() {
1085 let (channel, inst_types) = entry.pair();
1086 for inst_type in inst_types {
1087 all_args.push(OKXSubscriptionArg {
1088 channel: channel.clone(),
1089 inst_type: Some(*inst_type),
1090 inst_family: None,
1091 inst_id: None,
1092 });
1093 }
1094 }
1095
1096 for entry in self.subscriptions_inst_family.iter() {
1097 let (channel, inst_families) = entry.pair();
1098 for inst_family in inst_families {
1099 all_args.push(OKXSubscriptionArg {
1100 channel: channel.clone(),
1101 inst_type: None,
1102 inst_family: Some(*inst_family),
1103 inst_id: None,
1104 });
1105 }
1106 }
1107
1108 for entry in self.subscriptions_inst_id.iter() {
1109 let (channel, inst_ids) = entry.pair();
1110 for inst_id in inst_ids {
1111 all_args.push(OKXSubscriptionArg {
1112 channel: channel.clone(),
1113 inst_type: None,
1114 inst_family: None,
1115 inst_id: Some(*inst_id),
1116 });
1117 }
1118 }
1119
1120 for entry in self.subscriptions_bare.iter() {
1121 let channel = entry.key();
1122 all_args.push(OKXSubscriptionArg {
1123 channel: channel.clone(),
1124 inst_type: None,
1125 inst_family: None,
1126 inst_id: None,
1127 });
1128 }
1129
1130 if all_args.is_empty() {
1131 log::debug!("No active subscriptions to unsubscribe from");
1132 return Ok(());
1133 }
1134
1135 log::debug!("Batched unsubscribe from {} channels", all_args.len());
1136
1137 for chunk in all_args.chunks(BATCH_SIZE) {
1138 self.unsubscribe(chunk.to_vec()).await?;
1139 }
1140
1141 self.index_pair_subscribers.clear();
1145
1146 Ok(())
1147 }
1148
1149 pub async fn subscribe_instruments(
1161 &self,
1162 instrument_type: OKXInstrumentType,
1163 ) -> Result<(), OKXWsError> {
1164 let arg = OKXSubscriptionArg {
1165 channel: OKXWsChannel::Instruments,
1166 inst_type: Some(instrument_type),
1167 inst_family: None,
1168 inst_id: None,
1169 };
1170 self.subscribe(vec![arg]).await
1171 }
1172
1173 pub async fn subscribe_instrument(
1187 &self,
1188 instrument_id: InstrumentId,
1189 ) -> Result<(), OKXWsError> {
1190 let inst_type = okx_instrument_type_from_symbol(instrument_id.symbol.as_str());
1191 log::debug!("Subscribing to instrument type {inst_type:?} for {instrument_id}");
1192 self.subscribe_instruments(inst_type).await
1193 }
1194
1195 pub async fn subscribe_book(&self, instrument_id: InstrumentId) -> anyhow::Result<()> {
1204 self.subscribe_book_with_depth(instrument_id, 0).await
1205 }
1206
1207 pub(crate) async fn subscribe_books_channel(
1209 &self,
1210 instrument_id: InstrumentId,
1211 ) -> Result<(), OKXWsError> {
1212 self.subscribe_inst_id(OKXWsChannel::Books, instrument_id.symbol.inner())
1213 .await
1214 }
1215
1216 pub async fn subscribe_book_depth5(
1228 &self,
1229 instrument_id: InstrumentId,
1230 ) -> Result<(), OKXWsError> {
1231 self.subscribe_inst_id(OKXWsChannel::Books5, instrument_id.symbol.inner())
1232 .await
1233 }
1234
1235 pub async fn subscribe_book50_l2_tbt(
1247 &self,
1248 instrument_id: InstrumentId,
1249 ) -> Result<(), OKXWsError> {
1250 self.subscribe_inst_id(OKXWsChannel::Books50Tbt, instrument_id.symbol.inner())
1251 .await
1252 }
1253
1254 pub async fn subscribe_book_l2_tbt(
1266 &self,
1267 instrument_id: InstrumentId,
1268 ) -> Result<(), OKXWsError> {
1269 self.subscribe_inst_id(OKXWsChannel::BooksTbt, instrument_id.symbol.inner())
1270 .await
1271 }
1272
1273 pub async fn subscribe_book_with_depth(
1287 &self,
1288 instrument_id: InstrumentId,
1289 depth: u16,
1290 ) -> anyhow::Result<()> {
1291 let vip = self.vip_level();
1292
1293 match depth {
1294 50 => {
1295 if vip < OKXVipLevel::Vip4 {
1296 anyhow::bail!(
1297 "VIP level {vip} insufficient for 50 depth subscription (requires VIP4)"
1298 );
1299 }
1300 self.subscribe_book50_l2_tbt(instrument_id)
1301 .await
1302 .map_err(|e| anyhow::anyhow!(e))
1303 }
1304 0 | 400 => {
1305 if vip >= OKXVipLevel::Vip5 {
1306 self.subscribe_book_l2_tbt(instrument_id)
1307 .await
1308 .map_err(|e| anyhow::anyhow!(e))
1309 } else {
1310 self.subscribe_books_channel(instrument_id)
1311 .await
1312 .map_err(|e| anyhow::anyhow!(e))
1313 }
1314 }
1315 _ => anyhow::bail!("Invalid depth {depth}, must be 0, 50, or 400"),
1316 }
1317 }
1318
1319 pub async fn subscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1332 self.subscribe_inst_id(OKXWsChannel::BboTbt, instrument_id.symbol.inner())
1333 .await
1334 }
1335
1336 pub async fn subscribe_trades(
1350 &self,
1351 instrument_id: InstrumentId,
1352 aggregated: bool,
1353 ) -> Result<(), OKXWsError> {
1354 let channel = if aggregated {
1355 OKXWsChannel::TradesAll
1356 } else {
1357 OKXWsChannel::Trades
1358 };
1359 self.subscribe_inst_id(channel, instrument_id.symbol.inner())
1360 .await
1361 }
1362
1363 pub async fn subscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1375 self.subscribe_inst_id(OKXWsChannel::Tickers, instrument_id.symbol.inner())
1376 .await
1377 }
1378
1379 pub async fn subscribe_mark_prices(
1391 &self,
1392 instrument_id: InstrumentId,
1393 ) -> Result<(), OKXWsError> {
1394 self.subscribe_inst_id(OKXWsChannel::MarkPrice, instrument_id.symbol.inner())
1395 .await
1396 }
1397
1398 pub async fn subscribe_index_prices(
1410 &self,
1411 instrument_id: InstrumentId,
1412 ) -> Result<(), OKXWsError> {
1413 let symbol = instrument_id.symbol.inner();
1415 let (base, quote) = parse_base_quote_from_symbol(symbol.as_str())
1416 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1417 let base_pair = Ustr::from(&format!("{base}-{quote}"));
1418
1419 let _guard = self.index_pair_transition.lock().await;
1425
1426 let is_first = {
1431 let mut count = self.index_pair_subscribers.entry(base_pair).or_insert(0);
1432 *count += 1;
1433 *count == 1
1434 };
1435
1436 if !is_first {
1437 return Ok(());
1438 }
1439
1440 let arg = OKXSubscriptionArg {
1441 channel: OKXWsChannel::IndexTickers,
1442 inst_type: None,
1443 inst_family: None,
1444 inst_id: Some(base_pair),
1445 };
1446
1447 match self.subscribe(vec![arg]).await {
1448 Ok(()) => Ok(()),
1449 Err(e) => {
1450 self.index_pair_subscribers.remove(&base_pair);
1459 Err(e)
1460 }
1461 }
1462 }
1463
1464 pub async fn subscribe_option_summary(&self, inst_family: Ustr) -> Result<(), OKXWsError> {
1477 let arg = OKXSubscriptionArg {
1478 channel: OKXWsChannel::OptionSummary,
1479 inst_type: None,
1480 inst_family: Some(inst_family),
1481 inst_id: None,
1482 };
1483 self.subscribe(vec![arg]).await
1484 }
1485
1486 pub fn option_greeks_subs(&self) -> &Arc<AtomicMap<InstrumentId, AHashSet<OKXGreeksType>>> {
1490 &self.option_greeks_subs
1491 }
1492
1493 pub fn add_option_greeks_sub(&self, instrument_id: InstrumentId) {
1496 let both: AHashSet<OKXGreeksType> =
1497 [OKXGreeksType::Bs, OKXGreeksType::Pa].into_iter().collect();
1498 self.option_greeks_subs.insert(instrument_id, both);
1499 }
1500
1501 pub fn add_option_greeks_sub_with_conventions(
1504 &self,
1505 instrument_id: InstrumentId,
1506 conventions: AHashSet<OKXGreeksType>,
1507 ) {
1508 let set = if conventions.is_empty() {
1509 [OKXGreeksType::Bs, OKXGreeksType::Pa].into_iter().collect()
1510 } else {
1511 conventions
1512 };
1513 self.option_greeks_subs.insert(instrument_id, set);
1514 }
1515
1516 pub fn remove_option_greeks_sub(&self, instrument_id: &InstrumentId) {
1518 self.option_greeks_subs.remove(instrument_id);
1519 }
1520
1521 pub async fn subscribe_funding_rates(
1533 &self,
1534 instrument_id: InstrumentId,
1535 ) -> Result<(), OKXWsError> {
1536 self.subscribe_inst_id(OKXWsChannel::FundingRate, instrument_id.symbol.inner())
1537 .await
1538 }
1539
1540 pub async fn subscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1552 let channel = bar_spec_as_okx_channel(bar_type.spec())
1554 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1555 self.subscribe_inst_id(channel, bar_type.instrument_id().symbol.inner())
1556 .await
1557 }
1558
1559 pub async fn unsubscribe_instruments(
1565 &self,
1566 instrument_type: OKXInstrumentType,
1567 ) -> Result<(), OKXWsError> {
1568 let arg = OKXSubscriptionArg {
1569 channel: OKXWsChannel::Instruments,
1570 inst_type: Some(instrument_type),
1571 inst_family: None,
1572 inst_id: None,
1573 };
1574 self.unsubscribe(vec![arg]).await
1575 }
1576
1577 pub async fn unsubscribe_instrument(
1587 &self,
1588 instrument_id: InstrumentId,
1589 ) -> Result<(), OKXWsError> {
1590 log::debug!("Instrument unsubscribe is a no-op (shared per-type channel): {instrument_id}");
1591 Ok(())
1592 }
1593
1594 pub async fn unsubscribe_book(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1600 self.unsubscribe_inst_id(OKXWsChannel::Books, instrument_id.symbol.inner())
1601 .await
1602 }
1603
1604 pub async fn unsubscribe_book_depth5(
1610 &self,
1611 instrument_id: InstrumentId,
1612 ) -> Result<(), OKXWsError> {
1613 self.unsubscribe_inst_id(OKXWsChannel::Books5, instrument_id.symbol.inner())
1614 .await
1615 }
1616
1617 pub async fn unsubscribe_book50_l2_tbt(
1623 &self,
1624 instrument_id: InstrumentId,
1625 ) -> Result<(), OKXWsError> {
1626 self.unsubscribe_inst_id(OKXWsChannel::Books50Tbt, instrument_id.symbol.inner())
1627 .await
1628 }
1629
1630 pub async fn unsubscribe_book_l2_tbt(
1636 &self,
1637 instrument_id: InstrumentId,
1638 ) -> Result<(), OKXWsError> {
1639 self.unsubscribe_inst_id(OKXWsChannel::BooksTbt, instrument_id.symbol.inner())
1640 .await
1641 }
1642
1643 pub async fn unsubscribe_quotes(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1649 self.unsubscribe_inst_id(OKXWsChannel::BboTbt, instrument_id.symbol.inner())
1650 .await
1651 }
1652
1653 pub async fn unsubscribe_ticker(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
1659 self.unsubscribe_inst_id(OKXWsChannel::Tickers, instrument_id.symbol.inner())
1660 .await
1661 }
1662
1663 pub async fn unsubscribe_mark_prices(
1669 &self,
1670 instrument_id: InstrumentId,
1671 ) -> Result<(), OKXWsError> {
1672 self.unsubscribe_inst_id(OKXWsChannel::MarkPrice, instrument_id.symbol.inner())
1673 .await
1674 }
1675
1676 pub async fn unsubscribe_index_prices(
1689 &self,
1690 instrument_id: InstrumentId,
1691 ) -> Result<(), OKXWsError> {
1692 let symbol = instrument_id.symbol.inner();
1693 let (base, quote) = parse_base_quote_from_symbol(symbol.as_str())
1694 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1695 let base_pair = Ustr::from(&format!("{base}-{quote}"));
1696
1697 let _guard = self.index_pair_transition.lock().await;
1700
1701 let is_last = {
1702 let Some(mut count) = self.index_pair_subscribers.get_mut(&base_pair) else {
1703 return Ok(());
1705 };
1706 *count = count.saturating_sub(1);
1707 *count == 0
1708 };
1709
1710 if !is_last {
1711 return Ok(());
1712 }
1713
1714 self.index_pair_subscribers
1715 .remove_if(&base_pair, |_, count| *count == 0);
1716
1717 let arg = OKXSubscriptionArg {
1718 channel: OKXWsChannel::IndexTickers,
1719 inst_type: None,
1720 inst_family: None,
1721 inst_id: Some(base_pair),
1722 };
1723 self.unsubscribe(vec![arg]).await
1724 }
1725
1726 pub async fn unsubscribe_option_summary(&self, inst_family: Ustr) -> Result<(), OKXWsError> {
1732 let arg = OKXSubscriptionArg {
1733 channel: OKXWsChannel::OptionSummary,
1734 inst_type: None,
1735 inst_family: Some(inst_family),
1736 inst_id: None,
1737 };
1738 self.unsubscribe(vec![arg]).await
1739 }
1740
1741 pub async fn unsubscribe_funding_rates(
1747 &self,
1748 instrument_id: InstrumentId,
1749 ) -> Result<(), OKXWsError> {
1750 self.unsubscribe_inst_id(OKXWsChannel::FundingRate, instrument_id.symbol.inner())
1751 .await
1752 }
1753
1754 pub async fn unsubscribe_trades(
1760 &self,
1761 instrument_id: InstrumentId,
1762 aggregated: bool,
1763 ) -> Result<(), OKXWsError> {
1764 let channel = if aggregated {
1765 OKXWsChannel::TradesAll
1766 } else {
1767 OKXWsChannel::Trades
1768 };
1769 self.unsubscribe_inst_id(channel, instrument_id.symbol.inner())
1770 .await
1771 }
1772
1773 pub async fn unsubscribe_bars(&self, bar_type: BarType) -> Result<(), OKXWsError> {
1779 let channel = bar_spec_as_okx_channel(bar_type.spec())
1780 .map_err(|e| OKXWsError::ClientError(e.to_string()))?;
1781 self.unsubscribe_inst_id(channel, bar_type.instrument_id().symbol.inner())
1782 .await
1783 }
1784
1785 pub async fn subscribe_orders(
1791 &self,
1792 instrument_type: OKXInstrumentType,
1793 ) -> Result<(), OKXWsError> {
1794 let arg = OKXSubscriptionArg {
1795 channel: OKXWsChannel::Orders,
1796 inst_type: Some(instrument_type),
1797 inst_family: None,
1798 inst_id: None,
1799 };
1800 self.subscribe(vec![arg]).await
1801 }
1802
1803 pub async fn unsubscribe_orders(
1809 &self,
1810 instrument_type: OKXInstrumentType,
1811 ) -> Result<(), OKXWsError> {
1812 let arg = OKXSubscriptionArg {
1813 channel: OKXWsChannel::Orders,
1814 inst_type: Some(instrument_type),
1815 inst_family: None,
1816 inst_id: None,
1817 };
1818 self.unsubscribe(vec![arg]).await
1819 }
1820
1821 pub async fn subscribe_orders_algo(
1827 &self,
1828 instrument_type: OKXInstrumentType,
1829 ) -> Result<(), OKXWsError> {
1830 let arg = OKXSubscriptionArg {
1831 channel: OKXWsChannel::OrdersAlgo,
1832 inst_type: Some(instrument_type),
1833 inst_family: None,
1834 inst_id: None,
1835 };
1836 self.subscribe(vec![arg]).await
1837 }
1838
1839 pub async fn unsubscribe_orders_algo(
1845 &self,
1846 instrument_type: OKXInstrumentType,
1847 ) -> Result<(), OKXWsError> {
1848 let arg = OKXSubscriptionArg {
1849 channel: OKXWsChannel::OrdersAlgo,
1850 inst_type: Some(instrument_type),
1851 inst_family: None,
1852 inst_id: None,
1853 };
1854 self.unsubscribe(vec![arg]).await
1855 }
1856
1857 pub async fn subscribe_algo_advance(
1863 &self,
1864 instrument_type: OKXInstrumentType,
1865 ) -> Result<(), OKXWsError> {
1866 let arg = OKXSubscriptionArg {
1867 channel: OKXWsChannel::AlgoAdvance,
1868 inst_type: Some(instrument_type),
1869 inst_family: None,
1870 inst_id: None,
1871 };
1872 self.subscribe(vec![arg]).await
1873 }
1874
1875 pub async fn unsubscribe_algo_advance(
1881 &self,
1882 instrument_type: OKXInstrumentType,
1883 ) -> Result<(), OKXWsError> {
1884 let arg = OKXSubscriptionArg {
1885 channel: OKXWsChannel::AlgoAdvance,
1886 inst_type: Some(instrument_type),
1887 inst_family: None,
1888 inst_id: None,
1889 };
1890 self.unsubscribe(vec![arg]).await
1891 }
1892
1893 pub async fn subscribe_fills(
1899 &self,
1900 instrument_type: OKXInstrumentType,
1901 ) -> Result<(), OKXWsError> {
1902 let arg = OKXSubscriptionArg {
1903 channel: OKXWsChannel::Fills,
1904 inst_type: Some(instrument_type),
1905 inst_family: None,
1906 inst_id: None,
1907 };
1908 self.subscribe(vec![arg]).await
1909 }
1910
1911 pub async fn unsubscribe_fills(
1917 &self,
1918 instrument_type: OKXInstrumentType,
1919 ) -> Result<(), OKXWsError> {
1920 let arg = OKXSubscriptionArg {
1921 channel: OKXWsChannel::Fills,
1922 inst_type: Some(instrument_type),
1923 inst_family: None,
1924 inst_id: None,
1925 };
1926 self.unsubscribe(vec![arg]).await
1927 }
1928
1929 pub async fn subscribe_account(&self) -> Result<(), OKXWsError> {
1935 let arg = OKXSubscriptionArg {
1936 channel: OKXWsChannel::Account,
1937 inst_type: None,
1938 inst_family: None,
1939 inst_id: None,
1940 };
1941 self.subscribe(vec![arg]).await
1942 }
1943
1944 pub async fn unsubscribe_account(&self) -> Result<(), OKXWsError> {
1950 let arg = OKXSubscriptionArg {
1951 channel: OKXWsChannel::Account,
1952 inst_type: None,
1953 inst_family: None,
1954 inst_id: None,
1955 };
1956 self.unsubscribe(vec![arg]).await
1957 }
1958
1959 pub async fn subscribe_positions(
1969 &self,
1970 inst_type: OKXInstrumentType,
1971 ) -> Result<(), OKXWsError> {
1972 let arg = OKXSubscriptionArg {
1973 channel: OKXWsChannel::Positions,
1974 inst_type: Some(inst_type),
1975 inst_family: None,
1976 inst_id: None,
1977 };
1978 self.subscribe(vec![arg]).await
1979 }
1980
1981 pub async fn unsubscribe_positions(
1987 &self,
1988 inst_type: OKXInstrumentType,
1989 ) -> Result<(), OKXWsError> {
1990 let arg = OKXSubscriptionArg {
1991 channel: OKXWsChannel::Positions,
1992 inst_type: Some(inst_type),
1993 inst_family: None,
1994 inst_id: None,
1995 };
1996 self.unsubscribe(vec![arg]).await
1997 }
1998
1999 async fn ws_batch_place_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
2005 let request_id = self.generate_unique_request_id();
2006 let request = OKXWsRequest::<Value> {
2007 id: Some(request_id.clone()),
2008 op: super::enums::OKXWsOperation::BatchOrders,
2009 exp_time: None,
2010 args,
2011 };
2012
2013 let payload = serde_json::to_string(&request)
2014 .map_err(|e| OKXWsError::JsonError(format!("Failed to serialize batch orders: {e}")))?;
2015
2016 let cmd = HandlerCommand::Send {
2017 payload,
2018 rate_limit_keys: Some(OKX_RATE_LIMIT_KEY_ORDER.to_vec()),
2019 request_id: Some(request_id),
2020 client_order_id: None,
2021 op: Some(super::enums::OKXWsOperation::BatchOrders),
2022 };
2023
2024 self.send_cmd(cmd).await
2025 }
2026
2027 async fn ws_batch_cancel_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
2033 let request_id = self.generate_unique_request_id();
2034 let request = OKXWsRequest::<Value> {
2035 id: Some(request_id.clone()),
2036 op: super::enums::OKXWsOperation::BatchCancelOrders,
2037 exp_time: None,
2038 args,
2039 };
2040
2041 let payload = serde_json::to_string(&request)
2042 .map_err(|e| OKXWsError::JsonError(format!("Failed to serialize batch cancel: {e}")))?;
2043
2044 let cmd = HandlerCommand::Send {
2045 payload,
2046 rate_limit_keys: Some(OKX_RATE_LIMIT_KEY_CANCEL.to_vec()),
2047 request_id: Some(request_id),
2048 client_order_id: None,
2049 op: Some(super::enums::OKXWsOperation::BatchCancelOrders),
2050 };
2051
2052 self.send_cmd(cmd).await
2053 }
2054
2055 async fn ws_batch_amend_orders(&self, args: Vec<Value>) -> Result<(), OKXWsError> {
2061 let request_id = self.generate_unique_request_id();
2062 let request = OKXWsRequest::<Value> {
2063 id: Some(request_id.clone()),
2064 op: super::enums::OKXWsOperation::BatchAmendOrders,
2065 exp_time: None,
2066 args,
2067 };
2068
2069 let payload = serde_json::to_string(&request)
2070 .map_err(|e| OKXWsError::JsonError(format!("Failed to serialize batch amend: {e}")))?;
2071
2072 let cmd = HandlerCommand::Send {
2073 payload,
2074 rate_limit_keys: Some(OKX_RATE_LIMIT_KEY_AMEND.to_vec()),
2075 request_id: Some(request_id),
2076 client_order_id: None,
2077 op: Some(super::enums::OKXWsOperation::BatchAmendOrders),
2078 };
2079
2080 self.send_cmd(cmd).await
2081 }
2082
2083 #[expect(clippy::too_many_arguments)]
2095 pub async fn submit_order(
2096 &self,
2097 trader_id: TraderId,
2098 strategy_id: StrategyId,
2099 instrument_id: InstrumentId,
2100 td_mode: OKXTradeMode,
2101 client_order_id: ClientOrderId,
2102 order_side: OrderSide,
2103 order_type: OrderType,
2104 quantity: Quantity,
2105 time_in_force: Option<TimeInForce>,
2106 price: Option<Price>,
2107 trigger_price: Option<Price>,
2108 post_only: Option<bool>,
2109 reduce_only: Option<bool>,
2110 quote_quantity: Option<bool>,
2111 position_side: Option<PositionSide>,
2112 attach_algo_ords: Option<Vec<WsAttachAlgoOrdParams>>,
2113 px_usd: Option<String>,
2114 px_vol: Option<String>,
2115 ) -> Result<(), OKXWsError> {
2116 if !OKX_SUPPORTED_ORDER_TYPES.contains(&order_type) {
2117 return Err(OKXWsError::ClientError(format!(
2118 "Unsupported order type: {order_type:?}",
2119 )));
2120 }
2121
2122 if let Some(tif) = time_in_force
2123 && !OKX_SUPPORTED_TIME_IN_FORCE.contains(&tif)
2124 {
2125 return Err(OKXWsError::ClientError(format!(
2126 "Unsupported time in force: {tif:?}",
2127 )));
2128 }
2129
2130 let mut builder = WsPostOrderParamsBuilder::default();
2131
2132 let inst_id_code = self
2133 .get_inst_id_code(&instrument_id.symbol.inner())
2134 .ok_or_else(|| {
2135 OKXWsError::ClientError(format!(
2136 "No instIdCode cached for {instrument_id}, cannot submit order"
2137 ))
2138 })?;
2139 builder.inst_id_code(inst_id_code);
2140
2141 builder.td_mode(td_mode);
2142 builder.cl_ord_id(client_order_id.as_str());
2143
2144 let instrument = self
2145 .instruments_cache
2146 .get_cloned(&instrument_id.symbol.inner())
2147 .ok_or_else(|| {
2148 OKXWsError::ClientError(format!("Unknown instrument {instrument_id}"))
2149 })?;
2150
2151 let instrument_type =
2152 okx_instrument_type(&instrument).map_err(|e| OKXWsError::ClientError(e.to_string()))?;
2153 let quote_currency = instrument.quote_currency();
2154
2155 if instrument_type == OKXInstrumentType::Option
2157 && matches!(order_type, OrderType::Market | OrderType::MarketToLimit)
2158 {
2159 return Err(OKXWsError::ClientError(
2160 "Market orders are not supported for OKX options, use Limit orders instead"
2161 .to_string(),
2162 ));
2163 }
2164
2165 match instrument_type {
2166 OKXInstrumentType::Spot => {
2167 builder.ccy(quote_currency.to_string());
2169 }
2170 OKXInstrumentType::Margin => {
2171 builder.ccy(quote_currency.to_string());
2172
2173 if let Some(ro) = reduce_only
2174 && ro
2175 {
2176 builder.reduce_only(ro);
2177 }
2178 }
2179 OKXInstrumentType::Swap | OKXInstrumentType::Futures => {
2180 builder.ccy(quote_currency.to_string());
2182
2183 if position_side.is_none() {
2186 builder.pos_side(OKXPositionSide::Net);
2187 }
2188 }
2189 OKXInstrumentType::Option => {
2190 builder.ccy(quote_currency.to_string());
2191
2192 if position_side.is_none() {
2193 builder.pos_side(OKXPositionSide::Net);
2194 }
2195 }
2197 _ => {
2198 builder.ccy(quote_currency.to_string());
2199
2200 if position_side.is_none() {
2201 builder.pos_side(OKXPositionSide::Net);
2202 }
2203
2204 if let Some(ro) = reduce_only
2205 && ro
2206 {
2207 builder.reduce_only(ro);
2208 }
2209 }
2210 }
2211
2212 if let Some(attach_algo_ords) = attach_algo_ords {
2213 builder.attach_algo_ords(attach_algo_ords);
2214 }
2215
2216 if instrument_type == OKXInstrumentType::Spot
2223 && order_type == OrderType::Market
2224 && td_mode == OKXTradeMode::Cash
2225 {
2226 match quote_quantity {
2227 Some(true) => {
2228 builder.tgt_ccy(OKXTargetCurrency::QuoteCcy);
2229 }
2230 Some(false) if order_side == OrderSide::Buy => {
2232 builder.tgt_ccy(OKXTargetCurrency::BaseCcy);
2233 }
2234 Some(false) | None => {}
2236 }
2237 }
2238
2239 builder.side(order_side.as_specified());
2240
2241 if let Some(pos_side) = position_side {
2242 builder.pos_side(pos_side);
2243 }
2244
2245 let (okx_ord_type, price) = if post_only.unwrap_or(false) {
2249 (OKXOrderType::PostOnly, price)
2250 } else if let Some(tif) = time_in_force {
2251 match (order_type, tif) {
2252 (OrderType::Market, TimeInForce::Fok) => {
2253 return Err(OKXWsError::ClientError(
2254 "Market orders with FOK time-in-force are not supported by OKX. Use Limit order with FOK instead.".to_string()
2255 ));
2256 }
2257 (OrderType::Market, TimeInForce::Ioc) => {
2258 if matches!(
2260 instrument_type,
2261 OKXInstrumentType::Spot | OKXInstrumentType::Option
2262 ) {
2263 (OKXOrderType::Market, price)
2264 } else {
2265 (OKXOrderType::OptimalLimitIoc, price)
2266 }
2267 }
2268 (OrderType::Limit, TimeInForce::Fok) => {
2269 if instrument_type == OKXInstrumentType::Option {
2271 (OKXOrderType::OpFok, price)
2272 } else {
2273 (OKXOrderType::Fok, price)
2274 }
2275 }
2276 (OrderType::Limit, TimeInForce::Ioc) => (OKXOrderType::Ioc, price),
2277 _ => (OKXOrderType::from(order_type), price),
2278 }
2279 } else {
2280 (OKXOrderType::from(order_type), price)
2281 };
2282
2283 log::debug!(
2284 "Order type mapping: order_type={order_type:?}, time_in_force={time_in_force:?}, post_only={post_only:?} -> okx_ord_type={okx_ord_type:?}"
2285 );
2286
2287 builder.ord_type(okx_ord_type);
2288 builder.sz(quantity.to_string());
2289
2290 if let Some(usd) = px_usd {
2292 builder.px_usd(usd);
2293 } else if let Some(vol) = px_vol {
2294 builder.px_vol(vol);
2295 } else if let Some(tp) = trigger_price {
2296 builder.px(tp.to_string());
2297 } else if let Some(p) = price {
2298 builder.px(p.to_string());
2299 }
2300
2301 builder.tag(OKX_NAUTILUS_BROKER_ID);
2302
2303 let params = builder
2304 .build()
2305 .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
2306
2307 let request_id = self.generate_unique_request_id();
2308 let request = OKXWsRequest {
2309 id: Some(request_id.clone()),
2310 op: super::enums::OKXWsOperation::Order,
2311 exp_time: None,
2312 args: vec![params],
2313 };
2314
2315 let payload = serde_json::to_string(&request)
2316 .map_err(|e| OKXWsError::JsonError(format!("Failed to serialize order: {e}")))?;
2317
2318 let cl_ord_key = client_order_id.to_string();
2319 self.pending_orders.insert(
2320 cl_ord_key.clone(),
2321 PendingOrderInfo {
2322 trader_id,
2323 strategy_id,
2324 instrument_id,
2325 },
2326 );
2327
2328 let cmd = HandlerCommand::Send {
2329 payload,
2330 rate_limit_keys: Some(OKX_RATE_LIMIT_KEY_ORDER.to_vec()),
2331 request_id: Some(request_id),
2332 client_order_id: Some(client_order_id),
2333 op: Some(super::enums::OKXWsOperation::Order),
2334 };
2335
2336 let result = self.send_cmd(cmd).await;
2337
2338 if result.is_err() {
2339 self.pending_orders.remove(&cl_ord_key);
2340 }
2341
2342 result
2343 }
2344
2345 #[expect(clippy::too_many_arguments)]
2361 pub async fn modify_order(
2362 &self,
2363 trader_id: TraderId,
2364 strategy_id: StrategyId,
2365 instrument_id: InstrumentId,
2366 client_order_id: Option<ClientOrderId>,
2367 price: Option<Price>,
2368 quantity: Option<Quantity>,
2369 venue_order_id: Option<VenueOrderId>,
2370 new_px_usd: Option<String>,
2371 new_px_vol: Option<String>,
2372 ) -> Result<(), OKXWsError> {
2373 let mut builder = WsAmendOrderParamsBuilder::default();
2374
2375 let inst_id_code = self
2376 .get_inst_id_code(&instrument_id.symbol.inner())
2377 .ok_or_else(|| {
2378 OKXWsError::ClientError(format!(
2379 "No instIdCode cached for {instrument_id}, cannot amend order"
2380 ))
2381 })?;
2382 builder.inst_id_code(inst_id_code);
2383
2384 if let Some(venue_order_id) = venue_order_id {
2385 builder.ord_id(venue_order_id.as_str());
2386 }
2387
2388 let cl_ord_key = client_order_id.map(|id| id.to_string());
2389
2390 if let Some(client_order_id) = client_order_id {
2391 builder.cl_ord_id(client_order_id.as_str());
2392 self.pending_amends.insert(
2393 client_order_id.to_string(),
2394 PendingOrderInfo {
2395 trader_id,
2396 strategy_id,
2397 instrument_id,
2398 },
2399 );
2400 }
2401
2402 if let Some(usd) = new_px_usd {
2404 builder.new_px_usd(usd);
2405 } else if let Some(vol) = new_px_vol {
2406 builder.new_px_vol(vol);
2407 } else if let Some(price) = price {
2408 builder.new_px(price.to_string());
2409 }
2410
2411 if let Some(quantity) = quantity {
2412 builder.new_sz(quantity.to_string());
2413 }
2414
2415 let params = builder
2416 .build()
2417 .map_err(|e| OKXWsError::ClientError(format!("Build amend params error: {e}")))?;
2418
2419 let request_id = self.generate_unique_request_id();
2420 let request = OKXWsRequest {
2421 id: Some(request_id.clone()),
2422 op: super::enums::OKXWsOperation::AmendOrder,
2423 exp_time: None,
2424 args: vec![params],
2425 };
2426
2427 let payload = serde_json::to_string(&request)
2428 .map_err(|e| OKXWsError::JsonError(format!("Failed to serialize amend: {e}")))?;
2429
2430 let cmd = HandlerCommand::Send {
2431 payload,
2432 rate_limit_keys: Some(OKX_RATE_LIMIT_KEY_AMEND.to_vec()),
2433 request_id: Some(request_id),
2434 client_order_id,
2435 op: Some(super::enums::OKXWsOperation::AmendOrder),
2436 };
2437
2438 let result = self.send_cmd(cmd).await;
2439
2440 if let (Err(_), Some(key)) = (&result, &cl_ord_key) {
2441 self.pending_amends.remove(key);
2442 }
2443
2444 result
2445 }
2446
2447 pub async fn cancel_order(
2458 &self,
2459 trader_id: TraderId,
2460 strategy_id: StrategyId,
2461 instrument_id: InstrumentId,
2462 client_order_id: Option<ClientOrderId>,
2463 venue_order_id: Option<VenueOrderId>,
2464 ) -> Result<(), OKXWsError> {
2465 let mut builder = WsCancelOrderParamsBuilder::default();
2466
2467 let inst_id_code = self
2468 .get_inst_id_code(&instrument_id.symbol.inner())
2469 .ok_or_else(|| {
2470 OKXWsError::ClientError(format!(
2471 "No instIdCode cached for {instrument_id}, cannot cancel order"
2472 ))
2473 })?;
2474 builder.inst_id_code(inst_id_code);
2475
2476 if let Some(venue_order_id) = venue_order_id {
2477 builder.ord_id(venue_order_id.as_str());
2478 }
2479
2480 let cl_ord_key = client_order_id.map(|id| id.to_string());
2481
2482 if let Some(client_order_id) = client_order_id {
2483 builder.cl_ord_id(client_order_id.as_str());
2484 self.pending_cancels.insert(
2485 client_order_id.to_string(),
2486 PendingOrderInfo {
2487 trader_id,
2488 strategy_id,
2489 instrument_id,
2490 },
2491 );
2492 }
2493
2494 let params = builder
2495 .build()
2496 .map_err(|e| OKXWsError::ClientError(format!("Build cancel params error: {e}")))?;
2497
2498 let request_id = self.generate_unique_request_id();
2499 let request = OKXWsRequest {
2500 id: Some(request_id.clone()),
2501 op: super::enums::OKXWsOperation::CancelOrder,
2502 exp_time: None,
2503 args: vec![params],
2504 };
2505
2506 let payload = serde_json::to_string(&request)
2507 .map_err(|e| OKXWsError::JsonError(format!("Failed to serialize cancel: {e}")))?;
2508
2509 let cmd = HandlerCommand::Send {
2510 payload,
2511 rate_limit_keys: Some(OKX_RATE_LIMIT_KEY_CANCEL.to_vec()),
2512 request_id: Some(request_id),
2513 client_order_id,
2514 op: Some(super::enums::OKXWsOperation::CancelOrder),
2515 };
2516
2517 let result = self.send_cmd(cmd).await;
2518
2519 if let (Err(_), Some(key)) = (&result, &cl_ord_key) {
2520 self.pending_cancels.remove(key);
2521 }
2522
2523 result
2524 }
2525
2526 pub async fn mass_cancel_orders(&self, instrument_id: InstrumentId) -> Result<(), OKXWsError> {
2536 let instrument = self
2537 .instruments_cache
2538 .get_cloned(&instrument_id.symbol.inner())
2539 .ok_or_else(|| {
2540 OKXWsError::ClientError(format!("Unknown instrument {instrument_id}"))
2541 })?;
2542
2543 let inst_type =
2544 okx_instrument_type(&instrument).map_err(|e| OKXWsError::ClientError(e.to_string()))?;
2545
2546 let symbol = instrument.symbol().inner();
2547 let inst_family = match &instrument {
2548 InstrumentAny::CurrencyPair(_) => symbol.as_str().to_string(),
2549 InstrumentAny::CryptoPerpetual(_) => symbol
2550 .as_str()
2551 .strip_suffix("-SWAP")
2552 .unwrap_or(symbol.as_str())
2553 .to_string(),
2554 InstrumentAny::CryptoFuture(_) => {
2555 let s = symbol.as_str();
2556 if let Some(idx) = s.rfind('-') {
2557 s[..idx].to_string()
2558 } else {
2559 s.to_string()
2560 }
2561 }
2562 _ => {
2563 return Err(OKXWsError::ClientError(
2564 "Unsupported instrument type for mass cancel".to_string(),
2565 ));
2566 }
2567 };
2568 drop(instrument);
2569
2570 let params = WsMassCancelParams {
2571 inst_type,
2572 inst_family: Ustr::from(&inst_family),
2573 };
2574
2575 let request_id = self.generate_unique_request_id();
2576 let request = OKXWsRequest {
2577 id: Some(request_id.clone()),
2578 op: super::enums::OKXWsOperation::MassCancel,
2579 exp_time: None,
2580 args: vec![
2581 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?,
2582 ],
2583 };
2584
2585 let payload = serde_json::to_string(&request)
2586 .map_err(|e| OKXWsError::JsonError(format!("Failed to serialize mass cancel: {e}")))?;
2587
2588 let cmd = HandlerCommand::Send {
2589 payload,
2590 rate_limit_keys: Some(OKX_RATE_LIMIT_KEY_CANCEL.to_vec()),
2591 request_id: Some(request_id),
2592 client_order_id: None,
2593 op: Some(super::enums::OKXWsOperation::MassCancel),
2594 };
2595
2596 self.send_cmd(cmd).await
2597 }
2598
2599 #[expect(clippy::type_complexity)]
2606 pub async fn batch_submit_orders(
2607 &self,
2608 orders: Vec<(
2609 OKXInstrumentType,
2610 InstrumentId,
2611 OKXTradeMode,
2612 ClientOrderId,
2613 OrderSide,
2614 Option<PositionSide>,
2615 OrderType,
2616 Quantity,
2617 Option<Price>,
2618 Option<Price>,
2619 Option<bool>,
2620 Option<bool>,
2621 )>,
2622 ) -> Result<(), OKXWsError> {
2623 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2624
2625 for (
2626 inst_type,
2627 inst_id,
2628 td_mode,
2629 cl_ord_id,
2630 ord_side,
2631 pos_side,
2632 ord_type,
2633 qty,
2634 pr,
2635 tp,
2636 post_only,
2637 reduce_only,
2638 ) in orders
2639 {
2640 let mut builder = WsPostOrderParamsBuilder::default();
2641
2642 let inst_id_code = self
2643 .get_inst_id_code(&inst_id.symbol.inner())
2644 .ok_or_else(|| {
2645 OKXWsError::ClientError(format!(
2646 "No instIdCode cached for {inst_id}, cannot submit order"
2647 ))
2648 })?;
2649 builder.inst_id_code(inst_id_code);
2650
2651 builder.td_mode(td_mode);
2652 builder.cl_ord_id(cl_ord_id.as_str());
2653 builder.side(ord_side.as_specified());
2654
2655 if let Some(instrument) = self.instruments_cache.get_cloned(&inst_id.symbol.inner()) {
2656 builder.ccy(instrument.quote_currency().to_string());
2657 }
2658
2659 if let Some(ps) = pos_side {
2660 builder.pos_side(OKXPositionSide::from(ps));
2661 } else if !matches!(inst_type, OKXInstrumentType::Spot) {
2662 builder.pos_side(OKXPositionSide::Net);
2663 }
2664
2665 let okx_ord_type = if post_only.unwrap_or(false) {
2666 OKXOrderType::PostOnly
2667 } else {
2668 match ord_type {
2669 OrderType::Market => OKXOrderType::Market,
2670 OrderType::Limit => OKXOrderType::Limit,
2671 OrderType::MarketToLimit => OKXOrderType::Ioc,
2672 _ => {
2673 return Err(OKXWsError::ClientError(format!(
2674 "Unsupported order type for batch submit: {ord_type:?}"
2675 )));
2676 }
2677 }
2678 };
2679
2680 builder.ord_type(okx_ord_type);
2681 builder.sz(qty.to_string());
2682
2683 if let Some(p) = pr {
2684 builder.px(p.to_string());
2685 } else if let Some(p) = tp {
2686 builder.px(p.to_string());
2687 }
2688
2689 if let Some(ro) = reduce_only {
2690 builder.reduce_only(ro);
2691 }
2692
2693 builder.tag(OKX_NAUTILUS_BROKER_ID);
2694
2695 let params = builder
2696 .build()
2697 .map_err(|e| OKXWsError::ClientError(format!("Build order params error: {e}")))?;
2698 let val =
2699 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2700 args.push(val);
2701 }
2702
2703 self.ws_batch_place_orders(args).await
2704 }
2705
2706 #[expect(clippy::type_complexity)]
2713 pub async fn batch_modify_orders(
2714 &self,
2715 orders: Vec<(
2716 OKXInstrumentType,
2717 InstrumentId,
2718 ClientOrderId,
2719 ClientOrderId,
2720 Option<Price>,
2721 Option<Quantity>,
2722 )>,
2723 ) -> Result<(), OKXWsError> {
2724 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2725 for (_inst_type, inst_id, cl_ord_id, new_cl_ord_id, pr, sz) in orders {
2726 let mut builder = WsAmendOrderParamsBuilder::default();
2727
2728 let inst_id_code = self
2729 .get_inst_id_code(&inst_id.symbol.inner())
2730 .ok_or_else(|| {
2731 OKXWsError::ClientError(format!(
2732 "No instIdCode cached for {inst_id}, cannot amend order"
2733 ))
2734 })?;
2735 builder.inst_id_code(inst_id_code);
2736
2737 builder.cl_ord_id(cl_ord_id.as_str());
2738 builder.new_cl_ord_id(new_cl_ord_id.as_str());
2739
2740 if let Some(p) = pr {
2741 builder.new_px(p.to_string());
2742 }
2743
2744 if let Some(q) = sz {
2745 builder.new_sz(q.to_string());
2746 }
2747
2748 let params = builder.build().map_err(|e| {
2749 OKXWsError::ClientError(format!("Build amend batch params error: {e}"))
2750 })?;
2751 let val =
2752 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2753 args.push(val);
2754 }
2755
2756 self.ws_batch_amend_orders(args).await
2757 }
2758
2759 pub async fn batch_cancel_orders(
2772 &self,
2773 orders: Vec<(InstrumentId, Option<ClientOrderId>, Option<VenueOrderId>)>,
2774 ) -> Result<(), OKXWsError> {
2775 let mut args: Vec<Value> = Vec::with_capacity(orders.len());
2776 for (inst_id, cl_ord_id, ord_id) in orders {
2777 let mut builder = WsCancelOrderParamsBuilder::default();
2778
2779 let inst_id_code = self
2780 .get_inst_id_code(&inst_id.symbol.inner())
2781 .ok_or_else(|| {
2782 OKXWsError::ClientError(format!(
2783 "No instIdCode cached for {inst_id}, cannot cancel order"
2784 ))
2785 })?;
2786 builder.inst_id_code(inst_id_code);
2787
2788 if let Some(c) = cl_ord_id {
2789 builder.cl_ord_id(c.as_str());
2790 }
2791
2792 if let Some(o) = ord_id {
2793 builder.ord_id(o.as_str());
2794 }
2795
2796 let params = builder.build().map_err(|e| {
2797 OKXWsError::ClientError(format!("Build cancel batch params error: {e}"))
2798 })?;
2799 let val =
2800 serde_json::to_value(params).map_err(|e| OKXWsError::JsonError(e.to_string()))?;
2801 args.push(val);
2802 }
2803
2804 self.ws_batch_cancel_orders(args).await
2805 }
2806
2807 #[expect(clippy::too_many_arguments)]
2818 pub async fn submit_algo_order(
2819 &self,
2820 _trader_id: TraderId,
2821 _strategy_id: StrategyId,
2822 instrument_id: InstrumentId,
2823 td_mode: OKXTradeMode,
2824 client_order_id: ClientOrderId,
2825 order_side: OrderSide,
2826 order_type: OrderType,
2827 quantity: Quantity,
2828 trigger_price: Option<Price>,
2829 trigger_type: Option<TriggerType>,
2830 limit_price: Option<Price>,
2831 reduce_only: Option<bool>,
2832 callback_ratio: Option<String>,
2833 callback_spread: Option<String>,
2834 activation_price: Option<Price>,
2835 ) -> Result<(), OKXWsError> {
2836 if !is_conditional_order(order_type) {
2837 return Err(OKXWsError::ClientError(format!(
2838 "Order type {order_type:?} is not a conditional order"
2839 )));
2840 }
2841
2842 let mut builder = WsPostAlgoOrderParamsBuilder::default();
2843
2844 if !matches!(order_side, OrderSide::Buy | OrderSide::Sell) {
2845 return Err(OKXWsError::ClientError(
2846 "Invalid order side for OKX".to_string(),
2847 ));
2848 }
2849
2850 let inst_id_code = self
2851 .get_inst_id_code(&instrument_id.symbol.inner())
2852 .ok_or_else(|| {
2853 OKXWsError::ClientError(format!(
2854 "No instIdCode cached for {instrument_id}, cannot submit algo order"
2855 ))
2856 })?;
2857 builder.inst_id_code(inst_id_code);
2858
2859 builder.td_mode(td_mode);
2860 builder.cl_ord_id(client_order_id.as_str());
2861 builder.side(order_side.as_specified());
2862 builder.ord_type(
2863 conditional_order_to_algo_type(order_type)
2864 .map_err(|e| OKXWsError::ClientError(e.to_string()))?,
2865 );
2866 builder.sz(quantity.to_string());
2867
2868 if let Some(tp) = trigger_price {
2869 builder.trigger_px(tp.to_string());
2870 }
2871
2872 let okx_trigger_type = trigger_type.map_or(OKXTriggerType::Last, Into::into);
2874 builder.trigger_px_type(okx_trigger_type);
2875
2876 if matches!(order_type, OrderType::StopLimit | OrderType::LimitIfTouched)
2878 && let Some(price) = limit_price
2879 {
2880 builder.order_px(price.to_string());
2881 }
2882
2883 if let Some(reduce) = reduce_only {
2884 builder.reduce_only(reduce);
2885 }
2886
2887 if let Some(ratio) = callback_ratio {
2888 builder.callback_ratio(ratio);
2889 }
2890
2891 if let Some(spread) = callback_spread {
2892 builder.callback_spread(spread);
2893 }
2894
2895 if let Some(active) = activation_price {
2896 builder.active_px(active.to_string());
2897 }
2898
2899 builder.tag(OKX_NAUTILUS_BROKER_ID);
2900
2901 let params = builder
2902 .build()
2903 .map_err(|e| OKXWsError::ClientError(format!("Build algo order params error: {e}")))?;
2904
2905 let request_id = self.generate_unique_request_id();
2906 let request = OKXWsRequest {
2907 id: Some(request_id.clone()),
2908 op: super::enums::OKXWsOperation::OrderAlgo,
2909 exp_time: None,
2910 args: vec![params],
2911 };
2912
2913 let payload = serde_json::to_string(&request)
2914 .map_err(|e| OKXWsError::JsonError(format!("Failed to serialize algo order: {e}")))?;
2915
2916 let cmd = HandlerCommand::Send {
2917 payload,
2918 rate_limit_keys: Some(OKX_RATE_LIMIT_KEY_ORDER.to_vec()),
2919 request_id: Some(request_id),
2920 client_order_id: Some(client_order_id),
2921 op: Some(super::enums::OKXWsOperation::OrderAlgo),
2922 };
2923
2924 self.send_cmd(cmd).await
2925 }
2926
2927 pub async fn cancel_algo_order(
2938 &self,
2939 _trader_id: TraderId,
2940 _strategy_id: StrategyId,
2941 instrument_id: InstrumentId,
2942 client_order_id: Option<ClientOrderId>,
2943 algo_order_id: Option<String>,
2944 ) -> Result<(), OKXWsError> {
2945 let mut builder = super::messages::WsCancelAlgoOrderParamsBuilder::default();
2946
2947 let inst_id_code = self
2948 .get_inst_id_code(&instrument_id.symbol.inner())
2949 .ok_or_else(|| {
2950 OKXWsError::ClientError(format!(
2951 "No instIdCode cached for {instrument_id}, cannot cancel algo order"
2952 ))
2953 })?;
2954 builder.inst_id_code(inst_id_code);
2955
2956 if let Some(algo_id) = algo_order_id {
2957 builder.algo_id(algo_id);
2958 }
2959
2960 if let Some(cl_ord_id) = client_order_id {
2961 builder.algo_cl_ord_id(cl_ord_id.to_string());
2962 }
2963
2964 let params = builder
2965 .build()
2966 .map_err(|e| OKXWsError::ClientError(format!("Build cancel algo params error: {e}")))?;
2967
2968 let request_id = self.generate_unique_request_id();
2969 let request = OKXWsRequest {
2970 id: Some(request_id.clone()),
2971 op: super::enums::OKXWsOperation::CancelAlgos,
2972 exp_time: None,
2973 args: vec![params],
2974 };
2975
2976 let payload = serde_json::to_string(&request)
2977 .map_err(|e| OKXWsError::JsonError(format!("Failed to serialize cancel algo: {e}")))?;
2978
2979 let cmd = HandlerCommand::Send {
2980 payload,
2981 rate_limit_keys: Some(OKX_RATE_LIMIT_KEY_CANCEL.to_vec()),
2982 request_id: Some(request_id),
2983 client_order_id,
2984 op: Some(super::enums::OKXWsOperation::CancelAlgos),
2985 };
2986
2987 self.send_cmd(cmd).await
2988 }
2989
2990 async fn send_cmd(&self, cmd: HandlerCommand) -> Result<(), OKXWsError> {
2992 self.cmd_tx
2993 .read()
2994 .await
2995 .send(cmd)
2996 .map_err(|e| OKXWsError::ClientError(format!("Handler not available: {e}")))
2997 }
2998}
2999
3000#[cfg(test)]
3001mod tests {
3002 use nautilus_core::time::get_atomic_clock_realtime;
3003 use nautilus_network::RECONNECTED;
3004 use rstest::rstest;
3005 use tokio_tungstenite::tungstenite::Message;
3006
3007 use super::*;
3008 use crate::{
3009 common::{
3010 consts::OKX_POST_ONLY_CANCEL_SOURCE,
3011 enums::{
3012 OKXExecType, OKXOrderCategory, OKXOrderStatus, OKXPriceType, OKXQuickMarginType,
3013 OKXSelfTradePreventionMode, OKXSide,
3014 },
3015 },
3016 websocket::{
3017 handler::is_post_only_auto_cancel,
3018 messages::{OKXOrderMsg, OKXWebSocketError, OKXWsFrame},
3019 },
3020 };
3021
3022 #[rstest]
3023 fn test_timestamp_format_for_websocket_auth() {
3024 let timestamp = SystemTime::now()
3025 .duration_since(SystemTime::UNIX_EPOCH)
3026 .expect("System time should be after UNIX epoch")
3027 .as_secs()
3028 .to_string();
3029
3030 assert!(timestamp.parse::<u64>().is_ok());
3031 assert_eq!(timestamp.len(), 10);
3032 assert!(timestamp.chars().all(|c| c.is_ascii_digit()));
3033 }
3034
3035 #[rstest]
3036 fn test_new_without_credentials() {
3037 let client = OKXWebSocketClient::default();
3038 assert!(client.credential.is_none());
3039 assert_eq!(client.api_key(), None);
3040 }
3041
3042 #[rstest]
3043 fn test_add_option_greeks_sub_defaults_to_both_conventions() {
3044 let client = OKXWebSocketClient::default();
3045 let instrument_id = InstrumentId::from("BTC-USD-250328-92000-C.OKX");
3046
3047 client.add_option_greeks_sub(instrument_id);
3048
3049 let subs = client.option_greeks_subs().load();
3050 let stored = subs.get(&instrument_id).expect("instrument not registered");
3051 assert_eq!(stored.len(), 2);
3052 assert!(stored.contains(&OKXGreeksType::Bs));
3053 assert!(stored.contains(&OKXGreeksType::Pa));
3054 }
3055
3056 #[rstest]
3057 #[case::bs_only(vec![OKXGreeksType::Bs])]
3058 #[case::pa_only(vec![OKXGreeksType::Pa])]
3059 #[case::both(vec![OKXGreeksType::Bs, OKXGreeksType::Pa])]
3060 fn test_add_option_greeks_sub_with_conventions_stores_requested_set(
3061 #[case] conventions: Vec<OKXGreeksType>,
3062 ) {
3063 let client = OKXWebSocketClient::default();
3064 let instrument_id = InstrumentId::from("BTC-USD-250328-92000-C.OKX");
3065 let set: AHashSet<OKXGreeksType> = conventions.iter().copied().collect();
3066
3067 client.add_option_greeks_sub_with_conventions(instrument_id, set.clone());
3068
3069 let subs = client.option_greeks_subs().load();
3070 let stored = subs.get(&instrument_id).expect("instrument not registered");
3071 assert_eq!(stored, &set);
3072 }
3073
3074 #[rstest]
3075 fn test_add_option_greeks_sub_with_empty_conventions_falls_back_to_both() {
3076 let client = OKXWebSocketClient::default();
3077 let instrument_id = InstrumentId::from("BTC-USD-250328-92000-C.OKX");
3078
3079 client.add_option_greeks_sub_with_conventions(instrument_id, AHashSet::new());
3080
3081 let subs = client.option_greeks_subs().load();
3082 let stored = subs.get(&instrument_id).expect("instrument not registered");
3083 assert_eq!(stored.len(), 2);
3084 }
3085
3086 #[rstest]
3087 fn test_remove_option_greeks_sub_clears_entry() {
3088 let client = OKXWebSocketClient::default();
3089 let instrument_id = InstrumentId::from("BTC-USD-250328-92000-C.OKX");
3090
3091 client.add_option_greeks_sub(instrument_id);
3092 client.remove_option_greeks_sub(&instrument_id);
3093
3094 let subs = client.option_greeks_subs().load();
3095 assert!(!subs.contains_key(&instrument_id));
3096 }
3097
3098 #[rstest]
3099 fn test_new_with_credentials() {
3100 let client = OKXWebSocketClient::new(
3101 None,
3102 Some("test_key".to_string()),
3103 Some("test_secret".to_string()),
3104 Some("test_passphrase".to_string()),
3105 None,
3106 None,
3107 None,
3108 TransportBackend::default(),
3109 None,
3110 )
3111 .unwrap();
3112 assert!(client.credential.is_some());
3113 assert_eq!(client.api_key(), Some("test_key"));
3114 }
3115
3116 #[rstest]
3117 fn test_new_partial_credentials_fails() {
3118 let result = OKXWebSocketClient::new(
3119 None,
3120 Some("test_key".to_string()),
3121 None,
3122 Some("test_passphrase".to_string()),
3123 None,
3124 None,
3125 None,
3126 TransportBackend::default(),
3127 None,
3128 );
3129 assert!(result.is_err());
3130 }
3131
3132 #[rstest]
3133 fn test_request_id_generation() {
3134 let client = OKXWebSocketClient::default();
3135
3136 let initial_counter = client.request_id_counter.load(Ordering::SeqCst);
3137
3138 let id1 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
3139 let id2 = client.request_id_counter.fetch_add(1, Ordering::SeqCst);
3140
3141 assert_eq!(id1, initial_counter);
3142 assert_eq!(id2, initial_counter + 1);
3143 assert_eq!(
3144 client.request_id_counter.load(Ordering::SeqCst),
3145 initial_counter + 2
3146 );
3147 }
3148
3149 #[rstest]
3150 fn test_client_state_management() {
3151 let client = OKXWebSocketClient::default();
3152
3153 assert!(client.is_closed());
3154 assert!(!client.is_active());
3155
3156 let client_with_heartbeat = OKXWebSocketClient::new(
3157 None,
3158 None,
3159 None,
3160 None,
3161 None,
3162 Some(30),
3163 None,
3164 TransportBackend::default(),
3165 None,
3166 )
3167 .unwrap();
3168
3169 assert!(client_with_heartbeat.heartbeat.is_some());
3170 assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
3171 }
3172
3173 #[rstest]
3174 fn test_websocket_error_handling() {
3175 let clock = get_atomic_clock_realtime();
3176 let ts = clock.get_time_ns().as_u64();
3177
3178 let error = OKXWebSocketError {
3179 code: "60012".to_string(),
3180 message: "Invalid request".to_string(),
3181 conn_id: None,
3182 timestamp: ts,
3183 };
3184
3185 assert_eq!(error.code, "60012");
3186 assert_eq!(error.message, "Invalid request");
3187 assert_eq!(error.timestamp, ts);
3188
3189 let nautilus_msg = OKXWsMessage::Error(error);
3190 match nautilus_msg {
3191 OKXWsMessage::Error(e) => {
3192 assert_eq!(e.code, "60012");
3193 assert_eq!(e.message, "Invalid request");
3194 }
3195 _ => panic!("Expected Error variant"),
3196 }
3197 }
3198
3199 #[rstest]
3200 fn test_request_id_generation_sequence() {
3201 let client = OKXWebSocketClient::default();
3202
3203 let initial_counter = client
3204 .request_id_counter
3205 .load(std::sync::atomic::Ordering::SeqCst);
3206 let mut ids = Vec::new();
3207
3208 for _ in 0..10 {
3209 let id = client
3210 .request_id_counter
3211 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
3212 ids.push(id);
3213 }
3214
3215 for (i, &id) in ids.iter().enumerate() {
3216 assert_eq!(id, initial_counter + i as u64);
3217 }
3218
3219 assert_eq!(
3220 client
3221 .request_id_counter
3222 .load(std::sync::atomic::Ordering::SeqCst),
3223 initial_counter + 10
3224 );
3225 }
3226
3227 #[rstest]
3228 fn test_client_state_transitions() {
3229 let client = OKXWebSocketClient::default();
3230
3231 assert!(client.is_closed());
3232 assert!(!client.is_active());
3233
3234 let client_with_heartbeat = OKXWebSocketClient::new(
3235 None,
3236 None,
3237 None,
3238 None,
3239 None,
3240 Some(30), None,
3242 TransportBackend::default(),
3243 None,
3244 )
3245 .unwrap();
3246
3247 assert!(client_with_heartbeat.heartbeat.is_some());
3248 assert_eq!(client_with_heartbeat.heartbeat.unwrap(), 30);
3249
3250 let account_id = AccountId::from("test-account-123");
3251 let client_with_account = OKXWebSocketClient::new(
3252 None,
3253 None,
3254 None,
3255 None,
3256 Some(account_id),
3257 None,
3258 None,
3259 TransportBackend::default(),
3260 None,
3261 )
3262 .unwrap();
3263
3264 assert_eq!(client_with_account.account_id, account_id);
3265 }
3266
3267 #[rstest]
3268 fn test_websocket_error_scenarios() {
3269 let clock = get_atomic_clock_realtime();
3270 let ts = clock.get_time_ns().as_u64();
3271
3272 let error_scenarios = vec![
3273 ("60012", "Invalid request", None),
3274 ("60009", "Invalid API key", Some("conn-123".to_string())),
3275 ("60014", "Too many requests", None),
3276 ("50001", "Order not found", None),
3277 ];
3278
3279 for (code, message, conn_id) in error_scenarios {
3280 let error = OKXWebSocketError {
3281 code: code.to_string(),
3282 message: message.to_string(),
3283 conn_id: conn_id.clone(),
3284 timestamp: ts,
3285 };
3286
3287 assert_eq!(error.code, code);
3288 assert_eq!(error.message, message);
3289 assert_eq!(error.conn_id, conn_id);
3290 assert_eq!(error.timestamp, ts);
3291
3292 let nautilus_msg = OKXWsMessage::Error(error);
3293 match nautilus_msg {
3294 OKXWsMessage::Error(e) => {
3295 assert_eq!(e.code, code);
3296 assert_eq!(e.message, message);
3297 assert_eq!(e.conn_id, conn_id);
3298 }
3299 _ => panic!("Expected Error variant"),
3300 }
3301 }
3302 }
3303
3304 #[rstest]
3305 fn test_feed_handler_reconnection_detection() {
3306 let msg = Message::Text(RECONNECTED.to_string().into());
3307 let result = OKXWsFeedHandler::parse_raw_message(msg);
3308 assert!(matches!(result, Some(OKXWsFrame::Reconnected)));
3309 }
3310
3311 #[rstest]
3312 fn test_feed_handler_normal_message_processing() {
3313 let ping_msg = Message::Text(TEXT_PING.to_string().into());
3314 let result = OKXWsFeedHandler::parse_raw_message(ping_msg);
3315 assert!(matches!(result, Some(OKXWsFrame::Ping)));
3316
3317 let sub_msg = r#"{
3318 "event": "subscribe",
3319 "arg": {
3320 "channel": "tickers",
3321 "instType": "SPOT"
3322 },
3323 "connId": "a4d3ae55"
3324 }"#;
3325
3326 let sub_result =
3327 OKXWsFeedHandler::parse_raw_message(Message::Text(sub_msg.to_string().into()));
3328 assert!(matches!(sub_result, Some(OKXWsFrame::Subscription { .. })));
3329 }
3330
3331 #[rstest]
3332 fn test_feed_handler_close_message() {
3333 let result = OKXWsFeedHandler::parse_raw_message(Message::Close(None));
3334 assert!(result.is_none());
3335 }
3336
3337 #[rstest]
3338 fn test_reconnection_message_constant() {
3339 assert_eq!(RECONNECTED, "__RECONNECTED__");
3340 }
3341
3342 #[rstest]
3343 fn test_multiple_reconnection_signals() {
3344 for _ in 0..3 {
3345 let msg = Message::Text(RECONNECTED.to_string().into());
3346 let result = OKXWsFeedHandler::parse_raw_message(msg);
3347 assert!(matches!(result, Some(OKXWsFrame::Reconnected)));
3348 }
3349 }
3350
3351 #[tokio::test]
3352 async fn test_wait_until_active_timeout() {
3353 let client = OKXWebSocketClient::new(
3354 None,
3355 Some("test_key".to_string()),
3356 Some("test_secret".to_string()),
3357 Some("test_passphrase".to_string()),
3358 Some(AccountId::from("test-account")),
3359 None,
3360 None,
3361 TransportBackend::default(),
3362 None,
3363 )
3364 .unwrap();
3365
3366 let result = client.wait_until_active(0.1).await;
3367
3368 assert!(result.is_err());
3369 assert!(!client.is_active());
3370 }
3371
3372 fn sample_canceled_order_msg() -> OKXOrderMsg {
3373 OKXOrderMsg {
3374 acc_fill_sz: Some("0".to_string()),
3375 avg_px: "0".to_string(),
3376 c_time: 0,
3377 cancel_source: None,
3378 cancel_source_reason: None,
3379 category: OKXOrderCategory::Normal,
3380 ccy: Ustr::from("USDT"),
3381 cl_ord_id: "order-1".to_string(),
3382 algo_cl_ord_id: None,
3383 attach_algo_cl_ord_id: None,
3384 attach_algo_ords: Vec::new(),
3385 fee: None,
3386 fee_ccy: Ustr::from("USDT"),
3387 fill_px: "0".to_string(),
3388 fill_sz: "0".to_string(),
3389 fill_time: 0,
3390 inst_id: Ustr::from("ETH-USDT-SWAP"),
3391 inst_type: OKXInstrumentType::Swap,
3392 lever: "1".to_string(),
3393 ord_id: Ustr::from("123456"),
3394 ord_type: OKXOrderType::Limit,
3395 pnl: "0".to_string(),
3396 pos_side: OKXPositionSide::Net,
3397 px: "0".to_string(),
3398 reduce_only: "false".to_string(),
3399 side: OKXSide::Buy,
3400 state: OKXOrderStatus::Canceled,
3401 exec_type: OKXExecType::None,
3402 sz: "1".to_string(),
3403 td_mode: OKXTradeMode::Cross,
3404 tgt_ccy: None,
3405 trade_id: String::new(),
3406 algo_id: None,
3407 fill_fee: None,
3408 fill_fee_ccy: None,
3409 fill_mark_px: None,
3410 fill_mark_vol: None,
3411 fill_px_vol: None,
3412 fill_px_usd: None,
3413 fill_fwd_px: None,
3414 fill_notional_usd: None,
3415 fill_pnl: None,
3416 is_tp_limit: None,
3417 linked_algo_ord: None,
3418 notional_usd: None,
3419 px_type: OKXPriceType::None,
3420 px_usd: None,
3421 px_vol: None,
3422 quick_mgn_type: OKXQuickMarginType::None,
3423 rebate: None,
3424 rebate_ccy: None,
3425 sl_ord_px: None,
3426 sl_trigger_px: None,
3427 sl_trigger_px_type: None,
3428 source: None,
3429 stp_id: None,
3430 stp_mode: OKXSelfTradePreventionMode::None,
3431 tag: None,
3432 tp_ord_px: None,
3433 tp_trigger_px: None,
3434 tp_trigger_px_type: None,
3435 amend_result: None,
3436 req_id: None,
3437 code: None,
3438 msg: None,
3439 u_time: 0,
3440 }
3441 }
3442
3443 #[rstest]
3444 fn test_is_post_only_auto_cancel_detects_cancel_source() {
3445 let mut msg = sample_canceled_order_msg();
3446 msg.cancel_source = Some(OKX_POST_ONLY_CANCEL_SOURCE.to_string());
3447
3448 assert!(is_post_only_auto_cancel(&msg));
3449 }
3450
3451 #[rstest]
3452 fn test_is_post_only_auto_cancel_detects_reason() {
3453 let mut msg = sample_canceled_order_msg();
3454 msg.cancel_source_reason = Some("POST_ONLY would take liquidity".to_string());
3455
3456 assert!(is_post_only_auto_cancel(&msg));
3457 }
3458
3459 #[rstest]
3460 fn test_is_post_only_auto_cancel_false_without_markers() {
3461 let msg = sample_canceled_order_msg();
3462
3463 assert!(!is_post_only_auto_cancel(&msg));
3464 }
3465
3466 #[rstest]
3467 fn test_is_post_only_auto_cancel_false_for_order_type_only() {
3468 let mut msg = sample_canceled_order_msg();
3469 msg.ord_type = OKXOrderType::PostOnly;
3470
3471 assert!(!is_post_only_auto_cancel(&msg));
3472 }
3473
3474 #[tokio::test]
3475 async fn test_batch_cancel_orders_with_multiple_orders() {
3476 use nautilus_model::identifiers::{ClientOrderId, InstrumentId, VenueOrderId};
3477
3478 let client = OKXWebSocketClient::new(
3479 Some("wss://test.okx.com".to_string()),
3480 None,
3481 None,
3482 None,
3483 None,
3484 None,
3485 None,
3486 TransportBackend::default(),
3487 None,
3488 )
3489 .expect("Failed to create client");
3490
3491 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
3492 let client_order_id1 = ClientOrderId::new("order1");
3493 let client_order_id2 = ClientOrderId::new("order2");
3494 let venue_order_id1 = VenueOrderId::new("venue1");
3495 let venue_order_id2 = VenueOrderId::new("venue2");
3496
3497 let orders = vec![
3498 (instrument_id, Some(client_order_id1), Some(venue_order_id1)),
3499 (instrument_id, Some(client_order_id2), Some(venue_order_id2)),
3500 ];
3501
3502 let result = client.batch_cancel_orders(orders).await;
3503 assert!(result.is_err());
3504 }
3505
3506 #[tokio::test]
3507 async fn test_batch_cancel_orders_with_only_client_order_id() {
3508 use nautilus_model::identifiers::{ClientOrderId, InstrumentId};
3509
3510 let client = OKXWebSocketClient::new(
3511 Some("wss://test.okx.com".to_string()),
3512 None,
3513 None,
3514 None,
3515 None,
3516 None,
3517 None,
3518 TransportBackend::default(),
3519 None,
3520 )
3521 .expect("Failed to create client");
3522
3523 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
3524 let client_order_id = ClientOrderId::new("order1");
3525
3526 let orders = vec![(instrument_id, Some(client_order_id), None)];
3527
3528 let result = client.batch_cancel_orders(orders).await;
3529
3530 assert!(result.is_err());
3531 }
3532
3533 #[tokio::test]
3534 async fn test_batch_cancel_orders_with_only_venue_order_id() {
3535 use nautilus_model::identifiers::{InstrumentId, VenueOrderId};
3536
3537 let client = OKXWebSocketClient::new(
3538 Some("wss://test.okx.com".to_string()),
3539 None,
3540 None,
3541 None,
3542 None,
3543 None,
3544 None,
3545 TransportBackend::default(),
3546 None,
3547 )
3548 .expect("Failed to create client");
3549
3550 let instrument_id = InstrumentId::from("BTC-USDT.OKX");
3551 let venue_order_id = VenueOrderId::new("venue1");
3552
3553 let orders = vec![(instrument_id, None, Some(venue_order_id))];
3554
3555 let result = client.batch_cancel_orders(orders).await;
3556
3557 assert!(result.is_err());
3558 }
3559
3560 #[tokio::test]
3561 async fn test_batch_cancel_orders_with_both_ids() {
3562 use nautilus_model::identifiers::{ClientOrderId, InstrumentId, VenueOrderId};
3563
3564 let client = OKXWebSocketClient::new(
3565 Some("wss://test.okx.com".to_string()),
3566 None,
3567 None,
3568 None,
3569 None,
3570 None,
3571 None,
3572 TransportBackend::default(),
3573 None,
3574 )
3575 .expect("Failed to create client");
3576
3577 let instrument_id = InstrumentId::from("BTC-USDT-SWAP.OKX");
3578 let client_order_id = ClientOrderId::new("order1");
3579 let venue_order_id = VenueOrderId::new("venue1");
3580
3581 let orders = vec![(instrument_id, Some(client_order_id), Some(venue_order_id))];
3582
3583 let result = client.batch_cancel_orders(orders).await;
3584
3585 assert!(result.is_err());
3586 }
3587
3588 #[tokio::test]
3589 async fn test_cancel_order_fails_without_inst_id_code() {
3590 use nautilus_model::identifiers::{ClientOrderId, InstrumentId, StrategyId, TraderId};
3591
3592 let client = OKXWebSocketClient::default();
3593 let instrument_id = InstrumentId::from("BTC-USDT-SWAP.OKX");
3594
3595 let result = client
3596 .cancel_order(
3597 TraderId::from("TESTER-001"),
3598 StrategyId::from("S-001"),
3599 instrument_id,
3600 Some(ClientOrderId::new("O-001")),
3601 None,
3602 )
3603 .await;
3604
3605 assert!(result.is_err());
3606 let err = result.unwrap_err().to_string();
3607 assert!(
3608 err.contains("No instIdCode cached for BTC-USDT-SWAP.OKX"),
3609 "Expected instIdCode error, found: {err}"
3610 );
3611 }
3612
3613 #[tokio::test]
3614 async fn test_submit_order_fails_without_inst_id_code() {
3615 use nautilus_model::{
3616 enums::{OrderSide, OrderType},
3617 identifiers::{ClientOrderId, InstrumentId, StrategyId, TraderId},
3618 types::Quantity,
3619 };
3620
3621 use crate::common::enums::OKXTradeMode;
3622
3623 let client = OKXWebSocketClient::default();
3624 let instrument_id = InstrumentId::from("ETH-USDT-SWAP.OKX");
3625
3626 let result = client
3627 .submit_order(
3628 TraderId::from("TESTER-001"),
3629 StrategyId::from("S-001"),
3630 instrument_id,
3631 OKXTradeMode::Cross,
3632 ClientOrderId::new("O-001"),
3633 OrderSide::Buy,
3634 OrderType::Limit,
3635 Quantity::from("0.01"),
3636 None,
3637 None,
3638 None,
3639 None,
3640 None,
3641 None,
3642 None,
3643 None,
3644 None,
3645 None,
3646 )
3647 .await;
3648
3649 assert!(result.is_err());
3650 let err = result.unwrap_err().to_string();
3651 assert!(
3652 err.contains("No instIdCode cached for ETH-USDT-SWAP.OKX"),
3653 "Expected instIdCode error, found: {err}"
3654 );
3655 }
3656
3657 #[tokio::test]
3658 async fn test_cancel_order_passes_inst_id_code_lookup_when_cached() {
3659 use nautilus_model::identifiers::{ClientOrderId, InstrumentId, StrategyId, TraderId};
3660 use ustr::Ustr;
3661
3662 let client = OKXWebSocketClient::default();
3663 let instrument_id = InstrumentId::from("BTC-USDT-SWAP.OKX");
3664
3665 client.cache_inst_id_code(Ustr::from("BTC-USDT-SWAP"), 10459);
3667
3668 let result = client
3669 .cancel_order(
3670 TraderId::from("TESTER-001"),
3671 StrategyId::from("S-001"),
3672 instrument_id,
3673 Some(ClientOrderId::new("O-001")),
3674 None,
3675 )
3676 .await;
3677
3678 assert!(result.is_err());
3680 let err = result.unwrap_err().to_string();
3681 assert!(
3682 !err.contains("No instIdCode cached"),
3683 "Should pass instIdCode lookup, found: {err}"
3684 );
3685 }
3686
3687 #[rstest]
3688 fn test_race_unsubscribe_failure_recovery() {
3689 let client = OKXWebSocketClient::new(
3695 Some("wss://test.okx.com".to_string()),
3696 None,
3697 None,
3698 None,
3699 None,
3700 None,
3701 None,
3702 TransportBackend::default(),
3703 None,
3704 )
3705 .expect("Failed to create client");
3706
3707 let topic = "trades:BTC-USDT-SWAP";
3708
3709 client.subscriptions_state.mark_subscribe(topic);
3711 client.subscriptions_state.confirm_subscribe(topic);
3712 assert_eq!(client.subscriptions_state.len(), 1);
3713
3714 client.subscriptions_state.mark_unsubscribe(topic);
3716 assert_eq!(client.subscriptions_state.len(), 0);
3717 assert_eq!(
3718 client.subscriptions_state.pending_unsubscribe_topics(),
3719 vec![topic]
3720 );
3721
3722 client.subscriptions_state.confirm_unsubscribe(topic); client.subscriptions_state.mark_subscribe(topic); client.subscriptions_state.confirm_subscribe(topic); assert_eq!(client.subscriptions_state.len(), 1);
3730 assert!(
3731 client
3732 .subscriptions_state
3733 .pending_unsubscribe_topics()
3734 .is_empty()
3735 );
3736 assert!(
3737 client
3738 .subscriptions_state
3739 .pending_subscribe_topics()
3740 .is_empty()
3741 );
3742
3743 let all = client.subscriptions_state.all_topics();
3745 assert_eq!(all.len(), 1);
3746 assert!(all.contains(&topic.to_string()));
3747 }
3748
3749 #[rstest]
3750 fn test_race_resubscribe_before_unsubscribe_ack() {
3751 let client = OKXWebSocketClient::new(
3755 Some("wss://test.okx.com".to_string()),
3756 None,
3757 None,
3758 None,
3759 None,
3760 None,
3761 None,
3762 TransportBackend::default(),
3763 None,
3764 )
3765 .expect("Failed to create client");
3766
3767 let topic = "books:BTC-USDT";
3768
3769 client.subscriptions_state.mark_subscribe(topic);
3771 client.subscriptions_state.confirm_subscribe(topic);
3772 assert_eq!(client.subscriptions_state.len(), 1);
3773
3774 client.subscriptions_state.mark_unsubscribe(topic);
3776 assert_eq!(client.subscriptions_state.len(), 0);
3777 assert_eq!(
3778 client.subscriptions_state.pending_unsubscribe_topics(),
3779 vec![topic]
3780 );
3781
3782 client.subscriptions_state.mark_subscribe(topic);
3784 assert_eq!(
3785 client.subscriptions_state.pending_subscribe_topics(),
3786 vec![topic]
3787 );
3788
3789 client.subscriptions_state.confirm_unsubscribe(topic);
3791 assert!(
3792 client
3793 .subscriptions_state
3794 .pending_unsubscribe_topics()
3795 .is_empty()
3796 );
3797 assert_eq!(
3798 client.subscriptions_state.pending_subscribe_topics(),
3799 vec![topic]
3800 );
3801
3802 client.subscriptions_state.confirm_subscribe(topic);
3804 assert_eq!(client.subscriptions_state.len(), 1);
3805 assert!(
3806 client
3807 .subscriptions_state
3808 .pending_subscribe_topics()
3809 .is_empty()
3810 );
3811
3812 let all = client.subscriptions_state.all_topics();
3814 assert_eq!(all.len(), 1);
3815 assert!(all.contains(&topic.to_string()));
3816 }
3817
3818 #[rstest]
3819 fn test_race_late_subscribe_confirmation_after_unsubscribe() {
3820 let client = OKXWebSocketClient::new(
3823 Some("wss://test.okx.com".to_string()),
3824 None,
3825 None,
3826 None,
3827 None,
3828 None,
3829 None,
3830 TransportBackend::default(),
3831 None,
3832 )
3833 .expect("Failed to create client");
3834
3835 let topic = "tickers:ETH-USDT";
3836
3837 client.subscriptions_state.mark_subscribe(topic);
3839 assert_eq!(
3840 client.subscriptions_state.pending_subscribe_topics(),
3841 vec![topic]
3842 );
3843
3844 client.subscriptions_state.mark_unsubscribe(topic);
3846 assert!(
3847 client
3848 .subscriptions_state
3849 .pending_subscribe_topics()
3850 .is_empty()
3851 ); assert_eq!(
3853 client.subscriptions_state.pending_unsubscribe_topics(),
3854 vec![topic]
3855 );
3856
3857 client.subscriptions_state.confirm_subscribe(topic);
3859 assert_eq!(client.subscriptions_state.len(), 0); assert_eq!(
3861 client.subscriptions_state.pending_unsubscribe_topics(),
3862 vec![topic]
3863 );
3864
3865 client.subscriptions_state.confirm_unsubscribe(topic);
3867
3868 assert!(client.subscriptions_state.is_empty());
3870 assert!(client.subscriptions_state.all_topics().is_empty());
3871 }
3872
3873 #[rstest]
3874 fn test_race_reconnection_with_pending_states() {
3875 let client = OKXWebSocketClient::new(
3877 Some("wss://test.okx.com".to_string()),
3878 Some("test_key".to_string()),
3879 Some("test_secret".to_string()),
3880 Some("test_passphrase".to_string()),
3881 Some(AccountId::new("OKX-TEST")),
3882 None,
3883 None,
3884 TransportBackend::default(),
3885 None,
3886 )
3887 .expect("Failed to create client");
3888
3889 let trade_btc = "trades:BTC-USDT-SWAP";
3892 client.subscriptions_state.mark_subscribe(trade_btc);
3893 client.subscriptions_state.confirm_subscribe(trade_btc);
3894
3895 let trade_eth = "trades:ETH-USDT-SWAP";
3897 client.subscriptions_state.mark_subscribe(trade_eth);
3898
3899 let book_btc = "books:BTC-USDT";
3901 client.subscriptions_state.mark_subscribe(book_btc);
3902 client.subscriptions_state.confirm_subscribe(book_btc);
3903 client.subscriptions_state.mark_unsubscribe(book_btc);
3904
3905 let topics_to_restore = client.subscriptions_state.all_topics();
3907
3908 assert_eq!(topics_to_restore.len(), 2);
3910 assert!(topics_to_restore.contains(&trade_btc.to_string()));
3911 assert!(topics_to_restore.contains(&trade_eth.to_string()));
3912 assert!(!topics_to_restore.contains(&book_btc.to_string())); }
3914
3915 #[rstest]
3916 fn test_race_duplicate_subscribe_messages_idempotent() {
3917 let client = OKXWebSocketClient::new(
3920 Some("wss://test.okx.com".to_string()),
3921 None,
3922 None,
3923 None,
3924 None,
3925 None,
3926 None,
3927 TransportBackend::default(),
3928 None,
3929 )
3930 .expect("Failed to create client");
3931
3932 let topic = "trades:BTC-USDT-SWAP";
3933
3934 client.subscriptions_state.mark_subscribe(topic);
3936 client.subscriptions_state.confirm_subscribe(topic);
3937 assert_eq!(client.subscriptions_state.len(), 1);
3938
3939 client.subscriptions_state.mark_subscribe(topic);
3941 assert!(
3942 client
3943 .subscriptions_state
3944 .pending_subscribe_topics()
3945 .is_empty()
3946 ); assert_eq!(client.subscriptions_state.len(), 1); client.subscriptions_state.confirm_subscribe(topic);
3951 assert_eq!(client.subscriptions_state.len(), 1);
3952
3953 let all = client.subscriptions_state.all_topics();
3955 assert_eq!(all.len(), 1);
3956 assert_eq!(all[0], topic);
3957 }
3958}