Skip to main content

nautilus_architect_ax/websocket/orders/
handler.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Orders WebSocket message handler for Ax.
17
18use std::{
19    collections::VecDeque,
20    sync::{
21        Arc,
22        atomic::{AtomicBool, Ordering},
23    },
24};
25
26use ahash::AHashMap;
27use dashmap::DashMap;
28use nautilus_model::identifiers::ClientOrderId;
29use nautilus_network::websocket::{AuthTracker, WebSocketClient};
30use tokio_tungstenite::tungstenite::Message;
31use ustr::Ustr;
32
33use crate::{
34    common::enums::AxOrderRequestType,
35    websocket::{
36        messages::{
37            AxOrdersWsFrame, AxOrdersWsMessage, AxWsCancelOrder, AxWsError, AxWsGetOpenOrders,
38            AxWsOrderEvent, AxWsOrderResponse, AxWsPlaceOrder, OrderMetadata,
39        },
40        parse::parse_order_message,
41    },
42};
43
44/// Simple tracking info for pending WebSocket orders.
45#[derive(Clone, Debug)]
46pub struct WsOrderInfo {
47    /// Client order ID for correlation.
48    pub client_order_id: ClientOrderId,
49    /// Instrument symbol.
50    pub symbol: Ustr,
51}
52
53/// Commands sent from the outer client to the inner orders handler.
54#[derive(Debug)]
55pub enum HandlerCommand {
56    /// Set the WebSocket client for this handler.
57    SetClient(WebSocketClient),
58    /// Disconnect the WebSocket connection.
59    Disconnect,
60    /// Authenticate with the provided token.
61    Authenticate {
62        /// Bearer token for authentication.
63        token: String,
64    },
65    /// Place an order.
66    PlaceOrder {
67        /// Request ID for correlation.
68        request_id: i64,
69        /// Order placement message.
70        order: AxWsPlaceOrder,
71        /// Order info for tracking.
72        order_info: WsOrderInfo,
73    },
74    /// Cancel an order.
75    CancelOrder {
76        /// Request ID for correlation.
77        request_id: i64,
78        /// Order ID to cancel.
79        order_id: String,
80    },
81    /// Get open orders.
82    GetOpenOrders {
83        /// Request ID for correlation.
84        request_id: i64,
85    },
86}
87
88/// Orders feed handler that processes WebSocket messages.
89///
90/// Runs in a dedicated Tokio task and owns the WebSocket client exclusively.
91/// Emits raw venue types for downstream consumers to parse into domain events.
92pub(crate) struct AxOrdersWsFeedHandler {
93    signal: Arc<AtomicBool>,
94    inner: Option<WebSocketClient>,
95    cmd_rx: tokio::sync::mpsc::UnboundedReceiver<HandlerCommand>,
96    raw_rx: tokio::sync::mpsc::UnboundedReceiver<Message>,
97    auth_tracker: AuthTracker,
98    pending_orders: AHashMap<i64, WsOrderInfo>,
99    message_queue: VecDeque<AxOrdersWsMessage>,
100    orders_metadata: Arc<DashMap<ClientOrderId, OrderMetadata>>,
101    cid_to_client_order_id: Arc<DashMap<u64, ClientOrderId>>,
102    bearer_token: Option<String>,
103    needs_reauthentication: bool,
104}
105
106impl AxOrdersWsFeedHandler {
107    /// Creates a new [`AxOrdersWsFeedHandler`] instance.
108    #[must_use]
109    pub fn new(
110        signal: Arc<AtomicBool>,
111        cmd_rx: tokio::sync::mpsc::UnboundedReceiver<HandlerCommand>,
112        raw_rx: tokio::sync::mpsc::UnboundedReceiver<Message>,
113        auth_tracker: AuthTracker,
114        orders_metadata: Arc<DashMap<ClientOrderId, OrderMetadata>>,
115        cid_to_client_order_id: Arc<DashMap<u64, ClientOrderId>>,
116    ) -> Self {
117        Self {
118            signal,
119            inner: None,
120            cmd_rx,
121            raw_rx,
122            auth_tracker,
123            pending_orders: AHashMap::new(),
124            message_queue: VecDeque::new(),
125            orders_metadata,
126            cid_to_client_order_id,
127            bearer_token: None,
128            needs_reauthentication: false,
129        }
130    }
131
132    async fn reauthenticate(&mut self) {
133        if self.bearer_token.is_some() {
134            log::info!("Re-authenticating after reconnection");
135
136            // Ax uses Bearer token in connection headers which persist across reconnect
137            self.auth_tracker.succeed();
138            self.message_queue
139                .push_back(AxOrdersWsMessage::Authenticated);
140            log::info!("Re-authentication completed");
141        } else {
142            log::warn!("Cannot re-authenticate: no bearer token stored");
143        }
144    }
145
146    /// Returns the next message from the handler.
147    ///
148    /// This method blocks until a message is available or the handler is stopped.
149    pub async fn next(&mut self) -> Option<AxOrdersWsMessage> {
150        loop {
151            if self.needs_reauthentication && self.message_queue.is_empty() {
152                self.needs_reauthentication = false;
153                self.reauthenticate().await;
154            }
155
156            if let Some(msg) = self.message_queue.pop_front() {
157                return Some(msg);
158            }
159
160            tokio::select! {
161                Some(cmd) = self.cmd_rx.recv() => {
162                    self.handle_command(cmd).await;
163                }
164
165                () = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
166                    if self.signal.load(Ordering::Acquire) {
167                        log::debug!("Stop signal received during idle period");
168                        return None;
169                    }
170                }
171
172                msg = self.raw_rx.recv() => {
173                    let msg = match msg {
174                        Some(msg) => msg,
175                        None => {
176                            log::debug!("WebSocket stream closed");
177                            return None;
178                        }
179                    };
180
181                    if let Message::Ping(data) = &msg {
182                        log::trace!("Received ping frame with {} bytes", data.len());
183
184                        if let Some(client) = &self.inner
185                            && let Err(e) = client.send_pong(data.to_vec()).await
186                        {
187                            log::warn!("Failed to send pong frame: {e}");
188                        }
189                        continue;
190                    }
191
192                    if let Some(messages) = self.parse_raw_message(msg) {
193                        self.message_queue.extend(messages);
194                    }
195
196                    if self.signal.load(Ordering::Acquire) {
197                        log::debug!("Stop signal received");
198                        return None;
199                    }
200                }
201            }
202        }
203    }
204
205    async fn handle_command(&mut self, cmd: HandlerCommand) {
206        match cmd {
207            HandlerCommand::SetClient(client) => {
208                log::debug!("WebSocketClient received by handler");
209                self.inner = Some(client);
210            }
211            HandlerCommand::Disconnect => {
212                log::debug!("Disconnect command received");
213                self.auth_tracker.fail("Disconnected");
214
215                if let Some(inner) = self.inner.take() {
216                    inner.disconnect().await;
217                }
218            }
219            HandlerCommand::Authenticate { token } => {
220                log::debug!("Authenticate command received");
221                self.bearer_token = Some(token);
222
223                // Ax uses Bearer token in connection headers (handled at connect time)
224                self.auth_tracker.succeed();
225                self.message_queue
226                    .push_back(AxOrdersWsMessage::Authenticated);
227            }
228            HandlerCommand::PlaceOrder {
229                request_id,
230                order,
231                order_info,
232            } => {
233                log::debug!(
234                    "PlaceOrder command received: request_id={request_id}, symbol={}",
235                    order.s
236                );
237                self.pending_orders.insert(request_id, order_info.clone());
238
239                if let Err(e) = self.send_json(&order).await {
240                    log::error!("Failed to send place order message: {e}");
241                    self.pending_orders.remove(&request_id);
242                    self.orders_metadata.remove(&order_info.client_order_id);
243
244                    if let Some(cid) = order.cid {
245                        self.cid_to_client_order_id.remove(&cid);
246                    }
247                    self.message_queue
248                        .push_back(AxOrdersWsMessage::Error(AxWsError::new(format!(
249                            "Failed to send place order for {}: {e}",
250                            order_info.client_order_id
251                        ))));
252                }
253            }
254            HandlerCommand::CancelOrder {
255                request_id,
256                order_id,
257            } => {
258                log::debug!(
259                    "CancelOrder command received: request_id={request_id}, order_id={order_id}"
260                );
261                self.send_cancel_order(request_id, &order_id).await;
262            }
263            HandlerCommand::GetOpenOrders { request_id } => {
264                log::debug!("GetOpenOrders command received: request_id={request_id}");
265                self.send_get_open_orders(request_id).await;
266            }
267        }
268    }
269
270    async fn send_cancel_order(&mut self, request_id: i64, order_id: &str) {
271        let msg = AxWsCancelOrder {
272            rid: request_id,
273            t: AxOrderRequestType::CancelOrder,
274            oid: order_id.to_string(),
275        };
276
277        if let Err(e) = self.send_json(&msg).await {
278            log::error!("Failed to send cancel order message: {e}");
279            self.message_queue
280                .push_back(AxOrdersWsMessage::Error(AxWsError::new(format!(
281                    "Failed to send cancel for order {order_id}: {e}"
282                ))));
283        }
284    }
285
286    async fn send_get_open_orders(&mut self, request_id: i64) {
287        let msg = AxWsGetOpenOrders {
288            rid: request_id,
289            t: AxOrderRequestType::GetOpenOrders,
290        };
291
292        if let Err(e) = self.send_json(&msg).await {
293            log::error!("Failed to send get open orders message: {e}");
294            self.message_queue
295                .push_back(AxOrdersWsMessage::Error(AxWsError::new(format!(
296                    "Failed to send get open orders request: {e}"
297                ))));
298        }
299    }
300
301    async fn send_json<T: serde::Serialize>(&self, msg: &T) -> Result<(), String> {
302        let Some(inner) = &self.inner else {
303            return Err("No WebSocket client available".to_string());
304        };
305
306        let payload = serde_json::to_string(msg).map_err(|e| e.to_string())?;
307        log::trace!("Sending: {payload}");
308
309        inner
310            .send_text(payload, None)
311            .await
312            .map_err(|e| e.to_string())
313    }
314
315    fn parse_raw_message(&mut self, msg: Message) -> Option<Vec<AxOrdersWsMessage>> {
316        match msg {
317            Message::Text(text) => {
318                if text == nautilus_network::RECONNECTED {
319                    log::info!("Received WebSocket reconnected signal");
320                    self.auth_tracker.fail("Reconnecting");
321                    self.needs_reauthentication = true;
322                    return Some(vec![AxOrdersWsMessage::Reconnected]);
323                }
324
325                log::trace!("Raw websocket message: {text}");
326
327                let raw_msg: AxOrdersWsFrame = match parse_order_message(&text) {
328                    Ok(v) => v,
329                    Err(e) => {
330                        log::error!("Failed to parse WebSocket message: {e}: {text}");
331                        return None;
332                    }
333                };
334
335                self.handle_raw_message(raw_msg)
336            }
337            Message::Binary(data) => {
338                log::debug!("Received binary message with {} bytes", data.len());
339                None
340            }
341            Message::Close(_) => {
342                log::debug!("Received close message, waiting for reconnection");
343                None
344            }
345            _ => None,
346        }
347    }
348
349    fn handle_raw_message(&mut self, raw_msg: AxOrdersWsFrame) -> Option<Vec<AxOrdersWsMessage>> {
350        match raw_msg {
351            AxOrdersWsFrame::Error(err) => {
352                log::warn!(
353                    "Order error response: rid={} code={} msg={}",
354                    err.rid,
355                    err.err.code,
356                    err.err.msg
357                );
358
359                if let Some(order_info) = self.pending_orders.remove(&err.rid) {
360                    self.orders_metadata.remove(&order_info.client_order_id);
361                    log::debug!(
362                        "Cleaned up metadata for failed order: {}",
363                        order_info.client_order_id
364                    );
365                }
366
367                Some(vec![AxOrdersWsMessage::Error(err.into())])
368            }
369            AxOrdersWsFrame::Response(resp) => self.handle_response(resp),
370            AxOrdersWsFrame::Event(event) => self.handle_event(*event),
371        }
372    }
373
374    fn handle_response(&mut self, resp: AxWsOrderResponse) -> Option<Vec<AxOrdersWsMessage>> {
375        match resp {
376            AxWsOrderResponse::PlaceOrder(msg) => {
377                log::debug!("Place order response: rid={} oid={}", msg.rid, msg.res.oid);
378                self.pending_orders.remove(&msg.rid);
379                Some(vec![AxOrdersWsMessage::PlaceOrderResponse(msg)])
380            }
381            AxWsOrderResponse::CancelOrder(msg) => {
382                log::debug!(
383                    "Cancel order response: rid={} accepted={}",
384                    msg.rid,
385                    msg.res.cxl_rx
386                );
387                Some(vec![AxOrdersWsMessage::CancelOrderResponse(msg)])
388            }
389            AxWsOrderResponse::OpenOrders(msg) => {
390                log::debug!("Open orders response: {} orders", msg.res.len());
391                Some(vec![AxOrdersWsMessage::OpenOrdersResponse(msg)])
392            }
393            AxWsOrderResponse::List(msg) => {
394                let order_count = msg.res.o.as_ref().map_or(0, |o| o.len());
395                log::debug!(
396                    "List subscription response: rid={} li={} orders={}",
397                    msg.rid,
398                    msg.res.li,
399                    order_count
400                );
401                None
402            }
403        }
404    }
405
406    fn handle_event(&self, event: AxWsOrderEvent) -> Option<Vec<AxOrdersWsMessage>> {
407        if matches!(event, AxWsOrderEvent::Heartbeat) {
408            log::trace!("Received heartbeat");
409            return None;
410        }
411        Some(vec![AxOrdersWsMessage::Event(Box::new(event))])
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use std::sync::{Arc, atomic::AtomicBool};
418
419    use dashmap::DashMap;
420    use nautilus_network::websocket::AuthTracker;
421    use rstest::rstest;
422    use ustr::Ustr;
423
424    use super::*;
425    use crate::websocket::messages::{AxWsPlaceOrderResponse, AxWsPlaceOrderResult};
426
427    fn test_handler() -> AxOrdersWsFeedHandler {
428        let (_cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel();
429        let (_raw_tx, raw_rx) = tokio::sync::mpsc::unbounded_channel();
430        AxOrdersWsFeedHandler::new(
431            Arc::new(AtomicBool::new(false)),
432            cmd_rx,
433            raw_rx,
434            AuthTracker::default(),
435            Arc::new(DashMap::new()),
436            Arc::new(DashMap::new()),
437        )
438    }
439
440    #[rstest]
441    fn test_place_order_response_cleans_pending_order() {
442        let mut handler = test_handler();
443        let request_id = 11;
444        handler.pending_orders.insert(
445            request_id,
446            WsOrderInfo {
447                client_order_id: ClientOrderId::from("CID-11"),
448                symbol: Ustr::from("EURUSD-PERP"),
449            },
450        );
451
452        let response = AxWsOrderResponse::PlaceOrder(AxWsPlaceOrderResponse {
453            rid: request_id,
454            res: AxWsPlaceOrderResult {
455                oid: "OID-11".to_string(),
456            },
457        });
458
459        let messages = handler.handle_response(response).unwrap();
460        assert_eq!(messages.len(), 1);
461        assert!(handler.pending_orders.get(&request_id).is_none());
462    }
463
464    #[rstest]
465    fn test_handle_event_forwards_venue_event() {
466        let handler = test_handler();
467
468        let event = AxWsOrderEvent::Heartbeat;
469        let result = handler.handle_event(event);
470        assert!(result.is_none());
471    }
472}