1use std::{
27 collections::VecDeque,
28 sync::{
29 Arc,
30 atomic::{AtomicBool, Ordering},
31 },
32};
33
34use nautilus_model::identifiers::ClientOrderId;
35use nautilus_network::{
36 RECONNECTED,
37 retry::{RetryManager, create_websocket_retry_manager},
38 websocket::{AuthTracker, SubscriptionState, TEXT_PING, TEXT_PONG, WebSocketClient},
39};
40use serde_json::Value;
41use tokio_tungstenite::tungstenite::Message;
42use ustr::Ustr;
43
44use super::{
45 enums::{OKXSubscriptionEvent, OKXWsChannel, OKXWsOperation},
46 error::OKXWsError,
47 messages::{
48 OKXAlgoOrderMsg, OKXOrderMsg, OKXSubscription, OKXSubscriptionArg, OKXWebSocketArg,
49 OKXWebSocketError, OKXWsFrame, OKXWsMessage,
50 },
51 subscription::topic_from_websocket_arg,
52};
53use crate::{
54 common::{
55 consts::{OKX_FIELD_SCODE, OKX_FIELD_SMSG, OKX_SUCCESS_CODE, should_retry_error_code},
56 models::OKXInstrument,
57 },
58 websocket::client::OKX_RATE_LIMIT_KEY_SUBSCRIPTION,
59};
60
61#[derive(Debug)]
63pub enum HandlerCommand {
64 SetClient(WebSocketClient),
66 Disconnect,
68 Authenticate { payload: String },
70 Subscribe { args: Vec<OKXSubscriptionArg> },
72 Unsubscribe { args: Vec<OKXSubscriptionArg> },
74 Send {
76 payload: String,
77 rate_limit_keys: Option<Vec<Ustr>>,
78 request_id: Option<String>,
79 client_order_id: Option<ClientOrderId>,
80 op: Option<OKXWsOperation>,
81 },
82}
83
84pub(super) struct OKXWsFeedHandler {
85 signal: Arc<AtomicBool>,
86 inner: Option<WebSocketClient>,
87 cmd_rx: tokio::sync::mpsc::UnboundedReceiver<HandlerCommand>,
88 raw_rx: tokio::sync::mpsc::UnboundedReceiver<Message>,
89 out_tx: tokio::sync::mpsc::UnboundedSender<OKXWsMessage>,
90 auth_tracker: AuthTracker,
91 subscriptions_state: SubscriptionState,
92 retry_manager: RetryManager<OKXWsError>,
93 pending_messages: VecDeque<OKXWsMessage>,
94}
95
96impl OKXWsFeedHandler {
97 pub fn new(
99 signal: Arc<AtomicBool>,
100 cmd_rx: tokio::sync::mpsc::UnboundedReceiver<HandlerCommand>,
101 raw_rx: tokio::sync::mpsc::UnboundedReceiver<Message>,
102 out_tx: tokio::sync::mpsc::UnboundedSender<OKXWsMessage>,
103 auth_tracker: AuthTracker,
104 subscriptions_state: SubscriptionState,
105 ) -> Self {
106 Self {
107 signal,
108 inner: None,
109 cmd_rx,
110 raw_rx,
111 out_tx,
112 auth_tracker,
113 subscriptions_state,
114 retry_manager: create_websocket_retry_manager(),
115 pending_messages: VecDeque::new(),
116 }
117 }
118
119 pub(super) fn is_stopped(&self) -> bool {
120 self.signal.load(Ordering::Acquire)
121 }
122
123 pub(super) fn send(&self, msg: OKXWsMessage) -> Result<(), ()> {
124 self.out_tx.send(msg).map_err(|_| ())
125 }
126
127 async fn send_with_retry(
128 &self,
129 payload: String,
130 rate_limit_keys: Option<&[Ustr]>,
131 ) -> Result<(), OKXWsError> {
132 if let Some(client) = &self.inner {
133 let keys_owned: Option<Vec<Ustr>> = rate_limit_keys.map(|k| k.to_vec());
134 self.retry_manager
135 .execute_with_retry(
136 "websocket_send",
137 || {
138 let payload = payload.clone();
139 let keys = keys_owned.clone();
140 async move {
141 client
142 .send_text(payload, keys.as_deref())
143 .await
144 .map_err(|e| OKXWsError::ClientError(format!("Send failed: {e}")))
145 }
146 },
147 should_retry_okx_error,
148 create_okx_timeout_error,
149 )
150 .await
151 } else {
152 Err(OKXWsError::ClientError(
153 "No active WebSocket client".to_string(),
154 ))
155 }
156 }
157
158 pub(super) async fn send_pong(&self) -> anyhow::Result<()> {
159 match self.send_with_retry(TEXT_PONG.to_string(), None).await {
160 Ok(()) => {
161 log::trace!("Sent pong response to OKX text ping");
162 Ok(())
163 }
164 Err(e) => {
165 log::warn!("Failed to send pong after retries: error={e}");
166 Err(anyhow::anyhow!("Failed to send pong: {e}"))
167 }
168 }
169 }
170
171 pub(super) async fn next(&mut self) -> Option<OKXWsMessage> {
172 if let Some(message) = self.pending_messages.pop_front() {
173 return Some(message);
174 }
175
176 loop {
177 tokio::select! {
178 Some(cmd) = self.cmd_rx.recv() => {
179 match cmd {
180 HandlerCommand::SetClient(client) => {
181 log::debug!("Handler received WebSocket client");
182 self.inner = Some(client);
183 }
184 HandlerCommand::Disconnect => {
185 log::debug!("Handler disconnecting WebSocket client");
186 self.inner = None;
187 return None;
188 }
189 HandlerCommand::Authenticate { payload } => {
190 if let Err(e) = self.send_with_retry(
191 payload,
192 Some(OKX_RATE_LIMIT_KEY_SUBSCRIPTION.as_slice()),
193 ).await {
194 log::error!(
195 "Failed to send authentication message after retries: error={e}"
196 );
197 }
198 }
199 HandlerCommand::Subscribe { args } => {
200 if let Err(e) = self.handle_subscribe(args).await {
201 log::error!("Failed to handle subscribe command: error={e}");
202 }
203 }
204 HandlerCommand::Unsubscribe { args } => {
205 if let Err(e) = self.handle_unsubscribe(args).await {
206 log::error!("Failed to handle unsubscribe command: error={e}");
207 }
208 }
209 HandlerCommand::Send {
210 payload,
211 rate_limit_keys,
212 request_id,
213 client_order_id,
214 op,
215 } => {
216 if let Err(e) = self.send_with_retry(
217 payload,
218 rate_limit_keys.as_deref(),
219 ).await {
220 log::error!("Failed to send message after retries: error={e}");
221
222 if let Some(request_id) = request_id {
223 self.pending_messages.push_back(OKXWsMessage::SendFailed {
224 request_id,
225 client_order_id,
226 op,
227 error: format!("{e}"),
228 });
229 }
230 }
231 }
232 }
233 }
234
235 () = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
236 if self.signal.load(Ordering::Acquire) {
237 log::debug!("Stop signal received during idle period");
238 return None;
239 }
240 }
241
242 msg = self.raw_rx.recv() => {
243 let event = match msg {
244 Some(msg) => match Self::parse_raw_message(msg) {
245 Some(event) => event,
246 None => continue,
247 },
248 None => {
249 log::debug!("WebSocket stream closed");
250 return None;
251 }
252 };
253
254 match event {
255 OKXWsFrame::Ping => {
256 if let Err(e) = self.send_pong().await {
257 log::warn!("Failed to send pong response: error={e}");
258 }
259 }
260 OKXWsFrame::Login {
261 code, msg, conn_id, ..
262 } => {
263 if code == OKX_SUCCESS_CODE {
264 self.auth_tracker.succeed();
265 return Some(OKXWsMessage::Authenticated);
266 }
267
268 log::error!("WebSocket authentication failed: error={msg}");
269 self.auth_tracker.fail(msg.clone());
270
271 let error = OKXWebSocketError {
272 code,
273 message: msg,
274 conn_id: Some(conn_id),
275 timestamp: nautilus_core::time::get_atomic_clock_realtime()
276 .get_time_ns()
277 .as_u64(),
278 };
279 self.pending_messages.push_back(OKXWsMessage::Error(error));
280 }
281 OKXWsFrame::BookData { arg, action, data } => {
282 return Some(OKXWsMessage::BookData { arg, action, data });
283 }
284 OKXWsFrame::OrderResponse {
285 id, op, code, msg, data,
286 } => {
287 return Some(OKXWsMessage::OrderResponse {
288 id, op, code, msg, data,
289 });
290 }
291 OKXWsFrame::Data { arg, data } => {
292 if let Some(output) = self.route_data_message(arg, data) {
293 return Some(output);
294 }
295 }
296 OKXWsFrame::Error { code, msg } => {
297 let error = OKXWebSocketError {
298 code,
299 message: msg,
300 conn_id: None,
301 timestamp: nautilus_core::time::get_atomic_clock_realtime()
302 .get_time_ns()
303 .as_u64(),
304 };
305 return Some(OKXWsMessage::Error(error));
306 }
307 OKXWsFrame::Reconnected => {
308 self.auth_tracker.invalidate();
309 return Some(OKXWsMessage::Reconnected);
310 }
311 OKXWsFrame::Subscription {
312 event, arg, code, msg, ..
313 } => {
314 self.handle_subscription_ack(&event, &arg, code.as_deref(), msg.as_deref());
315 }
316 OKXWsFrame::ChannelConnCount { .. } => {}
317 }
318 }
319
320 else => {
321 log::debug!("Handler shutting down: stream ended or command channel closed");
322 return None;
323 }
324 }
325 }
326 }
327
328 fn route_data_message(&self, arg: OKXWebSocketArg, data: Value) -> Option<OKXWsMessage> {
329 let OKXWebSocketArg {
330 channel, inst_id, ..
331 } = arg;
332
333 match channel {
334 OKXWsChannel::Account => Some(OKXWsMessage::Account(data)),
335 OKXWsChannel::Positions => Some(OKXWsMessage::Positions(data)),
336 OKXWsChannel::Orders => match serde_json::from_value::<Vec<OKXOrderMsg>>(data) {
337 Ok(orders) => Some(OKXWsMessage::Orders(orders)),
338 Err(e) => {
339 log::error!("Failed to parse orders data: {e}");
340 None
341 }
342 },
343 OKXWsChannel::OrdersAlgo | OKXWsChannel::AlgoAdvance => {
344 match serde_json::from_value::<Vec<OKXAlgoOrderMsg>>(data) {
345 Ok(orders) => Some(OKXWsMessage::AlgoOrders(orders)),
346 Err(e) => {
347 log::error!("Failed to parse algo orders data: {e}");
348 None
349 }
350 }
351 }
352 OKXWsChannel::Instruments => match serde_json::from_value::<Vec<OKXInstrument>>(data) {
353 Ok(instruments) => Some(OKXWsMessage::Instruments(instruments)),
354 Err(e) => {
355 log::error!("Failed to parse instruments data: {e}");
356 None
357 }
358 },
359 _ => Some(OKXWsMessage::ChannelData {
360 channel,
361 inst_id,
362 data,
363 }),
364 }
365 }
366
367 fn handle_subscription_ack(
368 &self,
369 event: &OKXSubscriptionEvent,
370 arg: &OKXWebSocketArg,
371 code: Option<&str>,
372 msg: Option<&str>,
373 ) {
374 let topic = topic_from_websocket_arg(arg);
375 let success = code.is_none_or(|c| c == OKX_SUCCESS_CODE);
376
377 match event {
378 OKXSubscriptionEvent::Subscribe => {
379 if success {
380 self.subscriptions_state.confirm_subscribe(&topic);
381 } else {
382 log::warn!(
383 "Subscription failed: topic={topic:?}, error={msg:?}, code={code:?}"
384 );
385 self.subscriptions_state.mark_failure(&topic);
386 }
387 }
388 OKXSubscriptionEvent::Unsubscribe => {
389 if success {
390 self.subscriptions_state.confirm_unsubscribe(&topic);
391 } else {
392 log::warn!(
393 "Unsubscription failed - restoring subscription: \
394 topic={topic:?}, error={msg:?}, code={code:?}"
395 );
396 self.subscriptions_state.confirm_unsubscribe(&topic);
397 self.subscriptions_state.mark_subscribe(&topic);
398 self.subscriptions_state.confirm_subscribe(&topic);
399 }
400 }
401 }
402 }
403
404 async fn handle_subscribe(&self, args: Vec<OKXSubscriptionArg>) -> anyhow::Result<()> {
405 for arg in &args {
406 log::debug!(
407 "Subscribing to channel: channel={:?}, inst_id={:?}",
408 arg.channel,
409 arg.inst_id
410 );
411 }
412
413 let message = OKXSubscription {
414 op: OKXWsOperation::Subscribe,
415 args,
416 };
417
418 let json_txt = serde_json::to_string(&message)
419 .map_err(|e| anyhow::anyhow!("Failed to serialize subscription: {e}"))?;
420
421 self.send_with_retry(json_txt, Some(OKX_RATE_LIMIT_KEY_SUBSCRIPTION.as_slice()))
422 .await
423 .map_err(|e| anyhow::anyhow!("Failed to send subscription after retries: {e}"))?;
424 Ok(())
425 }
426
427 async fn handle_unsubscribe(&self, args: Vec<OKXSubscriptionArg>) -> anyhow::Result<()> {
428 for arg in &args {
429 log::debug!(
430 "Unsubscribing from channel: channel={:?}, inst_id={:?}",
431 arg.channel,
432 arg.inst_id
433 );
434 }
435
436 let message = OKXSubscription {
437 op: OKXWsOperation::Unsubscribe,
438 args,
439 };
440
441 let json_txt = serde_json::to_string(&message)
442 .map_err(|e| anyhow::anyhow!("Failed to serialize unsubscription: {e}"))?;
443
444 self.send_with_retry(json_txt, Some(OKX_RATE_LIMIT_KEY_SUBSCRIPTION.as_slice()))
445 .await
446 .map_err(|e| anyhow::anyhow!("Failed to send unsubscription after retries: {e}"))?;
447 Ok(())
448 }
449
450 pub(crate) fn parse_raw_message(
451 msg: tokio_tungstenite::tungstenite::Message,
452 ) -> Option<OKXWsFrame> {
453 match msg {
454 tokio_tungstenite::tungstenite::Message::Text(text) => {
455 if text == TEXT_PONG {
456 log::trace!("Received pong from OKX");
457 return None;
458 }
459
460 if text == TEXT_PING {
461 log::trace!("Received ping from OKX (text)");
462 return Some(OKXWsFrame::Ping);
463 }
464
465 if text == RECONNECTED {
466 log::debug!("Received WebSocket reconnection signal");
467 return Some(OKXWsFrame::Reconnected);
468 }
469 log::trace!("Received WebSocket message: {text}");
470
471 match serde_json::from_str(&text) {
472 Ok(ws_event) => match &ws_event {
473 OKXWsFrame::Error { code, msg } => {
474 log::error!("WebSocket error: {code} - {msg}");
475 Some(ws_event)
476 }
477 OKXWsFrame::Login {
478 event,
479 code,
480 msg,
481 conn_id,
482 } => {
483 if code == OKX_SUCCESS_CODE {
484 log::info!("WebSocket authenticated: conn_id={conn_id}");
485 } else {
486 log::error!(
487 "WebSocket authentication failed: \
488 event={event}, code={code}, error={msg}"
489 );
490 }
491 Some(ws_event)
492 }
493 OKXWsFrame::Subscription {
494 event,
495 arg,
496 conn_id,
497 ..
498 } => {
499 let channel_str = serde_json::to_string(&arg.channel)
500 .expect("Invalid OKX websocket channel")
501 .trim_matches('"')
502 .to_string();
503 log::debug!("{event}d: channel={channel_str}, conn_id={conn_id}");
504 Some(ws_event)
505 }
506 OKXWsFrame::ChannelConnCount {
507 channel,
508 conn_count,
509 conn_id,
510 ..
511 } => {
512 let channel_str = serde_json::to_string(channel)
513 .expect("Invalid OKX websocket channel")
514 .trim_matches('"')
515 .to_string();
516 log::debug!(
517 "Channel connection status: \
518 channel={channel_str}, connections={conn_count}, conn_id={conn_id}",
519 );
520 None
521 }
522 OKXWsFrame::Ping => {
523 log::trace!("Ignoring ping event parsed from text payload");
524 None
525 }
526 OKXWsFrame::Data { .. } | OKXWsFrame::BookData { .. } => Some(ws_event),
527 OKXWsFrame::OrderResponse {
528 id, op, code, data, ..
529 } => {
530 if code == OKX_SUCCESS_CODE {
531 log::debug!(
532 "Order operation successful: id={id:?}, op={op}, code={code}"
533 );
534
535 if let Some(order_data) = data.first() {
536 let success_msg = order_data
537 .get(OKX_FIELD_SMSG)
538 .and_then(|s| s.as_str())
539 .unwrap_or("Order operation successful");
540 log::debug!("Order success details: {success_msg}");
541 }
542 }
543 Some(ws_event)
544 }
545 OKXWsFrame::Reconnected => {
546 log::warn!("Unexpected Reconnected event from deserialization");
547 None
548 }
549 },
550 Err(e) => {
551 log::error!("Failed to parse message: {e}: {text}");
552 None
553 }
554 }
555 }
556 Message::Ping(_payload) => {
557 log::trace!("Received binary ping frame from OKX");
558 Some(OKXWsFrame::Ping)
559 }
560 Message::Pong(payload) => {
561 log::trace!("Received pong frame from OKX ({} bytes)", payload.len());
562 None
563 }
564 Message::Binary(msg) => {
565 log::debug!("Raw binary: {msg:?}");
566 None
567 }
568 Message::Close(_) => {
569 log::debug!("Received close message");
570 None
571 }
572 msg => {
573 log::warn!("Unexpected message: {msg}");
574 None
575 }
576 }
577 }
578}
579
580pub fn is_post_only_rejection(code: &str, data: &[Value]) -> bool {
582 use crate::common::consts::OKX_POST_ONLY_ERROR_CODE;
583
584 if code == OKX_POST_ONLY_ERROR_CODE {
585 return true;
586 }
587
588 for entry in data {
589 if let Some(s_code) = entry.get(OKX_FIELD_SCODE).and_then(|value| value.as_str())
590 && s_code == OKX_POST_ONLY_ERROR_CODE
591 {
592 return true;
593 }
594
595 if let Some(inner_code) = entry.get("code").and_then(|value| value.as_str())
596 && inner_code == OKX_POST_ONLY_ERROR_CODE
597 {
598 return true;
599 }
600 }
601
602 false
603}
604
605pub fn is_post_only_auto_cancel(msg: &OKXOrderMsg) -> bool {
607 use crate::common::{consts::OKX_POST_ONLY_CANCEL_SOURCE, enums::OKXOrderStatus};
608
609 if msg.state != OKXOrderStatus::Canceled {
610 return false;
611 }
612
613 let cancel_source_matches = matches!(
614 msg.cancel_source.as_deref(),
615 Some(source) if source == OKX_POST_ONLY_CANCEL_SOURCE
616 );
617
618 let reason_matches = matches!(
619 msg.cancel_source_reason.as_deref(),
620 Some(reason) if reason.contains("POST_ONLY")
621 );
622
623 if !(cancel_source_matches || reason_matches) {
624 return false;
625 }
626
627 msg.acc_fill_sz
628 .as_ref()
629 .is_none_or(|filled| filled == "0" || filled.is_empty())
630}
631
632#[inline]
633fn contains_ignore_ascii_case(haystack: &str, needle: &str) -> bool {
634 haystack
635 .as_bytes()
636 .windows(needle.len())
637 .any(|window| window.eq_ignore_ascii_case(needle.as_bytes()))
638}
639
640fn should_retry_okx_error(error: &OKXWsError) -> bool {
641 match error {
642 OKXWsError::OkxError { error_code, .. } => should_retry_error_code(error_code),
643 OKXWsError::TungsteniteError(_) => true,
644 OKXWsError::ClientError(msg) => {
645 contains_ignore_ascii_case(msg, "timeout")
646 || contains_ignore_ascii_case(msg, "timed out")
647 || contains_ignore_ascii_case(msg, "connection")
648 || contains_ignore_ascii_case(msg, "network")
649 }
650 OKXWsError::AuthenticationError(_)
651 | OKXWsError::JsonError(_)
652 | OKXWsError::ParsingError(_) => false,
653 }
654}
655
656fn create_okx_timeout_error(msg: String) -> OKXWsError {
657 OKXWsError::ClientError(msg)
658}
659
660#[cfg(test)]
661mod tests {
662 use rstest::rstest;
663 use serde_json::json;
664
665 use super::*;
666
667 #[rstest]
668 fn test_is_post_only_rejection_detects_by_code() {
669 assert!(is_post_only_rejection("51019", &[]));
670 }
671
672 #[rstest]
673 fn test_is_post_only_rejection_detects_by_inner_code() {
674 let data = vec![json!({ "sCode": "51019" })];
675 assert!(is_post_only_rejection("50000", &data));
676 }
677
678 #[rstest]
679 fn test_is_post_only_rejection_false_for_unrelated_error() {
680 let data = vec![json!({ "sMsg": "Insufficient balance" })];
681 assert!(!is_post_only_rejection("50000", &data));
682 }
683}