nautilus_polymarket/websocket/
client.rs1use std::sync::{
19 Arc,
20 atomic::{AtomicBool, AtomicU8, Ordering},
21};
22
23use nautilus_common::live::get_runtime;
24use nautilus_network::{
25 mode::ConnectionMode,
26 websocket::{
27 AuthTracker, SubscriptionState, TransportBackend, WebSocketClient, WebSocketConfig,
28 channel_message_handler,
29 },
30};
31
32use super::{
33 handler::{FeedHandler, HandlerCommand},
34 messages::PolymarketWsMessage,
35};
36use crate::common::{
37 credential::Credential,
38 urls::{clob_ws_market_url, clob_ws_user_url},
39};
40
41const POLYMARKET_HEARTBEAT_SECS: u64 = 30;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
45pub enum WsChannel {
46 Market,
47 User,
48}
49
50fn idle_timeout_ms_for(channel: WsChannel) -> u64 {
54 match channel {
55 WsChannel::Market => 60_000,
56 WsChannel::User => 300_000,
57 }
58}
59
60#[derive(Clone, Debug)]
64pub struct WsSubscriptionHandle {
65 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
66}
67
68impl WsSubscriptionHandle {
69 pub async fn subscribe_market(&self, asset_ids: Vec<String>) -> anyhow::Result<()> {
71 self.cmd_tx
72 .read()
73 .await
74 .send(HandlerCommand::SubscribeMarket(asset_ids))
75 .map_err(|e| anyhow::anyhow!("Failed to send SubscribeMarket: {e}"))
76 }
77
78 pub async fn unsubscribe_market(&self, asset_ids: Vec<String>) -> anyhow::Result<()> {
80 self.cmd_tx
81 .read()
82 .await
83 .send(HandlerCommand::UnsubscribeMarket(asset_ids))
84 .map_err(|e| anyhow::anyhow!("Failed to send UnsubscribeMarket: {e}"))
85 }
86
87 #[cfg(test)]
91 pub(crate) fn from_sender(sender: tokio::sync::mpsc::UnboundedSender<HandlerCommand>) -> Self {
92 Self {
93 cmd_tx: Arc::new(tokio::sync::RwLock::new(sender)),
94 }
95 }
96}
97
98#[derive(Debug)]
104pub struct PolymarketWebSocketClient {
105 channel: WsChannel,
106 url: String,
107 connection_mode: Arc<AtomicU8>,
108 signal: Arc<AtomicBool>,
109 cmd_tx: Arc<tokio::sync::RwLock<tokio::sync::mpsc::UnboundedSender<HandlerCommand>>>,
110 out_rx: Option<tokio::sync::mpsc::UnboundedReceiver<PolymarketWsMessage>>,
111 credential: Option<Credential>,
112 subscriptions: SubscriptionState,
113 auth_tracker: AuthTracker,
114 user_subscribed: Arc<AtomicBool>,
117 task_handle: Option<tokio::task::JoinHandle<()>>,
118 subscribe_new_markets: bool,
119 transport_backend: TransportBackend,
120}
121
122impl PolymarketWebSocketClient {
123 #[must_use]
127 pub fn new_market(
128 base_url: Option<String>,
129 subscribe_new_markets: bool,
130 transport_backend: TransportBackend,
131 ) -> Self {
132 let url = base_url.unwrap_or_else(|| clob_ws_market_url().to_string());
133 Self::new_inner(
134 WsChannel::Market,
135 url,
136 None,
137 subscribe_new_markets,
138 transport_backend,
139 )
140 }
141
142 #[must_use]
146 pub fn new_user(
147 base_url: Option<String>,
148 credential: Credential,
149 transport_backend: TransportBackend,
150 ) -> Self {
151 let url = base_url.unwrap_or_else(|| clob_ws_user_url().to_string());
152 Self::new_inner(
153 WsChannel::User,
154 url,
155 Some(credential),
156 false,
157 transport_backend,
158 )
159 }
160
161 fn new_inner(
162 channel: WsChannel,
163 url: String,
164 credential: Option<Credential>,
165 subscribe_new_markets: bool,
166 transport_backend: TransportBackend,
167 ) -> Self {
168 let (placeholder_tx, _) = tokio::sync::mpsc::unbounded_channel();
169 Self {
170 channel,
171 url,
172 connection_mode: Arc::new(AtomicU8::new(ConnectionMode::Closed.as_u8())),
173 signal: Arc::new(AtomicBool::new(false)),
174 cmd_tx: Arc::new(tokio::sync::RwLock::new(placeholder_tx)),
175 out_rx: None,
176 credential,
177 subscriptions: SubscriptionState::new(':'),
178 auth_tracker: AuthTracker::new(),
179 user_subscribed: Arc::new(AtomicBool::new(false)),
180 task_handle: None,
181 subscribe_new_markets,
182 transport_backend,
183 }
184 }
185
186 pub async fn connect(&mut self) -> anyhow::Result<()> {
191 let mode = ConnectionMode::from_atomic(&self.connection_mode);
192 if mode.is_active() || mode.is_reconnect() {
193 log::warn!("Polymarket WebSocket already connected or reconnecting");
194 return Ok(());
195 }
196
197 let (message_handler, raw_rx) = channel_message_handler();
198 let cfg = WebSocketConfig {
199 url: self.url.clone(),
200 headers: vec![],
201 heartbeat: Some(POLYMARKET_HEARTBEAT_SECS),
202 heartbeat_msg: None,
203 reconnect_timeout_ms: Some(15_000),
204 reconnect_delay_initial_ms: Some(250),
205 reconnect_delay_max_ms: Some(5_000),
206 reconnect_backoff_factor: Some(2.0),
207 reconnect_jitter_ms: Some(200),
208 reconnect_max_attempts: None,
209 idle_timeout_ms: Some(idle_timeout_ms_for(self.channel)),
210 backend: self.transport_backend,
211 proxy_url: None,
212 };
213
214 let client =
215 WebSocketClient::connect(cfg, Some(message_handler), None, None, vec![], None).await?;
216
217 let (cmd_tx, cmd_rx) = tokio::sync::mpsc::unbounded_channel::<HandlerCommand>();
218 let (out_tx, out_rx) = tokio::sync::mpsc::unbounded_channel::<PolymarketWsMessage>();
219
220 *self.cmd_tx.write().await = cmd_tx.clone();
221 self.out_rx = Some(out_rx);
222
223 let client_mode = client.connection_mode_atomic();
224 self.connection_mode = client_mode;
225
226 log::info!("Polymarket WebSocket connected: {}", self.url);
227
228 cmd_tx
229 .send(HandlerCommand::SetClient(client))
230 .map_err(|e| anyhow::anyhow!("Failed to send SetClient: {e}"))?;
231
232 match self.channel {
236 WsChannel::Market => {
237 let topics = self.subscriptions.all_topics();
238 if !topics.is_empty() {
239 log::info!(
240 "Replaying {} market subscription(s) onto new session",
241 topics.len()
242 );
243 cmd_tx
244 .send(HandlerCommand::SubscribeMarket(topics))
245 .map_err(|e| anyhow::anyhow!("Failed to replay SubscribeMarket: {e}"))?;
246 }
247 }
248 WsChannel::User => {
249 if self.user_subscribed.load(Ordering::Relaxed) {
250 log::info!("Replaying user subscribe onto new session");
251 cmd_tx
252 .send(HandlerCommand::SubscribeUser)
253 .map_err(|e| anyhow::anyhow!("Failed to replay SubscribeUser: {e}"))?;
254 }
255 }
256 }
257
258 let signal = Arc::clone(&self.signal);
259 let channel = self.channel;
260 let credential = self.credential.clone();
261 let subscriptions = self.subscriptions.clone();
262 let auth_tracker = self.auth_tracker.clone();
263 let user_subscribed = self.user_subscribed.load(Ordering::Relaxed);
264 let subscribe_new_markets = self.subscribe_new_markets;
265
266 let stream_handle = get_runtime().spawn(async move {
267 let mut handler = FeedHandler::new(
268 signal,
269 channel,
270 cmd_rx,
271 raw_rx,
272 out_tx,
273 credential,
274 subscriptions,
275 auth_tracker,
276 user_subscribed,
277 subscribe_new_markets,
278 );
279
280 loop {
281 match handler.next().await {
282 Some(PolymarketWsMessage::Reconnected) => {
283 log::info!("Polymarket WebSocket reconnected");
284 }
285 Some(msg) => {
286 if handler.send(msg).is_err() {
287 log::error!("Output channel closed, stopping handler");
288 break;
289 }
290 }
291 None => {
292 if handler.is_stopped() {
293 log::debug!("Stop signal received, ending handler task");
294 } else {
295 log::warn!("Polymarket WebSocket stream ended unexpectedly");
296 }
297 break;
298 }
299 }
300 }
301 log::debug!("Polymarket WebSocket handler task completed");
302 });
303 self.task_handle = Some(stream_handle);
304 Ok(())
305 }
306
307 pub(crate) fn abort(&mut self) {
310 self.signal.store(true, Ordering::Relaxed);
311 self.connection_mode
312 .store(ConnectionMode::Closed.as_u8(), Ordering::SeqCst);
313
314 if let Some(handle) = self.task_handle.take() {
315 handle.abort();
316 }
317 self.auth_tracker.invalidate();
318 }
319
320 pub async fn disconnect(&mut self) -> anyhow::Result<()> {
322 log::info!("Disconnecting Polymarket WebSocket");
323 self.signal.store(true, Ordering::Relaxed);
324
325 if let Err(e) = self.cmd_tx.read().await.send(HandlerCommand::Disconnect) {
326 log::debug!("Failed to send disconnect (handler may already be shut down): {e}");
327 }
328
329 if let Some(handle) = self.task_handle.take() {
330 let abort_handle = handle.abort_handle();
331 tokio::select! {
332 result = handle => {
333 match result {
334 Ok(()) => log::debug!("Handler task completed"),
335 Err(e) if e.is_cancelled() => log::debug!("Handler task was cancelled"),
336 Err(e) => log::error!("Handler task error: {e:?}"),
337 }
338 }
339 () = tokio::time::sleep(tokio::time::Duration::from_secs(2)) => {
340 log::warn!("Timeout waiting for handler task, aborting");
341 abort_handle.abort();
342 }
343 }
344 }
345 self.auth_tracker.invalidate();
348 log::debug!("Polymarket WebSocket disconnected");
349 Ok(())
350 }
351
352 #[must_use]
354 pub fn is_active(&self) -> bool {
355 ConnectionMode::from_atomic(&self.connection_mode).is_active()
356 }
357
358 #[must_use]
360 pub fn url(&self) -> &str {
361 &self.url
362 }
363
364 #[must_use]
366 pub fn subscription_count(&self) -> usize {
367 self.subscriptions.all_topics().len()
368 }
369
370 #[must_use]
372 pub fn is_authenticated(&self) -> bool {
373 self.auth_tracker.is_authenticated()
374 }
375
376 pub async fn subscribe_market(&self, asset_ids: Vec<String>) -> anyhow::Result<()> {
385 if self.channel != WsChannel::Market {
386 anyhow::bail!(
387 "subscribe_market() requires a market-channel client (created with new_market())"
388 );
389 }
390 self.cmd_tx
391 .read()
392 .await
393 .send(HandlerCommand::SubscribeMarket(asset_ids))
394 .map_err(|e| anyhow::anyhow!("Failed to send SubscribeMarket: {e}"))
395 }
396
397 pub async fn unsubscribe_market(&self, asset_ids: Vec<String>) -> anyhow::Result<()> {
406 if self.channel != WsChannel::Market {
407 anyhow::bail!(
408 "unsubscribe_market() requires a market-channel client (created with new_market())"
409 );
410 }
411 self.cmd_tx
412 .read()
413 .await
414 .send(HandlerCommand::UnsubscribeMarket(asset_ids))
415 .map_err(|e| anyhow::anyhow!("Failed to send UnsubscribeMarket: {e}"))
416 }
417
418 pub async fn subscribe_user(&self) -> anyhow::Result<()> {
424 if self.channel != WsChannel::User {
425 anyhow::bail!(
426 "subscribe_user() requires a user-channel client (created with new_user())"
427 );
428 }
429 self.cmd_tx
430 .read()
431 .await
432 .send(HandlerCommand::SubscribeUser)
433 .map_err(|e| anyhow::anyhow!("Failed to send SubscribeUser: {e}"))?;
434 self.user_subscribed.store(true, Ordering::Relaxed);
437 Ok(())
438 }
439
440 #[must_use]
442 pub fn clone_subscription_handle(&self) -> WsSubscriptionHandle {
443 WsSubscriptionHandle {
444 cmd_tx: Arc::clone(&self.cmd_tx),
445 }
446 }
447
448 #[must_use]
454 pub fn take_message_receiver(
455 &mut self,
456 ) -> Option<tokio::sync::mpsc::UnboundedReceiver<PolymarketWsMessage>> {
457 self.out_rx.take()
458 }
459
460 pub async fn next_message(&mut self) -> Option<PolymarketWsMessage> {
465 if let Some(ref mut rx) = self.out_rx {
466 rx.recv().await
467 } else {
468 None
469 }
470 }
471}
472
473#[cfg(test)]
474mod tests {
475 use rstest::rstest;
476
477 use super::{WsChannel, idle_timeout_ms_for};
478
479 #[rstest]
480 #[case::market(WsChannel::Market, 60_000)]
481 #[case::user(WsChannel::User, 300_000)]
482 fn test_idle_timeout_ms_for_channel(#[case] channel: WsChannel, #[case] expected: u64) {
483 assert_eq!(idle_timeout_ms_for(channel), expected);
484 }
485}