Skip to main content

nautilus_architect_ax/websocket/orders/
client.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Orders WebSocket client for Ax.
17
18use std::{
19    fmt::Debug,
20    sync::{
21        Arc,
22        atomic::{AtomicBool, AtomicI64, AtomicU8, Ordering},
23    },
24    time::Duration,
25};
26
27use arc_swap::ArcSwap;
28use dashmap::DashMap;
29use nautilus_common::live::get_runtime;
30use nautilus_core::{
31    AtomicMap,
32    consts::NAUTILUS_USER_AGENT,
33    nanos::UnixNanos,
34    time::{AtomicTime, get_atomic_clock_realtime},
35};
36use nautilus_model::{
37    enums::{OrderSide, OrderType, TimeInForce},
38    identifiers::{AccountId, ClientOrderId, InstrumentId, StrategyId, TraderId, VenueOrderId},
39    instruments::{Instrument, InstrumentAny},
40    types::{Price, Quantity},
41};
42use nautilus_network::{
43    backoff::ExponentialBackoff,
44    mode::ConnectionMode,
45    websocket::{
46        AuthTracker, PingHandler, TransportBackend, WebSocketClient, WebSocketConfig,
47        channel_message_handler,
48    },
49};
50use ustr::Ustr;
51
52use super::handler::{AxOrdersWsFeedHandler, HandlerCommand, WsOrderInfo};
53use crate::{
54    common::{
55        consts::AX_NAUTILUS_TAG,
56        enums::{AxOrderRequestType, AxOrderSide, AxOrderType, AxTimeInForce},
57        parse::{client_order_id_to_cid, quantity_to_contracts},
58    },
59    websocket::messages::{AxOrdersWsMessage, AxWsPlaceOrder, OrderMetadata},
60};
61
62/// Result type for Ax orders WebSocket operations.
63pub type AxOrdersWsResult<T> = Result<T, AxOrdersWsClientError>;
64
65/// Shared caches for order state tracking between the client and consumers.
66#[derive(Debug, Clone)]
67pub struct OrdersCaches {
68    /// Maps client order IDs to order metadata.
69    pub orders_metadata: Arc<DashMap<ClientOrderId, OrderMetadata>>,
70    /// Maps venue order IDs to client order IDs.
71    pub venue_to_client_id: Arc<DashMap<VenueOrderId, ClientOrderId>>,
72    /// Maps AX cid values to client order IDs.
73    pub cid_to_client_order_id: Arc<DashMap<u64, ClientOrderId>>,
74}
75
76impl Default for OrdersCaches {
77    fn default() -> Self {
78        Self {
79            orders_metadata: Arc::new(DashMap::new()),
80            venue_to_client_id: Arc::new(DashMap::new()),
81            cid_to_client_order_id: Arc::new(DashMap::new()),
82        }
83    }
84}
85
86/// Error type for the Ax orders WebSocket client.
87#[derive(Debug, Clone)]
88pub enum AxOrdersWsClientError {
89    /// Transport/connection error.
90    Transport(String),
91    /// Channel send error.
92    ChannelError(String),
93    /// Authentication error.
94    AuthenticationError(String),
95    /// Client-side validation error.
96    ClientError(String),
97}
98
99impl core::fmt::Display for AxOrdersWsClientError {
100    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
101        match self {
102            Self::Transport(msg) => write!(f, "Transport error: {msg}"),
103            Self::ChannelError(msg) => write!(f, "Channel error: {msg}"),
104            Self::AuthenticationError(msg) => write!(f, "Authentication error: {msg}"),
105            Self::ClientError(msg) => write!(f, "Client error: {msg}"),
106        }
107    }
108}
109
110impl std::error::Error for AxOrdersWsClientError {}
111
112impl From<&'static str> for AxOrdersWsClientError {
113    fn from(msg: &'static str) -> Self {
114        Self::ClientError(msg.to_string())
115    }
116}
117
118/// Orders WebSocket client for Ax.
119///
120/// Provides authenticated order management including placing, canceling,
121/// and monitoring order status via WebSocket.
122pub struct AxOrdersWebSocketClient {
123    clock: &'static AtomicTime,
124    url: String,
125    heartbeat: Option<u64>,
126    connection_mode: Arc<ArcSwap<AtomicU8>>,
127    cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
128    out_rx: Option<Arc<tokio::sync::mpsc::UnboundedReceiver<AxOrdersWsMessage>>>,
129    signal: Arc<AtomicBool>,
130    task_handle: Option<tokio::task::JoinHandle<()>>,
131    auth_tracker: AuthTracker,
132    instruments_cache: Arc<AtomicMap<Ustr, InstrumentAny>>,
133    caches: OrdersCaches,
134    request_id_counter: Arc<AtomicI64>,
135    account_id: AccountId,
136    trader_id: TraderId,
137    transport_backend: TransportBackend,
138    proxy_url: Option<String>,
139}
140
141impl Debug for AxOrdersWebSocketClient {
142    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
143        f.debug_struct(stringify!(AxOrdersWebSocketClient))
144            .field("url", &self.url)
145            .field("heartbeat", &self.heartbeat)
146            .field("account_id", &self.account_id)
147            .finish()
148    }
149}
150
151impl Clone for AxOrdersWebSocketClient {
152    fn clone(&self) -> Self {
153        Self {
154            clock: self.clock,
155            url: self.url.clone(),
156            heartbeat: self.heartbeat,
157            connection_mode: Arc::clone(&self.connection_mode),
158            cmd_tx: Arc::clone(&self.cmd_tx),
159            out_rx: None, // Each clone gets its own receiver
160            signal: Arc::clone(&self.signal),
161            task_handle: None,
162            auth_tracker: self.auth_tracker.clone(),
163            instruments_cache: Arc::clone(&self.instruments_cache),
164            caches: self.caches.clone(),
165            request_id_counter: Arc::clone(&self.request_id_counter),
166            account_id: self.account_id,
167            trader_id: self.trader_id,
168            transport_backend: self.transport_backend,
169            proxy_url: self.proxy_url.clone(),
170        }
171    }
172}
173
174impl AxOrdersWebSocketClient {
175    /// Creates a new Ax orders WebSocket client.
176    #[must_use]
177    pub fn new(
178        url: String,
179        account_id: AccountId,
180        trader_id: TraderId,
181        heartbeat: u64,
182        transport_backend: TransportBackend,
183        proxy_url: Option<String>,
184    ) -> Self {
185        let (cmd_tx, _cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
186
187        let initial_mode = AtomicU8::new(ConnectionMode::Closed.as_u8());
188        let connection_mode = Arc::new(ArcSwap::from_pointee(initial_mode));
189
190        Self {
191            clock: get_atomic_clock_realtime(),
192            url,
193            heartbeat: Some(heartbeat),
194            connection_mode,
195            cmd_tx: Arc::new(tokio::sync::RwLock::new(cmd_tx)),
196            out_rx: None,
197            signal: Arc::new(AtomicBool::new(false)),
198            task_handle: None,
199            auth_tracker: AuthTracker::default(),
200            instruments_cache: Arc::new(AtomicMap::new()),
201            caches: OrdersCaches::default(),
202            request_id_counter: Arc::new(AtomicI64::new(1)),
203            account_id,
204            trader_id,
205            transport_backend,
206            proxy_url,
207        }
208    }
209
210    fn generate_ts_init(&self) -> UnixNanos {
211        self.clock.get_time_ns()
212    }
213
214    /// Returns the WebSocket URL.
215    #[must_use]
216    pub fn url(&self) -> &str {
217        &self.url
218    }
219
220    /// Returns the account ID.
221    #[must_use]
222    pub fn account_id(&self) -> AccountId {
223        self.account_id
224    }
225
226    /// Returns whether the client is currently connected and active.
227    #[must_use]
228    pub fn is_active(&self) -> bool {
229        let connection_mode_arc = self.connection_mode.load();
230        ConnectionMode::from_atomic(&connection_mode_arc).is_active()
231            && !self.signal.load(Ordering::Acquire)
232    }
233
234    /// Returns whether the client is closed.
235    #[must_use]
236    pub fn is_closed(&self) -> bool {
237        let connection_mode_arc = self.connection_mode.load();
238        ConnectionMode::from_atomic(&connection_mode_arc).is_closed()
239            || self.signal.load(Ordering::Acquire)
240    }
241
242    /// Generates a unique request ID.
243    fn next_request_id(&self) -> i64 {
244        self.request_id_counter.fetch_add(1, Ordering::Relaxed)
245    }
246
247    /// Caches an instrument for use during message parsing.
248    pub fn cache_instrument(&self, instrument: InstrumentAny) {
249        let symbol = instrument.symbol().inner();
250        self.instruments_cache.insert(symbol, instrument);
251    }
252
253    /// Caches multiple instruments for use during message parsing.
254    pub fn cache_instruments(&self, instruments: &[InstrumentAny]) {
255        self.instruments_cache.rcu(|m| {
256            for inst in instruments {
257                m.insert(inst.symbol().inner(), inst.clone());
258            }
259        });
260    }
261
262    /// Returns a cached instrument by symbol.
263    #[must_use]
264    pub fn get_cached_instrument(&self, symbol: &Ustr) -> Option<InstrumentAny> {
265        self.instruments_cache.get_cloned(symbol)
266    }
267
268    /// Returns the shared order caches.
269    #[must_use]
270    pub fn caches(&self) -> &OrdersCaches {
271        &self.caches
272    }
273
274    /// Returns the instruments cache.
275    #[must_use]
276    pub fn instruments_cache(&self) -> Arc<AtomicMap<Ustr, InstrumentAny>> {
277        Arc::clone(&self.instruments_cache)
278    }
279
280    /// Returns the orders metadata cache.
281    #[must_use]
282    pub fn orders_metadata(&self) -> &Arc<DashMap<ClientOrderId, OrderMetadata>> {
283        &self.caches.orders_metadata
284    }
285
286    /// Returns the cid to client order ID mapping for order correlation.
287    #[must_use]
288    pub fn cid_to_client_order_id(&self) -> &Arc<DashMap<u64, ClientOrderId>> {
289        &self.caches.cid_to_client_order_id
290    }
291
292    /// Resolves a cid to a ClientOrderId if the mapping exists.
293    #[must_use]
294    pub fn resolve_cid(&self, cid: u64) -> Option<ClientOrderId> {
295        self.caches.cid_to_client_order_id.get(&cid).map(|v| *v)
296    }
297
298    /// Registers an external order with the WebSocket handler for event tracking.
299    ///
300    /// This allows the handler to create proper events (e.g., OrderCanceled, OrderFilled)
301    /// for orders that were reconciled externally and not submitted through this client.
302    ///
303    /// Returns `false` if the instrument is not cached (registration skipped).
304    pub fn register_external_order(
305        &self,
306        client_order_id: ClientOrderId,
307        venue_order_id: VenueOrderId,
308        instrument_id: InstrumentId,
309        strategy_id: StrategyId,
310    ) -> bool {
311        if self.caches.orders_metadata.contains_key(&client_order_id) {
312            return true;
313        }
314
315        // Required for correct precision on fills
316        let symbol = instrument_id.symbol.inner();
317        let Some(instrument) = self.get_cached_instrument(&symbol) else {
318            log::warn!(
319                "Cannot register external order {client_order_id}: \
320                 instrument {instrument_id} not in cache"
321            );
322            return false;
323        };
324
325        let metadata = OrderMetadata {
326            trader_id: self.trader_id,
327            strategy_id,
328            instrument_id,
329            client_order_id,
330            venue_order_id: Some(venue_order_id),
331            ts_init: self.generate_ts_init(),
332            size_precision: instrument.size_precision(),
333            price_precision: instrument.price_precision(),
334            quote_currency: instrument.quote_currency(),
335            pending_trigger_price: None,
336        };
337
338        self.caches
339            .orders_metadata
340            .insert(client_order_id, metadata);
341        self.caches
342            .venue_to_client_id
343            .insert(venue_order_id, client_order_id);
344
345        log::debug!(
346            "Registered external order {client_order_id} ({venue_order_id}) for {instrument_id} [{strategy_id}]"
347        );
348
349        true
350    }
351
352    /// Establishes the WebSocket connection with authentication.
353    ///
354    /// # Arguments
355    ///
356    /// * `bearer_token` - The bearer token for authentication.
357    ///
358    /// # Errors
359    ///
360    /// Returns an error if the connection cannot be established.
361    pub async fn connect(&mut self, bearer_token: &str) -> AxOrdersWsResult<()> {
362        const MAX_RETRIES: u32 = 5;
363        const CONNECTION_TIMEOUT_SECS: u64 = 10;
364
365        self.signal.store(false, Ordering::Release);
366
367        let (raw_handler, raw_rx) = channel_message_handler();
368
369        // No-op ping handler: handler owns the WebSocketClient and responds to pings directly
370        let ping_handler: PingHandler = Arc::new(move |_payload: Vec<u8>| {
371            // Handler responds to pings internally via select! loop
372        });
373
374        let config = WebSocketConfig {
375            url: self.url.clone(),
376            headers: vec![
377                ("User-Agent".to_string(), NAUTILUS_USER_AGENT.to_string()),
378                (
379                    "Authorization".to_string(),
380                    format!("Bearer {bearer_token}"),
381                ),
382            ],
383            heartbeat: self.heartbeat,
384            heartbeat_msg: None, // Ax server sends heartbeats
385            reconnect_timeout_ms: Some(5_000),
386            reconnect_delay_initial_ms: Some(500),
387            reconnect_delay_max_ms: Some(5_000),
388            reconnect_backoff_factor: Some(1.5),
389            reconnect_jitter_ms: Some(250),
390            reconnect_max_attempts: None,
391            idle_timeout_ms: None,
392            backend: self.transport_backend,
393            proxy_url: self.proxy_url.clone(),
394        };
395
396        // Retry initial connection with exponential backoff
397        let mut backoff = ExponentialBackoff::new(
398            Duration::from_millis(500),
399            Duration::from_millis(5000),
400            2.0,
401            250,
402            false,
403        )
404        .map_err(|e| AxOrdersWsClientError::Transport(e.to_string()))?;
405
406        let mut last_error: String;
407        let mut attempt = 0;
408
409        let client = loop {
410            attempt += 1;
411
412            match tokio::time::timeout(
413                Duration::from_secs(CONNECTION_TIMEOUT_SECS),
414                WebSocketClient::connect(
415                    config.clone(),
416                    Some(raw_handler.clone()),
417                    Some(ping_handler.clone()),
418                    None,
419                    vec![],
420                    None,
421                ),
422            )
423            .await
424            {
425                Ok(Ok(client)) => {
426                    if attempt > 1 {
427                        log::info!("WebSocket connection established after {attempt} attempts");
428                    }
429                    break client;
430                }
431                Ok(Err(e)) => {
432                    last_error = e.to_string();
433                    log::warn!(
434                        "WebSocket connection attempt failed: attempt={attempt}, max_retries={MAX_RETRIES}, url={}, error={last_error}",
435                        self.url
436                    );
437                }
438                Err(_) => {
439                    last_error = format!("Connection timeout after {CONNECTION_TIMEOUT_SECS}s");
440                    log::warn!(
441                        "WebSocket connection attempt timed out: attempt={attempt}, max_retries={MAX_RETRIES}, url={}",
442                        self.url
443                    );
444                }
445            }
446
447            if attempt >= MAX_RETRIES {
448                return Err(AxOrdersWsClientError::Transport(format!(
449                    "Failed to connect to {} after {MAX_RETRIES} attempts: {}",
450                    self.url,
451                    if last_error.is_empty() {
452                        "unknown error"
453                    } else {
454                        &last_error
455                    }
456                )));
457            }
458
459            let delay = backoff.next_duration();
460            log::debug!(
461                "Retrying in {delay:?} (attempt {}/{MAX_RETRIES})",
462                attempt + 1
463            );
464            tokio::time::sleep(delay).await;
465        };
466
467        self.connection_mode.store(client.connection_mode_atomic());
468
469        let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<AxOrdersWsMessage>();
470        self.out_rx = Some(Arc::new(out_rx));
471
472        let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
473        *self.cmd_tx.write().await = cmd_tx.clone();
474
475        self.send_cmd(HandlerCommand::SetClient(client)).await?;
476
477        // Bearer token is passed in connection headers
478        self.send_cmd(HandlerCommand::Authenticate {
479            token: bearer_token.to_string(),
480        })
481        .await?;
482
483        let signal = Arc::clone(&self.signal);
484        let auth_tracker = self.auth_tracker.clone();
485        let orders_metadata = Arc::clone(&self.caches.orders_metadata);
486        let cid_to_client_order_id = Arc::clone(&self.caches.cid_to_client_order_id);
487
488        let stream_handle = get_runtime().spawn(async move {
489            let mut handler = AxOrdersWsFeedHandler::new(
490                signal.clone(),
491                cmd_rx,
492                raw_rx,
493                auth_tracker.clone(),
494                orders_metadata,
495                cid_to_client_order_id,
496            );
497
498            while let Some(msg) = handler.next().await {
499                if matches!(msg, AxOrdersWsMessage::Reconnected) {
500                    log::info!("WebSocket reconnected, authentication will be restored");
501                }
502
503                if out_tx.send(msg).is_err() {
504                    log::debug!("Output channel closed");
505                    break;
506                }
507            }
508
509            log::debug!("Handler loop exited");
510        });
511
512        self.task_handle = Some(stream_handle);
513
514        Ok(())
515    }
516
517    /// Submits an order using Nautilus domain types.
518    ///
519    /// This method handles conversion from Nautilus domain types to AX-specific
520    /// types and stores order metadata for event correlation.
521    ///
522    /// # Errors
523    ///
524    /// Returns an error if:
525    /// - The order type is not supported (only MARKET (simulated), LIMIT and STOP_LIMIT).
526    /// - The time-in-force is not supported.
527    /// - The instrument is not found in the cache.
528    /// - A limit order is missing a price.
529    /// - A stop-loss order is missing a trigger price.
530    /// - The order command cannot be sent.
531    #[expect(clippy::too_many_arguments)]
532    pub async fn submit_order(
533        &self,
534        trader_id: TraderId,
535        strategy_id: StrategyId,
536        instrument_id: InstrumentId,
537        client_order_id: ClientOrderId,
538        order_side: OrderSide,
539        order_type: OrderType,
540        quantity: Quantity,
541        time_in_force: TimeInForce,
542        price: Option<Price>,
543        trigger_price: Option<Price>,
544        post_only: bool,
545    ) -> AxOrdersWsResult<i64> {
546        if !matches!(
547            order_type,
548            OrderType::Market | OrderType::Limit | OrderType::StopLimit
549        ) {
550            return Err(AxOrdersWsClientError::ClientError(format!(
551                "Unsupported order type: {order_type:?}. AX supports MARKET, LIMIT and STOP_LIMIT."
552            )));
553        }
554
555        // Get instrument from cache for precision
556        let symbol = instrument_id.symbol.inner();
557        let instrument = self.get_cached_instrument(&symbol).ok_or_else(|| {
558            AxOrdersWsClientError::ClientError(format!(
559                "Instrument {instrument_id} not found in cache"
560            ))
561        })?;
562
563        let ax_side = AxOrderSide::try_from(order_side)?;
564
565        let qty_contracts = quantity_to_contracts(quantity)
566            .map_err(|e| AxOrdersWsClientError::ClientError(e.to_string()))?;
567
568        // Market orders are simulated as IOC limit orders with aggressive pricing
569        // because Architect does not support native market orders
570        let request_id = self.next_request_id();
571
572        let (ax_price, ax_tif, ax_post_only, ax_order_type, ax_trigger_price) = match order_type {
573            OrderType::Market => {
574                let market_price = price.ok_or_else(|| {
575                    AxOrdersWsClientError::ClientError(
576                        "Market order requires price (calculated from quote)".to_string(),
577                    )
578                })?;
579                (
580                    market_price.as_decimal(),
581                    AxTimeInForce::Ioc,
582                    false,
583                    None,
584                    None,
585                )
586            }
587            OrderType::Limit => {
588                let ax_tif = AxTimeInForce::try_from(time_in_force)?;
589                let limit_price = price.ok_or_else(|| {
590                    AxOrdersWsClientError::ClientError("Limit order requires price".to_string())
591                })?;
592                (limit_price.as_decimal(), ax_tif, post_only, None, None)
593            }
594            OrderType::StopLimit => {
595                let ax_tif = AxTimeInForce::try_from(time_in_force)?;
596                let limit_price = price.ok_or_else(|| {
597                    AxOrdersWsClientError::ClientError(
598                        "Stop-limit order requires price".to_string(),
599                    )
600                })?;
601                let stop_price = trigger_price.ok_or_else(|| {
602                    AxOrdersWsClientError::ClientError(
603                        "Stop-limit order requires trigger price".to_string(),
604                    )
605                })?;
606                (
607                    limit_price.as_decimal(),
608                    ax_tif,
609                    false,
610                    Some(AxOrderType::StopLossLimit),
611                    Some(stop_price.as_decimal()),
612                )
613            }
614            _ => {
615                return Err(AxOrdersWsClientError::ClientError(format!(
616                    "Unsupported order type: {order_type:?}"
617                )));
618            }
619        };
620
621        // Store order metadata for event correlation (after validation to avoid stale entries)
622        let metadata = OrderMetadata {
623            trader_id,
624            strategy_id,
625            instrument_id,
626            client_order_id,
627            venue_order_id: None,
628            ts_init: self.generate_ts_init(),
629            size_precision: instrument.size_precision(),
630            price_precision: instrument.price_precision(),
631            quote_currency: instrument.quote_currency(),
632            pending_trigger_price: None,
633        };
634        self.caches
635            .orders_metadata
636            .insert(client_order_id, metadata);
637
638        // Store cid -> client_order_id mapping for correlation
639        let cid = client_order_id_to_cid(&client_order_id);
640        self.caches
641            .cid_to_client_order_id
642            .insert(cid, client_order_id);
643
644        let order = AxWsPlaceOrder {
645            rid: request_id,
646            t: AxOrderRequestType::PlaceOrder,
647            s: symbol,
648            d: ax_side,
649            q: qty_contracts,
650            p: ax_price,
651            tif: ax_tif,
652            po: ax_post_only,
653            tag: Some(AX_NAUTILUS_TAG.to_string()),
654            cid: Some(cid),
655            order_type: ax_order_type,
656            trigger_price: ax_trigger_price,
657        };
658
659        let order_info = WsOrderInfo {
660            client_order_id,
661            symbol,
662        };
663
664        let result = self
665            .send_cmd(HandlerCommand::PlaceOrder {
666                request_id,
667                order,
668                order_info,
669            })
670            .await;
671
672        if result.is_err() {
673            self.caches.orders_metadata.remove(&client_order_id);
674            self.caches.cid_to_client_order_id.remove(&cid);
675        }
676
677        result?;
678        Ok(request_id)
679    }
680
681    /// Cancels an order via WebSocket.
682    ///
683    /// Requires a known `venue_order_id`.
684    ///
685    /// # Errors
686    ///
687    /// Returns an error if the cancel command cannot be sent.
688    pub async fn cancel_order(
689        &self,
690        client_order_id: ClientOrderId,
691        venue_order_id: Option<VenueOrderId>,
692    ) -> AxOrdersWsResult<i64> {
693        let order_id = venue_order_id.map(|v| v.to_string()).ok_or_else(|| {
694            AxOrdersWsClientError::ClientError(format!(
695                "Cannot cancel order {client_order_id}: missing venue_order_id"
696            ))
697        })?;
698
699        let request_id = self.next_request_id();
700
701        self.send_cmd(HandlerCommand::CancelOrder {
702            request_id,
703            order_id,
704        })
705        .await?;
706
707        Ok(request_id)
708    }
709
710    /// Requests open orders via WebSocket.
711    ///
712    /// # Errors
713    ///
714    /// Returns an error if the request command cannot be sent.
715    pub async fn get_open_orders(&self) -> AxOrdersWsResult<i64> {
716        let request_id = self.next_request_id();
717
718        self.send_cmd(HandlerCommand::GetOpenOrders { request_id })
719            .await?;
720
721        Ok(request_id)
722    }
723
724    /// Returns a stream of WebSocket messages.
725    ///
726    /// # Panics
727    ///
728    /// Panics if called before `connect()` or if the stream has already been taken.
729    pub fn stream(&mut self) -> impl futures_util::Stream<Item = AxOrdersWsMessage> + 'static {
730        let rx = self
731            .out_rx
732            .take()
733            .expect("Stream receiver already taken or client not connected - stream() can only be called once");
734        let mut rx = Arc::try_unwrap(rx).expect(
735            "Cannot take ownership of stream - client was cloned and other references exist",
736        );
737        async_stream::stream! {
738            while let Some(msg) = rx.recv().await {
739                yield msg;
740            }
741        }
742    }
743
744    /// Disconnects the WebSocket connection gracefully.
745    pub async fn disconnect(&self) {
746        log::debug!("Disconnecting WebSocket");
747        let _ = self.send_cmd(HandlerCommand::Disconnect).await;
748    }
749
750    /// Closes the WebSocket connection and cleans up resources.
751    pub async fn close(&mut self) {
752        log::debug!("Closing WebSocket client");
753
754        // Send disconnect first to allow graceful cleanup before signal
755        let _ = self.send_cmd(HandlerCommand::Disconnect).await;
756        tokio::time::sleep(Duration::from_millis(50)).await;
757        self.signal.store(true, Ordering::Release);
758
759        if let Some(handle) = self.task_handle.take() {
760            const CLOSE_TIMEOUT: Duration = Duration::from_secs(2);
761            let abort_handle = handle.abort_handle();
762
763            match tokio::time::timeout(CLOSE_TIMEOUT, handle).await {
764                Ok(Ok(())) => log::debug!("Handler task completed gracefully"),
765                Ok(Err(e)) => log::warn!("Handler task panicked: {e}"),
766                Err(_) => {
767                    log::warn!("Handler task did not complete within timeout, aborting");
768                    abort_handle.abort();
769                }
770            }
771        }
772    }
773
774    async fn send_cmd(&self, cmd: HandlerCommand) -> AxOrdersWsResult<()> {
775        let guard = self.cmd_tx.read().await;
776        guard
777            .send(cmd)
778            .map_err(|e| AxOrdersWsClientError::ChannelError(e.to_string()))
779    }
780}
781
782#[cfg(test)]
783mod tests {
784    use std::sync::Arc;
785
786    use super::*;
787
788    #[tokio::test]
789    async fn test_cancel_order_rejects_without_venue_order_id() {
790        let client = AxOrdersWebSocketClient::new(
791            "wss://example.com/orders/ws".to_string(),
792            AccountId::from("AX-001"),
793            TraderId::from("TRADER-001"),
794            30,
795            TransportBackend::default(),
796            None,
797        );
798        let client_order_id = ClientOrderId::from("CID-123");
799
800        let result = client.cancel_order(client_order_id, None).await;
801
802        assert!(matches!(
803            result,
804            Err(AxOrdersWsClientError::ClientError(msg))
805            if msg.contains("missing venue_order_id")
806        ));
807    }
808
809    #[tokio::test]
810    async fn test_cancel_order_sends_known_venue_order_id() {
811        let mut client = AxOrdersWebSocketClient::new(
812            "wss://example.com/orders/ws".to_string(),
813            AccountId::from("AX-001"),
814            TraderId::from("TRADER-001"),
815            30,
816            TransportBackend::default(),
817            None,
818        );
819
820        let (cmd_tx, mut cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
821        client.cmd_tx = Arc::new(tokio::sync::RwLock::new(cmd_tx));
822
823        let client_order_id = ClientOrderId::from("CID-456");
824        let venue_order_id = VenueOrderId::from("V-ORDER-789");
825
826        let request_id = client
827            .cancel_order(client_order_id, Some(venue_order_id))
828            .await
829            .unwrap();
830
831        assert_eq!(request_id, 1);
832        let cmd = cmd_rx.recv().await.unwrap();
833        match cmd {
834            HandlerCommand::CancelOrder {
835                request_id,
836                order_id,
837            } => {
838                assert_eq!(request_id, 1);
839                assert_eq!(order_id, "V-ORDER-789");
840            }
841            other => panic!("unexpected command: {other:?}"),
842        }
843    }
844}