nautilus_architect_ax/websocket/orders/
handler.rs1use std::{
19 collections::VecDeque,
20 sync::{
21 Arc,
22 atomic::{AtomicBool, Ordering},
23 },
24};
25
26use ahash::AHashMap;
27use dashmap::DashMap;
28use nautilus_model::identifiers::ClientOrderId;
29use nautilus_network::websocket::{AuthTracker, WebSocketClient};
30use tokio_tungstenite::tungstenite::Message;
31use ustr::Ustr;
32
33use crate::{
34 common::enums::AxOrderRequestType,
35 websocket::{
36 messages::{
37 AxOrdersWsFrame, AxOrdersWsMessage, AxWsCancelOrder, AxWsError, AxWsGetOpenOrders,
38 AxWsOrderEvent, AxWsOrderResponse, AxWsPlaceOrder, OrderMetadata,
39 },
40 parse::parse_order_message,
41 },
42};
43
44#[derive(Clone, Debug)]
46pub struct WsOrderInfo {
47 pub client_order_id: ClientOrderId,
49 pub symbol: Ustr,
51}
52
53#[derive(Debug)]
55pub enum HandlerCommand {
56 SetClient(WebSocketClient),
58 Disconnect,
60 Authenticate {
62 token: String,
64 },
65 PlaceOrder {
67 request_id: i64,
69 order: AxWsPlaceOrder,
71 order_info: WsOrderInfo,
73 },
74 CancelOrder {
76 request_id: i64,
78 order_id: String,
80 },
81 GetOpenOrders {
83 request_id: i64,
85 },
86}
87
88pub(crate) struct AxOrdersWsFeedHandler {
93 signal: Arc<AtomicBool>,
94 inner: Option<WebSocketClient>,
95 cmd_rx: tokio::sync::mpsc::UnboundedReceiver<HandlerCommand>,
96 raw_rx: tokio::sync::mpsc::UnboundedReceiver<Message>,
97 auth_tracker: AuthTracker,
98 pending_orders: AHashMap<i64, WsOrderInfo>,
99 message_queue: VecDeque<AxOrdersWsMessage>,
100 orders_metadata: Arc<DashMap<ClientOrderId, OrderMetadata>>,
101 cid_to_client_order_id: Arc<DashMap<u64, ClientOrderId>>,
102 bearer_token: Option<String>,
103 needs_reauthentication: bool,
104}
105
106impl AxOrdersWsFeedHandler {
107 #[must_use]
109 pub fn new(
110 signal: Arc<AtomicBool>,
111 cmd_rx: tokio::sync::mpsc::UnboundedReceiver<HandlerCommand>,
112 raw_rx: tokio::sync::mpsc::UnboundedReceiver<Message>,
113 auth_tracker: AuthTracker,
114 orders_metadata: Arc<DashMap<ClientOrderId, OrderMetadata>>,
115 cid_to_client_order_id: Arc<DashMap<u64, ClientOrderId>>,
116 ) -> Self {
117 Self {
118 signal,
119 inner: None,
120 cmd_rx,
121 raw_rx,
122 auth_tracker,
123 pending_orders: AHashMap::new(),
124 message_queue: VecDeque::new(),
125 orders_metadata,
126 cid_to_client_order_id,
127 bearer_token: None,
128 needs_reauthentication: false,
129 }
130 }
131
132 async fn reauthenticate(&mut self) {
133 if self.bearer_token.is_some() {
134 log::info!("Re-authenticating after reconnection");
135
136 self.auth_tracker.succeed();
138 self.message_queue
139 .push_back(AxOrdersWsMessage::Authenticated);
140 log::info!("Re-authentication completed");
141 } else {
142 log::warn!("Cannot re-authenticate: no bearer token stored");
143 }
144 }
145
146 pub async fn next(&mut self) -> Option<AxOrdersWsMessage> {
150 loop {
151 if self.needs_reauthentication && self.message_queue.is_empty() {
152 self.needs_reauthentication = false;
153 self.reauthenticate().await;
154 }
155
156 if let Some(msg) = self.message_queue.pop_front() {
157 return Some(msg);
158 }
159
160 tokio::select! {
161 Some(cmd) = self.cmd_rx.recv() => {
162 self.handle_command(cmd).await;
163 }
164
165 () = tokio::time::sleep(std::time::Duration::from_millis(100)) => {
166 if self.signal.load(Ordering::Acquire) {
167 log::debug!("Stop signal received during idle period");
168 return None;
169 }
170 }
171
172 msg = self.raw_rx.recv() => {
173 let msg = match msg {
174 Some(msg) => msg,
175 None => {
176 log::debug!("WebSocket stream closed");
177 return None;
178 }
179 };
180
181 if let Message::Ping(data) = &msg {
182 log::trace!("Received ping frame with {} bytes", data.len());
183
184 if let Some(client) = &self.inner
185 && let Err(e) = client.send_pong(data.to_vec()).await
186 {
187 log::warn!("Failed to send pong frame: {e}");
188 }
189 continue;
190 }
191
192 if let Some(messages) = self.parse_raw_message(msg) {
193 self.message_queue.extend(messages);
194 }
195
196 if self.signal.load(Ordering::Acquire) {
197 log::debug!("Stop signal received");
198 return None;
199 }
200 }
201 }
202 }
203 }
204
205 async fn handle_command(&mut self, cmd: HandlerCommand) {
206 match cmd {
207 HandlerCommand::SetClient(client) => {
208 log::debug!("WebSocketClient received by handler");
209 self.inner = Some(client);
210 }
211 HandlerCommand::Disconnect => {
212 log::debug!("Disconnect command received");
213 self.auth_tracker.fail("Disconnected");
214
215 if let Some(inner) = self.inner.take() {
216 inner.disconnect().await;
217 }
218 }
219 HandlerCommand::Authenticate { token } => {
220 log::debug!("Authenticate command received");
221 self.bearer_token = Some(token);
222
223 self.auth_tracker.succeed();
225 self.message_queue
226 .push_back(AxOrdersWsMessage::Authenticated);
227 }
228 HandlerCommand::PlaceOrder {
229 request_id,
230 order,
231 order_info,
232 } => {
233 log::debug!(
234 "PlaceOrder command received: request_id={request_id}, symbol={}",
235 order.s
236 );
237 self.pending_orders.insert(request_id, order_info.clone());
238
239 if let Err(e) = self.send_json(&order).await {
240 log::error!("Failed to send place order message: {e}");
241 self.pending_orders.remove(&request_id);
242 self.orders_metadata.remove(&order_info.client_order_id);
243
244 if let Some(cid) = order.cid {
245 self.cid_to_client_order_id.remove(&cid);
246 }
247 self.message_queue
248 .push_back(AxOrdersWsMessage::Error(AxWsError::new(format!(
249 "Failed to send place order for {}: {e}",
250 order_info.client_order_id
251 ))));
252 }
253 }
254 HandlerCommand::CancelOrder {
255 request_id,
256 order_id,
257 } => {
258 log::debug!(
259 "CancelOrder command received: request_id={request_id}, order_id={order_id}"
260 );
261 self.send_cancel_order(request_id, &order_id).await;
262 }
263 HandlerCommand::GetOpenOrders { request_id } => {
264 log::debug!("GetOpenOrders command received: request_id={request_id}");
265 self.send_get_open_orders(request_id).await;
266 }
267 }
268 }
269
270 async fn send_cancel_order(&mut self, request_id: i64, order_id: &str) {
271 let msg = AxWsCancelOrder {
272 rid: request_id,
273 t: AxOrderRequestType::CancelOrder,
274 oid: order_id.to_string(),
275 };
276
277 if let Err(e) = self.send_json(&msg).await {
278 log::error!("Failed to send cancel order message: {e}");
279 self.message_queue
280 .push_back(AxOrdersWsMessage::Error(AxWsError::new(format!(
281 "Failed to send cancel for order {order_id}: {e}"
282 ))));
283 }
284 }
285
286 async fn send_get_open_orders(&mut self, request_id: i64) {
287 let msg = AxWsGetOpenOrders {
288 rid: request_id,
289 t: AxOrderRequestType::GetOpenOrders,
290 };
291
292 if let Err(e) = self.send_json(&msg).await {
293 log::error!("Failed to send get open orders message: {e}");
294 self.message_queue
295 .push_back(AxOrdersWsMessage::Error(AxWsError::new(format!(
296 "Failed to send get open orders request: {e}"
297 ))));
298 }
299 }
300
301 async fn send_json<T: serde::Serialize>(&self, msg: &T) -> Result<(), String> {
302 let Some(inner) = &self.inner else {
303 return Err("No WebSocket client available".to_string());
304 };
305
306 let payload = serde_json::to_string(msg).map_err(|e| e.to_string())?;
307 log::trace!("Sending: {payload}");
308
309 inner
310 .send_text(payload, None)
311 .await
312 .map_err(|e| e.to_string())
313 }
314
315 fn parse_raw_message(&mut self, msg: Message) -> Option<Vec<AxOrdersWsMessage>> {
316 match msg {
317 Message::Text(text) => {
318 if text == nautilus_network::RECONNECTED {
319 log::info!("Received WebSocket reconnected signal");
320 self.auth_tracker.fail("Reconnecting");
321 self.needs_reauthentication = true;
322 return Some(vec![AxOrdersWsMessage::Reconnected]);
323 }
324
325 log::trace!("Raw websocket message: {text}");
326
327 let raw_msg: AxOrdersWsFrame = match parse_order_message(&text) {
328 Ok(v) => v,
329 Err(e) => {
330 log::error!("Failed to parse WebSocket message: {e}: {text}");
331 return None;
332 }
333 };
334
335 self.handle_raw_message(raw_msg)
336 }
337 Message::Binary(data) => {
338 log::debug!("Received binary message with {} bytes", data.len());
339 None
340 }
341 Message::Close(_) => {
342 log::debug!("Received close message, waiting for reconnection");
343 None
344 }
345 _ => None,
346 }
347 }
348
349 fn handle_raw_message(&mut self, raw_msg: AxOrdersWsFrame) -> Option<Vec<AxOrdersWsMessage>> {
350 match raw_msg {
351 AxOrdersWsFrame::Error(err) => {
352 log::warn!(
353 "Order error response: rid={} code={} msg={}",
354 err.rid,
355 err.err.code,
356 err.err.msg
357 );
358
359 if let Some(order_info) = self.pending_orders.remove(&err.rid) {
360 self.orders_metadata.remove(&order_info.client_order_id);
361 log::debug!(
362 "Cleaned up metadata for failed order: {}",
363 order_info.client_order_id
364 );
365 }
366
367 Some(vec![AxOrdersWsMessage::Error(err.into())])
368 }
369 AxOrdersWsFrame::Response(resp) => self.handle_response(resp),
370 AxOrdersWsFrame::Event(event) => self.handle_event(*event),
371 }
372 }
373
374 fn handle_response(&mut self, resp: AxWsOrderResponse) -> Option<Vec<AxOrdersWsMessage>> {
375 match resp {
376 AxWsOrderResponse::PlaceOrder(msg) => {
377 log::debug!("Place order response: rid={} oid={}", msg.rid, msg.res.oid);
378 self.pending_orders.remove(&msg.rid);
379 Some(vec![AxOrdersWsMessage::PlaceOrderResponse(msg)])
380 }
381 AxWsOrderResponse::CancelOrder(msg) => {
382 log::debug!(
383 "Cancel order response: rid={} accepted={}",
384 msg.rid,
385 msg.res.cxl_rx
386 );
387 Some(vec![AxOrdersWsMessage::CancelOrderResponse(msg)])
388 }
389 AxWsOrderResponse::OpenOrders(msg) => {
390 log::debug!("Open orders response: {} orders", msg.res.len());
391 Some(vec![AxOrdersWsMessage::OpenOrdersResponse(msg)])
392 }
393 AxWsOrderResponse::List(msg) => {
394 let order_count = msg.res.o.as_ref().map_or(0, |o| o.len());
395 log::debug!(
396 "List subscription response: rid={} li={} orders={}",
397 msg.rid,
398 msg.res.li,
399 order_count
400 );
401 None
402 }
403 }
404 }
405
406 fn handle_event(&self, event: AxWsOrderEvent) -> Option<Vec<AxOrdersWsMessage>> {
407 if matches!(event, AxWsOrderEvent::Heartbeat) {
408 log::trace!("Received heartbeat");
409 return None;
410 }
411 Some(vec![AxOrdersWsMessage::Event(Box::new(event))])
412 }
413}
414
415#[cfg(test)]
416mod tests {
417 use std::sync::{Arc, atomic::AtomicBool};
418
419 use dashmap::DashMap;
420 use nautilus_network::websocket::AuthTracker;
421 use rstest::rstest;
422 use ustr::Ustr;
423
424 use super::*;
425 use crate::websocket::messages::{AxWsPlaceOrderResponse, AxWsPlaceOrderResult};
426
427 fn test_handler() -> AxOrdersWsFeedHandler {
428 let (_cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel();
429 let (_raw_tx, raw_rx) = tokio::sync::mpsc::unbounded_channel();
430 AxOrdersWsFeedHandler::new(
431 Arc::new(AtomicBool::new(false)),
432 cmd_rx,
433 raw_rx,
434 AuthTracker::default(),
435 Arc::new(DashMap::new()),
436 Arc::new(DashMap::new()),
437 )
438 }
439
440 #[rstest]
441 fn test_place_order_response_cleans_pending_order() {
442 let mut handler = test_handler();
443 let request_id = 11;
444 handler.pending_orders.insert(
445 request_id,
446 WsOrderInfo {
447 client_order_id: ClientOrderId::from("CID-11"),
448 symbol: Ustr::from("EURUSD-PERP"),
449 },
450 );
451
452 let response = AxWsOrderResponse::PlaceOrder(AxWsPlaceOrderResponse {
453 rid: request_id,
454 res: AxWsPlaceOrderResult {
455 oid: "OID-11".to_string(),
456 },
457 });
458
459 let messages = handler.handle_response(response).unwrap();
460 assert_eq!(messages.len(), 1);
461 assert!(handler.pending_orders.get(&request_id).is_none());
462 }
463
464 #[rstest]
465 fn test_handle_event_forwards_venue_event() {
466 let handler = test_handler();
467
468 let event = AxWsOrderEvent::Heartbeat;
469 let result = handler.handle_event(event);
470 assert!(result.is_none());
471 }
472}