1use std::{
87 any::{Any, TypeId},
88 cell::RefCell,
89 collections::HashMap,
90 hash::{Hash, Hasher},
91 rc::Rc,
92};
93
94use ahash::{AHashMap, AHashSet};
95use indexmap::IndexMap;
96use nautilus_core::{UUID4, correctness::FAILED};
97use nautilus_model::{
98 data::{
99 Bar, Data, FundingRateUpdate, GreeksData, IndexPriceUpdate, MarkPriceUpdate,
100 OrderBookDeltas, OrderBookDepth10, QuoteTick, TradeTick,
101 option_chain::{OptionChainSlice, OptionGreeks},
102 },
103 events::{AccountState, OrderEventAny, PositionEvent},
104 identifiers::TraderId,
105 orderbook::OrderBook,
106 orders::OrderAny,
107 position::Position,
108};
109use smallvec::SmallVec;
110use ustr::Ustr;
111
112use super::{
113 ShareableMessageHandler,
114 matching::is_matching_backtracking,
115 mstr::{Endpoint, MStr, Pattern, Topic},
116 set_message_bus,
117 switchboard::MessagingSwitchboard,
118 typed_endpoints::{EndpointMap, IntoEndpointMap},
119 typed_router::TopicRouter,
120};
121use crate::messages::{
122 data::{DataCommand, DataResponse},
123 execution::{ExecutionReport, TradingCommand},
124};
125
126#[derive(Clone, Debug)]
132pub struct Subscription {
133 pub handler: ShareableMessageHandler,
135 pub handler_id: Ustr,
137 pub pattern: MStr<Pattern>,
139 pub priority: u8,
143}
144
145impl Subscription {
146 #[must_use]
148 pub fn new(
149 pattern: MStr<Pattern>,
150 handler: ShareableMessageHandler,
151 priority: Option<u8>,
152 ) -> Self {
153 Self {
154 handler_id: handler.0.id(),
155 pattern,
156 handler,
157 priority: priority.unwrap_or(0),
158 }
159 }
160}
161
162impl PartialEq<Self> for Subscription {
163 fn eq(&self, other: &Self) -> bool {
164 self.pattern == other.pattern && self.handler_id == other.handler_id
165 }
166}
167
168impl Eq for Subscription {}
169
170impl PartialOrd for Subscription {
171 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
172 Some(self.cmp(other))
173 }
174}
175
176impl Ord for Subscription {
177 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
178 other
179 .priority
180 .cmp(&self.priority)
181 .then_with(|| self.pattern.cmp(&other.pattern))
182 .then_with(|| self.handler_id.cmp(&other.handler_id))
183 }
184}
185
186impl Hash for Subscription {
187 fn hash<H: Hasher>(&self, state: &mut H) {
188 self.pattern.hash(state);
189 self.handler_id.hash(state);
190 }
191}
192
193#[derive(Debug)]
214pub struct MessageBus {
215 pub trader_id: TraderId,
217 pub instance_id: UUID4,
219 pub name: String,
221 pub has_backing: bool,
223 pub(crate) switchboard: MessagingSwitchboard,
224 pub(crate) subscriptions: AHashSet<Subscription>,
225 pub(crate) topics: IndexMap<MStr<Topic>, Vec<Subscription>>,
226 pub(crate) endpoints: IndexMap<MStr<Endpoint>, ShareableMessageHandler>,
227 pub(crate) correlation_index: AHashMap<UUID4, ShareableMessageHandler>,
228 pub(crate) router_quotes: TopicRouter<QuoteTick>,
229 pub(crate) router_trades: TopicRouter<TradeTick>,
230 pub(crate) router_bars: TopicRouter<Bar>,
231 pub(crate) router_deltas: TopicRouter<OrderBookDeltas>,
232 pub(crate) router_depth10: TopicRouter<OrderBookDepth10>,
233 pub(crate) router_book_snapshots: TopicRouter<OrderBook>,
234 pub(crate) router_mark_prices: TopicRouter<MarkPriceUpdate>,
235 pub(crate) router_index_prices: TopicRouter<IndexPriceUpdate>,
236 pub(crate) router_funding_rates: TopicRouter<FundingRateUpdate>,
237 pub(crate) router_order_events: TopicRouter<OrderEventAny>,
238 pub(crate) router_position_events: TopicRouter<PositionEvent>,
239 pub(crate) router_account_state: TopicRouter<AccountState>,
240 pub(crate) router_orders: TopicRouter<OrderAny>,
241 pub(crate) router_positions: TopicRouter<Position>,
242 pub(crate) router_greeks: TopicRouter<GreeksData>,
243 pub(crate) router_option_greeks: TopicRouter<OptionGreeks>,
244 pub(crate) router_option_chain: TopicRouter<OptionChainSlice>,
245 #[cfg(feature = "defi")]
246 pub(crate) router_defi_blocks: TopicRouter<nautilus_model::defi::Block>, #[cfg(feature = "defi")]
248 pub(crate) router_defi_pools: TopicRouter<nautilus_model::defi::Pool>, #[cfg(feature = "defi")]
250 pub(crate) router_defi_swaps: TopicRouter<nautilus_model::defi::PoolSwap>, #[cfg(feature = "defi")]
252 pub(crate) router_defi_liquidity: TopicRouter<nautilus_model::defi::PoolLiquidityUpdate>, #[cfg(feature = "defi")]
254 pub(crate) router_defi_collects: TopicRouter<nautilus_model::defi::PoolFeeCollect>, #[cfg(feature = "defi")]
256 pub(crate) router_defi_flash: TopicRouter<nautilus_model::defi::PoolFlash>, #[cfg(feature = "defi")]
258 pub(crate) endpoints_defi_data: IntoEndpointMap<nautilus_model::defi::DefiData>, pub(crate) endpoints_quotes: EndpointMap<QuoteTick>,
260 pub(crate) endpoints_trades: EndpointMap<TradeTick>,
261 pub(crate) endpoints_bars: EndpointMap<Bar>,
262 pub(crate) endpoints_account_state: EndpointMap<AccountState>,
263 pub(crate) endpoints_trading_commands: IntoEndpointMap<TradingCommand>,
264 pub(crate) endpoints_data_commands: IntoEndpointMap<DataCommand>,
265 pub(crate) endpoints_data_responses: IntoEndpointMap<DataResponse>,
266 pub(crate) endpoints_exec_reports: IntoEndpointMap<ExecutionReport>,
267 pub(crate) endpoints_order_events: IntoEndpointMap<OrderEventAny>,
268 pub(crate) endpoints_data: IntoEndpointMap<Data>,
269 routers_typed: AHashMap<TypeId, Box<dyn Any>>,
270 endpoints_typed: AHashMap<TypeId, Box<dyn Any>>,
271}
272
273impl Default for MessageBus {
274 fn default() -> Self {
276 Self::new(TraderId::from("TRADER-001"), UUID4::new(), None, None)
277 }
278}
279
280impl MessageBus {
281 #[must_use]
283 pub fn new(
284 trader_id: TraderId,
285 instance_id: UUID4,
286 name: Option<String>,
287 _config: Option<HashMap<String, serde_json::Value>>,
288 ) -> Self {
289 Self {
290 trader_id,
291 instance_id,
292 name: name.unwrap_or(stringify!(MessageBus).to_owned()),
293 switchboard: MessagingSwitchboard::default(),
294 subscriptions: AHashSet::new(),
295 topics: IndexMap::new(),
296 endpoints: IndexMap::new(),
297 correlation_index: AHashMap::new(),
298 has_backing: false,
299 router_quotes: TopicRouter::new(),
300 router_trades: TopicRouter::new(),
301 router_bars: TopicRouter::new(),
302 router_deltas: TopicRouter::new(),
303 router_depth10: TopicRouter::new(),
304 router_book_snapshots: TopicRouter::new(),
305 router_mark_prices: TopicRouter::new(),
306 router_index_prices: TopicRouter::new(),
307 router_funding_rates: TopicRouter::new(),
308 router_order_events: TopicRouter::new(),
309 router_position_events: TopicRouter::new(),
310 router_account_state: TopicRouter::new(),
311 router_orders: TopicRouter::new(),
312 router_positions: TopicRouter::new(),
313 router_greeks: TopicRouter::new(),
314 router_option_greeks: TopicRouter::new(),
315 router_option_chain: TopicRouter::new(),
316 #[cfg(feature = "defi")]
317 router_defi_blocks: TopicRouter::new(),
318 #[cfg(feature = "defi")]
319 router_defi_pools: TopicRouter::new(),
320 #[cfg(feature = "defi")]
321 router_defi_swaps: TopicRouter::new(),
322 #[cfg(feature = "defi")]
323 router_defi_liquidity: TopicRouter::new(),
324 #[cfg(feature = "defi")]
325 router_defi_collects: TopicRouter::new(),
326 #[cfg(feature = "defi")]
327 router_defi_flash: TopicRouter::new(),
328 #[cfg(feature = "defi")]
329 endpoints_defi_data: IntoEndpointMap::new(),
330 endpoints_quotes: EndpointMap::new(),
331 endpoints_trades: EndpointMap::new(),
332 endpoints_bars: EndpointMap::new(),
333 endpoints_account_state: EndpointMap::new(),
334 endpoints_trading_commands: IntoEndpointMap::new(),
335 endpoints_data_commands: IntoEndpointMap::new(),
336 endpoints_data_responses: IntoEndpointMap::new(),
337 endpoints_exec_reports: IntoEndpointMap::new(),
338 endpoints_order_events: IntoEndpointMap::new(),
339 endpoints_data: IntoEndpointMap::new(),
340 routers_typed: AHashMap::new(),
341 endpoints_typed: AHashMap::new(),
342 }
343 }
344
345 pub fn register_message_bus(self) -> Rc<RefCell<Self>> {
347 let msgbus = Rc::new(RefCell::new(self));
348 set_message_bus(msgbus.clone());
349 msgbus
350 }
351
352 pub fn router<T: 'static>(&mut self) -> &mut TopicRouter<T> {
358 self.routers_typed
359 .entry(TypeId::of::<T>())
360 .or_insert_with(|| Box::new(TopicRouter::<T>::new()))
361 .downcast_mut::<TopicRouter<T>>()
362 .expect("TopicRouter type mismatch - this is a bug")
363 }
364
365 pub fn endpoint_map<T: 'static>(&mut self) -> &mut EndpointMap<T> {
371 self.endpoints_typed
372 .entry(TypeId::of::<T>())
373 .or_insert_with(|| Box::new(EndpointMap::<T>::new()))
374 .downcast_mut::<EndpointMap<T>>()
375 .expect("EndpointMap type mismatch - this is a bug")
376 }
377
378 pub fn dispose(&mut self) {
381 self.subscriptions.clear();
382 self.topics.clear();
383 self.endpoints.clear();
384 self.correlation_index.clear();
385
386 self.router_quotes.clear();
387 self.router_trades.clear();
388 self.router_bars.clear();
389 self.router_deltas.clear();
390 self.router_depth10.clear();
391 self.router_book_snapshots.clear();
392 self.router_mark_prices.clear();
393 self.router_index_prices.clear();
394 self.router_funding_rates.clear();
395 self.router_order_events.clear();
396 self.router_position_events.clear();
397 self.router_account_state.clear();
398 self.router_orders.clear();
399 self.router_positions.clear();
400 self.router_greeks.clear();
401 self.router_option_greeks.clear();
402 self.router_option_chain.clear();
403
404 #[cfg(feature = "defi")]
405 {
406 self.router_defi_blocks.clear();
407 self.router_defi_pools.clear();
408 self.router_defi_swaps.clear();
409 self.router_defi_liquidity.clear();
410 self.router_defi_collects.clear();
411 self.router_defi_flash.clear();
412 self.endpoints_defi_data.clear();
413 }
414
415 self.endpoints_quotes.clear();
416 self.endpoints_trades.clear();
417 self.endpoints_bars.clear();
418 self.endpoints_account_state.clear();
419 self.endpoints_trading_commands.clear();
420 self.endpoints_data_commands.clear();
421 self.endpoints_data_responses.clear();
422 self.endpoints_exec_reports.clear();
423 self.endpoints_order_events.clear();
424 self.endpoints_data.clear();
425
426 self.routers_typed.clear();
427 self.endpoints_typed.clear();
428 }
429
430 #[must_use]
432 pub fn mem_address(&self) -> String {
433 format!("{self:p}")
434 }
435
436 #[must_use]
438 pub fn switchboard(&self) -> &MessagingSwitchboard {
439 &self.switchboard
440 }
441
442 #[must_use]
444 pub fn endpoints(&self) -> Vec<&str> {
445 self.endpoints.iter().map(|e| e.0.as_str()).collect()
446 }
447
448 #[must_use]
450 pub fn patterns(&self) -> Vec<&str> {
451 self.subscriptions
452 .iter()
453 .map(|s| s.pattern.as_str())
454 .collect()
455 }
456
457 pub fn has_subscribers<T: AsRef<str>>(&self, topic: T) -> bool {
459 self.subscriptions_count(topic) > 0
460 }
461
462 #[must_use]
468 pub fn subscriptions_count<T: AsRef<str>>(&self, topic: T) -> usize {
469 let topic = MStr::<Topic>::topic(topic).expect(FAILED);
470 self.topics
471 .get(&topic)
472 .map_or_else(|| self.find_topic_matches(topic).len(), |subs| subs.len())
473 }
474
475 #[must_use]
477 pub fn subscriptions(&self) -> Vec<&Subscription> {
478 self.subscriptions.iter().collect()
479 }
480
481 #[must_use]
483 pub fn subscription_handler_ids(&self) -> Vec<&str> {
484 self.subscriptions
485 .iter()
486 .map(|s| s.handler_id.as_str())
487 .collect()
488 }
489
490 #[must_use]
496 pub fn is_registered<T: Into<MStr<Endpoint>>>(&self, endpoint: T) -> bool {
497 let endpoint: MStr<Endpoint> = endpoint.into();
498 self.endpoints.contains_key(&endpoint)
499 }
500
501 #[must_use]
503 pub fn is_subscribed<T: AsRef<str>>(
504 &self,
505 pattern: T,
506 handler: ShareableMessageHandler,
507 ) -> bool {
508 let pattern = MStr::<Pattern>::pattern(pattern);
509 let sub = Subscription::new(pattern, handler, None);
510 self.subscriptions.contains(&sub)
511 }
512
513 pub const fn close(&self) -> anyhow::Result<()> {
519 Ok(())
521 }
522
523 #[must_use]
525 pub fn get_endpoint(&self, endpoint: MStr<Endpoint>) -> Option<&ShareableMessageHandler> {
526 self.endpoints.get(&endpoint)
527 }
528
529 #[must_use]
531 pub fn get_response_handler(&self, correlation_id: &UUID4) -> Option<&ShareableMessageHandler> {
532 self.correlation_index.get(correlation_id)
533 }
534
535 pub(crate) fn find_topic_matches(&self, topic: MStr<Topic>) -> Vec<Subscription> {
537 self.subscriptions
538 .iter()
539 .filter_map(|sub| {
540 if is_matching_backtracking(topic, sub.pattern) {
541 Some(sub.clone())
542 } else {
543 None
544 }
545 })
546 .collect()
547 }
548
549 #[must_use]
552 pub fn matching_subscriptions<T: Into<MStr<Topic>>>(&mut self, topic: T) -> Vec<Subscription> {
553 self.inner_matching_subscriptions(topic.into())
554 }
555
556 pub(crate) fn inner_matching_subscriptions(&mut self, topic: MStr<Topic>) -> Vec<Subscription> {
557 self.topics.get(&topic).cloned().unwrap_or_else(|| {
558 let mut matches = self.find_topic_matches(topic);
559 matches.sort();
560 self.topics.insert(topic, matches.clone());
561 matches
562 })
563 }
564
565 pub(crate) fn fill_matching_any_handlers(
567 &mut self,
568 topic: MStr<Topic>,
569 buf: &mut SmallVec<[ShareableMessageHandler; 64]>,
570 ) {
571 if let Some(subs) = self.topics.get(&topic) {
572 for sub in subs {
573 buf.push(sub.handler.clone());
574 }
575 } else {
576 let mut matches = self.find_topic_matches(topic);
577 matches.sort();
578
579 for sub in &matches {
580 buf.push(sub.handler.clone());
581 }
582
583 self.topics.insert(topic, matches);
584 }
585 }
586
587 pub fn register_response_handler(
593 &mut self,
594 correlation_id: &UUID4,
595 handler: ShareableMessageHandler,
596 ) -> anyhow::Result<()> {
597 if self.correlation_index.contains_key(correlation_id) {
598 anyhow::bail!("Correlation ID <{correlation_id}> already has a registered handler");
599 }
600
601 self.correlation_index.insert(*correlation_id, handler);
602
603 Ok(())
604 }
605}
606
607#[cfg(test)]
608mod tests {
609 use rand::{RngExt, SeedableRng, rngs::StdRng};
610 use rstest::rstest;
611 use ustr::Ustr;
612
613 use super::*;
614 use crate::msgbus::{
615 self, ShareableMessageHandler, get_message_bus,
616 matching::is_matching_backtracking,
617 stubs::{get_call_check_handler, get_stub_shareable_handler},
618 subscriptions_count_any,
619 };
620
621 #[rstest]
622 fn test_new() {
623 let trader_id = TraderId::default();
624 let msgbus = MessageBus::new(trader_id, UUID4::new(), None, None);
625
626 assert_eq!(msgbus.trader_id, trader_id);
627 assert_eq!(msgbus.name, stringify!(MessageBus));
628 }
629
630 #[rstest]
631 fn test_endpoints_when_no_endpoints() {
632 let msgbus = get_message_bus();
633 assert!(msgbus.borrow().endpoints().is_empty());
634 }
635
636 #[rstest]
637 fn test_topics_when_no_subscriptions() {
638 let msgbus = get_message_bus();
639 assert!(msgbus.borrow().patterns().is_empty());
640 assert!(!msgbus.borrow().has_subscribers("my-topic"));
641 }
642
643 #[rstest]
644 fn test_is_subscribed_when_no_subscriptions() {
645 let msgbus = get_message_bus();
646 let handler = get_stub_shareable_handler(None);
647
648 assert!(!msgbus.borrow().is_subscribed("my-topic", handler));
649 }
650
651 #[rstest]
652 fn test_get_response_handler_when_no_handler() {
653 let msgbus = get_message_bus();
654 let msgbus_ref = msgbus.borrow();
655 let handler = msgbus_ref.get_response_handler(&UUID4::new());
656 assert!(handler.is_none());
657 }
658
659 #[rstest]
660 fn test_get_response_handler_when_already_registered() {
661 let msgbus = get_message_bus();
662 let mut msgbus_ref = msgbus.borrow_mut();
663 let handler = get_stub_shareable_handler(None);
664
665 let request_id = UUID4::new();
666 msgbus_ref
667 .register_response_handler(&request_id, handler.clone())
668 .unwrap();
669
670 let result = msgbus_ref.register_response_handler(&request_id, handler);
671 assert!(result.is_err());
672 }
673
674 #[rstest]
675 fn test_get_response_handler_when_registered() {
676 let msgbus = get_message_bus();
677 let mut msgbus_ref = msgbus.borrow_mut();
678 let handler = get_stub_shareable_handler(None);
679
680 let request_id = UUID4::new();
681 msgbus_ref
682 .register_response_handler(&request_id, handler)
683 .unwrap();
684
685 let handler = msgbus_ref.get_response_handler(&request_id).unwrap();
686 assert_eq!(handler.id(), handler.id());
687 }
688
689 #[rstest]
690 fn test_is_registered_when_no_registrations() {
691 let msgbus = get_message_bus();
692 assert!(!msgbus.borrow().is_registered("MyEndpoint"));
693 }
694
695 #[rstest]
696 fn test_register_endpoint() {
697 let msgbus = get_message_bus();
698 let endpoint = "MyEndpoint".into();
699 let handler = get_stub_shareable_handler(None);
700
701 msgbus::register_any(endpoint, handler);
702
703 assert_eq!(msgbus.borrow().endpoints(), vec![endpoint.to_string()]);
704 assert!(msgbus.borrow().get_endpoint(endpoint).is_some());
705 }
706
707 #[rstest]
708 fn test_endpoint_send() {
709 let msgbus = get_message_bus();
710 let endpoint = "MyEndpoint".into();
711 let (handler, checker) = get_call_check_handler(None);
712
713 msgbus::register_any(endpoint, handler);
714 assert!(msgbus.borrow().get_endpoint(endpoint).is_some());
715 assert!(!checker.was_called());
716
717 msgbus::send_any(endpoint, &"Test Message");
719 assert!(checker.was_called());
720 }
721
722 #[rstest]
723 fn test_deregsiter_endpoint() {
724 let msgbus = get_message_bus();
725 let endpoint = "MyEndpoint".into();
726 let handler = get_stub_shareable_handler(None);
727
728 msgbus::register_any(endpoint, handler);
729 msgbus::deregister_any(endpoint);
730
731 assert!(msgbus.borrow().endpoints().is_empty());
732 }
733
734 #[rstest]
735 fn test_subscribe() {
736 let msgbus = get_message_bus();
737 let topic = "my-topic";
738 let handler = get_stub_shareable_handler(None);
739
740 msgbus::subscribe_any(topic.into(), handler, Some(1));
741
742 assert!(msgbus.borrow().has_subscribers(topic));
743 assert_eq!(msgbus.borrow().patterns(), vec![topic]);
744 }
745
746 #[rstest]
747 fn test_unsubscribe() {
748 let msgbus = get_message_bus();
749 let topic = "my-topic";
750 let handler = get_stub_shareable_handler(None);
751
752 msgbus::subscribe_any(topic.into(), handler.clone(), None);
753 msgbus::unsubscribe_any(topic.into(), &handler);
754
755 assert!(!msgbus.borrow().has_subscribers(topic));
756 assert!(msgbus.borrow().patterns().is_empty());
757 }
758
759 #[rstest]
760 fn test_matching_subscriptions() {
761 let msgbus = get_message_bus();
762 let pattern = "my-pattern";
763
764 let handler_id1 = Ustr::from("1");
765 let handler1 = get_stub_shareable_handler(Some(handler_id1));
766
767 let handler_id2 = Ustr::from("2");
768 let handler2 = get_stub_shareable_handler(Some(handler_id2));
769
770 let handler_id3 = Ustr::from("3");
771 let handler3 = get_stub_shareable_handler(Some(handler_id3));
772
773 let handler_id4 = Ustr::from("4");
774 let handler4 = get_stub_shareable_handler(Some(handler_id4));
775
776 msgbus::subscribe_any(pattern.into(), handler1, None);
777 msgbus::subscribe_any(pattern.into(), handler2, None);
778 msgbus::subscribe_any(pattern.into(), handler3, Some(1));
779 msgbus::subscribe_any(pattern.into(), handler4, Some(2));
780
781 assert_eq!(
782 msgbus.borrow().patterns(),
783 vec![pattern, pattern, pattern, pattern]
784 );
785 assert_eq!(subscriptions_count_any(pattern), 4);
786
787 let topic = pattern;
788 let subs = msgbus.borrow_mut().matching_subscriptions(topic);
789 assert_eq!(subs.len(), 4);
790 assert_eq!(subs[0].handler_id, handler_id4);
791 assert_eq!(subs[1].handler_id, handler_id3);
792 assert_eq!(subs[2].handler_id, handler_id1);
793 assert_eq!(subs[3].handler_id, handler_id2);
794 }
795
796 #[rstest]
797 fn test_subscription_pattern_matching() {
798 let msgbus = get_message_bus();
799 let handler1 = get_stub_shareable_handler(Some(Ustr::from("1")));
800 let handler2 = get_stub_shareable_handler(Some(Ustr::from("2")));
801 let handler3 = get_stub_shareable_handler(Some(Ustr::from("3")));
802
803 msgbus::subscribe_any("data.quotes.*".into(), handler1, None);
804 msgbus::subscribe_any("data.trades.*".into(), handler2, None);
805 msgbus::subscribe_any("data.*.BINANCE.*".into(), handler3, None);
806 assert_eq!(msgbus.borrow().subscriptions().len(), 3);
807
808 let topic = "data.quotes.BINANCE.ETHUSDT";
809 assert_eq!(msgbus.borrow().find_topic_matches(topic.into()).len(), 2);
810
811 let matches = msgbus.borrow_mut().matching_subscriptions(topic);
812 assert_eq!(matches.len(), 2);
813 assert_eq!(matches[0].handler_id, Ustr::from("3"));
814 assert_eq!(matches[1].handler_id, Ustr::from("1"));
815 }
816
817 struct SimpleSubscriptionModel {
819 subscriptions: Vec<(String, String)>,
821 }
822
823 impl SimpleSubscriptionModel {
824 fn new() -> Self {
825 Self {
826 subscriptions: Vec::new(),
827 }
828 }
829
830 fn subscribe(&mut self, pattern: &str, handler_id: &str) {
831 let subscription = (pattern.to_string(), handler_id.to_string());
832 if !self.subscriptions.contains(&subscription) {
833 self.subscriptions.push(subscription);
834 }
835 }
836
837 fn unsubscribe(&mut self, pattern: &str, handler_id: &str) -> bool {
838 let subscription = (pattern.to_string(), handler_id.to_string());
839 if let Some(idx) = self.subscriptions.iter().position(|s| s == &subscription) {
840 self.subscriptions.remove(idx);
841 true
842 } else {
843 false
844 }
845 }
846
847 fn is_subscribed(&self, pattern: &str, handler_id: &str) -> bool {
848 self.subscriptions
849 .contains(&(pattern.to_string(), handler_id.to_string()))
850 }
851
852 fn matching_subscriptions(&self, topic: &str) -> Vec<(String, String)> {
853 let topic = topic.into();
854
855 self.subscriptions
856 .iter()
857 .filter(|(pat, _)| is_matching_backtracking(topic, pat.into()))
858 .map(|(pat, id)| (pat.clone(), id.clone()))
859 .collect()
860 }
861
862 fn subscription_count(&self) -> usize {
863 self.subscriptions.len()
864 }
865 }
866
867 #[rstest]
868 fn subscription_model_fuzz_testing() {
869 let mut rng = StdRng::seed_from_u64(42);
870
871 let msgbus = get_message_bus();
872 let mut model = SimpleSubscriptionModel::new();
873
874 let mut handlers: Vec<(String, ShareableMessageHandler)> = Vec::new();
876
877 let patterns = generate_test_patterns(&mut rng);
879
880 let handler_ids: Vec<String> = (0..50).map(|i| format!("handler_{i}")).collect();
882
883 for id in &handler_ids {
885 let handler = get_stub_shareable_handler(Some(Ustr::from(id)));
886 handlers.push((id.clone(), handler));
887 }
888
889 let num_operations = 50_000;
890 for op_num in 0..num_operations {
891 let operation = rng.random_range(0..4);
892
893 match operation {
894 0 => {
896 let pattern_idx = rng.random_range(0..patterns.len());
897 let handler_idx = rng.random_range(0..handlers.len());
898 let pattern = &patterns[pattern_idx];
899 let (handler_id, handler) = &handlers[handler_idx];
900
901 model.subscribe(pattern, handler_id);
903
904 msgbus::subscribe_any(pattern.as_str().into(), handler.clone(), None);
906
907 assert_eq!(
908 model.subscription_count(),
909 msgbus.borrow().subscriptions().len()
910 );
911
912 assert!(
913 msgbus.borrow().is_subscribed(pattern, handler.clone()),
914 "Op {op_num}: is_subscribed should return true after subscribe"
915 );
916 }
917
918 1 => {
920 if model.subscription_count() > 0 {
921 let sub_idx = rng.random_range(0..model.subscription_count());
922 let (pattern, handler_id) = model.subscriptions[sub_idx].clone();
923
924 model.unsubscribe(&pattern, &handler_id);
926
927 let handler = handlers
929 .iter()
930 .find(|(id, _)| id == &handler_id)
931 .map(|(_, h)| h.clone())
932 .unwrap();
933
934 msgbus::unsubscribe_any(pattern.as_str().into(), &handler);
936
937 assert_eq!(
938 model.subscription_count(),
939 msgbus.borrow().subscriptions().len()
940 );
941 assert!(
942 !msgbus.borrow().is_subscribed(pattern, handler.clone()),
943 "Op {op_num}: is_subscribed should return false after unsubscribe"
944 );
945 }
946 }
947
948 2 => {
950 let pattern_idx = rng.random_range(0..patterns.len());
952 let handler_idx = rng.random_range(0..handlers.len());
953 let pattern = &patterns[pattern_idx];
954 let (handler_id, handler) = &handlers[handler_idx];
955
956 let expected = model.is_subscribed(pattern, handler_id);
957 let actual = msgbus.borrow().is_subscribed(pattern, handler.clone());
958
959 assert_eq!(
960 expected, actual,
961 "Op {op_num}: Subscription state mismatch for pattern '{pattern}', handler '{handler_id}': expected={expected}, actual={actual}"
962 );
963 }
964
965 3 => {
967 let topic = create_topic(&mut rng);
969
970 let actual_matches = msgbus.borrow_mut().matching_subscriptions(topic);
971 let expected_matches = model.matching_subscriptions(&topic);
972
973 assert_eq!(
974 expected_matches.len(),
975 actual_matches.len(),
976 "Op {}: Match count mismatch for topic '{}': expected={}, actual={}",
977 op_num,
978 topic,
979 expected_matches.len(),
980 actual_matches.len()
981 );
982
983 for sub in &actual_matches {
984 assert!(
985 expected_matches
986 .contains(&(sub.pattern.to_string(), sub.handler_id.to_string())),
987 "Op {}: Expected match not found: pattern='{}', handler_id='{}'",
988 op_num,
989 sub.pattern,
990 sub.handler_id
991 );
992 }
993 }
994 _ => unreachable!(),
995 }
996 }
997 }
998
999 fn generate_pattern_from_topic(topic: &str, rng: &mut StdRng) -> String {
1000 let mut pattern = String::new();
1001
1002 for c in topic.chars() {
1003 let val: f64 = rng.random();
1004 if val < 0.1 {
1005 pattern.push('*');
1006 } else if val < 0.3 {
1007 pattern.push('?');
1008 } else if val >= 0.5 {
1009 pattern.push(c);
1010 }
1011 }
1012
1013 pattern
1014 }
1015
1016 fn generate_test_patterns(rng: &mut StdRng) -> Vec<String> {
1017 let mut patterns = vec![
1018 "data.*.*.*".to_string(),
1019 "*.*.BINANCE.*".to_string(),
1020 "events.order.*".to_string(),
1021 "data.*.*.?USDT".to_string(),
1022 "*.trades.*.BTC*".to_string(),
1023 "*.*.*.*".to_string(),
1024 ];
1025
1026 for _ in 0..50 {
1028 match rng.random_range(0..10) {
1029 0..=1 => {
1031 let idx = rng.random_range(0..patterns.len());
1032 patterns.push(patterns[idx].clone());
1033 }
1034 _ => {
1036 let topic = create_topic(rng);
1037 let pattern = generate_pattern_from_topic(&topic, rng);
1038 patterns.push(pattern);
1039 }
1040 }
1041 }
1042
1043 patterns
1044 }
1045
1046 fn create_topic(rng: &mut StdRng) -> Ustr {
1047 let cat = ["data", "info", "order"];
1048 let model = ["quotes", "trades", "orderbooks", "depths"];
1049 let venue = ["BINANCE", "BYBIT", "OKX", "FTX", "KRAKEN"];
1050 let instrument = ["BTCUSDT", "ETHUSDT", "SOLUSDT", "XRPUSDT", "DOGEUSDT"];
1051
1052 let cat = cat[rng.random_range(0..cat.len())];
1053 let model = model[rng.random_range(0..model.len())];
1054 let venue = venue[rng.random_range(0..venue.len())];
1055 let instrument = instrument[rng.random_range(0..instrument.len())];
1056 Ustr::from(&format!("{cat}.{model}.{venue}.{instrument}"))
1057 }
1058}