nautilus_network/websocket/
auth.rs1use std::{
42 sync::{
43 Arc, Mutex,
44 atomic::{AtomicU8, Ordering},
45 },
46 time::Duration,
47};
48
49pub type AuthResultSender = tokio::sync::oneshot::Sender<Result<(), String>>;
50pub type AuthResultReceiver = tokio::sync::oneshot::Receiver<Result<(), String>>;
51
52#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
54#[repr(u8)]
55pub enum AuthState {
56 #[default]
58 Unauthenticated = 0,
59 Authenticated = 1,
61 Failed = 2,
63}
64
65impl AuthState {
66 #[inline]
67 #[must_use]
68 #[expect(
69 clippy::match_same_arms,
70 reason = "explicit variant listing is clearer than collapsing 0 with wildcard"
71 )]
72 fn from_u8(value: u8) -> Self {
73 match value {
74 0 => Self::Unauthenticated,
75 1 => Self::Authenticated,
76 2 => Self::Failed,
77 _ => Self::Unauthenticated,
78 }
79 }
80
81 #[inline]
82 #[must_use]
83 const fn as_u8(self) -> u8 {
84 self as u8
85 }
86}
87
88#[derive(Clone, Debug)]
111pub struct AuthTracker {
112 tx: Arc<Mutex<Option<AuthResultSender>>>,
113 state: Arc<AtomicU8>,
114 state_notify: Arc<tokio::sync::Notify>,
115}
116
117impl AuthTracker {
118 #[must_use]
120 pub fn new() -> Self {
121 Self {
122 tx: Arc::new(Mutex::new(None)),
123 state: Arc::new(AtomicU8::new(AuthState::Unauthenticated.as_u8())),
124 state_notify: Arc::new(tokio::sync::Notify::new()),
125 }
126 }
127
128 #[must_use]
130 pub fn auth_state(&self) -> AuthState {
131 AuthState::from_u8(self.state.load(Ordering::Acquire))
132 }
133
134 #[must_use]
136 pub fn is_authenticated(&self) -> bool {
137 self.auth_state() == AuthState::Authenticated
138 }
139
140 pub fn invalidate(&self) {
145 self.state
146 .store(AuthState::Unauthenticated.as_u8(), Ordering::Release);
147 self.state_notify.notify_waiters();
148 }
149
150 #[allow(
159 clippy::must_use_candidate,
160 reason = "callers use this for side effects"
161 )]
162 pub fn begin(&self) -> AuthResultReceiver {
163 let (sender, receiver) = tokio::sync::oneshot::channel();
164 self.state
165 .store(AuthState::Unauthenticated.as_u8(), Ordering::Release);
166
167 if let Ok(mut guard) = self.tx.lock() {
168 if let Some(old) = guard.take() {
169 log::warn!("New authentication request superseding previous pending request");
170 let _ = old.send(Err("Authentication attempt superseded".to_string()));
171 } else {
172 log::debug!("Starting new authentication request");
173 }
174 *guard = Some(sender);
175 }
176
177 receiver
178 }
179
180 pub fn succeed(&self) {
189 self.state
190 .store(AuthState::Authenticated.as_u8(), Ordering::Release);
191 self.state_notify.notify_waiters();
192
193 if let Ok(mut guard) = self.tx.lock()
194 && let Some(sender) = guard.take()
195 {
196 let _ = sender.send(Ok(()));
197 }
198 }
199
200 pub fn fail(&self, error: impl Into<String>) {
209 self.state
210 .store(AuthState::Failed.as_u8(), Ordering::Release);
211 self.state_notify.notify_waiters();
212 let message = error.into();
213
214 if let Ok(mut guard) = self.tx.lock()
215 && let Some(sender) = guard.take()
216 {
217 let _ = sender.send(Err(message));
218 }
219 }
220
221 pub async fn wait_for_result<E>(
238 &self,
239 timeout: Duration,
240 receiver: AuthResultReceiver,
241 ) -> Result<(), E>
242 where
243 E: From<String>,
244 {
245 match tokio::time::timeout(timeout, receiver).await {
246 Ok(Ok(Ok(()))) => Ok(()),
247 Ok(Ok(Err(msg))) => Err(E::from(msg)),
248 Ok(Err(_)) => Err(E::from("Authentication channel closed".to_string())),
249 Err(_) => {
250 Err(E::from("Authentication timed out".to_string()))
254 }
255 }
256 }
257
258 pub async fn wait_for_authenticated(&self, timeout: Duration) -> bool {
272 if self.is_authenticated() {
273 return true;
274 }
275
276 tokio::time::timeout(timeout, async {
277 loop {
278 let notified = self.state_notify.notified();
279
280 match self.auth_state() {
281 AuthState::Authenticated => return true,
282 AuthState::Failed => return false,
283 AuthState::Unauthenticated => notified.await,
284 }
285 }
286 })
287 .await
288 .unwrap_or(false)
289 }
290}
291
292impl Default for AuthTracker {
293 fn default() -> Self {
294 Self::new()
295 }
296}
297
298#[cfg(test)]
299mod tests {
300 use std::{
301 sync::atomic::{AtomicBool, Ordering},
302 time::Duration,
303 };
304
305 use rstest::rstest;
306
307 use super::*;
308
309 #[derive(Debug, PartialEq)]
310 struct TestError(String);
311
312 impl From<String> for TestError {
313 fn from(msg: String) -> Self {
314 Self(msg)
315 }
316 }
317
318 #[rstest]
319 #[tokio::test]
320 async fn test_successful_authentication() {
321 let tracker = AuthTracker::new();
322 let rx = tracker.begin();
323
324 tracker.succeed();
325
326 let result: Result<(), TestError> =
327 tracker.wait_for_result(Duration::from_secs(1), rx).await;
328
329 assert!(result.is_ok());
330 }
331
332 #[rstest]
333 #[tokio::test]
334 async fn test_failed_authentication() {
335 let tracker = AuthTracker::new();
336 let rx = tracker.begin();
337
338 tracker.fail("Invalid credentials");
339
340 let result: Result<(), TestError> =
341 tracker.wait_for_result(Duration::from_secs(1), rx).await;
342
343 assert_eq!(
344 result.unwrap_err(),
345 TestError("Invalid credentials".to_string())
346 );
347 }
348
349 #[rstest]
350 #[tokio::test]
351 async fn test_authentication_timeout() {
352 let tracker = AuthTracker::new();
353 let rx = tracker.begin();
354
355 let result: Result<(), TestError> =
358 tracker.wait_for_result(Duration::from_millis(50), rx).await;
359
360 assert_eq!(
361 result.unwrap_err(),
362 TestError("Authentication timed out".to_string())
363 );
364 }
365
366 #[rstest]
367 #[tokio::test]
368 async fn test_begin_supersedes_previous_sender() {
369 let tracker = AuthTracker::new();
370
371 let first = tracker.begin();
372 let second = tracker.begin();
373
374 let result = first.await.expect("oneshot closed unexpectedly");
376 assert_eq!(result, Err("Authentication attempt superseded".to_string()));
377
378 tracker.succeed();
380 let result: Result<(), TestError> = tracker
381 .wait_for_result(Duration::from_secs(1), second)
382 .await;
383
384 assert!(result.is_ok());
385 }
386
387 #[rstest]
388 #[tokio::test]
389 async fn test_succeed_without_pending_auth() {
390 let tracker = AuthTracker::new();
391
392 tracker.succeed();
394 }
395
396 #[rstest]
397 #[tokio::test]
398 async fn test_fail_without_pending_auth() {
399 let tracker = AuthTracker::new();
400
401 tracker.fail("Some error");
403 }
404
405 #[rstest]
406 #[tokio::test]
407 async fn test_multiple_sequential_authentications() {
408 let tracker = AuthTracker::new();
409
410 let rx1 = tracker.begin();
412 tracker.succeed();
413 let result1: Result<(), TestError> =
414 tracker.wait_for_result(Duration::from_secs(1), rx1).await;
415 assert!(result1.is_ok());
416
417 let rx2 = tracker.begin();
419 tracker.fail("Credentials expired");
420 let result2: Result<(), TestError> =
421 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
422 assert_eq!(
423 result2.unwrap_err(),
424 TestError("Credentials expired".to_string())
425 );
426
427 let rx3 = tracker.begin();
429 tracker.succeed();
430 let result3: Result<(), TestError> =
431 tracker.wait_for_result(Duration::from_secs(1), rx3).await;
432 assert!(result3.is_ok());
433 }
434
435 #[rstest]
436 #[tokio::test]
437 async fn test_channel_closed_before_result() {
438 let tracker = AuthTracker::new();
439 let rx = tracker.begin();
440
441 tracker.begin();
443
444 let result: Result<(), TestError> =
446 tracker.wait_for_result(Duration::from_secs(1), rx).await;
447
448 assert_eq!(
449 result.unwrap_err(),
450 TestError("Authentication attempt superseded".to_string())
451 );
452 }
453
454 #[rstest]
455 #[tokio::test]
456 async fn test_concurrent_auth_attempts() {
457 let tracker = Arc::new(AuthTracker::new());
458 let mut handles = vec![];
459
460 for i in 0..10 {
462 let tracker_clone = Arc::clone(&tracker);
463 let handle = tokio::spawn(async move {
464 let rx = tracker_clone.begin();
465
466 if i == 9 {
468 tokio::time::sleep(Duration::from_millis(10)).await;
469 tracker_clone.succeed();
470 }
471
472 let result: Result<(), TestError> = tracker_clone
473 .wait_for_result(Duration::from_secs(1), rx)
474 .await;
475
476 (i, result)
477 });
478 handles.push(handle);
479 }
480
481 let mut successes = 0;
482 let mut superseded = 0;
483
484 for handle in handles {
485 let (i, result) = handle.await.unwrap();
486 match result {
487 Ok(()) => {
488 assert_eq!(i, 9);
490 successes += 1;
491 }
492 Err(TestError(msg)) if msg.contains("superseded") => {
493 superseded += 1;
494 }
495 Err(e) => panic!("Unexpected error: {e:?}"),
496 }
497 }
498
499 assert_eq!(successes, 1);
500 assert_eq!(superseded, 9);
501 }
502
503 #[rstest]
504 fn test_default_trait() {
505 let _tracker = AuthTracker::default();
506 }
507
508 #[rstest]
509 #[tokio::test]
510 async fn test_clone_trait() {
511 let tracker = AuthTracker::new();
512 let cloned = tracker.clone();
513
514 let rx = tracker.begin();
516 cloned.succeed(); let result: Result<(), TestError> =
518 tracker.wait_for_result(Duration::from_secs(1), rx).await;
519 assert!(result.is_ok());
520 }
521
522 #[rstest]
523 fn test_debug_trait() {
524 let tracker = AuthTracker::new();
525 let debug_str = format!("{tracker:?}");
526 assert!(debug_str.contains("AuthTracker"));
527 }
528
529 #[rstest]
530 #[tokio::test]
531 async fn test_timeout_clears_sender() {
532 let tracker = AuthTracker::new();
533
534 let rx1 = tracker.begin();
536 let result1: Result<(), TestError> = tracker
537 .wait_for_result(Duration::from_millis(50), rx1)
538 .await;
539 assert_eq!(
540 result1.unwrap_err(),
541 TestError("Authentication timed out".to_string())
542 );
543
544 let rx2 = tracker.begin();
546 tracker.succeed();
547 let result2: Result<(), TestError> =
548 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
549 assert!(result2.is_ok());
550 }
551
552 #[rstest]
553 #[tokio::test]
554 async fn test_fail_clears_sender() {
555 let tracker = AuthTracker::new();
556
557 let rx1 = tracker.begin();
559 tracker.fail("Bad credentials");
560 let result1: Result<(), TestError> =
561 tracker.wait_for_result(Duration::from_secs(1), rx1).await;
562 assert!(result1.is_err());
563
564 let rx2 = tracker.begin();
566 tracker.succeed();
567 let result2: Result<(), TestError> =
568 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
569 assert!(result2.is_ok());
570 }
571
572 #[rstest]
573 #[tokio::test]
574 async fn test_succeed_clears_sender() {
575 let tracker = AuthTracker::new();
576
577 let rx1 = tracker.begin();
579 tracker.succeed();
580 let result1: Result<(), TestError> =
581 tracker.wait_for_result(Duration::from_secs(1), rx1).await;
582 assert!(result1.is_ok());
583
584 let rx2 = tracker.begin();
586 tracker.succeed();
587 let result2: Result<(), TestError> =
588 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
589 assert!(result2.is_ok());
590 }
591
592 #[rstest]
593 #[tokio::test]
594 async fn test_rapid_begin_succeed_cycles() {
595 let tracker = AuthTracker::new();
596
597 for _ in 0..100 {
599 let rx = tracker.begin();
600 tracker.succeed();
601 let result: Result<(), TestError> =
602 tracker.wait_for_result(Duration::from_secs(1), rx).await;
603 assert!(result.is_ok());
604 }
605 }
606
607 #[rstest]
608 #[tokio::test]
609 async fn test_double_succeed_is_safe() {
610 let tracker = AuthTracker::new();
611 let rx = tracker.begin();
612
613 tracker.succeed();
615 tracker.succeed(); let result: Result<(), TestError> =
618 tracker.wait_for_result(Duration::from_secs(1), rx).await;
619 assert!(result.is_ok());
620 }
621
622 #[rstest]
623 #[tokio::test]
624 async fn test_double_fail_is_safe() {
625 let tracker = AuthTracker::new();
626 let rx = tracker.begin();
627
628 tracker.fail("Error 1");
630 tracker.fail("Error 2"); let result: Result<(), TestError> =
633 tracker.wait_for_result(Duration::from_secs(1), rx).await;
634 assert_eq!(
635 result.unwrap_err(),
636 TestError("Error 1".to_string()) );
638 }
639
640 #[rstest]
641 #[tokio::test]
642 async fn test_succeed_after_fail_is_ignored() {
643 let tracker = AuthTracker::new();
644 let rx = tracker.begin();
645
646 tracker.fail("Auth failed");
647 tracker.succeed(); let result: Result<(), TestError> =
650 tracker.wait_for_result(Duration::from_secs(1), rx).await;
651 assert!(result.is_err()); }
653
654 #[rstest]
655 #[tokio::test]
656 async fn test_fail_after_succeed_is_ignored() {
657 let tracker = AuthTracker::new();
658 let rx = tracker.begin();
659
660 tracker.succeed();
661 tracker.fail("Auth failed"); let result: Result<(), TestError> =
664 tracker.wait_for_result(Duration::from_secs(1), rx).await;
665 assert!(result.is_ok()); }
667
668 #[rstest]
675 #[tokio::test]
676 async fn test_reconnect_flow_waits_for_auth() {
677 let tracker = Arc::new(AuthTracker::new());
678 let subscribed = Arc::new(tokio::sync::Notify::new());
679 let auth_completed = Arc::new(tokio::sync::Notify::new());
680
681 let tracker_reconnect = Arc::clone(&tracker);
683 let subscribed_reconnect = Arc::clone(&subscribed);
684 let auth_completed_reconnect = Arc::clone(&auth_completed);
685
686 let reconnect_task = tokio::spawn(async move {
687 let rx = tracker_reconnect.begin();
689
690 let tracker_resub = Arc::clone(&tracker_reconnect);
692 let subscribed_resub = Arc::clone(&subscribed_reconnect);
693 let auth_completed_resub = Arc::clone(&auth_completed_reconnect);
694
695 let resub_task = tokio::spawn(async move {
696 let result: Result<(), TestError> = tracker_resub
698 .wait_for_result(Duration::from_secs(5), rx)
699 .await;
700
701 if result.is_ok() {
702 auth_completed_resub.notify_one();
703 tokio::time::sleep(Duration::from_millis(10)).await;
705 subscribed_resub.notify_one();
706 }
707 });
708
709 resub_task.await.unwrap();
710 });
711
712 tokio::time::sleep(Duration::from_millis(100)).await;
714 tracker.succeed();
715
716 reconnect_task.await.unwrap();
718
719 tokio::select! {
721 () = auth_completed.notified() => {
722 }
724 () = tokio::time::sleep(Duration::from_secs(1)) => {
725 panic!("Auth never completed");
726 }
727 }
728
729 tokio::select! {
731 () = subscribed.notified() => {
732 }
734 () = tokio::time::sleep(Duration::from_secs(1)) => {
735 panic!("Subscription never completed");
736 }
737 }
738 }
739
740 #[rstest]
742 #[tokio::test]
743 async fn test_reconnect_flow_blocks_on_auth_failure() {
744 let tracker = Arc::new(AuthTracker::new());
745 let subscribed = Arc::new(AtomicBool::new(false));
746
747 let tracker_reconnect = Arc::clone(&tracker);
748 let subscribed_reconnect = Arc::clone(&subscribed);
749
750 let reconnect_task = tokio::spawn(async move {
751 let rx = tracker_reconnect.begin();
752
753 let tracker_resub = Arc::clone(&tracker_reconnect);
755 let subscribed_resub = Arc::clone(&subscribed_reconnect);
756
757 let resub_task = tokio::spawn(async move {
758 let result: Result<(), TestError> = tracker_resub
759 .wait_for_result(Duration::from_secs(5), rx)
760 .await;
761
762 if result.is_ok() {
764 subscribed_resub.store(true, Ordering::Relaxed);
765 }
766 });
767
768 resub_task.await.unwrap();
769 });
770
771 tokio::time::sleep(Duration::from_millis(50)).await;
773 tracker.fail("Invalid credentials");
774
775 reconnect_task.await.unwrap();
777
778 tokio::time::sleep(Duration::from_millis(100)).await;
780 assert!(!subscribed.load(Ordering::Relaxed));
781 }
782
783 #[rstest]
785 #[tokio::test]
786 async fn test_state_machine_transitions() {
787 let tracker = AuthTracker::new();
788
789 let rx1 = tracker.begin();
791
792 tracker.succeed();
794 let result1: Result<(), TestError> =
795 tracker.wait_for_result(Duration::from_secs(1), rx1).await;
796 assert!(result1.is_ok());
797
798 let rx2 = tracker.begin();
800
801 tracker.fail("Error");
803 let result2: Result<(), TestError> =
804 tracker.wait_for_result(Duration::from_secs(1), rx2).await;
805 assert!(result2.is_err());
806
807 let rx3 = tracker.begin();
809
810 let result3: Result<(), TestError> = tracker
812 .wait_for_result(Duration::from_millis(50), rx3)
813 .await;
814 assert_eq!(
815 result3.unwrap_err(),
816 TestError("Authentication timed out".to_string())
817 );
818
819 let rx4 = tracker.begin();
821
822 let rx5 = tracker.begin();
824 let result4: Result<(), TestError> =
825 tracker.wait_for_result(Duration::from_secs(1), rx4).await;
826 assert_eq!(
827 result4.unwrap_err(),
828 TestError("Authentication attempt superseded".to_string())
829 );
830
831 tracker.succeed();
833 let result5: Result<(), TestError> =
834 tracker.wait_for_result(Duration::from_secs(1), rx5).await;
835 assert!(result5.is_ok());
836 }
837
838 #[rstest]
840 #[tokio::test]
841 async fn test_no_sender_leaks() {
842 let tracker = AuthTracker::new();
843
844 for _ in 0..100 {
845 let rx = tracker.begin();
846 let _result: Result<(), TestError> =
847 tracker.wait_for_result(Duration::from_millis(1), rx).await;
848 }
849
850 let rx = tracker.begin();
851 tracker.succeed();
852 let result: Result<(), TestError> =
853 tracker.wait_for_result(Duration::from_secs(1), rx).await;
854 assert!(result.is_ok());
855 }
856
857 #[rstest]
859 #[tokio::test]
860 async fn test_concurrent_succeed_fail_calls() {
861 let tracker = Arc::new(AuthTracker::new());
862 let rx = tracker.begin();
863
864 let mut handles = vec![];
865
866 for _ in 0..50 {
868 let tracker_clone = Arc::clone(&tracker);
869 handles.push(tokio::spawn(async move {
870 tracker_clone.succeed();
871 }));
872 }
873
874 for _ in 0..50 {
876 let tracker_clone = Arc::clone(&tracker);
877 handles.push(tokio::spawn(async move {
878 tracker_clone.fail("Error");
879 }));
880 }
881
882 for handle in handles {
884 handle.await.unwrap();
885 }
886
887 let result: Result<(), TestError> =
889 tracker.wait_for_result(Duration::from_secs(1), rx).await;
890 let _ = result;
892 }
893
894 #[rstest]
895 fn test_is_authenticated_initial_state() {
896 let tracker = AuthTracker::new();
897 assert!(!tracker.is_authenticated());
898 }
899
900 #[rstest]
901 #[tokio::test]
902 async fn test_is_authenticated_after_succeed() {
903 let tracker = AuthTracker::new();
904 assert!(!tracker.is_authenticated());
905
906 let _rx = tracker.begin();
907 assert!(!tracker.is_authenticated());
908
909 tracker.succeed();
910 assert!(tracker.is_authenticated());
911 }
912
913 #[rstest]
914 #[tokio::test]
915 async fn test_is_authenticated_after_fail() {
916 let tracker = AuthTracker::new();
917 let _rx = tracker.begin();
918 tracker.fail("error");
919 assert!(!tracker.is_authenticated());
920 }
921
922 #[rstest]
923 #[tokio::test]
924 async fn test_invalidate_clears_auth_state() {
925 let tracker = AuthTracker::new();
926 let _rx = tracker.begin();
927 tracker.succeed();
928 assert!(tracker.is_authenticated());
929
930 tracker.invalidate();
931 assert!(!tracker.is_authenticated());
932 }
933
934 #[rstest]
935 #[tokio::test]
936 async fn test_begin_clears_auth_state() {
937 let tracker = AuthTracker::new();
938 let _rx1 = tracker.begin();
939 tracker.succeed();
940 assert!(tracker.is_authenticated());
941
942 let _rx2 = tracker.begin();
943 assert!(!tracker.is_authenticated());
944 }
945
946 #[rstest]
947 fn test_is_authenticated_shared_across_clones() {
948 let tracker = AuthTracker::new();
949 let cloned = tracker.clone();
950
951 let _rx = tracker.begin();
952 tracker.succeed();
953
954 assert!(cloned.is_authenticated());
955 }
956
957 #[rstest]
958 fn test_invalidate_shared_across_clones() {
959 let tracker = AuthTracker::new();
960 let cloned = tracker.clone();
961
962 let _rx = tracker.begin();
963 tracker.succeed();
964 assert!(tracker.is_authenticated());
965
966 cloned.invalidate();
967 assert!(!tracker.is_authenticated());
968 }
969
970 #[rstest]
971 fn test_succeed_without_begin_still_updates_auth_state() {
972 let tracker = AuthTracker::new();
973 assert!(!tracker.is_authenticated());
974
975 tracker.succeed();
977 assert!(tracker.is_authenticated());
978 }
979
980 #[rstest]
981 fn test_fail_without_begin_still_updates_auth_state() {
982 let tracker = AuthTracker::new();
983 tracker.succeed();
984 assert!(tracker.is_authenticated());
985
986 tracker.fail("error");
988 assert!(!tracker.is_authenticated());
989 }
990
991 #[rstest]
992 #[tokio::test]
993 async fn test_auth_state_false_after_timeout_until_late_response() {
994 let tracker = AuthTracker::new();
995 let rx = tracker.begin();
996 assert!(!tracker.is_authenticated());
997
998 let result: Result<(), TestError> =
999 tracker.wait_for_result(Duration::from_millis(10), rx).await;
1000
1001 assert!(result.is_err());
1002 assert!(!tracker.is_authenticated());
1003
1004 tracker.succeed();
1006 assert!(tracker.is_authenticated());
1007 }
1008
1009 #[rstest]
1010 #[tokio::test]
1011 async fn test_wait_for_authenticated_already_authenticated() {
1012 let tracker = AuthTracker::new();
1013 let _rx = tracker.begin();
1014 tracker.succeed();
1015
1016 assert!(
1017 tracker
1018 .wait_for_authenticated(Duration::from_millis(50))
1019 .await
1020 );
1021 }
1022
1023 #[rstest]
1024 #[tokio::test]
1025 async fn test_wait_for_authenticated_succeeds_after_delay() {
1026 let tracker = AuthTracker::new();
1027 let _rx = tracker.begin();
1028
1029 let tracker_clone = tracker.clone();
1030
1031 tokio::spawn(async move {
1032 tokio::time::sleep(Duration::from_millis(50)).await;
1033 tracker_clone.succeed();
1034 });
1035
1036 assert!(tracker.wait_for_authenticated(Duration::from_secs(1)).await);
1037 }
1038
1039 #[rstest]
1040 #[tokio::test]
1041 async fn test_wait_for_authenticated_returns_false_on_failure() {
1042 let tracker = AuthTracker::new();
1043 let _rx = tracker.begin();
1044
1045 let tracker_clone = tracker.clone();
1046
1047 tokio::spawn(async move {
1048 tokio::time::sleep(Duration::from_millis(50)).await;
1049 tracker_clone.fail("rejected");
1050 });
1051
1052 let start = tokio::time::Instant::now();
1053 let result = tracker.wait_for_authenticated(Duration::from_secs(5)).await;
1054 let elapsed = start.elapsed();
1055
1056 assert!(!result);
1057 assert!(elapsed < Duration::from_secs(1));
1058 }
1059
1060 #[rstest]
1061 #[tokio::test]
1062 async fn test_wait_for_authenticated_times_out() {
1063 let tracker = AuthTracker::new();
1064 let _rx = tracker.begin();
1065
1066 assert!(
1067 !tracker
1068 .wait_for_authenticated(Duration::from_millis(50))
1069 .await
1070 );
1071 }
1072
1073 #[rstest]
1074 #[tokio::test]
1075 async fn test_wait_for_authenticated_begin_clears_failed() {
1076 let tracker = AuthTracker::new();
1077 let _rx = tracker.begin();
1078 tracker.fail("first attempt");
1079
1080 assert!(
1081 !tracker
1082 .wait_for_authenticated(Duration::from_millis(10))
1083 .await
1084 );
1085
1086 let _rx = tracker.begin();
1088
1089 let tracker_clone = tracker.clone();
1090
1091 tokio::spawn(async move {
1092 tokio::time::sleep(Duration::from_millis(50)).await;
1093 tracker_clone.succeed();
1094 });
1095
1096 assert!(tracker.wait_for_authenticated(Duration::from_secs(1)).await);
1097 }
1098
1099 #[rstest]
1100 #[tokio::test]
1101 async fn test_wait_for_authenticated_invalidate_does_not_return_false() {
1102 let tracker = AuthTracker::new();
1103 let _rx = tracker.begin();
1104
1105 let tracker_clone = tracker.clone();
1106
1107 tokio::spawn(async move {
1108 tokio::time::sleep(Duration::from_millis(20)).await;
1110 tracker_clone.invalidate();
1111 tokio::time::sleep(Duration::from_millis(20)).await;
1113 tracker_clone.succeed();
1114 });
1115
1116 assert!(tracker.wait_for_authenticated(Duration::from_secs(1)).await);
1117 }
1118
1119 #[rstest]
1120 #[tokio::test]
1121 async fn test_wait_for_authenticated_concurrent_waiters() {
1122 let tracker = Arc::new(AuthTracker::new());
1123 let _rx = tracker.begin();
1124
1125 let mut handles = vec![];
1126
1127 for _ in 0..10 {
1128 let t = Arc::clone(&tracker);
1129 handles.push(tokio::spawn(async move {
1130 t.wait_for_authenticated(Duration::from_secs(1)).await
1131 }));
1132 }
1133
1134 tokio::time::sleep(Duration::from_millis(50)).await;
1135 tracker.succeed();
1136
1137 for handle in handles {
1138 assert!(handle.await.unwrap());
1139 }
1140 }
1141
1142 #[rstest]
1143 #[tokio::test]
1144 async fn test_wait_for_authenticated_not_authenticated_initially() {
1145 let tracker = AuthTracker::new();
1146
1147 assert!(
1150 !tracker
1151 .wait_for_authenticated(Duration::from_millis(50))
1152 .await
1153 );
1154 }
1155}
1156
1157#[cfg(test)]
1158mod proptest_tests {
1159 use std::{sync::Arc, time::Duration};
1160
1161 use proptest::prelude::*;
1162 use rstest::rstest;
1163
1164 use super::*;
1165
1166 proptest! {
1167 #[rstest]
1171 fn test_state_consistency_after_random_operations(
1172 ops in proptest::collection::vec(0u8..4, 1..50)
1173 ) {
1174 let tracker = AuthTracker::new();
1175 let mut expected_auth = false;
1176
1177 for op in &ops {
1178 match op {
1179 0 => {
1180 let _rx = tracker.begin();
1181 expected_auth = false;
1182 }
1183 1 => {
1184 tracker.succeed();
1185 expected_auth = true;
1186 }
1187 2 => {
1188 tracker.fail("test");
1189 expected_auth = false;
1190 }
1191 3 => {
1192 tracker.invalidate();
1193 expected_auth = false;
1194 }
1195 _ => unreachable!(),
1196 }
1197 }
1198
1199 prop_assert_eq!(tracker.is_authenticated(), expected_auth);
1200 }
1201
1202 #[rstest]
1205 fn test_begin_always_clears_failed(
1206 prior_ops in proptest::collection::vec(0u8..4, 0..20)
1207 ) {
1208 let tracker = AuthTracker::new();
1209
1210 for op in &prior_ops {
1211 match op {
1212 0 => { let _rx = tracker.begin(); }
1213 1 => tracker.succeed(),
1214 2 => tracker.fail("test"),
1215 3 => tracker.invalidate(),
1216 _ => unreachable!(),
1217 }
1218 }
1219
1220 let _rx = tracker.begin();
1221 prop_assert_eq!(tracker.auth_state(), AuthState::Unauthenticated);
1223 }
1224
1225 #[rstest]
1228 fn test_succeed_always_sets_authenticated(
1229 prior_ops in proptest::collection::vec(0u8..4, 0..20)
1230 ) {
1231 let tracker = AuthTracker::new();
1232
1233 for op in &prior_ops {
1234 match op {
1235 0 => { let _rx = tracker.begin(); }
1236 1 => tracker.succeed(),
1237 2 => tracker.fail("test"),
1238 3 => tracker.invalidate(),
1239 _ => unreachable!(),
1240 }
1241 }
1242
1243 tracker.succeed();
1244 prop_assert_eq!(tracker.auth_state(), AuthState::Authenticated);
1245 }
1246 }
1247
1248 #[rstest]
1251 #[tokio::test]
1252 async fn test_wait_responds_within_bounded_time() {
1253 for auth_result in [true, false] {
1254 let tracker = Arc::new(AuthTracker::new());
1255 let _rx = tracker.begin();
1256
1257 let tracker_clone = Arc::clone(&tracker);
1258
1259 tokio::spawn(async move {
1260 tokio::time::sleep(Duration::from_millis(30)).await;
1261
1262 if auth_result {
1263 tracker_clone.succeed();
1264 } else {
1265 tracker_clone.fail("rejected");
1266 }
1267 });
1268
1269 let start = tokio::time::Instant::now();
1270 let result = tracker
1271 .wait_for_authenticated(Duration::from_secs(10))
1272 .await;
1273 let elapsed = start.elapsed();
1274
1275 assert_eq!(result, auth_result);
1276 assert!(
1277 elapsed < Duration::from_millis(500),
1278 "wait_for_authenticated took {elapsed:?} for auth_result={auth_result}"
1279 );
1280 }
1281 }
1282}