1use std::{
38 num::NonZeroUsize,
39 sync::{Arc, LazyLock},
40};
41
42use ahash::AHashSet;
43use dashmap::DashMap;
44use ustr::Ustr;
45
46pub(crate) static CHANNEL_LEVEL_MARKER: LazyLock<Ustr> = LazyLock::new(|| Ustr::from(""));
51
52#[derive(Clone, Debug)]
69pub struct SubscriptionState {
70 confirmed: Arc<DashMap<Ustr, AHashSet<Ustr>>>,
72 pending_subscribe: Arc<DashMap<Ustr, AHashSet<Ustr>>>,
74 pending_unsubscribe: Arc<DashMap<Ustr, AHashSet<Ustr>>>,
76 reference_counts: Arc<DashMap<Ustr, NonZeroUsize>>,
78 delimiter: char,
80}
81
82impl SubscriptionState {
83 #[must_use]
85 pub fn new(delimiter: char) -> Self {
86 Self {
87 confirmed: Arc::new(DashMap::new()),
88 pending_subscribe: Arc::new(DashMap::new()),
89 pending_unsubscribe: Arc::new(DashMap::new()),
90 reference_counts: Arc::new(DashMap::new()),
91 delimiter,
92 }
93 }
94
95 #[must_use]
97 pub fn delimiter(&self) -> char {
98 self.delimiter
99 }
100
101 #[must_use]
103 pub fn confirmed(&self) -> Arc<DashMap<Ustr, AHashSet<Ustr>>> {
104 Arc::clone(&self.confirmed)
105 }
106
107 #[must_use]
109 pub fn pending_subscribe(&self) -> Arc<DashMap<Ustr, AHashSet<Ustr>>> {
110 Arc::clone(&self.pending_subscribe)
111 }
112
113 #[must_use]
115 pub fn pending_unsubscribe(&self) -> Arc<DashMap<Ustr, AHashSet<Ustr>>> {
116 Arc::clone(&self.pending_unsubscribe)
117 }
118
119 #[must_use]
123 pub fn len(&self) -> usize {
124 self.confirmed.iter().map(|entry| entry.value().len()).sum()
125 }
126
127 #[must_use]
129 pub fn is_empty(&self) -> bool {
130 self.confirmed.is_empty()
131 && self.pending_subscribe.is_empty()
132 && self.pending_unsubscribe.is_empty()
133 }
134
135 #[must_use]
137 pub fn is_subscribed(&self, channel: &Ustr, symbol: &Ustr) -> bool {
138 if let Some(symbols) = self.confirmed.get(channel)
139 && symbols.contains(symbol)
140 {
141 return true;
142 }
143
144 if let Some(symbols) = self.pending_subscribe.get(channel)
145 && symbols.contains(symbol)
146 {
147 return true;
148 }
149 false
150 }
151
152 pub fn mark_subscribe(&self, topic: &str) {
158 let (channel, symbol) = split_topic(topic, self.delimiter);
159
160 if is_tracked(&self.confirmed, channel, symbol) {
162 return;
163 }
164
165 untrack_topic(&self.pending_unsubscribe, channel, symbol);
167
168 track_topic(&self.pending_subscribe, channel, symbol);
169 }
170
171 pub fn try_mark_subscribe(&self, topic: &str) -> bool {
178 let (channel, symbol) = split_topic(topic, self.delimiter);
179
180 if is_tracked(&self.confirmed, channel, symbol) {
182 return false;
183 }
184
185 let channel_ustr = Ustr::from(channel);
187 let symbol_ustr = symbol.map_or(*CHANNEL_LEVEL_MARKER, Ustr::from);
188
189 let mut entry = self.pending_subscribe.entry(channel_ustr).or_default();
190 let inserted = entry.insert(symbol_ustr);
191
192 if inserted {
194 untrack_topic(&self.pending_unsubscribe, channel, symbol);
195 }
196
197 inserted
198 }
199
200 pub fn mark_unsubscribe(&self, topic: &str) {
206 let (channel, symbol) = split_topic(topic, self.delimiter);
207 track_topic(&self.pending_unsubscribe, channel, symbol);
208 untrack_topic(&self.confirmed, channel, symbol);
209 untrack_topic(&self.pending_subscribe, channel, symbol);
210 }
211
212 pub fn confirm_subscribe(&self, topic: &str) {
218 let (channel, symbol) = split_topic(topic, self.delimiter);
219
220 if is_tracked(&self.pending_unsubscribe, channel, symbol) {
222 return;
223 }
224
225 untrack_topic(&self.pending_subscribe, channel, symbol);
226 track_topic(&self.confirmed, channel, symbol);
227 }
228
229 pub fn confirm_unsubscribe(&self, topic: &str) {
240 let (channel, symbol) = split_topic(topic, self.delimiter);
241
242 if !is_tracked(&self.pending_unsubscribe, channel, symbol) {
245 return; }
247
248 untrack_topic(&self.pending_unsubscribe, channel, symbol);
249 untrack_topic(&self.confirmed, channel, symbol);
250 }
252
253 pub fn mark_failure(&self, topic: &str) {
258 let (channel, symbol) = split_topic(topic, self.delimiter);
259
260 if is_tracked(&self.pending_unsubscribe, channel, symbol) {
262 return;
263 }
264
265 untrack_topic(&self.confirmed, channel, symbol);
266 track_topic(&self.pending_subscribe, channel, symbol);
267 }
268
269 #[must_use]
271 pub fn pending_subscribe_topics(&self) -> Vec<String> {
272 self.topics_from_map(&self.pending_subscribe)
273 }
274
275 #[must_use]
277 pub fn pending_unsubscribe_topics(&self) -> Vec<String> {
278 self.topics_from_map(&self.pending_unsubscribe)
279 }
280
281 #[must_use]
288 pub fn all_topics(&self) -> Vec<String> {
289 let mut topics = Vec::new();
290 topics.extend(self.topics_from_map(&self.confirmed));
291 topics.extend(self.topics_from_map(&self.pending_subscribe));
292 topics
293 }
294
295 fn topics_from_map(&self, map: &DashMap<Ustr, AHashSet<Ustr>>) -> Vec<String> {
297 let mut topics = Vec::new();
298 let marker = *CHANNEL_LEVEL_MARKER;
299
300 for entry in map {
301 let channel = entry.key();
302 let symbols = entry.value();
303
304 if symbols.contains(&marker) {
306 topics.push(channel.to_string());
307 }
308
309 for symbol in symbols {
311 if *symbol != marker {
312 topics.push(format!(
313 "{}{}{}",
314 channel.as_str(),
315 self.delimiter,
316 symbol.as_str()
317 ));
318 }
319 }
320 }
321
322 topics
323 }
324
325 #[allow(
334 clippy::must_use_candidate,
335 reason = "callers use this for side effects"
336 )]
337 pub fn add_reference(&self, topic: &str) -> bool {
338 let mut should_subscribe = false;
339 let topic_ustr = Ustr::from(topic);
340
341 self.reference_counts
342 .entry(topic_ustr)
343 .and_modify(|count| {
344 *count = NonZeroUsize::new(count.get() + 1).expect("reference count overflow");
345 })
346 .or_insert_with(|| {
347 should_subscribe = true;
348 NonZeroUsize::new(1).expect("NonZeroUsize::new(1) should never fail")
349 });
350
351 should_subscribe
352 }
353
354 #[allow(
364 clippy::must_use_candidate,
365 reason = "callers use this for side effects"
366 )]
367 pub fn remove_reference(&self, topic: &str) -> bool {
368 let topic_ustr = Ustr::from(topic);
369
370 if let dashmap::mapref::entry::Entry::Occupied(mut entry) =
373 self.reference_counts.entry(topic_ustr)
374 {
375 let current = entry.get().get();
376
377 if current == 1 {
378 entry.remove();
379 return true;
380 }
381
382 *entry.get_mut() = NonZeroUsize::new(current - 1)
383 .expect("reference count should never reach zero here");
384 }
385
386 false
387 }
388
389 #[must_use]
393 pub fn get_reference_count(&self, topic: &str) -> usize {
394 let topic_ustr = Ustr::from(topic);
395 self.reference_counts
396 .get(&topic_ustr)
397 .map_or(0, |count| count.get())
398 }
399
400 pub fn clear(&self) {
404 self.confirmed.clear();
405 self.pending_subscribe.clear();
406 self.pending_unsubscribe.clear();
407 self.reference_counts.clear();
408 }
409}
410
411#[must_use]
413pub fn split_topic(topic: &str, delimiter: char) -> (&str, Option<&str>) {
414 topic
415 .split_once(delimiter)
416 .map_or((topic, None), |(channel, symbol)| (channel, Some(symbol)))
417}
418
419fn track_topic(map: &DashMap<Ustr, AHashSet<Ustr>>, channel: &str, symbol: Option<&str>) {
424 let channel_ustr = Ustr::from(channel);
425 let mut entry = map.entry(channel_ustr).or_default();
426
427 if let Some(symbol) = symbol {
428 entry.insert(Ustr::from(symbol));
429 } else {
430 entry.insert(*CHANNEL_LEVEL_MARKER);
431 }
432}
433
434fn untrack_topic(map: &DashMap<Ustr, AHashSet<Ustr>>, channel: &str, symbol: Option<&str>) {
438 let channel_ustr = Ustr::from(channel);
439 let symbol_to_remove = if let Some(symbol) = symbol {
440 Ustr::from(symbol)
441 } else {
442 *CHANNEL_LEVEL_MARKER
443 };
444
445 if let dashmap::mapref::entry::Entry::Occupied(mut entry) = map.entry(channel_ustr) {
448 entry.get_mut().remove(&symbol_to_remove);
449 if entry.get().is_empty() {
450 entry.remove();
451 }
452 }
453}
454
455fn is_tracked(map: &DashMap<Ustr, AHashSet<Ustr>>, channel: &str, symbol: Option<&str>) -> bool {
457 let channel_ustr = Ustr::from(channel);
458 let symbol_to_check = if let Some(symbol) = symbol {
459 Ustr::from(symbol)
460 } else {
461 *CHANNEL_LEVEL_MARKER
462 };
463
464 if let Some(entry) = map.get(&channel_ustr) {
465 entry.contains(&symbol_to_check)
466 } else {
467 false
468 }
469}
470
471#[cfg(test)]
472mod tests {
473 use rstest::rstest;
474
475 use super::*;
476
477 #[rstest]
478 fn test_split_topic_with_symbol() {
479 let (channel, symbol) = split_topic("tickers.BTCUSDT", '.');
480 assert_eq!(channel, "tickers");
481 assert_eq!(symbol, Some("BTCUSDT"));
482
483 let (channel, symbol) = split_topic("orderBookL2:XBTUSD", ':');
484 assert_eq!(channel, "orderBookL2");
485 assert_eq!(symbol, Some("XBTUSD"));
486 }
487
488 #[rstest]
489 fn test_split_topic_without_symbol() {
490 let (channel, symbol) = split_topic("orderbook", '.');
491 assert_eq!(channel, "orderbook");
492 assert_eq!(symbol, None);
493 }
494
495 #[rstest]
496 fn test_new_state_is_empty() {
497 let state = SubscriptionState::new('.');
498 assert!(state.is_empty());
499 assert_eq!(state.len(), 0);
500 }
501
502 #[rstest]
503 fn test_mark_subscribe() {
504 let state = SubscriptionState::new('.');
505 state.mark_subscribe("tickers.BTCUSDT");
506
507 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
508 assert_eq!(state.len(), 0); }
510
511 #[rstest]
512 fn test_confirm_subscribe() {
513 let state = SubscriptionState::new('.');
514 state.mark_subscribe("tickers.BTCUSDT");
515 state.confirm_subscribe("tickers.BTCUSDT");
516
517 assert!(state.pending_subscribe_topics().is_empty());
518 assert_eq!(state.len(), 1);
519 }
520
521 #[rstest]
522 fn test_is_subscribed_empty_state() {
523 let state = SubscriptionState::new('.');
524 let channel = Ustr::from("tickers");
525 let symbol = Ustr::from("BTCUSDT");
526
527 assert!(!state.is_subscribed(&channel, &symbol));
528 }
529
530 #[rstest]
531 fn test_is_subscribed_pending() {
532 let state = SubscriptionState::new('.');
533 let channel = Ustr::from("tickers");
534 let symbol = Ustr::from("BTCUSDT");
535
536 state.mark_subscribe("tickers.BTCUSDT");
537
538 assert!(state.is_subscribed(&channel, &symbol));
539 }
540
541 #[rstest]
542 fn test_is_subscribed_confirmed() {
543 let state = SubscriptionState::new('.');
544 let channel = Ustr::from("tickers");
545 let symbol = Ustr::from("BTCUSDT");
546
547 state.mark_subscribe("tickers.BTCUSDT");
548 state.confirm_subscribe("tickers.BTCUSDT");
549
550 assert!(state.is_subscribed(&channel, &symbol));
551 }
552
553 #[rstest]
554 fn test_is_subscribed_after_unsubscribe() {
555 let state = SubscriptionState::new('.');
556 let channel = Ustr::from("tickers");
557 let symbol = Ustr::from("BTCUSDT");
558
559 state.mark_subscribe("tickers.BTCUSDT");
560 state.confirm_subscribe("tickers.BTCUSDT");
561 state.mark_unsubscribe("tickers.BTCUSDT");
562
563 assert!(!state.is_subscribed(&channel, &symbol));
565 }
566
567 #[rstest]
568 fn test_is_subscribed_after_confirm_unsubscribe() {
569 let state = SubscriptionState::new('.');
570 let channel = Ustr::from("tickers");
571 let symbol = Ustr::from("BTCUSDT");
572
573 state.mark_subscribe("tickers.BTCUSDT");
574 state.confirm_subscribe("tickers.BTCUSDT");
575 state.mark_unsubscribe("tickers.BTCUSDT");
576 state.confirm_unsubscribe("tickers.BTCUSDT");
577
578 assert!(!state.is_subscribed(&channel, &symbol));
579 }
580
581 #[rstest]
582 fn test_mark_unsubscribe() {
583 let state = SubscriptionState::new('.');
584 state.mark_subscribe("tickers.BTCUSDT");
585 state.confirm_subscribe("tickers.BTCUSDT");
586 state.mark_unsubscribe("tickers.BTCUSDT");
587
588 assert_eq!(state.len(), 0); assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
590 }
591
592 #[rstest]
593 fn test_confirm_unsubscribe() {
594 let state = SubscriptionState::new('.');
595 state.mark_subscribe("tickers.BTCUSDT");
596 state.confirm_subscribe("tickers.BTCUSDT");
597 state.mark_unsubscribe("tickers.BTCUSDT");
598 state.confirm_unsubscribe("tickers.BTCUSDT");
599
600 assert!(state.is_empty());
601 }
602
603 #[rstest]
604 fn test_resubscribe_before_unsubscribe_ack() {
605 let state = SubscriptionState::new('.');
609
610 state.mark_subscribe("tickers.BTCUSDT");
611 state.confirm_subscribe("tickers.BTCUSDT");
612 assert_eq!(state.len(), 1);
613
614 state.mark_unsubscribe("tickers.BTCUSDT");
615 assert_eq!(state.len(), 0);
616 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
617
618 state.mark_subscribe("tickers.BTCUSDT");
620 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
621
622 state.confirm_unsubscribe("tickers.BTCUSDT");
624 assert!(state.pending_unsubscribe_topics().is_empty());
625 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]); state.confirm_subscribe("tickers.BTCUSDT");
629 assert_eq!(state.len(), 1);
630 assert!(state.pending_subscribe_topics().is_empty());
631
632 let all = state.all_topics();
634 assert_eq!(all.len(), 1);
635 assert!(all.contains(&"tickers.BTCUSDT".to_string()));
636 }
637
638 #[rstest]
639 fn test_stale_unsubscribe_ack_after_resubscribe_confirmed() {
640 let state = SubscriptionState::new('.');
645
646 state.mark_subscribe("tickers.BTCUSDT");
648 state.confirm_subscribe("tickers.BTCUSDT");
649 assert_eq!(state.len(), 1);
650
651 state.mark_unsubscribe("tickers.BTCUSDT");
653 assert_eq!(state.len(), 0);
654 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
655
656 state.mark_subscribe("tickers.BTCUSDT");
658 assert!(state.pending_unsubscribe_topics().is_empty()); assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
660
661 state.confirm_subscribe("tickers.BTCUSDT");
663 assert_eq!(state.len(), 1); assert!(state.pending_subscribe_topics().is_empty());
665
666 state.confirm_unsubscribe("tickers.BTCUSDT");
669
670 assert_eq!(state.len(), 1); assert!(state.pending_unsubscribe_topics().is_empty());
673 assert!(state.pending_subscribe_topics().is_empty());
674
675 let all = state.all_topics();
677 assert_eq!(all.len(), 1);
678 assert!(all.contains(&"tickers.BTCUSDT".to_string()));
679 }
680
681 #[rstest]
682 fn test_mark_failure() {
683 let state = SubscriptionState::new('.');
684 state.mark_subscribe("tickers.BTCUSDT");
685 state.confirm_subscribe("tickers.BTCUSDT");
686 state.mark_failure("tickers.BTCUSDT");
687
688 assert_eq!(state.len(), 0);
689 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
690 }
691
692 #[rstest]
693 fn test_all_topics_includes_confirmed_and_pending_subscribe() {
694 let state = SubscriptionState::new('.');
695 state.mark_subscribe("tickers.BTCUSDT");
696 state.confirm_subscribe("tickers.BTCUSDT");
697 state.mark_subscribe("tickers.ETHUSDT");
698
699 let topics = state.all_topics();
700 assert_eq!(topics.len(), 2);
701 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
702 assert!(topics.contains(&"tickers.ETHUSDT".to_string()));
703 }
704
705 #[rstest]
706 fn test_all_topics_excludes_pending_unsubscribe() {
707 let state = SubscriptionState::new('.');
708 state.mark_subscribe("tickers.BTCUSDT");
709 state.confirm_subscribe("tickers.BTCUSDT");
710 state.mark_unsubscribe("tickers.BTCUSDT");
711
712 let topics = state.all_topics();
713 assert!(topics.is_empty());
714 }
715
716 #[rstest]
717 fn test_reference_counting_single_topic() {
718 let state = SubscriptionState::new('.');
719
720 assert!(state.add_reference("tickers.BTCUSDT"));
721 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 1);
722
723 assert!(!state.add_reference("tickers.BTCUSDT"));
724 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 2);
725
726 assert!(!state.remove_reference("tickers.BTCUSDT"));
727 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 1);
728
729 assert!(state.remove_reference("tickers.BTCUSDT"));
730 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 0);
731 }
732
733 #[rstest]
734 fn test_reference_counting_multiple_topics() {
735 let state = SubscriptionState::new('.');
736
737 assert!(state.add_reference("tickers.BTCUSDT"));
738 assert!(state.add_reference("tickers.ETHUSDT"));
739
740 assert!(!state.add_reference("tickers.BTCUSDT"));
741 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 2);
742 assert_eq!(state.get_reference_count("tickers.ETHUSDT"), 1);
743
744 assert!(!state.remove_reference("tickers.BTCUSDT"));
745 assert!(state.remove_reference("tickers.ETHUSDT"));
746 }
747
748 #[rstest]
749 fn test_topic_without_symbol() {
750 let state = SubscriptionState::new('.');
751 state.mark_subscribe("orderbook");
752 state.confirm_subscribe("orderbook");
753
754 assert_eq!(state.len(), 1);
755 assert_eq!(state.all_topics(), vec!["orderbook"]);
756 }
757
758 #[rstest]
759 fn test_different_delimiters() {
760 let state_dot = SubscriptionState::new('.');
761 state_dot.mark_subscribe("tickers.BTCUSDT");
762 assert_eq!(
763 state_dot.pending_subscribe_topics(),
764 vec!["tickers.BTCUSDT"]
765 );
766
767 let state_colon = SubscriptionState::new(':');
768 state_colon.mark_subscribe("orderBookL2:XBTUSD");
769 assert_eq!(
770 state_colon.pending_subscribe_topics(),
771 vec!["orderBookL2:XBTUSD"]
772 );
773 }
774
775 #[rstest]
776 fn test_clear() {
777 let state = SubscriptionState::new('.');
778 state.mark_subscribe("tickers.BTCUSDT");
779 state.confirm_subscribe("tickers.BTCUSDT");
780 state.add_reference("tickers.BTCUSDT");
781
782 state.clear();
783
784 assert!(state.is_empty());
785 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 0);
786 }
787
788 #[rstest]
789 fn test_multiple_symbols_same_channel() {
790 let state = SubscriptionState::new('.');
791 state.mark_subscribe("tickers.BTCUSDT");
792 state.mark_subscribe("tickers.ETHUSDT");
793 state.confirm_subscribe("tickers.BTCUSDT");
794 state.confirm_subscribe("tickers.ETHUSDT");
795
796 assert_eq!(state.len(), 2);
797 let topics = state.all_topics();
798 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
799 assert!(topics.contains(&"tickers.ETHUSDT".to_string()));
800 }
801
802 #[rstest]
803 fn test_mixed_channel_and_symbol_subscriptions() {
804 let state = SubscriptionState::new('.');
805
806 state.mark_subscribe("tickers");
808 state.confirm_subscribe("tickers");
809 assert_eq!(state.len(), 1);
810 assert_eq!(state.all_topics(), vec!["tickers"]);
811
812 state.mark_subscribe("tickers.BTCUSDT");
814 state.confirm_subscribe("tickers.BTCUSDT");
815 assert_eq!(state.len(), 2);
816
817 let topics = state.all_topics();
819 assert_eq!(topics.len(), 2);
820 assert!(topics.contains(&"tickers".to_string()));
821 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
822
823 state.mark_subscribe("tickers.ETHUSDT");
825 state.confirm_subscribe("tickers.ETHUSDT");
826 assert_eq!(state.len(), 3);
827
828 let topics = state.all_topics();
829 assert_eq!(topics.len(), 3);
830 assert!(topics.contains(&"tickers".to_string()));
831 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
832 assert!(topics.contains(&"tickers.ETHUSDT".to_string()));
833
834 state.mark_unsubscribe("tickers");
836 state.confirm_unsubscribe("tickers");
837 assert_eq!(state.len(), 2);
838
839 let topics = state.all_topics();
840 assert_eq!(topics.len(), 2);
841 assert!(!topics.contains(&"tickers".to_string()));
842 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
843 assert!(topics.contains(&"tickers.ETHUSDT".to_string()));
844 }
845
846 #[rstest]
847 fn test_symbol_subscription_before_channel() {
848 let state = SubscriptionState::new('.');
849
850 state.mark_subscribe("tickers.BTCUSDT");
852 state.confirm_subscribe("tickers.BTCUSDT");
853 assert_eq!(state.len(), 1);
854
855 state.mark_subscribe("tickers");
857 state.confirm_subscribe("tickers");
858 assert_eq!(state.len(), 2);
859
860 let topics = state.all_topics();
862 assert_eq!(topics.len(), 2);
863 assert!(topics.contains(&"tickers".to_string()));
864 assert!(topics.contains(&"tickers.BTCUSDT".to_string()));
865 }
866
867 #[rstest]
868 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
869 async fn test_concurrent_subscribe_same_topic() {
870 let state = Arc::new(SubscriptionState::new('.'));
871 let mut handles = vec![];
872
873 for _ in 0..10 {
875 let state_clone = Arc::clone(&state);
876 let handle = tokio::spawn(async move {
877 state_clone.add_reference("tickers.BTCUSDT");
878 state_clone.mark_subscribe("tickers.BTCUSDT");
879 state_clone.confirm_subscribe("tickers.BTCUSDT");
880 });
881 handles.push(handle);
882 }
883
884 for handle in handles {
885 handle.await.unwrap();
886 }
887
888 assert_eq!(state.get_reference_count("tickers.BTCUSDT"), 10);
890 assert_eq!(state.len(), 1);
891 }
892
893 #[rstest]
894 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
895 async fn test_concurrent_subscribe_unsubscribe() {
896 let state = Arc::new(SubscriptionState::new('.'));
897 let mut handles = vec![];
898
899 for i in 0..20 {
902 let state_clone = Arc::clone(&state);
903
904 let handle = tokio::spawn(async move {
905 let topic = format!("tickers.SYMBOL{i}");
906 state_clone.add_reference(&topic);
908 state_clone.add_reference(&topic);
909 state_clone.mark_subscribe(&topic);
910 state_clone.confirm_subscribe(&topic);
911
912 state_clone.remove_reference(&topic);
914 });
915 handles.push(handle);
916 }
917
918 for handle in handles {
919 handle.await.unwrap();
920 }
921
922 for i in 0..20 {
924 let topic = format!("tickers.SYMBOL{i}");
925 assert_eq!(state.get_reference_count(&topic), 1);
926 }
927
928 assert_eq!(state.len(), 20);
930 assert!(!state.is_empty());
931 }
932
933 #[rstest]
934 #[tokio::test(flavor = "multi_thread", worker_threads = 4)]
935 async fn test_concurrent_reference_counting_same_topic() {
936 let state = Arc::new(SubscriptionState::new('.'));
937 let topic = "tickers.BTCUSDT";
938 let mut handles = vec![];
939
940 for _ in 0..10 {
942 let state_clone = Arc::clone(&state);
943
944 let handle = tokio::spawn(async move {
945 for _ in 0..10 {
946 state_clone.add_reference(topic);
947 }
948 });
949 handles.push(handle);
950 }
951
952 for handle in handles {
953 handle.await.unwrap();
954 }
955
956 assert_eq!(state.get_reference_count(topic), 100);
958
959 for _ in 0..50 {
961 state.remove_reference(topic);
962 }
963
964 assert_eq!(state.get_reference_count(topic), 50);
966 }
967
968 #[rstest]
969 fn test_reconnection_scenario() {
970 let state = SubscriptionState::new('.');
971
972 state.add_reference("tickers.BTCUSDT");
974 state.mark_subscribe("tickers.BTCUSDT");
975 state.confirm_subscribe("tickers.BTCUSDT");
976
977 state.add_reference("tickers.ETHUSDT");
978 state.mark_subscribe("tickers.ETHUSDT");
979 state.confirm_subscribe("tickers.ETHUSDT");
980
981 state.add_reference("orderbook");
982 state.mark_subscribe("orderbook");
983 state.confirm_subscribe("orderbook");
984
985 assert_eq!(state.len(), 3);
986
987 let topics_to_resubscribe = state.all_topics();
989 assert_eq!(topics_to_resubscribe.len(), 3);
990 assert!(topics_to_resubscribe.contains(&"tickers.BTCUSDT".to_string()));
991 assert!(topics_to_resubscribe.contains(&"tickers.ETHUSDT".to_string()));
992 assert!(topics_to_resubscribe.contains(&"orderbook".to_string()));
993
994 for topic in &topics_to_resubscribe {
996 state.mark_subscribe(topic);
997 }
998
999 for topic in &topics_to_resubscribe {
1001 state.confirm_subscribe(topic);
1002 }
1003
1004 assert_eq!(state.len(), 3);
1006 assert_eq!(state.all_topics().len(), 3);
1007 }
1008
1009 #[rstest]
1010 fn test_state_machine_invalid_transitions() {
1011 let state = SubscriptionState::new('.');
1012
1013 state.confirm_subscribe("tickers.BTCUSDT");
1015 assert_eq!(state.len(), 1); state.confirm_unsubscribe("tickers.ETHUSDT");
1019 assert_eq!(state.len(), 1); state.mark_subscribe("orderbook");
1023 state.confirm_subscribe("orderbook");
1024 state.confirm_subscribe("orderbook"); assert_eq!(state.len(), 2);
1026
1027 state.mark_unsubscribe("nonexistent");
1029 state.confirm_unsubscribe("nonexistent");
1030 assert_eq!(state.len(), 2); }
1032
1033 #[rstest]
1034 fn test_mark_failure_moves_to_pending() {
1035 let state = SubscriptionState::new('.');
1036
1037 state.mark_subscribe("tickers.BTCUSDT");
1039 state.confirm_subscribe("tickers.BTCUSDT");
1040 assert_eq!(state.len(), 1);
1041 assert!(state.pending_subscribe_topics().is_empty());
1042
1043 state.mark_failure("tickers.BTCUSDT");
1045
1046 assert_eq!(state.len(), 0);
1048 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
1049
1050 assert_eq!(state.all_topics(), vec!["tickers.BTCUSDT"]);
1052 }
1053
1054 #[rstest]
1055 fn test_pending_subscribe_excludes_pending_unsubscribe() {
1056 let state = SubscriptionState::new('.');
1057
1058 state.mark_subscribe("tickers.BTCUSDT");
1060 state.confirm_subscribe("tickers.BTCUSDT");
1061
1062 state.mark_unsubscribe("tickers.BTCUSDT");
1064
1065 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
1067 assert!(state.all_topics().is_empty());
1068 assert_eq!(state.len(), 0);
1069 }
1070
1071 #[rstest]
1072 fn test_remove_reference_nonexistent_topic() {
1073 let state = SubscriptionState::new('.');
1074
1075 let should_unsubscribe = state.remove_reference("nonexistent");
1077
1078 assert!(!should_unsubscribe);
1080 assert_eq!(state.get_reference_count("nonexistent"), 0);
1081 }
1082
1083 #[rstest]
1084 fn test_edge_case_empty_channel_name() {
1085 let state = SubscriptionState::new('.');
1086
1087 state.mark_subscribe("");
1089 state.confirm_subscribe("");
1090
1091 assert_eq!(state.len(), 1);
1092 assert_eq!(state.all_topics(), vec![""]);
1093 }
1094
1095 #[rstest]
1096 fn test_special_characters_in_topics() {
1097 let state = SubscriptionState::new('.');
1098
1099 let special_topics = vec![
1101 "channel.symbol-with-dash",
1102 "channel.SYMBOL_WITH_UNDERSCORE",
1103 "channel.symbol123",
1104 "channel.symbol@special",
1105 ];
1106
1107 for topic in &special_topics {
1108 state.mark_subscribe(topic);
1109 state.confirm_subscribe(topic);
1110 }
1111
1112 assert_eq!(state.len(), special_topics.len());
1113
1114 let all_topics = state.all_topics();
1115
1116 for topic in &special_topics {
1117 assert!(
1118 all_topics.contains(&(*topic).to_string()),
1119 "Missing topic: {topic}"
1120 );
1121 }
1122 }
1123
1124 #[rstest]
1125 fn test_clear_resets_all_state() {
1126 let state = SubscriptionState::new('.');
1127
1128 for i in 0..10 {
1130 let topic = format!("channel{i}.SYMBOL");
1131 state.add_reference(&topic);
1132 state.add_reference(&topic); state.mark_subscribe(&topic);
1134 state.confirm_subscribe(&topic);
1135 }
1136
1137 assert_eq!(state.len(), 10);
1138 assert!(!state.is_empty());
1139
1140 state.clear();
1142
1143 assert_eq!(state.len(), 0);
1145 assert!(state.is_empty());
1146 assert!(state.all_topics().is_empty());
1147 assert!(state.pending_subscribe_topics().is_empty());
1148 assert!(state.pending_unsubscribe_topics().is_empty());
1149
1150 for i in 0..10 {
1152 let topic = format!("channel{i}.SYMBOL");
1153 assert_eq!(state.get_reference_count(&topic), 0);
1154 }
1155 }
1156
1157 #[rstest]
1158 fn test_different_delimiter_does_not_affect_storage() {
1159 let state_dot = SubscriptionState::new('.');
1161 let state_colon = SubscriptionState::new(':');
1162
1163 state_dot.mark_subscribe("channel.SYMBOL");
1165 state_colon.mark_subscribe("channel:SYMBOL");
1166
1167 assert_eq!(state_dot.pending_subscribe_topics(), vec!["channel.SYMBOL"]);
1169 assert_eq!(
1170 state_colon.pending_subscribe_topics(),
1171 vec!["channel:SYMBOL"]
1172 );
1173 }
1174
1175 #[rstest]
1176 fn test_unsubscribe_before_subscribe_confirmed() {
1177 let state = SubscriptionState::new('.');
1178
1179 state.mark_subscribe("tickers.BTCUSDT");
1181 assert_eq!(state.pending_subscribe_topics(), vec!["tickers.BTCUSDT"]);
1182
1183 state.mark_unsubscribe("tickers.BTCUSDT");
1185
1186 assert!(state.pending_subscribe_topics().is_empty());
1188 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
1189
1190 state.confirm_unsubscribe("tickers.BTCUSDT");
1192
1193 assert!(state.is_empty());
1195 assert!(state.all_topics().is_empty());
1196 assert_eq!(state.len(), 0);
1197 }
1198
1199 #[rstest]
1200 fn test_late_subscribe_confirmation_after_unsubscribe() {
1201 let state = SubscriptionState::new('.');
1202
1203 state.mark_subscribe("tickers.BTCUSDT");
1205
1206 state.mark_unsubscribe("tickers.BTCUSDT");
1208
1209 state.confirm_subscribe("tickers.BTCUSDT");
1211
1212 assert_eq!(state.len(), 0);
1214 assert!(state.pending_subscribe_topics().is_empty());
1215
1216 state.confirm_unsubscribe("tickers.BTCUSDT");
1218
1219 assert!(state.is_empty());
1221 assert!(state.all_topics().is_empty());
1222 }
1223
1224 #[rstest]
1225 fn test_unsubscribe_clears_all_states() {
1226 let state = SubscriptionState::new('.');
1227
1228 state.mark_subscribe("tickers.BTCUSDT");
1230 state.confirm_subscribe("tickers.BTCUSDT");
1231 assert_eq!(state.len(), 1);
1232
1233 state.mark_unsubscribe("tickers.BTCUSDT");
1235
1236 assert_eq!(state.len(), 0);
1238 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
1239
1240 state.confirm_subscribe("tickers.BTCUSDT");
1242
1243 state.confirm_unsubscribe("tickers.BTCUSDT");
1245
1246 assert!(state.is_empty());
1248 assert_eq!(state.len(), 0);
1249 assert!(state.pending_subscribe_topics().is_empty());
1250 assert!(state.pending_unsubscribe_topics().is_empty());
1251 assert!(state.all_topics().is_empty());
1252 }
1253
1254 #[rstest]
1255 fn test_mark_failure_respects_pending_unsubscribe() {
1256 let state = SubscriptionState::new('.');
1257
1258 state.mark_subscribe("tickers.BTCUSDT");
1260 state.confirm_subscribe("tickers.BTCUSDT");
1261 assert_eq!(state.len(), 1);
1262
1263 state.mark_unsubscribe("tickers.BTCUSDT");
1265 assert_eq!(state.len(), 0);
1266 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
1267
1268 state.mark_failure("tickers.BTCUSDT");
1270
1271 assert!(state.pending_subscribe_topics().is_empty());
1273 assert_eq!(state.pending_unsubscribe_topics(), vec!["tickers.BTCUSDT"]);
1274
1275 assert!(state.all_topics().is_empty());
1277
1278 state.confirm_unsubscribe("tickers.BTCUSDT");
1280 assert!(state.is_empty());
1281 }
1282
1283 #[rstest]
1284 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
1285 async fn test_concurrent_stress_mixed_operations() {
1286 let state = Arc::new(SubscriptionState::new('.'));
1287 let mut handles = vec![];
1288
1289 for i in 0..50 {
1291 let state_clone = Arc::clone(&state);
1292
1293 let handle = tokio::spawn(async move {
1294 let topic1 = format!("channel.SYMBOL{i}");
1295 let topic2 = format!("channel.SYMBOL{}", i + 100);
1296
1297 state_clone.add_reference(&topic1);
1299 state_clone.add_reference(&topic2);
1300
1301 state_clone.mark_subscribe(&topic1);
1303 state_clone.confirm_subscribe(&topic1);
1304 state_clone.mark_subscribe(&topic2);
1305
1306 if i % 3 == 0 {
1308 state_clone.mark_unsubscribe(&topic1);
1309 state_clone.confirm_unsubscribe(&topic1);
1310 }
1311
1312 state_clone.add_reference(&topic2);
1314 state_clone.remove_reference(&topic2);
1315
1316 state_clone.confirm_subscribe(&topic2);
1318 });
1319 handles.push(handle);
1320 }
1321
1322 for handle in handles {
1323 handle.await.unwrap();
1324 }
1325
1326 let all = state.all_topics();
1328 let confirmed_count = state.len();
1329
1330 assert!(confirmed_count > 50); assert!(confirmed_count <= 100); assert_eq!(
1335 all.len(),
1336 confirmed_count + state.pending_subscribe_topics().len()
1337 );
1338 }
1339
1340 #[rstest]
1341 fn test_edge_case_malformed_topics() {
1342 let state = SubscriptionState::new('.');
1343
1344 state.mark_subscribe("channel.symbol.extra");
1346 state.confirm_subscribe("channel.symbol.extra");
1347 let topics = state.all_topics();
1348 assert!(topics.contains(&"channel.symbol.extra".to_string()));
1349
1350 state.mark_subscribe(".channel");
1352 state.confirm_subscribe(".channel");
1353 assert_eq!(state.len(), 2);
1354
1355 state.mark_subscribe("channel.");
1358 state.confirm_subscribe("channel.");
1359 assert_eq!(state.len(), 3);
1360
1361 state.mark_subscribe("tickers");
1363 state.confirm_subscribe("tickers");
1364 assert_eq!(state.len(), 4);
1365
1366 let all = state.all_topics();
1368 assert_eq!(all.len(), 4);
1369 assert!(all.contains(&"channel.symbol.extra".to_string()));
1370 assert!(all.contains(&".channel".to_string()));
1371 assert!(all.contains(&"channel".to_string())); assert!(all.contains(&"tickers".to_string()));
1373 }
1374
1375 #[rstest]
1376 fn test_reference_count_underflow_safety() {
1377 let state = SubscriptionState::new('.');
1378
1379 assert!(!state.remove_reference("never.added"));
1381 assert_eq!(state.get_reference_count("never.added"), 0);
1382
1383 state.add_reference("once.added");
1385 assert_eq!(state.get_reference_count("once.added"), 1);
1386
1387 assert!(state.remove_reference("once.added")); assert_eq!(state.get_reference_count("once.added"), 0);
1389
1390 assert!(!state.remove_reference("once.added")); assert!(!state.remove_reference("once.added")); assert_eq!(state.get_reference_count("once.added"), 0);
1393
1394 assert!(state.add_reference("once.added"));
1396 assert_eq!(state.get_reference_count("once.added"), 1);
1397 }
1398
1399 #[rstest]
1400 fn test_reconnection_with_partial_state() {
1401 let state = SubscriptionState::new('.');
1402
1403 state.mark_subscribe("confirmed.BTCUSDT");
1406 state.confirm_subscribe("confirmed.BTCUSDT");
1407
1408 state.mark_subscribe("pending.ETHUSDT");
1410
1411 state.mark_subscribe("cancelled.XRPUSDT");
1413 state.confirm_subscribe("cancelled.XRPUSDT");
1414 state.mark_unsubscribe("cancelled.XRPUSDT");
1415
1416 assert_eq!(state.len(), 1); let all = state.all_topics();
1419 assert_eq!(all.len(), 2); assert!(all.contains(&"confirmed.BTCUSDT".to_string()));
1421 assert!(all.contains(&"pending.ETHUSDT".to_string()));
1422 assert!(!all.contains(&"cancelled.XRPUSDT".to_string())); let topics_to_resubscribe = state.all_topics();
1426
1427 state.confirmed().clear();
1429
1430 for topic in &topics_to_resubscribe {
1432 state.mark_subscribe(topic);
1433 }
1434
1435 for topic in &topics_to_resubscribe {
1437 state.confirm_subscribe(topic);
1438 }
1439
1440 assert_eq!(state.len(), 2); let final_topics = state.all_topics();
1443 assert_eq!(final_topics.len(), 2);
1444 assert!(final_topics.contains(&"confirmed.BTCUSDT".to_string()));
1445 assert!(final_topics.contains(&"pending.ETHUSDT".to_string()));
1446 assert!(!final_topics.contains(&"cancelled.XRPUSDT".to_string()));
1447 }
1448
1449 fn check_invariants(state: &SubscriptionState, label: &str) {
1460 let confirmed_topics: AHashSet<String> = state
1462 .topics_from_map(&state.confirmed)
1463 .into_iter()
1464 .collect();
1465 let pending_sub_topics: AHashSet<String> =
1466 state.pending_subscribe_topics().into_iter().collect();
1467 let pending_unsub_topics: AHashSet<String> =
1468 state.pending_unsubscribe_topics().into_iter().collect();
1469
1470 let confirmed_and_pending_sub: Vec<_> =
1472 confirmed_topics.intersection(&pending_sub_topics).collect();
1473 assert!(
1474 confirmed_and_pending_sub.is_empty(),
1475 "{label}: Topic in both confirmed and pending_subscribe: {confirmed_and_pending_sub:?}"
1476 );
1477
1478 let confirmed_and_pending_unsub: Vec<_> = confirmed_topics
1479 .intersection(&pending_unsub_topics)
1480 .collect();
1481 assert!(
1482 confirmed_and_pending_unsub.is_empty(),
1483 "{label}: Topic in both confirmed and pending_unsubscribe: {confirmed_and_pending_unsub:?}"
1484 );
1485
1486 let pending_sub_and_unsub: Vec<_> = pending_sub_topics
1487 .intersection(&pending_unsub_topics)
1488 .collect();
1489 assert!(
1490 pending_sub_and_unsub.is_empty(),
1491 "{label}: Topic in both pending_subscribe and pending_unsubscribe: {pending_sub_and_unsub:?}"
1492 );
1493
1494 let all_topics: AHashSet<String> = state.all_topics().into_iter().collect();
1496 let expected_all: AHashSet<String> = confirmed_topics
1497 .union(&pending_sub_topics)
1498 .cloned()
1499 .collect();
1500 assert_eq!(
1501 all_topics, expected_all,
1502 "{label}: all_topics() doesn't match confirmed ∪ pending_subscribe"
1503 );
1504
1505 for topic in &pending_unsub_topics {
1507 assert!(
1508 !all_topics.contains(topic),
1509 "{label}: pending_unsubscribe topic {topic} incorrectly in all_topics()"
1510 );
1511 }
1512
1513 let expected_len: usize = state
1515 .confirmed
1516 .iter()
1517 .map(|entry| entry.value().len())
1518 .sum();
1519 assert_eq!(
1520 state.len(),
1521 expected_len,
1522 "{label}: len() mismatch. Expected {expected_len}, was {}",
1523 state.len()
1524 );
1525
1526 let should_be_empty = state.confirmed.is_empty()
1528 && pending_sub_topics.is_empty()
1529 && pending_unsub_topics.is_empty();
1530 assert_eq!(
1531 state.is_empty(),
1532 should_be_empty,
1533 "{label}: is_empty() inconsistent. Maps empty: {should_be_empty}, is_empty(): {}",
1534 state.is_empty()
1535 );
1536
1537 for entry in state.reference_counts.iter() {
1539 let count = entry.value().get();
1540 assert!(
1541 count > 0,
1542 "{label}: Reference count should be NonZeroUsize (> 0), was {count} for {:?}",
1543 entry.key()
1544 );
1545 }
1546 }
1547
1548 fn check_topic_exclusivity(state: &SubscriptionState, topic: &str, label: &str) {
1550 let (channel, symbol) = split_topic(topic, state.delimiter);
1551
1552 let in_confirmed = is_tracked(&state.confirmed, channel, symbol);
1553 let in_pending_sub = is_tracked(&state.pending_subscribe, channel, symbol);
1554 let in_pending_unsub = is_tracked(&state.pending_unsubscribe, channel, symbol);
1555
1556 let count = [in_confirmed, in_pending_sub, in_pending_unsub]
1557 .iter()
1558 .filter(|&&x| x)
1559 .count();
1560
1561 assert!(
1562 count <= 1,
1563 "{label}: Topic {topic} in {count} states (should be 0 or 1). \
1564 confirmed: {in_confirmed}, pending_sub: {in_pending_sub}, pending_unsub: {in_pending_unsub}"
1565 );
1566 }
1567
1568 #[cfg(test)]
1569 mod property_tests {
1570 use proptest::prelude::*;
1571
1572 use super::*;
1573
1574 #[derive(Debug, Clone)]
1575 enum Operation {
1576 MarkSubscribe(String),
1577 ConfirmSubscribe(String),
1578 MarkUnsubscribe(String),
1579 ConfirmUnsubscribe(String),
1580 MarkFailure(String),
1581 AddReference(String),
1582 RemoveReference(String),
1583 Clear,
1584 }
1585
1586 fn topic_strategy() -> impl Strategy<Value = String> {
1588 prop_oneof![
1589 (any::<u8>(), any::<u8>())
1591 .prop_map(|(ch, sym)| { format!("channel{}.SYMBOL{}", ch % 5, sym % 10) }),
1592 any::<u8>().prop_map(|ch| format!("channel{}", ch % 5)),
1594 ]
1595 }
1596
1597 fn operation_strategy() -> impl Strategy<Value = Operation> {
1599 topic_strategy().prop_flat_map(|topic| {
1600 prop_oneof![
1601 Just(Operation::MarkSubscribe(topic.clone())),
1602 Just(Operation::ConfirmSubscribe(topic.clone())),
1603 Just(Operation::MarkUnsubscribe(topic.clone())),
1604 Just(Operation::ConfirmUnsubscribe(topic.clone())),
1605 Just(Operation::MarkFailure(topic.clone())),
1606 Just(Operation::AddReference(topic.clone())),
1607 Just(Operation::RemoveReference(topic)),
1608 Just(Operation::Clear),
1609 ]
1610 })
1611 }
1612
1613 fn apply_operation(state: &SubscriptionState, op: &Operation) {
1615 match op {
1616 Operation::MarkSubscribe(topic) => state.mark_subscribe(topic),
1617 Operation::ConfirmSubscribe(topic) => state.confirm_subscribe(topic),
1618 Operation::MarkUnsubscribe(topic) => state.mark_unsubscribe(topic),
1619 Operation::ConfirmUnsubscribe(topic) => state.confirm_unsubscribe(topic),
1620 Operation::MarkFailure(topic) => state.mark_failure(topic),
1621 Operation::AddReference(topic) => {
1622 state.add_reference(topic);
1623 }
1624 Operation::RemoveReference(topic) => {
1625 state.remove_reference(topic);
1626 }
1627 Operation::Clear => state.clear(),
1628 }
1629 }
1630
1631 proptest! {
1632 #![proptest_config(ProptestConfig::with_cases(500))]
1633
1634 #[rstest]
1636 fn prop_invariants_hold_after_operations(
1637 operations in prop::collection::vec(operation_strategy(), 1..50)
1638 ) {
1639 let state = SubscriptionState::new('.');
1640
1641 for (i, op) in operations.iter().enumerate() {
1643 apply_operation(&state, op);
1644
1645 check_invariants(&state, &format!("After op {i}: {op:?}"));
1647 }
1648
1649 check_invariants(&state, "Final state");
1651 }
1652
1653 #[rstest]
1655 fn prop_reference_counting_consistency(
1656 ops in prop::collection::vec(
1657 topic_strategy().prop_flat_map(|t| {
1658 prop_oneof![
1659 Just(Operation::AddReference(t.clone())),
1660 Just(Operation::RemoveReference(t)),
1661 ]
1662 }),
1663 1..100
1664 )
1665 ) {
1666 let state = SubscriptionState::new('.');
1667
1668 for op in &ops {
1669 apply_operation(&state, op);
1670
1671 for entry in state.reference_counts.iter() {
1673 assert!(entry.value().get() > 0);
1674 }
1675 }
1676 }
1677
1678 #[rstest]
1680 fn prop_all_topics_is_union(
1681 operations in prop::collection::vec(operation_strategy(), 1..50)
1682 ) {
1683 let state = SubscriptionState::new('.');
1684
1685 for op in &operations {
1686 apply_operation(&state, op);
1687
1688 let all_topics: AHashSet<String> = state.all_topics().into_iter().collect();
1690 let confirmed: AHashSet<String> = state.topics_from_map(&state.confirmed).into_iter().collect();
1691 let pending_sub: AHashSet<String> = state.pending_subscribe_topics().into_iter().collect();
1692 let expected: AHashSet<String> = confirmed.union(&pending_sub).cloned().collect();
1693
1694 assert_eq!(all_topics, expected);
1695
1696 let pending_unsub: AHashSet<String> = state.pending_unsubscribe_topics().into_iter().collect();
1698 for topic in pending_unsub {
1699 assert!(!all_topics.contains(&topic));
1700 }
1701 }
1702 }
1703
1704 #[rstest]
1706 fn prop_clear_resets_completely(
1707 operations in prop::collection::vec(operation_strategy(), 1..30)
1708 ) {
1709 let state = SubscriptionState::new('.');
1710
1711 for op in &operations {
1713 apply_operation(&state, op);
1714 }
1715
1716 state.clear();
1718
1719 assert!(state.is_empty());
1720 assert_eq!(state.len(), 0);
1721 assert!(state.all_topics().is_empty());
1722 assert!(state.pending_subscribe_topics().is_empty());
1723 assert!(state.pending_unsubscribe_topics().is_empty());
1724 assert!(state.confirmed.is_empty());
1725 assert!(state.pending_subscribe.is_empty());
1726 assert!(state.pending_unsubscribe.is_empty());
1727 assert!(state.reference_counts.is_empty());
1728 }
1729
1730 #[rstest]
1732 fn prop_topic_mutual_exclusivity(
1733 operations in prop::collection::vec(operation_strategy(), 1..50),
1734 topic in topic_strategy()
1735 ) {
1736 let state = SubscriptionState::new('.');
1737
1738 for (i, op) in operations.iter().enumerate() {
1739 apply_operation(&state, op);
1740 check_topic_exclusivity(&state, &topic, &format!("After op {i}: {op:?}"));
1741 }
1742 }
1743 }
1744 }
1745
1746 #[rstest]
1747 fn test_exhaustive_two_step_transitions() {
1748 let operations = [
1749 "mark_subscribe",
1750 "confirm_subscribe",
1751 "mark_unsubscribe",
1752 "confirm_unsubscribe",
1753 "mark_failure",
1754 ];
1755
1756 for &op1 in &operations {
1757 for &op2 in &operations {
1758 let state = SubscriptionState::new('.');
1759 let topic = "test.TOPIC";
1760
1761 apply_op(&state, op1, topic);
1763 apply_op(&state, op2, topic);
1764
1765 check_invariants(&state, &format!("{op1} → {op2}"));
1767 check_topic_exclusivity(&state, topic, &format!("{op1} → {op2}"));
1768 }
1769 }
1770 }
1771
1772 fn apply_op(state: &SubscriptionState, op: &str, topic: &str) {
1773 match op {
1774 "mark_subscribe" => state.mark_subscribe(topic),
1775 "confirm_subscribe" => state.confirm_subscribe(topic),
1776 "mark_unsubscribe" => state.mark_unsubscribe(topic),
1777 "confirm_unsubscribe" => state.confirm_unsubscribe(topic),
1778 "mark_failure" => state.mark_failure(topic),
1779 _ => panic!("Unknown operation: {op}"),
1780 }
1781 }
1782
1783 #[rstest]
1784 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
1785 async fn test_stress_rapid_resubscribe_pattern() {
1786 let state = Arc::new(SubscriptionState::new('.'));
1788 let mut handles = vec![];
1789
1790 for i in 0..100 {
1791 let state_clone = Arc::clone(&state);
1792
1793 let handle = tokio::spawn(async move {
1794 let topic = format!("rapid.SYMBOL{}", i % 10); state_clone.mark_subscribe(&topic);
1798 state_clone.confirm_subscribe(&topic);
1799
1800 state_clone.mark_unsubscribe(&topic);
1802 state_clone.mark_subscribe(&topic);
1804 state_clone.confirm_unsubscribe(&topic);
1806 state_clone.confirm_subscribe(&topic);
1808 });
1809 handles.push(handle);
1810 }
1811
1812 for handle in handles {
1813 handle.await.unwrap();
1814 }
1815
1816 check_invariants(&state, "After rapid resubscribe stress test");
1817 }
1818
1819 #[rstest]
1820 #[tokio::test(flavor = "multi_thread", worker_threads = 8)]
1821 async fn test_stress_failure_recovery_loop() {
1822 let state = Arc::new(SubscriptionState::new('.'));
1825 let mut handles = vec![];
1826
1827 for i in 0..30 {
1828 let state_clone = Arc::clone(&state);
1829
1830 let handle = tokio::spawn(async move {
1831 let topic = format!("failure.SYMBOL{i}"); state_clone.mark_subscribe(&topic);
1835 state_clone.confirm_subscribe(&topic);
1836
1837 for _ in 0..5 {
1839 state_clone.mark_failure(&topic);
1840 state_clone.confirm_subscribe(&topic); }
1842 });
1843 handles.push(handle);
1844 }
1845
1846 for handle in handles {
1847 handle.await.unwrap();
1848 }
1849
1850 check_invariants(&state, "After failure recovery loops");
1851
1852 assert_eq!(state.len(), 30);
1854 }
1855}