1use std::{
19 future::Future,
20 sync::{Arc, Mutex},
21 time::{Duration, Instant},
22};
23
24use ahash::AHashMap;
25use anyhow::Context;
26use async_trait::async_trait;
27use futures_util::{StreamExt, pin_mut};
28use nautilus_common::{
29 clients::ExecutionClient,
30 live::{get_runtime, runner::get_exec_event_sender},
31 messages::execution::{
32 BatchCancelOrders, CancelAllOrders, CancelOrder, GenerateFillReports,
33 GenerateFillReportsBuilder, GenerateOrderStatusReport, GenerateOrderStatusReports,
34 GenerateOrderStatusReportsBuilder, GeneratePositionStatusReports,
35 GeneratePositionStatusReportsBuilder, ModifyOrder, QueryAccount, QueryOrder, SubmitOrder,
36 SubmitOrderList,
37 },
38};
39use nautilus_core::{
40 MUTEX_POISONED, UnixNanos,
41 params::Params,
42 time::{AtomicTime, get_atomic_clock_realtime},
43};
44use nautilus_live::{ExecutionClientCore, ExecutionEventEmitter};
45use nautilus_model::{
46 accounts::AccountAny,
47 enums::{AccountType, OmsType, OrderSide, OrderType, TimeInForce, TrailingOffsetType},
48 identifiers::{
49 AccountId, ClientId, ClientOrderId, InstrumentId, StrategyId, TraderId, Venue, VenueOrderId,
50 },
51 orders::Order,
52 reports::{ExecutionMassStatus, FillReport, OrderStatusReport, PositionStatusReport},
53 types::{AccountBalance, MarginBalance, Money, Quantity},
54};
55use rust_decimal::Decimal;
56use tokio::task::JoinHandle;
57use ustr::Ustr;
58
59use crate::{
60 common::{
61 consts::{
62 OKX_CONDITIONAL_ORDER_TYPES, OKX_SUCCESS_CODE, OKX_VENUE, OKX_WS_HEARTBEAT_SECS,
63 resolve_instrument_families,
64 },
65 enums::{OKXInstrumentType, OKXMarginMode, OKXTradeMode, is_advance_algo_order},
66 parse::{nanos_to_datetime, okx_instrument_type_from_symbol},
67 },
68 config::OKXExecClientConfig,
69 http::{client::OKXHttpClient, models::OKXCancelAlgoOrderRequest},
70 websocket::{
71 client::OKXWebSocketClient,
72 dispatch::{
73 AlgoCancelContext, OrderIdentity, WsDispatchState, dispatch_ws_message,
74 emit_algo_cancel_rejections, emit_batch_cancel_failure,
75 },
76 parse::OrderStateSnapshot,
77 },
78};
79
80fn get_param_as_string(params: &Option<Params>, key: &str) -> Option<String> {
81 params.as_ref().and_then(|p| {
82 p.get(key).and_then(|v| {
83 v.as_str()
84 .map(ToString::to_string)
85 .or_else(|| v.as_f64().map(|n| n.to_string()))
86 })
87 })
88}
89
90#[derive(Debug)]
91pub struct OKXExecutionClient {
92 core: ExecutionClientCore,
93 clock: &'static AtomicTime,
94 config: OKXExecClientConfig,
95 emitter: ExecutionEventEmitter,
96 http_client: OKXHttpClient,
97 ws_private: OKXWebSocketClient,
98 ws_business: OKXWebSocketClient,
99 trade_mode: OKXTradeMode,
100 ws_stream_handle: Option<JoinHandle<()>>,
101 ws_business_stream_handle: Option<JoinHandle<()>>,
102 ws_dispatch_state: Arc<WsDispatchState>,
103 pending_tasks: Mutex<Vec<JoinHandle<()>>>,
104}
105
106impl OKXExecutionClient {
107 pub fn new(core: ExecutionClientCore, config: OKXExecClientConfig) -> anyhow::Result<Self> {
113 let http_client = OKXHttpClient::with_credentials(
114 config.api_key.clone(),
115 config.api_secret.clone(),
116 config.api_passphrase.clone(),
117 config.base_url_http.clone(),
118 config.http_timeout_secs,
119 config.max_retries,
120 config.retry_delay_initial_ms,
121 config.retry_delay_max_ms,
122 config.environment,
123 config.proxy_url.clone(),
124 )?;
125
126 let account_id = core.account_id;
127
128 let ws_private = OKXWebSocketClient::with_credentials(
129 Some(config.ws_private_url()),
130 config.api_key.clone(),
131 config.api_secret.clone(),
132 config.api_passphrase.clone(),
133 Some(account_id),
134 Some(OKX_WS_HEARTBEAT_SECS),
135 None,
136 config.transport_backend,
137 config.proxy_url.clone(),
138 )
139 .context("failed to construct OKX private websocket client")?;
140
141 let ws_business = OKXWebSocketClient::with_credentials(
142 Some(config.ws_business_url()),
143 config.api_key.clone(),
144 config.api_secret.clone(),
145 config.api_passphrase.clone(),
146 Some(account_id),
147 Some(OKX_WS_HEARTBEAT_SECS),
148 None,
149 config.transport_backend,
150 config.proxy_url.clone(),
151 )
152 .context("failed to construct OKX business websocket client")?;
153
154 let trade_mode = Self::derive_default_trade_mode(core.account_type, &config);
155 let clock = get_atomic_clock_realtime();
156 let emitter = ExecutionEventEmitter::new(
157 clock,
158 core.trader_id,
159 core.account_id,
160 core.account_type,
161 None,
162 );
163
164 let ws_dispatch_state = Arc::new(WsDispatchState::with_pending_maps(
165 ws_private.pending_orders.clone(),
166 ws_private.pending_cancels.clone(),
167 ws_private.pending_amends.clone(),
168 ));
169
170 Ok(Self {
171 core,
172 clock,
173 config,
174 emitter,
175 http_client,
176 ws_private,
177 ws_business,
178 trade_mode,
179 ws_stream_handle: None,
180 ws_business_stream_handle: None,
181 ws_dispatch_state,
182 pending_tasks: Mutex::new(Vec::new()),
183 })
184 }
185
186 fn derive_default_trade_mode(
187 account_type: AccountType,
188 config: &OKXExecClientConfig,
189 ) -> OKXTradeMode {
190 let is_cross_margin = config.margin_mode == Some(OKXMarginMode::Cross);
191
192 if account_type == AccountType::Cash {
193 if !config.use_spot_margin {
194 return OKXTradeMode::Cash;
195 }
196 return if is_cross_margin {
197 OKXTradeMode::Cross
198 } else {
199 OKXTradeMode::Isolated
200 };
201 }
202
203 if is_cross_margin {
204 OKXTradeMode::Cross
205 } else {
206 OKXTradeMode::Isolated
207 }
208 }
209
210 fn trade_mode_for_order(
211 &self,
212 instrument_id: InstrumentId,
213 params: &Option<Params>,
214 ) -> OKXTradeMode {
215 if let Some(td_mode_str) = get_param_as_string(params, "td_mode") {
216 match td_mode_str.parse::<OKXTradeMode>() {
217 Ok(mode) => return mode,
218 Err(_) => {
219 log::warn!("Invalid td_mode '{td_mode_str}', using derived trade mode");
220 }
221 }
222 }
223
224 derive_trade_mode_for_instrument(
225 instrument_id,
226 self.config.margin_mode,
227 self.config.use_spot_margin,
228 )
229 }
230
231 fn instrument_types(&self) -> Vec<OKXInstrumentType> {
232 if self.config.instrument_types.is_empty() {
233 vec![OKXInstrumentType::Spot]
234 } else {
235 self.config.instrument_types.clone()
236 }
237 }
238
239 fn update_account_state(&self) {
240 let http_client = self.http_client.clone();
241 let account_id = self.core.account_id;
242 let emitter = self.emitter.clone();
243
244 self.spawn_task("query_account", async move {
245 let account_state = http_client
246 .request_account_state(account_id)
247 .await
248 .context("failed to request OKX account state")?;
249 emitter.send_account_state(account_state);
250 Ok(())
251 });
252 }
253
254 fn is_conditional_order(&self, order_type: OrderType) -> bool {
255 OKX_CONDITIONAL_ORDER_TYPES.contains(&order_type)
256 }
257
258 fn submit_regular_order(&self, cmd: &SubmitOrder) -> anyhow::Result<()> {
259 let order = {
260 let cache = self.core.cache();
261 cache
262 .order(&cmd.client_order_id)
263 .cloned()
264 .ok_or_else(|| anyhow::anyhow!("Order not found: {}", cmd.client_order_id))?
265 };
266 let ws_private = self.ws_private.clone();
267 let trade_mode = self.trade_mode_for_order(cmd.instrument_id, &cmd.params);
268
269 let emitter = self.emitter.clone();
270 let clock = self.clock;
271 let trader_id = self.core.trader_id;
272 let client_order_id = order.client_order_id();
273 let strategy_id = order.strategy_id();
274 let instrument_id = order.instrument_id();
275
276 self.ws_dispatch_state.order_identities.insert(
277 client_order_id,
278 OrderIdentity {
279 instrument_id,
280 strategy_id,
281 order_side: order.order_side(),
282 order_type: order.order_type(),
283 },
284 );
285 let order_side = order.order_side();
286 let order_type = order.order_type();
287 let quantity = order.quantity();
288 let time_in_force = order.time_in_force();
289 let price = order.price();
290 let trigger_price = order.trigger_price();
291 let is_post_only = order.is_post_only();
292 let is_reduce_only = order.is_reduce_only();
293 let is_quote_quantity = order.is_quote_quantity();
294
295 let px_usd = get_param_as_string(&cmd.params, "px_usd");
296 let px_vol = get_param_as_string(&cmd.params, "px_vol");
297
298 self.spawn_task("submit_order", async move {
299 let result = ws_private
300 .submit_order(
301 trader_id,
302 strategy_id,
303 instrument_id,
304 trade_mode,
305 client_order_id,
306 order_side,
307 order_type,
308 quantity,
309 Some(time_in_force),
310 price,
311 trigger_price,
312 Some(is_post_only),
313 Some(is_reduce_only),
314 Some(is_quote_quantity),
315 None,
316 None,
317 px_usd,
318 px_vol,
319 )
320 .await
321 .map_err(|e| anyhow::anyhow!("Submit order failed: {e}"));
322
323 if let Err(e) = result {
324 let ts_event = clock.get_time_ns();
325 emitter.emit_order_rejected_event(
326 strategy_id,
327 instrument_id,
328 client_order_id,
329 &format!("submit-order-error: {e}"),
330 ts_event,
331 false,
332 );
333 return Err(e);
334 }
335
336 Ok(())
337 });
338
339 Ok(())
340 }
341
342 fn submit_conditional_order(&self, cmd: &SubmitOrder) -> anyhow::Result<()> {
343 let order = {
344 let cache = self.core.cache();
345 cache
346 .order(&cmd.client_order_id)
347 .cloned()
348 .ok_or_else(|| anyhow::anyhow!("Order not found: {}", cmd.client_order_id))?
349 };
350 let http_client = self.http_client.clone();
351 let trade_mode = self.trade_mode_for_order(cmd.instrument_id, &cmd.params);
352
353 let emitter = self.emitter.clone();
354 let clock = self.clock;
355 let client_order_id = order.client_order_id();
356 let strategy_id = order.strategy_id();
357 let instrument_id = order.instrument_id();
358 let order_side = order.order_side();
359 let order_type = order.order_type();
360
361 self.ws_dispatch_state.order_identities.insert(
362 client_order_id,
363 OrderIdentity {
364 instrument_id,
365 strategy_id,
366 order_side,
367 order_type,
368 },
369 );
370 let quantity = order.quantity();
371 let trigger_type = order.trigger_type();
372 let trigger_price = order.trigger_price();
373 let price = order.price();
374 let is_reduce_only = order.is_reduce_only();
375
376 let trailing_offset = order.trailing_offset();
377 let trailing_offset_type = order.trailing_offset_type();
378 let activation_price = order.activation_price();
379
380 let close_fraction = get_param_as_string(&cmd.params, "close_fraction");
381 let reduce_only = if close_fraction.is_some() {
382 Some(true)
383 } else {
384 Some(is_reduce_only)
385 };
386
387 let (callback_ratio, callback_spread) = if order_type == OrderType::TrailingStopMarket {
388 let offset = trailing_offset
389 .ok_or_else(|| anyhow::anyhow!("TrailingStopMarket requires trailing_offset"))?;
390 let offset_type = trailing_offset_type.ok_or_else(|| {
391 anyhow::anyhow!("TrailingStopMarket requires trailing_offset_type")
392 })?;
393
394 match offset_type {
395 TrailingOffsetType::BasisPoints => {
396 let ratio = offset / Decimal::from(10000);
398 (Some(ratio.to_string()), None)
399 }
400 TrailingOffsetType::Price => (None, Some(offset.to_string())),
401 _ => {
402 anyhow::bail!("Unsupported trailing_offset_type for OKX: {offset_type:?}");
403 }
404 }
405 } else {
406 (None, None)
407 };
408
409 self.spawn_task("submit_algo_order", async move {
410 let result = http_client
411 .place_algo_order_with_domain_types(
412 instrument_id,
413 trade_mode,
414 client_order_id,
415 order_side,
416 order_type,
417 quantity,
418 trigger_price,
419 trigger_type,
420 price,
421 reduce_only,
422 close_fraction,
423 callback_ratio,
424 callback_spread,
425 activation_price,
426 )
427 .await
428 .map_err(|e| anyhow::anyhow!("Submit algo order failed: {e}"));
429
430 if let Err(e) = result {
431 let ts_event = clock.get_time_ns();
432 emitter.emit_order_rejected_event(
433 strategy_id,
434 instrument_id,
435 client_order_id,
436 &format!("submit-order-error: {e}"),
437 ts_event,
438 false,
439 );
440 return Err(e);
441 }
442
443 Ok(())
444 });
445
446 Ok(())
447 }
448
449 fn cancel_ws_order(&self, cmd: &CancelOrder) {
450 self.ensure_order_identity(cmd.client_order_id, cmd.strategy_id, cmd.instrument_id);
451
452 let ws_private = self.ws_private.clone();
453 let command = cmd.clone();
454
455 let emitter = self.emitter.clone();
456 let clock = self.clock;
457
458 self.spawn_task("cancel_order", async move {
459 let result = ws_private
460 .cancel_order(
461 command.trader_id,
462 command.strategy_id,
463 command.instrument_id,
464 Some(command.client_order_id),
465 command.venue_order_id,
466 )
467 .await
468 .map_err(|e| anyhow::anyhow!("Cancel order failed: {e}"));
469
470 if let Err(e) = result {
471 let ts_event = clock.get_time_ns();
472 emitter.emit_order_cancel_rejected_event(
473 command.strategy_id,
474 command.instrument_id,
475 command.client_order_id,
476 command.venue_order_id,
477 &format!("cancel-order-error: {e}"),
478 ts_event,
479 );
480 return Err(e);
481 }
482
483 Ok(())
484 });
485 }
486
487 fn cancel_algo_order(&self, cmd: &CancelOrder) {
488 let http_client = self.http_client.clone();
489 let command = cmd.clone();
490 let emitter = self.emitter.clone();
491 let clock = self.clock;
492
493 let cache = self.core.cache();
494 let is_advance = cache
495 .order(&cmd.client_order_id)
496 .is_some_and(|o| is_advance_algo_order(o.order_type()));
497 drop(cache);
498
499 let request = OKXCancelAlgoOrderRequest {
500 inst_id: cmd.instrument_id.symbol.to_string(),
501 inst_id_code: None,
502 algo_id: cmd.venue_order_id.map(|id| id.to_string()),
503 algo_cl_ord_id: if cmd.venue_order_id.is_none() {
504 Some(cmd.client_order_id.to_string())
505 } else {
506 None
507 },
508 };
509
510 self.spawn_task("cancel_algo_order", async move {
511 let responses = if is_advance {
512 http_client
513 .cancel_advance_algo_orders(vec![request])
514 .await
515 .map_err(|e| anyhow::anyhow!("Cancel advance algo order failed: {e}"))
516 } else {
517 http_client
518 .cancel_algo_orders(vec![request])
519 .await
520 .map_err(|e| anyhow::anyhow!("Cancel algo order failed: {e}"))
521 };
522
523 let reject_reason = match &responses {
524 Err(e) => Some(format!("cancel-algo-order-error: {e}")),
525 Ok(resps) => {
526 resps.first().and_then(|r| {
528 r.s_code.as_deref().and_then(|code| {
529 if code == OKX_SUCCESS_CODE {
530 None
531 } else {
532 let msg = r.s_msg.as_deref().unwrap_or("unknown");
533 Some(format!(
534 "cancel-algo-order-rejected: s_code={code}, s_msg={msg}"
535 ))
536 }
537 })
538 })
539 }
540 };
541
542 if let Some(reason) = reject_reason {
543 let ts_event = clock.get_time_ns();
544 emitter.emit_order_cancel_rejected_event(
545 command.strategy_id,
546 command.instrument_id,
547 command.client_order_id,
548 command.venue_order_id,
549 &reason,
550 ts_event,
551 );
552 anyhow::bail!("{reason}");
553 }
554
555 Ok(())
556 });
557 }
558
559 fn mass_cancel_instrument(&self, instrument_id: InstrumentId) {
560 let ws_private = self.ws_private.clone();
561
562 self.spawn_task("mass_cancel_orders", async move {
563 ws_private.mass_cancel_orders(instrument_id).await?;
564 Ok(())
565 });
566 }
567
568 fn ensure_order_identity(
577 &self,
578 client_order_id: ClientOrderId,
579 strategy_id: StrategyId,
580 instrument_id: InstrumentId,
581 ) {
582 self.ws_dispatch_state
583 .order_identities
584 .entry(client_order_id)
585 .or_insert_with(|| {
586 let cache = self.core.cache();
587 let (order_side, order_type) = cache
588 .order(&client_order_id)
589 .map_or((OrderSide::NoOrderSide, OrderType::Market), |o| {
590 (o.order_side(), o.order_type())
591 });
592 drop(cache);
593
594 OrderIdentity {
595 instrument_id,
596 strategy_id,
597 order_side,
598 order_type,
599 }
600 });
601 }
602
603 fn spawn_task<F>(&self, description: &'static str, fut: F)
604 where
605 F: Future<Output = anyhow::Result<()>> + Send + 'static,
606 {
607 let runtime = get_runtime();
608 let handle = runtime.spawn(async move {
609 if let Err(e) = fut.await {
610 log::warn!("{description} failed: {e:?}");
611 }
612 });
613
614 let mut tasks = self.pending_tasks.lock().expect(MUTEX_POISONED);
615 tasks.retain(|handle| !handle.is_finished());
616 tasks.push(handle);
617 }
618
619 fn dispatch_algo_cancels(&self, items: Vec<(OKXCancelAlgoOrderRequest, AlgoCancelContext)>) {
622 let mut regular_requests = Vec::new();
623 let mut regular_contexts = Vec::new();
624 let mut advance_requests = Vec::new();
625 let mut advance_contexts = Vec::new();
626
627 let cache = self.core.cache();
628
629 for (request, ctx) in items {
630 let is_advance = cache
631 .order(&ctx.client_order_id)
632 .is_some_and(|o| is_advance_algo_order(o.order_type()));
633
634 if is_advance {
635 advance_requests.push(request);
636 advance_contexts.push(ctx);
637 } else {
638 regular_requests.push(request);
639 regular_contexts.push(ctx);
640 }
641 }
642
643 drop(cache);
644
645 if !regular_requests.is_empty() {
646 let client = self.http_client.clone();
647 let emitter = self.emitter.clone();
648 let clock = self.clock;
649
650 self.spawn_task("cancel_algo_orders", async move {
651 match client.cancel_algo_orders(regular_requests).await {
652 Ok(responses) => {
653 emit_algo_cancel_rejections(&responses, ®ular_contexts, &emitter, clock);
654 }
655 Err(e) => {
656 let msg = format!("{e}");
657 emit_batch_cancel_failure(®ular_contexts, &msg, &emitter, clock);
658 anyhow::bail!("{e}");
659 }
660 }
661 Ok(())
662 });
663 }
664
665 if !advance_requests.is_empty() {
666 let client = self.http_client.clone();
667 let emitter = self.emitter.clone();
668 let clock = self.clock;
669
670 self.spawn_task("cancel_advance_algo_orders", async move {
671 match client.cancel_advance_algo_orders(advance_requests).await {
672 Ok(responses) => {
673 emit_algo_cancel_rejections(&responses, &advance_contexts, &emitter, clock);
674 }
675 Err(e) => {
676 let msg = format!("{e}");
677 emit_batch_cancel_failure(&advance_contexts, &msg, &emitter, clock);
678 anyhow::bail!("{e}");
679 }
680 }
681 Ok(())
682 });
683 }
684 }
685
686 fn abort_pending_tasks(&self) {
687 let mut tasks = self.pending_tasks.lock().expect(MUTEX_POISONED);
688
689 for handle in tasks.drain(..) {
690 handle.abort();
691 }
692 }
693
694 async fn await_account_registered(&self, timeout_secs: f64) -> anyhow::Result<()> {
696 let account_id = self.core.account_id;
697
698 if self.core.cache().account(&account_id).is_some() {
699 log::info!("Account {account_id} registered");
700 return Ok(());
701 }
702
703 let start = Instant::now();
704 let timeout = Duration::from_secs_f64(timeout_secs);
705 let interval = Duration::from_millis(10);
706
707 loop {
708 tokio::time::sleep(interval).await;
709
710 if self.core.cache().account(&account_id).is_some() {
711 log::info!("Account {account_id} registered");
712 return Ok(());
713 }
714
715 if start.elapsed() >= timeout {
716 anyhow::bail!(
717 "Timeout waiting for account {account_id} to be registered after {timeout_secs}s"
718 );
719 }
720 }
721 }
722}
723
724fn derive_trade_mode_for_instrument(
725 instrument_id: InstrumentId,
726 margin_mode: Option<OKXMarginMode>,
727 use_spot_margin: bool,
728) -> OKXTradeMode {
729 let inst_type = okx_instrument_type_from_symbol(instrument_id.symbol.as_str());
730 let is_cross_margin = margin_mode == Some(OKXMarginMode::Cross);
731
732 match inst_type {
733 OKXInstrumentType::Spot => {
734 if use_spot_margin {
735 if is_cross_margin {
736 OKXTradeMode::Cross
737 } else {
738 OKXTradeMode::Isolated
739 }
740 } else {
741 OKXTradeMode::Cash
742 }
743 }
744 _ => {
745 if is_cross_margin {
746 OKXTradeMode::Cross
747 } else {
748 OKXTradeMode::Isolated
749 }
750 }
751 }
752}
753
754#[async_trait(?Send)]
755impl ExecutionClient for OKXExecutionClient {
756 fn is_connected(&self) -> bool {
757 self.core.is_connected()
758 }
759
760 fn client_id(&self) -> ClientId {
761 self.core.client_id
762 }
763
764 fn account_id(&self) -> AccountId {
765 self.core.account_id
766 }
767
768 fn venue(&self) -> Venue {
769 *OKX_VENUE
770 }
771
772 fn oms_type(&self) -> OmsType {
773 self.core.oms_type
774 }
775
776 fn get_account(&self) -> Option<AccountAny> {
777 self.core.cache().account(&self.core.account_id).cloned()
778 }
779
780 async fn connect(&mut self) -> anyhow::Result<()> {
781 if self.core.is_connected() {
782 return Ok(());
783 }
784
785 let instrument_types = self.instrument_types();
786
787 if !self.core.instruments_initialized() {
788 let mut all_instruments = Vec::new();
789 let mut all_inst_id_codes = Vec::new();
790
791 for instrument_type in &instrument_types {
792 let Some(families) =
793 resolve_instrument_families(&self.config.instrument_families, *instrument_type)
794 else {
795 continue;
796 };
797
798 if families.is_empty() {
799 let (instruments, inst_id_codes) = self
800 .http_client
801 .request_instruments(*instrument_type, None)
802 .await
803 .with_context(|| {
804 format!("failed to request OKX instruments for {instrument_type:?}")
805 })?;
806
807 if instruments.is_empty() {
808 log::warn!("No instruments returned for {instrument_type:?}");
809 continue;
810 }
811
812 log::info!(
813 "Loaded {} {instrument_type:?} instruments",
814 instruments.len()
815 );
816
817 self.http_client.cache_instruments(&instruments);
818 all_instruments.extend(instruments);
819 all_inst_id_codes.extend(inst_id_codes);
820 } else {
821 for family in &families {
822 let (instruments, inst_id_codes) = self
823 .http_client
824 .request_instruments(*instrument_type, Some(family.clone()))
825 .await
826 .with_context(|| {
827 format!(
828 "failed to request OKX instruments for {instrument_type:?} family {family}"
829 )
830 })?;
831
832 if instruments.is_empty() {
833 log::warn!(
834 "No instruments returned for {instrument_type:?} family {family}"
835 );
836 continue;
837 }
838
839 log::info!(
840 "Loaded {} {instrument_type:?} instruments for family {family}",
841 instruments.len()
842 );
843
844 self.http_client.cache_instruments(&instruments);
845 all_instruments.extend(instruments);
846 all_inst_id_codes.extend(inst_id_codes);
847 }
848 }
849 }
850
851 if all_instruments.is_empty() {
852 anyhow::bail!(
853 "No instruments loaded for configured types {instrument_types:?}, \
854 cannot initialize execution client"
855 );
856 }
857
858 self.ws_private.cache_instruments(&all_instruments);
859 self.ws_private
860 .cache_inst_id_codes(all_inst_id_codes.clone());
861 self.ws_business.cache_instruments(&all_instruments);
862 self.ws_business.cache_inst_id_codes(all_inst_id_codes);
863 self.core.set_instruments_initialized();
864 }
865
866 self.ws_private.connect().await?;
867 self.ws_private.wait_until_active(10.0).await?;
868 log::info!("Connected to private WebSocket");
869
870 if self.ws_stream_handle.is_none() {
871 let stream = self.ws_private.stream();
872 let emitter = self.emitter.clone();
873 let state = Arc::clone(&self.ws_dispatch_state);
874 let account_id = self.core.account_id;
875 let instruments = self.ws_private.instruments_snapshot();
876 let clock = self.clock;
877
878 let handle = get_runtime().spawn(async move {
879 let mut fee_cache: AHashMap<Ustr, Money> = AHashMap::new();
880 let mut filled_qty_cache: AHashMap<Ustr, Quantity> = AHashMap::new();
881 let mut order_state_cache: AHashMap<ClientOrderId, OrderStateSnapshot> =
882 AHashMap::new();
883
884 pin_mut!(stream);
885
886 while let Some(message) = stream.next().await {
887 dispatch_ws_message(
888 message,
889 &emitter,
890 &state,
891 account_id,
892 &instruments,
893 &mut fee_cache,
894 &mut filled_qty_cache,
895 &mut order_state_cache,
896 clock,
897 );
898 }
899 });
900 self.ws_stream_handle = Some(handle);
901 }
902
903 self.ws_business.connect().await?;
904 self.ws_business.wait_until_active(10.0).await?;
905 log::info!("Connected to business WebSocket");
906
907 if self.ws_business_stream_handle.is_none() {
908 let stream = self.ws_business.stream();
909 let emitter = self.emitter.clone();
910 let state = Arc::clone(&self.ws_dispatch_state);
911 let account_id = self.core.account_id;
912 let instruments = self.ws_business.instruments_snapshot();
913 let clock = self.clock;
914
915 let handle = get_runtime().spawn(async move {
916 let mut fee_cache: AHashMap<Ustr, Money> = AHashMap::new();
917 let mut filled_qty_cache: AHashMap<Ustr, Quantity> = AHashMap::new();
918 let mut order_state_cache: AHashMap<ClientOrderId, OrderStateSnapshot> =
919 AHashMap::new();
920
921 pin_mut!(stream);
922
923 while let Some(message) = stream.next().await {
924 dispatch_ws_message(
925 message,
926 &emitter,
927 &state,
928 account_id,
929 &instruments,
930 &mut fee_cache,
931 &mut filled_qty_cache,
932 &mut order_state_cache,
933 clock,
934 );
935 }
936 });
937
938 self.ws_business_stream_handle = Some(handle);
939 }
940
941 for inst_type in &instrument_types {
942 log::info!("Subscribing to orders channel for {inst_type:?}");
943 self.ws_private.subscribe_orders(*inst_type).await?;
944
945 if self.config.use_fills_channel {
946 log::info!("Subscribing to fills channel for {inst_type:?}");
947 if let Err(e) = self.ws_private.subscribe_fills(*inst_type).await {
948 log::warn!("Failed to subscribe to fills channel ({inst_type:?}): {e}");
949 }
950 }
951 }
952
953 self.ws_private.subscribe_account().await?;
954
955 for inst_type in &instrument_types {
957 if *inst_type != OKXInstrumentType::Option {
958 self.ws_business.subscribe_orders_algo(*inst_type).await?;
959 self.ws_business.subscribe_algo_advance(*inst_type).await?;
960 }
961 }
962
963 let account_state = self
964 .http_client
965 .request_account_state(self.core.account_id)
966 .await
967 .context("failed to request OKX account state")?;
968
969 if !account_state.balances.is_empty() {
970 log::info!(
971 "Received account state with {} balance(s)",
972 account_state.balances.len()
973 );
974 }
975 self.emitter.send_account_state(account_state);
976
977 self.await_account_registered(30.0).await?;
979
980 self.core.set_connected();
981 log::info!("Connected: client_id={}", self.core.client_id);
982 Ok(())
983 }
984
985 async fn disconnect(&mut self) -> anyhow::Result<()> {
986 if self.core.is_disconnected() {
987 return Ok(());
988 }
989
990 self.abort_pending_tasks();
991 self.http_client.cancel_all_requests();
992
993 if let Err(e) = self.ws_private.close().await {
994 log::warn!("Error closing private websocket: {e:?}");
995 }
996
997 if let Err(e) = self.ws_business.close().await {
998 log::warn!("Error closing business websocket: {e:?}");
999 }
1000
1001 if let Some(handle) = self.ws_stream_handle.take() {
1002 handle.abort();
1003 }
1004
1005 if let Some(handle) = self.ws_business_stream_handle.take() {
1006 handle.abort();
1007 }
1008
1009 self.core.set_disconnected();
1010 log::info!("Disconnected: client_id={}", self.core.client_id);
1011 Ok(())
1012 }
1013
1014 fn query_account(&self, _cmd: QueryAccount) -> anyhow::Result<()> {
1015 self.update_account_state();
1016 Ok(())
1017 }
1018
1019 fn query_order(&self, cmd: QueryOrder) -> anyhow::Result<()> {
1020 let http_client = self.http_client.clone();
1021 let account_id = self.core.account_id;
1022 let emitter = self.emitter.clone();
1023 let instrument_id = cmd.instrument_id;
1024 let client_order_id = cmd.client_order_id;
1025 let venue_order_id = cmd.venue_order_id;
1026
1027 self.spawn_task("query_order", async move {
1028 let mut reports = match http_client
1029 .request_order_status_reports(
1030 account_id,
1031 None,
1032 Some(instrument_id),
1033 None,
1034 None,
1035 false,
1036 None,
1037 )
1038 .await
1039 {
1040 Ok(r) => r,
1041 Err(e) => {
1042 log::error!("OKX query_order failed to fetch regular orders: {e}");
1043 Vec::new()
1044 }
1045 };
1046
1047 match http_client
1050 .request_algo_order_status_reports(
1051 account_id,
1052 None,
1053 Some(instrument_id),
1054 None,
1055 Some(client_order_id),
1056 None,
1057 None,
1058 )
1059 .await
1060 {
1061 Ok(mut algo) => reports.append(&mut algo),
1062 Err(e) => {
1063 log::warn!("OKX query_order algo lookup failed for {instrument_id}: {e}");
1064 }
1065 }
1066
1067 let Some(report) = select_query_order_report(reports, client_order_id, venue_order_id)
1068 else {
1069 log::warn!(
1070 "OKX query_order found no order for client_order_id={client_order_id}, venue_order_id={venue_order_id:?}",
1071 );
1072 return Ok(());
1073 };
1074
1075 emitter.send_order_status_report(report);
1076 Ok(())
1077 });
1078 Ok(())
1079 }
1080
1081 fn generate_account_state(
1082 &self,
1083 balances: Vec<AccountBalance>,
1084 margins: Vec<MarginBalance>,
1085 reported: bool,
1086 ts_event: UnixNanos,
1087 ) -> anyhow::Result<()> {
1088 self.emitter
1089 .emit_account_state(balances, margins, reported, ts_event);
1090 Ok(())
1091 }
1092
1093 fn start(&mut self) -> anyhow::Result<()> {
1094 if self.core.is_started() {
1095 return Ok(());
1096 }
1097
1098 let sender = get_exec_event_sender();
1099 self.emitter.set_sender(sender);
1100 self.core.set_started();
1101
1102 let http_client = self.http_client.clone();
1103 let ws_private = self.ws_private.clone();
1104 let ws_business = self.ws_business.clone();
1105 let instrument_types = self.config.instrument_types.clone();
1106 let instrument_families = self.config.instrument_families.clone();
1107
1108 get_runtime().spawn(async move {
1109 let mut all_instruments = Vec::new();
1110 let mut all_inst_id_codes = Vec::new();
1111
1112 for instrument_type in instrument_types {
1113 let Some(families) =
1114 resolve_instrument_families(&instrument_families, instrument_type)
1115 else {
1116 continue;
1117 };
1118
1119 if families.is_empty() {
1120 match http_client.request_instruments(instrument_type, None).await {
1121 Ok((instruments, inst_id_codes)) => {
1122 if instruments.is_empty() {
1123 log::warn!("No instruments returned for {instrument_type:?}");
1124 continue;
1125 }
1126 http_client.cache_instruments(&instruments);
1127 all_instruments.extend(instruments);
1128 all_inst_id_codes.extend(inst_id_codes);
1129 }
1130 Err(e) => {
1131 log::error!(
1132 "Failed to request instruments for {instrument_type:?}: {e}"
1133 );
1134 }
1135 }
1136 } else {
1137 for family in &families {
1138 match http_client
1139 .request_instruments(instrument_type, Some(family.clone()))
1140 .await
1141 {
1142 Ok((instruments, inst_id_codes)) => {
1143 if instruments.is_empty() {
1144 log::warn!(
1145 "No instruments returned for {instrument_type:?} family {family}"
1146 );
1147 continue;
1148 }
1149 http_client.cache_instruments(&instruments);
1150 all_instruments.extend(instruments);
1151 all_inst_id_codes.extend(inst_id_codes);
1152 }
1153 Err(e) => {
1154 log::error!(
1155 "Failed to request instruments for {instrument_type:?} family {family}: {e}"
1156 );
1157 }
1158 }
1159 }
1160 }
1161 }
1162
1163 if all_instruments.is_empty() {
1164 log::error!(
1165 "Instrument bootstrap yielded no instruments, order submissions will fail"
1166 );
1167 } else {
1168 ws_private.cache_instruments(&all_instruments);
1169 ws_private.cache_inst_id_codes(all_inst_id_codes.clone());
1170 ws_business.cache_instruments(&all_instruments);
1171 ws_business.cache_inst_id_codes(all_inst_id_codes);
1172 log::info!("Instruments initialized");
1173 }
1174 });
1175
1176 log::info!(
1177 "Started: client_id={}, account_id={}, account_type={:?}, trade_mode={:?}, instrument_types={:?}, use_fills_channel={}, environment={}, proxy_url={:?}",
1178 self.core.client_id,
1179 self.core.account_id,
1180 self.core.account_type,
1181 self.trade_mode,
1182 self.config.instrument_types,
1183 self.config.use_fills_channel,
1184 self.config.environment,
1185 self.config.proxy_url,
1186 );
1187 Ok(())
1188 }
1189
1190 fn stop(&mut self) -> anyhow::Result<()> {
1191 if self.core.is_stopped() {
1192 return Ok(());
1193 }
1194
1195 self.core.set_stopped();
1196 self.core.set_disconnected();
1197
1198 if let Some(handle) = self.ws_stream_handle.take() {
1199 handle.abort();
1200 }
1201
1202 if let Some(handle) = self.ws_business_stream_handle.take() {
1203 handle.abort();
1204 }
1205 self.abort_pending_tasks();
1206 log::info!("Stopped: client_id={}", self.core.client_id);
1207 Ok(())
1208 }
1209
1210 async fn generate_order_status_report(
1211 &self,
1212 cmd: &GenerateOrderStatusReport,
1213 ) -> anyhow::Result<Option<OrderStatusReport>> {
1214 let Some(instrument_id) = cmd.instrument_id else {
1215 log::warn!("generate_order_status_report requires instrument_id: {cmd:?}");
1216 return Ok(None);
1217 };
1218
1219 let mut reports = self
1220 .http_client
1221 .request_order_status_reports(
1222 self.core.account_id,
1223 None,
1224 Some(instrument_id),
1225 None,
1226 None,
1227 false,
1228 None,
1229 )
1230 .await?;
1231
1232 match self
1237 .http_client
1238 .request_algo_order_status_reports(
1239 self.core.account_id,
1240 None,
1241 Some(instrument_id),
1242 None,
1243 cmd.client_order_id,
1244 None,
1245 None,
1246 )
1247 .await
1248 {
1249 Ok(mut algo_reports) => reports.append(&mut algo_reports),
1250 Err(e) => {
1251 log::warn!("Failed to fetch algo order status reports for {instrument_id}: {e}");
1252 }
1253 }
1254
1255 if let Some(client_order_id) = cmd.client_order_id {
1256 reports.retain(|report| report.client_order_id == Some(client_order_id));
1257 }
1258
1259 if let Some(venue_order_id) = cmd.venue_order_id {
1260 reports.retain(|report| report.venue_order_id.as_str() == venue_order_id.as_str());
1261 }
1262
1263 Ok(reports.into_iter().next())
1264 }
1265
1266 async fn generate_order_status_reports(
1267 &self,
1268 cmd: &GenerateOrderStatusReports,
1269 ) -> anyhow::Result<Vec<OrderStatusReport>> {
1270 let mut reports = Vec::new();
1271
1272 if let Some(instrument_id) = cmd.instrument_id {
1273 let mut fetched = self
1274 .http_client
1275 .request_order_status_reports(
1276 self.core.account_id,
1277 None,
1278 Some(instrument_id),
1279 None,
1280 None,
1281 false,
1282 None,
1283 )
1284 .await?;
1285 reports.append(&mut fetched);
1286
1287 match self
1293 .http_client
1294 .request_algo_order_status_reports(
1295 self.core.account_id,
1296 None,
1297 Some(instrument_id),
1298 None,
1299 None,
1300 None,
1301 None,
1302 )
1303 .await
1304 {
1305 Ok(mut algo) => reports.append(&mut algo),
1306 Err(e) => {
1307 log::warn!(
1308 "Failed to fetch algo order status reports for {instrument_id}: {e}"
1309 );
1310 }
1311 }
1312 } else {
1313 for inst_type in self.instrument_types() {
1314 let mut fetched = self
1315 .http_client
1316 .request_order_status_reports(
1317 self.core.account_id,
1318 Some(inst_type),
1319 None,
1320 None,
1321 None,
1322 false,
1323 None,
1324 )
1325 .await?;
1326 reports.append(&mut fetched);
1327
1328 match self
1329 .http_client
1330 .request_algo_order_status_reports(
1331 self.core.account_id,
1332 Some(inst_type),
1333 None,
1334 None,
1335 None,
1336 None,
1337 None,
1338 )
1339 .await
1340 {
1341 Ok(mut algo) => reports.append(&mut algo),
1342 Err(e) => log::warn!(
1343 "Failed to fetch algo order status reports for {inst_type:?}: {e}"
1344 ),
1345 }
1346 }
1347 }
1348
1349 if cmd.open_only {
1350 reports.retain(|r| r.order_status.is_open());
1351 }
1352
1353 if let Some(start) = cmd.start {
1354 reports.retain(|r| r.ts_last >= start);
1355 }
1356
1357 if let Some(end) = cmd.end {
1358 reports.retain(|r| r.ts_last <= end);
1359 }
1360
1361 Ok(reports)
1362 }
1363
1364 async fn generate_fill_reports(
1365 &self,
1366 cmd: GenerateFillReports,
1367 ) -> anyhow::Result<Vec<FillReport>> {
1368 let start_dt = nanos_to_datetime(cmd.start);
1369 let end_dt = nanos_to_datetime(cmd.end);
1370 let mut reports = Vec::new();
1371
1372 if let Some(instrument_id) = cmd.instrument_id {
1373 let mut fetched = self
1374 .http_client
1375 .request_fill_reports(
1376 self.core.account_id,
1377 None,
1378 Some(instrument_id),
1379 start_dt,
1380 end_dt,
1381 None,
1382 )
1383 .await?;
1384 reports.append(&mut fetched);
1385 } else {
1386 for inst_type in self.instrument_types() {
1387 let mut fetched = self
1388 .http_client
1389 .request_fill_reports(
1390 self.core.account_id,
1391 Some(inst_type),
1392 None,
1393 start_dt,
1394 end_dt,
1395 None,
1396 )
1397 .await?;
1398 reports.append(&mut fetched);
1399 }
1400 }
1401
1402 if let Some(venue_order_id) = cmd.venue_order_id {
1403 reports.retain(|report| report.venue_order_id.as_str() == venue_order_id.as_str());
1404 }
1405
1406 Ok(reports)
1407 }
1408
1409 async fn generate_position_status_reports(
1410 &self,
1411 cmd: &GeneratePositionStatusReports,
1412 ) -> anyhow::Result<Vec<PositionStatusReport>> {
1413 let mut reports = Vec::new();
1414
1415 if let Some(instrument_id) = cmd.instrument_id {
1418 let inst_type = okx_instrument_type_from_symbol(instrument_id.symbol.as_str());
1419 if inst_type != OKXInstrumentType::Spot && inst_type != OKXInstrumentType::Margin {
1420 let mut fetched = self
1421 .http_client
1422 .request_position_status_reports(
1423 self.core.account_id,
1424 None,
1425 Some(instrument_id),
1426 )
1427 .await?;
1428 reports.append(&mut fetched);
1429 }
1430 } else {
1431 for inst_type in self.instrument_types() {
1432 if inst_type == OKXInstrumentType::Spot || inst_type == OKXInstrumentType::Margin {
1434 continue;
1435 }
1436 let mut fetched = self
1437 .http_client
1438 .request_position_status_reports(self.core.account_id, Some(inst_type), None)
1439 .await?;
1440 reports.append(&mut fetched);
1441 }
1442 }
1443
1444 let mut margin_reports = self
1447 .http_client
1448 .request_spot_margin_position_reports(self.core.account_id)
1449 .await?;
1450
1451 if let Some(instrument_id) = cmd.instrument_id {
1452 margin_reports.retain(|report| report.instrument_id == instrument_id);
1453 }
1454
1455 reports.append(&mut margin_reports);
1456
1457 Ok(reports)
1458 }
1459
1460 async fn generate_mass_status(
1461 &self,
1462 lookback_mins: Option<u64>,
1463 ) -> anyhow::Result<Option<ExecutionMassStatus>> {
1464 log::info!("Generating ExecutionMassStatus (lookback_mins={lookback_mins:?})");
1465
1466 let ts_now = self.clock.get_time_ns();
1467
1468 let start = lookback_mins.map(|mins| {
1469 let lookback_ns = mins * 60 * 1_000_000_000;
1470 UnixNanos::from(ts_now.as_u64().saturating_sub(lookback_ns))
1471 });
1472
1473 let order_cmd = GenerateOrderStatusReportsBuilder::default()
1474 .ts_init(ts_now)
1475 .open_only(false) .start(start)
1477 .build()
1478 .map_err(|e| anyhow::anyhow!("{e}"))?;
1479
1480 let fill_cmd = GenerateFillReportsBuilder::default()
1481 .ts_init(ts_now)
1482 .start(start)
1483 .build()
1484 .map_err(|e| anyhow::anyhow!("{e}"))?;
1485
1486 let position_cmd = GeneratePositionStatusReportsBuilder::default()
1487 .ts_init(ts_now)
1488 .start(start)
1489 .build()
1490 .map_err(|e| anyhow::anyhow!("{e}"))?;
1491
1492 let (order_reports, fill_reports, position_reports) = tokio::try_join!(
1493 self.generate_order_status_reports(&order_cmd),
1494 self.generate_fill_reports(fill_cmd),
1495 self.generate_position_status_reports(&position_cmd),
1496 )?;
1497
1498 log::info!("Received {} OrderStatusReports", order_reports.len());
1499 log::info!("Received {} FillReports", fill_reports.len());
1500 log::info!("Received {} PositionReports", position_reports.len());
1501
1502 let mut mass_status = ExecutionMassStatus::new(
1503 self.core.client_id,
1504 self.core.account_id,
1505 *OKX_VENUE,
1506 ts_now,
1507 None,
1508 );
1509
1510 mass_status.add_order_reports(order_reports);
1511 mass_status.add_fill_reports(fill_reports);
1512 mass_status.add_position_reports(position_reports);
1513
1514 Ok(Some(mass_status))
1515 }
1516
1517 fn submit_order(&self, cmd: SubmitOrder) -> anyhow::Result<()> {
1518 let order_type = {
1519 let cache = self.core.cache();
1520 let order = cache
1521 .order(&cmd.client_order_id)
1522 .ok_or_else(|| anyhow::anyhow!("Order not found: {}", cmd.client_order_id))?;
1523
1524 if order.is_closed() {
1525 log::warn!("Cannot submit closed order {}", order.client_order_id());
1526 return Ok(());
1527 }
1528
1529 let order_type = order.order_type();
1530
1531 if self.is_conditional_order(order_type) {
1534 let inst_type = okx_instrument_type_from_symbol(cmd.instrument_id.symbol.as_str());
1535
1536 if inst_type == OKXInstrumentType::Option {
1537 anyhow::bail!(
1538 "Trigger/conditional orders ({order_type:?}) are not supported for OKX options"
1539 );
1540 }
1541 }
1542
1543 log::debug!("OrderSubmitted client_order_id={}", order.client_order_id());
1544 self.emitter.emit_order_submitted(order);
1545
1546 order_type
1547 };
1548
1549 if self.is_conditional_order(order_type) {
1550 self.submit_conditional_order(&cmd)
1551 } else {
1552 self.submit_regular_order(&cmd)
1553 }
1554 }
1555
1556 fn submit_order_list(&self, cmd: SubmitOrderList) -> anyhow::Result<()> {
1557 let inst_type = okx_instrument_type_from_symbol(cmd.instrument_id.symbol.as_str());
1558
1559 let cache = self.core.cache();
1561
1562 for client_order_id in &cmd.order_list.client_order_ids {
1563 let order = cache
1564 .order(client_order_id)
1565 .ok_or_else(|| anyhow::anyhow!("Order not found: {client_order_id}"))?;
1566
1567 if self.is_conditional_order(order.order_type()) {
1568 anyhow::bail!("Conditional orders not supported in order lists: {client_order_id}");
1569 }
1570
1571 if order.time_in_force() != TimeInForce::Gtc {
1572 anyhow::bail!(
1573 "Only GTC orders supported in order lists: {client_order_id} has {:?}",
1574 order.time_in_force()
1575 );
1576 }
1577 }
1578
1579 let mut batch_orders = Vec::new();
1581
1582 for client_order_id in &cmd.order_list.client_order_ids {
1583 let order = cache.order(client_order_id).expect("validated above");
1584
1585 batch_orders.push((
1586 inst_type,
1587 cmd.instrument_id,
1588 self.trade_mode_for_order(cmd.instrument_id, &cmd.params),
1589 order.client_order_id(),
1590 order.order_side(),
1591 None, order.order_type(),
1593 order.quantity(),
1594 order.price(),
1595 order.trigger_price(),
1596 Some(order.is_post_only()),
1597 Some(order.is_reduce_only()),
1598 ));
1599
1600 self.ws_dispatch_state.order_identities.insert(
1601 order.client_order_id(),
1602 OrderIdentity {
1603 instrument_id: cmd.instrument_id,
1604 strategy_id: order.strategy_id(),
1605 order_side: order.order_side(),
1606 order_type: order.order_type(),
1607 },
1608 );
1609
1610 log::debug!("OrderSubmitted client_order_id={}", order.client_order_id());
1611 self.emitter.emit_order_submitted(order);
1612 }
1613
1614 drop(cache);
1615
1616 let ws_private = self.ws_private.clone();
1617 let emitter = self.emitter.clone();
1618 let clock = self.clock;
1619 let instrument_id = cmd.instrument_id;
1620 let strategy_id = cmd.strategy_id;
1621 let client_order_ids: Vec<_> = cmd.order_list.client_order_ids;
1622 let dispatch_state = Arc::clone(&self.ws_dispatch_state);
1623
1624 self.spawn_task("batch_submit_orders", async move {
1625 let result = ws_private
1626 .batch_submit_orders(batch_orders)
1627 .await
1628 .map_err(|e| anyhow::anyhow!("Batch submit orders failed: {e}"));
1629
1630 if let Err(e) = result {
1631 let ts_event = clock.get_time_ns();
1632
1633 for cid in &client_order_ids {
1634 dispatch_state.order_identities.remove(cid);
1635 emitter.emit_order_rejected_event(
1636 strategy_id,
1637 instrument_id,
1638 *cid,
1639 &format!("batch-submit-error: {e}"),
1640 ts_event,
1641 false,
1642 );
1643 }
1644 return Err(e);
1645 }
1646
1647 Ok(())
1648 });
1649
1650 Ok(())
1651 }
1652
1653 fn modify_order(&self, cmd: ModifyOrder) -> anyhow::Result<()> {
1654 self.ensure_order_identity(cmd.client_order_id, cmd.strategy_id, cmd.instrument_id);
1655
1656 let ws_private = self.ws_private.clone();
1657 let command = cmd.clone();
1658
1659 let new_px_usd = get_param_as_string(&cmd.params, "px_usd");
1660 let new_px_vol = get_param_as_string(&cmd.params, "px_vol");
1661
1662 let emitter = self.emitter.clone();
1663 let clock = self.clock;
1664
1665 self.spawn_task("modify_order", async move {
1666 let result = ws_private
1667 .modify_order(
1668 command.trader_id,
1669 command.strategy_id,
1670 command.instrument_id,
1671 Some(command.client_order_id),
1672 command.price,
1673 command.quantity,
1674 command.venue_order_id,
1675 new_px_usd,
1676 new_px_vol,
1677 )
1678 .await
1679 .map_err(|e| anyhow::anyhow!("Modify order failed: {e}"));
1680
1681 if let Err(e) = result {
1682 let ts_event = clock.get_time_ns();
1683 emitter.emit_order_modify_rejected_event(
1684 command.strategy_id,
1685 command.instrument_id,
1686 command.client_order_id,
1687 command.venue_order_id,
1688 &format!("modify-order-error: {e}"),
1689 ts_event,
1690 );
1691 return Err(e);
1692 }
1693
1694 Ok(())
1695 });
1696
1697 Ok(())
1698 }
1699
1700 fn cancel_order(&self, cmd: CancelOrder) -> anyhow::Result<()> {
1701 let cache = self.core.cache();
1702 let is_pending_algo = cache.order(&cmd.client_order_id).is_some_and(|o| {
1703 self.is_conditional_order(o.order_type()) && o.is_triggered() != Some(true)
1704 });
1705 drop(cache);
1706
1707 if is_pending_algo {
1708 self.cancel_algo_order(&cmd);
1709 } else {
1710 self.cancel_ws_order(&cmd);
1711 }
1712 Ok(())
1713 }
1714
1715 fn cancel_all_orders(&self, cmd: CancelAllOrders) -> anyhow::Result<()> {
1716 if self.config.use_mm_mass_cancel {
1717 self.mass_cancel_instrument(cmd.instrument_id);
1719 Ok(())
1720 } else {
1721 let cache = self.core.cache();
1723 let open_orders = cache.orders_open(None, Some(&cmd.instrument_id), None, None, None);
1724
1725 if open_orders.is_empty() {
1726 log::debug!("No open orders to cancel for {}", cmd.instrument_id);
1727 return Ok(());
1728 }
1729
1730 let mut regular_payload = Vec::new();
1731 let mut regular_cancel_contexts = Vec::new();
1732 let mut algo_orders: Vec<(
1733 InstrumentId,
1734 ClientOrderId,
1735 Option<VenueOrderId>,
1736 TraderId,
1737 StrategyId,
1738 )> = Vec::new();
1739
1740 for order in &open_orders {
1741 let is_pending_algo = self.is_conditional_order(order.order_type())
1743 && order.is_triggered() != Some(true);
1744
1745 if is_pending_algo {
1746 algo_orders.push((
1747 order.instrument_id(),
1748 order.client_order_id(),
1749 order.venue_order_id(),
1750 order.trader_id(),
1751 order.strategy_id(),
1752 ));
1753 } else {
1754 self.ensure_order_identity(
1755 order.client_order_id(),
1756 order.strategy_id(),
1757 order.instrument_id(),
1758 );
1759 regular_payload.push((
1760 order.instrument_id(),
1761 Some(order.client_order_id()),
1762 order.venue_order_id(),
1763 ));
1764 regular_cancel_contexts.push((
1765 order.client_order_id(),
1766 order.instrument_id(),
1767 order.strategy_id(),
1768 ));
1769 }
1770 }
1771 drop(cache);
1772
1773 log::debug!(
1774 "Canceling {} regular orders and {} algo orders for {}",
1775 regular_payload.len(),
1776 algo_orders.len(),
1777 cmd.instrument_id
1778 );
1779
1780 if !regular_payload.is_empty() {
1781 let ws_private = self.ws_private.clone();
1782 let emitter = self.emitter.clone();
1783 let clock = self.clock;
1784
1785 self.spawn_task("batch_cancel_orders", async move {
1786 if let Err(e) = ws_private.batch_cancel_orders(regular_payload).await {
1787 let ts = clock.get_time_ns();
1788
1789 for (cid, inst_id, strat_id) in ®ular_cancel_contexts {
1790 emitter.emit_order_cancel_rejected_event(
1791 *strat_id,
1792 *inst_id,
1793 *cid,
1794 None,
1795 &format!("batch-cancel-error: {e}"),
1796 ts,
1797 );
1798 }
1799 anyhow::bail!("Batch cancel orders failed: {e}");
1800 }
1801 Ok(())
1802 });
1803 }
1804
1805 if !algo_orders.is_empty() {
1807 let items: Vec<_> = algo_orders
1808 .into_iter()
1809 .map(
1810 |(
1811 instrument_id,
1812 client_order_id,
1813 venue_order_id,
1814 _trader_id,
1815 strategy_id,
1816 )| {
1817 let request = OKXCancelAlgoOrderRequest {
1818 inst_id: instrument_id.symbol.to_string(),
1819 inst_id_code: None,
1820 algo_id: venue_order_id.map(|id| id.to_string()),
1821 algo_cl_ord_id: if venue_order_id.is_none() {
1822 Some(client_order_id.to_string())
1823 } else {
1824 None
1825 },
1826 };
1827 let ctx = AlgoCancelContext {
1828 client_order_id,
1829 instrument_id,
1830 strategy_id,
1831 venue_order_id,
1832 };
1833 (request, ctx)
1834 },
1835 )
1836 .collect();
1837 self.dispatch_algo_cancels(items);
1838 }
1839
1840 Ok(())
1841 }
1842 }
1843
1844 fn batch_cancel_orders(&self, cmd: BatchCancelOrders) -> anyhow::Result<()> {
1845 let cache = self.core.cache();
1846
1847 let mut regular_payload = Vec::new();
1848 let mut algo_orders = Vec::new();
1849
1850 for cancel in &cmd.cancels {
1851 let is_pending_algo = cache.order(&cancel.client_order_id).is_some_and(|o| {
1853 self.is_conditional_order(o.order_type()) && o.is_triggered() != Some(true)
1854 });
1855
1856 if is_pending_algo {
1857 algo_orders.push(cancel.clone());
1858 } else {
1859 self.ensure_order_identity(
1860 cancel.client_order_id,
1861 cancel.strategy_id,
1862 cancel.instrument_id,
1863 );
1864 regular_payload.push((
1865 cancel.instrument_id,
1866 Some(cancel.client_order_id),
1867 cancel.venue_order_id,
1868 ));
1869 }
1870 }
1871 drop(cache);
1872
1873 if !regular_payload.is_empty() {
1874 let ws_private = self.ws_private.clone();
1875 let emitter = self.emitter.clone();
1876 let clock = self.clock;
1877 let cancel_contexts: Vec<_> = cmd
1878 .cancels
1879 .iter()
1880 .filter(|c| {
1881 regular_payload
1882 .iter()
1883 .any(|(_, cid, _)| *cid == Some(c.client_order_id))
1884 })
1885 .map(|c| (c.client_order_id, c.instrument_id, c.strategy_id))
1886 .collect();
1887
1888 self.spawn_task("batch_cancel_orders", async move {
1889 if let Err(e) = ws_private.batch_cancel_orders(regular_payload).await {
1890 let ts = clock.get_time_ns();
1891
1892 for (cid, inst_id, strat_id) in &cancel_contexts {
1893 emitter.emit_order_cancel_rejected_event(
1894 *strat_id,
1895 *inst_id,
1896 *cid,
1897 None,
1898 &format!("batch-cancel-error: {e}"),
1899 ts,
1900 );
1901 }
1902 anyhow::bail!("Batch cancel orders failed: {e}");
1903 }
1904 Ok(())
1905 });
1906 }
1907
1908 if !algo_orders.is_empty() {
1910 let items: Vec<_> = algo_orders
1911 .into_iter()
1912 .map(|cancel| {
1913 let request = OKXCancelAlgoOrderRequest {
1914 inst_id: cancel.instrument_id.symbol.to_string(),
1915 inst_id_code: None,
1916 algo_id: cancel.venue_order_id.map(|id| id.to_string()),
1917 algo_cl_ord_id: if cancel.venue_order_id.is_none() {
1918 Some(cancel.client_order_id.to_string())
1919 } else {
1920 None
1921 },
1922 };
1923 let ctx = AlgoCancelContext {
1924 client_order_id: cancel.client_order_id,
1925 instrument_id: cancel.instrument_id,
1926 strategy_id: cancel.strategy_id,
1927 venue_order_id: cancel.venue_order_id,
1928 };
1929 (request, ctx)
1930 })
1931 .collect();
1932 self.dispatch_algo_cancels(items);
1933 }
1934
1935 Ok(())
1936 }
1937}
1938
1939fn select_query_order_report(
1951 reports: Vec<OrderStatusReport>,
1952 client_order_id: ClientOrderId,
1953 venue_order_id: Option<VenueOrderId>,
1954) -> Option<OrderStatusReport> {
1955 let mut by_vid: Option<OrderStatusReport> = None;
1956
1957 for report in reports {
1958 if report.client_order_id == Some(client_order_id) {
1959 return Some(report);
1960 }
1961
1962 if by_vid.is_none()
1963 && venue_order_id
1964 .as_ref()
1965 .is_some_and(|vid| report.venue_order_id.as_str() == vid.as_str())
1966 {
1967 by_vid = Some(report);
1968 }
1969 }
1970
1971 by_vid
1972}
1973
1974#[cfg(test)]
1975mod tests {
1976 use nautilus_model::enums::OrderStatus;
1977 use rstest::rstest;
1978 use serde_json::Value;
1979
1980 use super::*;
1981
1982 fn build_config(
1983 margin_mode: Option<OKXMarginMode>,
1984 use_spot_margin: bool,
1985 ) -> OKXExecClientConfig {
1986 OKXExecClientConfig {
1987 margin_mode,
1988 use_spot_margin,
1989 ..OKXExecClientConfig::default()
1990 }
1991 }
1992
1993 #[rstest]
1994 #[case::cash_no_spot_margin(AccountType::Cash, None, false, OKXTradeMode::Cash)]
1995 #[case::cash_spot_margin_cross(
1996 AccountType::Cash,
1997 Some(OKXMarginMode::Cross),
1998 true,
1999 OKXTradeMode::Cross
2000 )]
2001 #[case::cash_spot_margin_isolated(
2002 AccountType::Cash,
2003 Some(OKXMarginMode::Isolated),
2004 true,
2005 OKXTradeMode::Isolated
2006 )]
2007 #[case::cash_spot_margin_none(AccountType::Cash, None, true, OKXTradeMode::Isolated)]
2008 #[case::margin_cross(
2009 AccountType::Margin,
2010 Some(OKXMarginMode::Cross),
2011 false,
2012 OKXTradeMode::Cross
2013 )]
2014 #[case::margin_isolated(
2015 AccountType::Margin,
2016 Some(OKXMarginMode::Isolated),
2017 false,
2018 OKXTradeMode::Isolated
2019 )]
2020 #[case::margin_none(AccountType::Margin, None, false, OKXTradeMode::Isolated)]
2021 fn test_derive_default_trade_mode(
2022 #[case] account_type: AccountType,
2023 #[case] margin_mode: Option<OKXMarginMode>,
2024 #[case] use_spot_margin: bool,
2025 #[case] expected: OKXTradeMode,
2026 ) {
2027 let config = build_config(margin_mode, use_spot_margin);
2028
2029 let result = OKXExecutionClient::derive_default_trade_mode(account_type, &config);
2030
2031 assert_eq!(result, expected);
2032 }
2033
2034 #[rstest]
2035 #[case::spot_no_margin("BTC-USDT", None, false, OKXTradeMode::Cash)]
2036 #[case::spot_cross_margin("BTC-USDT", Some(OKXMarginMode::Cross), true, OKXTradeMode::Cross)]
2037 #[case::spot_isolated_margin(
2038 "ETH-USDT",
2039 Some(OKXMarginMode::Isolated),
2040 true,
2041 OKXTradeMode::Isolated
2042 )]
2043 #[case::spot_margin_no_mode("BTC-USDT", None, true, OKXTradeMode::Isolated)]
2044 #[case::swap_cross(
2045 "BTC-USDT-SWAP",
2046 Some(OKXMarginMode::Cross),
2047 false,
2048 OKXTradeMode::Cross
2049 )]
2050 #[case::swap_isolated(
2051 "BTC-USDT-SWAP",
2052 Some(OKXMarginMode::Isolated),
2053 false,
2054 OKXTradeMode::Isolated
2055 )]
2056 #[case::swap_no_mode("ETH-USDT-SWAP", None, false, OKXTradeMode::Isolated)]
2057 #[case::futures_cross(
2058 "BTC-USDT-250328",
2059 Some(OKXMarginMode::Cross),
2060 false,
2061 OKXTradeMode::Cross
2062 )]
2063 #[case::futures_isolated("BTC-USDT-250328", None, false, OKXTradeMode::Isolated)]
2064 #[case::option_cross(
2065 "BTC-USD-250328-50000-C",
2066 Some(OKXMarginMode::Cross),
2067 false,
2068 OKXTradeMode::Cross
2069 )]
2070 #[case::option_isolated("BTC-USD-250328-50000-C", None, false, OKXTradeMode::Isolated)]
2071 fn test_derive_trade_mode_for_instrument(
2072 #[case] symbol: &str,
2073 #[case] margin_mode: Option<OKXMarginMode>,
2074 #[case] use_spot_margin: bool,
2075 #[case] expected: OKXTradeMode,
2076 ) {
2077 let instrument_id = InstrumentId::from(format!("{symbol}.OKX").as_str());
2078
2079 let result = derive_trade_mode_for_instrument(instrument_id, margin_mode, use_spot_margin);
2080
2081 assert_eq!(result, expected);
2082 }
2083
2084 #[rstest]
2085 #[case::override_to_cross("cross", OKXTradeMode::Cross)]
2086 #[case::override_to_cash("cash", OKXTradeMode::Cash)]
2087 #[case::override_to_isolated("isolated", OKXTradeMode::Isolated)]
2088 #[case::override_to_spot_isolated("spot_isolated", OKXTradeMode::SpotIsolated)]
2089 #[case::case_insensitive("CROSS", OKXTradeMode::Cross)]
2090 fn test_td_mode_param_override(#[case] td_mode_value: &str, #[case] expected: OKXTradeMode) {
2091 let mut params = Params::new();
2092 params.insert(
2093 "td_mode".to_string(),
2094 Value::String(td_mode_value.to_string()),
2095 );
2096
2097 let result = get_param_as_string(&Some(params), "td_mode")
2098 .and_then(|s| s.parse::<OKXTradeMode>().ok());
2099
2100 assert_eq!(result, Some(expected));
2101 }
2102
2103 #[rstest]
2104 fn test_td_mode_param_invalid_falls_through() {
2105 let mut params = Params::new();
2106 params.insert("td_mode".to_string(), Value::String("invalid".to_string()));
2107
2108 let result = get_param_as_string(&Some(params), "td_mode")
2109 .and_then(|s| s.parse::<OKXTradeMode>().ok());
2110
2111 assert_eq!(result, None);
2112 }
2113
2114 #[rstest]
2115 fn test_td_mode_param_absent_falls_through() {
2116 let result = get_param_as_string(&None, "td_mode");
2117
2118 assert_eq!(result, None);
2119 }
2120
2121 #[rstest]
2122 fn test_close_fraction_present_sets_reduce_only_true() {
2123 let mut params = Params::new();
2124 params.insert("close_fraction".to_string(), Value::String("1".to_string()));
2125 let params = Some(params);
2126
2127 let close_fraction = get_param_as_string(¶ms, "close_fraction");
2128 let is_reduce_only = false;
2129 let reduce_only = if close_fraction.is_some() {
2130 Some(true)
2131 } else {
2132 Some(is_reduce_only)
2133 };
2134
2135 assert_eq!(close_fraction, Some("1".to_string()));
2136 assert_eq!(reduce_only, Some(true));
2137 }
2138
2139 #[rstest]
2140 fn test_close_fraction_absent_preserves_reduce_only() {
2141 let params: Option<Params> = None;
2142
2143 let close_fraction = get_param_as_string(¶ms, "close_fraction");
2144 let is_reduce_only = false;
2145 let reduce_only = if close_fraction.is_some() {
2146 Some(true)
2147 } else {
2148 Some(is_reduce_only)
2149 };
2150
2151 assert_eq!(close_fraction, None);
2152 assert_eq!(reduce_only, Some(false));
2153 }
2154
2155 #[rstest]
2156 fn test_close_fraction_absent_with_reduce_only_true() {
2157 let params: Option<Params> = None;
2158
2159 let close_fraction = get_param_as_string(¶ms, "close_fraction");
2160 let is_reduce_only = true;
2161 let reduce_only = if close_fraction.is_some() {
2162 Some(true)
2163 } else {
2164 Some(is_reduce_only)
2165 };
2166
2167 assert_eq!(close_fraction, None);
2168 assert_eq!(reduce_only, Some(true));
2169 }
2170
2171 fn make_query_order_report(cid: Option<&str>, vid: &str) -> OrderStatusReport {
2172 OrderStatusReport::new(
2173 AccountId::from("OKX-001"),
2174 InstrumentId::from("BTC-USDT.OKX"),
2175 cid.map(ClientOrderId::from),
2176 VenueOrderId::from(vid),
2177 OrderSide::Buy,
2178 OrderType::Limit,
2179 TimeInForce::Gtc,
2180 OrderStatus::Accepted,
2181 Quantity::new(1.0, 0),
2182 Quantity::zero(0),
2183 UnixNanos::default(),
2184 UnixNanos::default(),
2185 UnixNanos::default(),
2186 None,
2187 )
2188 }
2189
2190 fn with_linked(mut report: OrderStatusReport, linked: &[&str]) -> OrderStatusReport {
2191 report.linked_order_ids = Some(linked.iter().map(|s| ClientOrderId::from(*s)).collect());
2192 report
2193 }
2194
2195 #[rstest]
2196 fn test_select_query_order_report_matches_client_order_id() {
2197 let reports = vec![make_query_order_report(Some("O-001"), "V-1")];
2198 let selected = select_query_order_report(reports, ClientOrderId::from("O-001"), None);
2199 assert_eq!(
2200 selected.and_then(|r| r.client_order_id),
2201 Some(ClientOrderId::from("O-001"))
2202 );
2203 }
2204
2205 #[rstest]
2206 fn test_select_query_order_report_client_wins_over_venue_mismatch() {
2207 let reports = vec![make_query_order_report(Some("O-001"), "V-1")];
2208 let selected = select_query_order_report(
2209 reports,
2210 ClientOrderId::from("O-001"),
2211 Some(VenueOrderId::from("V-OTHER")),
2212 );
2213 assert_eq!(
2214 selected.and_then(|r| r.client_order_id),
2215 Some(ClientOrderId::from("O-001"))
2216 );
2217 }
2218
2219 #[rstest]
2220 fn test_select_query_order_report_falls_back_to_venue_order_id() {
2221 let reports = vec![make_query_order_report(Some("O-CHILD"), "V-1")];
2224 let selected = select_query_order_report(
2225 reports,
2226 ClientOrderId::from("O-PARENT"),
2227 Some(VenueOrderId::from("V-1")),
2228 );
2229 assert_eq!(
2230 selected.map(|r| r.venue_order_id.as_str().to_string()),
2231 Some("V-1".to_string()),
2232 );
2233 }
2234
2235 #[rstest]
2236 fn test_select_query_order_report_rejects_when_nothing_matches() {
2237 let reports = vec![make_query_order_report(Some("O-OTHER"), "V-OTHER")];
2238 let selected = select_query_order_report(
2239 reports,
2240 ClientOrderId::from("O-001"),
2241 Some(VenueOrderId::from("V-1")),
2242 );
2243 assert!(selected.is_none());
2244 }
2245
2246 #[rstest]
2247 fn test_select_query_order_report_rejects_when_client_differs_and_no_vid_provided() {
2248 let reports = vec![make_query_order_report(Some("O-OTHER"), "V-1")];
2249 let selected = select_query_order_report(reports, ClientOrderId::from("O-001"), None);
2250 assert!(selected.is_none());
2251 }
2252
2253 #[rstest]
2254 fn test_select_query_order_report_ignores_linked_order_ids_for_parent_with_attached_tp() {
2255 let child_cid = "O-CHILD-TP";
2259 let reports = vec![with_linked(
2260 make_query_order_report(Some("O-PARENT"), "V-PARENT"),
2261 &[child_cid, "O-CHILD-SL"],
2262 )];
2263 let selected = select_query_order_report(reports, ClientOrderId::from(child_cid), None);
2264 assert!(selected.is_none());
2265 }
2266
2267 #[rstest]
2268 fn test_select_query_order_report_client_match_wins_over_vid_match_elsewhere() {
2269 let reports = vec![
2272 make_query_order_report(Some("O-OTHER"), "V-1"),
2273 make_query_order_report(Some("O-001"), "V-2"),
2274 ];
2275 let selected = select_query_order_report(
2276 reports,
2277 ClientOrderId::from("O-001"),
2278 Some(VenueOrderId::from("V-1")),
2279 );
2280 assert_eq!(
2281 selected.and_then(|r| r.client_order_id),
2282 Some(ClientOrderId::from("O-001")),
2283 );
2284 }
2285}