Skip to main content

nautilus_network/websocket/
auth.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Authentication state tracking for WebSocket clients.
17//!
18//! This module provides a robust authentication tracker that coordinates login attempts
19//! and ensures each attempt produces a fresh success or failure signal before operations
20//! resume. It follows a proven pattern used in production.
21//!
22//! # Key Features
23//!
24//! - **Three-state model**: `Unauthenticated`, `Authenticated`, `Failed` via `AuthState` enum.
25//! - **Oneshot signaling**: Each auth attempt gets a dedicated channel for result notification.
26//! - **Superseding logic**: New authentication requests cancel pending ones.
27//! - **Timeout handling**: Configurable timeout for authentication responses.
28//! - **Generic error mapping**: Adapters can map to their specific error types.
29//! - **Auth-gated waiting**: `wait_for_authenticated()` blocks until auth completes or fails.
30//!
31//! # Recommended Integration Pattern
32//!
33//! Based on production usage, the recommended pattern is:
34//!
35//! 1. **Order operations**: Call `wait_for_authenticated()` before private operations.
36//!    This waits for re-auth after reconnection instead of rejecting immediately.
37//! 2. **Reconnection flow**: Authenticate BEFORE resubscribing to topics.
38//! 3. **Event propagation**: Send auth failures through event channels to consumers.
39//! 4. **State lifecycle**: Call `invalidate()` on disconnect, `succeed()`/`fail()` handle auth results.
40
41use std::{
42    sync::{
43        Arc, Mutex,
44        atomic::{AtomicU8, Ordering},
45    },
46    time::Duration,
47};
48
49pub type AuthResultSender = tokio::sync::oneshot::Sender<Result<(), String>>;
50pub type AuthResultReceiver = tokio::sync::oneshot::Receiver<Result<(), String>>;
51
52/// Authentication state for a WebSocket session.
53#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
54#[repr(u8)]
55pub enum AuthState {
56    /// Not authenticated (initial state, after invalidate/begin).
57    #[default]
58    Unauthenticated = 0,
59    /// Successfully authenticated (after succeed).
60    Authenticated = 1,
61    /// Authentication explicitly rejected by the server (after fail).
62    Failed = 2,
63}
64
65impl AuthState {
66    #[inline]
67    #[must_use]
68    #[expect(
69        clippy::match_same_arms,
70        reason = "explicit variant listing is clearer than collapsing 0 with wildcard"
71    )]
72    fn from_u8(value: u8) -> Self {
73        match value {
74            0 => Self::Unauthenticated,
75            1 => Self::Authenticated,
76            2 => Self::Failed,
77            _ => Self::Unauthenticated,
78        }
79    }
80
81    #[inline]
82    #[must_use]
83    const fn as_u8(self) -> u8 {
84        self as u8
85    }
86}
87
88/// Generic authentication state tracker for WebSocket connections.
89///
90/// Coordinates authentication attempts by providing a channel-based signaling
91/// mechanism. Each authentication attempt receives a dedicated oneshot channel
92/// that will be resolved when the server responds.
93///
94/// # State Management
95///
96/// The tracker maintains a three-state machine:
97/// - `Unauthenticated`: after `begin()`, `invalidate()`, or initial construction.
98/// - `Authenticated`: after `succeed()`. Queryable via `is_authenticated()`.
99/// - `Failed`: after `fail()`. Causes `wait_for_authenticated()` to return early.
100///
101/// # Superseding Behavior
102///
103/// If a new authentication attempt begins while a previous one is pending,
104/// the old attempt is automatically cancelled with an error. This prevents
105/// auth response race conditions during rapid reconnections.
106///
107/// # Thread Safety
108///
109/// All operations are thread-safe and can be called concurrently from multiple tasks.
110#[derive(Clone, Debug)]
111pub struct AuthTracker {
112    tx: Arc<Mutex<Option<AuthResultSender>>>,
113    state: Arc<AtomicU8>,
114    state_notify: Arc<tokio::sync::Notify>,
115}
116
117impl AuthTracker {
118    /// Creates a new authentication tracker.
119    #[must_use]
120    pub fn new() -> Self {
121        Self {
122            tx: Arc::new(Mutex::new(None)),
123            state: Arc::new(AtomicU8::new(AuthState::Unauthenticated.as_u8())),
124            state_notify: Arc::new(tokio::sync::Notify::new()),
125        }
126    }
127
128    /// Returns the current authentication state.
129    #[must_use]
130    pub fn auth_state(&self) -> AuthState {
131        AuthState::from_u8(self.state.load(Ordering::Acquire))
132    }
133
134    /// Returns whether the client is currently authenticated.
135    #[must_use]
136    pub fn is_authenticated(&self) -> bool {
137        self.auth_state() == AuthState::Authenticated
138    }
139
140    /// Clears the authentication state without affecting pending auth attempts.
141    ///
142    /// Call this on disconnect or when the connection is closed to ensure
143    /// operations requiring authentication are properly guarded.
144    pub fn invalidate(&self) {
145        self.state
146            .store(AuthState::Unauthenticated.as_u8(), Ordering::Release);
147        self.state_notify.notify_waiters();
148    }
149
150    /// Begins a new authentication attempt.
151    ///
152    /// Returns a receiver that will be notified when authentication completes.
153    /// If a previous authentication attempt is still pending, it will be cancelled
154    /// with an error message indicating it was superseded.
155    ///
156    /// Transitions to `Unauthenticated` since a new attempt invalidates any
157    /// previous status.
158    #[allow(
159        clippy::must_use_candidate,
160        reason = "callers use this for side effects"
161    )]
162    pub fn begin(&self) -> AuthResultReceiver {
163        let (sender, receiver) = tokio::sync::oneshot::channel();
164        self.state
165            .store(AuthState::Unauthenticated.as_u8(), Ordering::Release);
166
167        if let Ok(mut guard) = self.tx.lock() {
168            if let Some(old) = guard.take() {
169                log::warn!("New authentication request superseding previous pending request");
170                let _ = old.send(Err("Authentication attempt superseded".to_string()));
171            } else {
172                log::debug!("Starting new authentication request");
173            }
174            *guard = Some(sender);
175        }
176
177        receiver
178    }
179
180    /// Marks the current authentication attempt as successful.
181    ///
182    /// Transitions to `Authenticated` and notifies any waiting receiver
183    /// with `Ok(())`. This should be called when the server sends a successful
184    /// authentication response.
185    ///
186    /// The state is always updated even if no receiver is waiting (e.g., after
187    /// a timeout), since the server has confirmed authentication.
188    pub fn succeed(&self) {
189        self.state
190            .store(AuthState::Authenticated.as_u8(), Ordering::Release);
191        self.state_notify.notify_waiters();
192
193        if let Ok(mut guard) = self.tx.lock()
194            && let Some(sender) = guard.take()
195        {
196            let _ = sender.send(Ok(()));
197        }
198    }
199
200    /// Marks the current authentication attempt as failed.
201    ///
202    /// Transitions to `Failed` and notifies any waiting receiver
203    /// with `Err(message)`. This should be called when the server sends an
204    /// authentication error response.
205    ///
206    /// The state is always updated even if no receiver is waiting, since the
207    /// server has rejected authentication.
208    pub fn fail(&self, error: impl Into<String>) {
209        self.state
210            .store(AuthState::Failed.as_u8(), Ordering::Release);
211        self.state_notify.notify_waiters();
212        let message = error.into();
213
214        if let Ok(mut guard) = self.tx.lock()
215            && let Some(sender) = guard.take()
216        {
217            let _ = sender.send(Err(message));
218        }
219    }
220
221    /// Waits for the authentication result with a timeout.
222    ///
223    /// Returns `Ok(())` if authentication succeeds, or an error if it fails,
224    /// times out, or the channel is closed.
225    ///
226    /// # Type Parameters
227    ///
228    /// - `E`: Error type that implements `From<String>` for error message conversion
229    ///
230    /// # Errors
231    ///
232    /// Returns an error in the following cases:
233    /// - Authentication fails (server rejects credentials)
234    /// - Authentication times out (no response within timeout duration)
235    /// - Authentication channel closes unexpectedly
236    /// - Authentication attempt is superseded by a new attempt
237    pub async fn wait_for_result<E>(
238        &self,
239        timeout: Duration,
240        receiver: AuthResultReceiver,
241    ) -> Result<(), E>
242    where
243        E: From<String>,
244    {
245        match tokio::time::timeout(timeout, receiver).await {
246            Ok(Ok(Ok(()))) => Ok(()),
247            Ok(Ok(Err(msg))) => Err(E::from(msg)),
248            Ok(Err(_)) => Err(E::from("Authentication channel closed".to_string())),
249            Err(_) => {
250                // Don't clear the sender: a concurrent begin() may have replaced it,
251                // and guard.take() would cancel the newer sender. The next begin()
252                // call cleans up any stale sender.
253                Err(E::from("Authentication timed out".to_string()))
254            }
255        }
256    }
257
258    /// Waits for the tracker to enter the authenticated state.
259    ///
260    /// Returns `true` if authenticated within the timeout, `false` if the timeout
261    /// expires or authentication explicitly fails. Uses event-driven notification
262    /// from `succeed()` / `fail()` / `invalidate()` to avoid polling.
263    ///
264    /// Returns early with `false` when `fail()` is called (e.g., the exchange
265    /// rejects credentials), so callers are not blocked for the full timeout
266    /// on a definitive auth rejection.
267    ///
268    /// This is intended for callers on a separate task who need to gate operations
269    /// on authentication state (e.g., order sends that must wait for re-authentication
270    /// after a WebSocket reconnection).
271    pub async fn wait_for_authenticated(&self, timeout: Duration) -> bool {
272        if self.is_authenticated() {
273            return true;
274        }
275
276        tokio::time::timeout(timeout, async {
277            loop {
278                let notified = self.state_notify.notified();
279
280                match self.auth_state() {
281                    AuthState::Authenticated => return true,
282                    AuthState::Failed => return false,
283                    AuthState::Unauthenticated => notified.await,
284                }
285            }
286        })
287        .await
288        .unwrap_or(false)
289    }
290}
291
292impl Default for AuthTracker {
293    fn default() -> Self {
294        Self::new()
295    }
296}
297
298#[cfg(test)]
299mod tests {
300    use std::{
301        sync::atomic::{AtomicBool, Ordering},
302        time::Duration,
303    };
304
305    use rstest::rstest;
306
307    use super::*;
308
309    #[derive(Debug, PartialEq)]
310    struct TestError(String);
311
312    impl From<String> for TestError {
313        fn from(msg: String) -> Self {
314            Self(msg)
315        }
316    }
317
318    #[rstest]
319    #[tokio::test]
320    async fn test_successful_authentication() {
321        let tracker = AuthTracker::new();
322        let rx = tracker.begin();
323
324        tracker.succeed();
325
326        let result: Result<(), TestError> =
327            tracker.wait_for_result(Duration::from_secs(1), rx).await;
328
329        assert!(result.is_ok());
330    }
331
332    #[rstest]
333    #[tokio::test]
334    async fn test_failed_authentication() {
335        let tracker = AuthTracker::new();
336        let rx = tracker.begin();
337
338        tracker.fail("Invalid credentials");
339
340        let result: Result<(), TestError> =
341            tracker.wait_for_result(Duration::from_secs(1), rx).await;
342
343        assert_eq!(
344            result.unwrap_err(),
345            TestError("Invalid credentials".to_string())
346        );
347    }
348
349    #[rstest]
350    #[tokio::test]
351    async fn test_authentication_timeout() {
352        let tracker = AuthTracker::new();
353        let rx = tracker.begin();
354
355        // Don't call succeed or fail - let it timeout
356
357        let result: Result<(), TestError> =
358            tracker.wait_for_result(Duration::from_millis(50), rx).await;
359
360        assert_eq!(
361            result.unwrap_err(),
362            TestError("Authentication timed out".to_string())
363        );
364    }
365
366    #[rstest]
367    #[tokio::test]
368    async fn test_begin_supersedes_previous_sender() {
369        let tracker = AuthTracker::new();
370
371        let first = tracker.begin();
372        let second = tracker.begin();
373
374        // First receiver should get superseded error
375        let result = first.await.expect("oneshot closed unexpectedly");
376        assert_eq!(result, Err("Authentication attempt superseded".to_string()));
377
378        // Second attempt should succeed
379        tracker.succeed();
380        let result: Result<(), TestError> = tracker
381            .wait_for_result(Duration::from_secs(1), second)
382            .await;
383
384        assert!(result.is_ok());
385    }
386
387    #[rstest]
388    #[tokio::test]
389    async fn test_succeed_without_pending_auth() {
390        let tracker = AuthTracker::new();
391
392        // Calling succeed without begin should not panic
393        tracker.succeed();
394    }
395
396    #[rstest]
397    #[tokio::test]
398    async fn test_fail_without_pending_auth() {
399        let tracker = AuthTracker::new();
400
401        // Calling fail without begin should not panic
402        tracker.fail("Some error");
403    }
404
405    #[rstest]
406    #[tokio::test]
407    async fn test_multiple_sequential_authentications() {
408        let tracker = AuthTracker::new();
409
410        // First auth succeeds
411        let rx1 = tracker.begin();
412        tracker.succeed();
413        let result1: Result<(), TestError> =
414            tracker.wait_for_result(Duration::from_secs(1), rx1).await;
415        assert!(result1.is_ok());
416
417        // Second auth fails
418        let rx2 = tracker.begin();
419        tracker.fail("Credentials expired");
420        let result2: Result<(), TestError> =
421            tracker.wait_for_result(Duration::from_secs(1), rx2).await;
422        assert_eq!(
423            result2.unwrap_err(),
424            TestError("Credentials expired".to_string())
425        );
426
427        // Third auth succeeds
428        let rx3 = tracker.begin();
429        tracker.succeed();
430        let result3: Result<(), TestError> =
431            tracker.wait_for_result(Duration::from_secs(1), rx3).await;
432        assert!(result3.is_ok());
433    }
434
435    #[rstest]
436    #[tokio::test]
437    async fn test_channel_closed_before_result() {
438        let tracker = AuthTracker::new();
439        let rx = tracker.begin();
440
441        // Drop the tracker's sender by starting a new auth
442        tracker.begin();
443
444        // Original receiver should get channel closed error
445        let result: Result<(), TestError> =
446            tracker.wait_for_result(Duration::from_secs(1), rx).await;
447
448        assert_eq!(
449            result.unwrap_err(),
450            TestError("Authentication attempt superseded".to_string())
451        );
452    }
453
454    #[rstest]
455    #[tokio::test]
456    async fn test_concurrent_auth_attempts() {
457        let tracker = Arc::new(AuthTracker::new());
458        let mut handles = vec![];
459
460        // Spawn 10 concurrent auth attempts
461        for i in 0..10 {
462            let tracker_clone = Arc::clone(&tracker);
463            let handle = tokio::spawn(async move {
464                let rx = tracker_clone.begin();
465
466                // Only the last one should succeed
467                if i == 9 {
468                    tokio::time::sleep(Duration::from_millis(10)).await;
469                    tracker_clone.succeed();
470                }
471
472                let result: Result<(), TestError> = tracker_clone
473                    .wait_for_result(Duration::from_secs(1), rx)
474                    .await;
475
476                (i, result)
477            });
478            handles.push(handle);
479        }
480
481        let mut successes = 0;
482        let mut superseded = 0;
483
484        for handle in handles {
485            let (i, result) = handle.await.unwrap();
486            match result {
487                Ok(()) => {
488                    // Only task 9 should succeed
489                    assert_eq!(i, 9);
490                    successes += 1;
491                }
492                Err(TestError(msg)) if msg.contains("superseded") => {
493                    superseded += 1;
494                }
495                Err(e) => panic!("Unexpected error: {e:?}"),
496            }
497        }
498
499        assert_eq!(successes, 1);
500        assert_eq!(superseded, 9);
501    }
502
503    #[rstest]
504    fn test_default_trait() {
505        let _tracker = AuthTracker::default();
506    }
507
508    #[rstest]
509    #[tokio::test]
510    async fn test_clone_trait() {
511        let tracker = AuthTracker::new();
512        let cloned = tracker.clone();
513
514        // Verify cloned instance shares state with original (Arc behavior)
515        let rx = tracker.begin();
516        cloned.succeed(); // Succeed via clone affects original
517        let result: Result<(), TestError> =
518            tracker.wait_for_result(Duration::from_secs(1), rx).await;
519        assert!(result.is_ok());
520    }
521
522    #[rstest]
523    fn test_debug_trait() {
524        let tracker = AuthTracker::new();
525        let debug_str = format!("{tracker:?}");
526        assert!(debug_str.contains("AuthTracker"));
527    }
528
529    #[rstest]
530    #[tokio::test]
531    async fn test_timeout_clears_sender() {
532        let tracker = AuthTracker::new();
533
534        // Start auth that will timeout
535        let rx1 = tracker.begin();
536        let result1: Result<(), TestError> = tracker
537            .wait_for_result(Duration::from_millis(50), rx1)
538            .await;
539        assert_eq!(
540            result1.unwrap_err(),
541            TestError("Authentication timed out".to_string())
542        );
543
544        // Verify sender was cleared - new auth should work
545        let rx2 = tracker.begin();
546        tracker.succeed();
547        let result2: Result<(), TestError> =
548            tracker.wait_for_result(Duration::from_secs(1), rx2).await;
549        assert!(result2.is_ok());
550    }
551
552    #[rstest]
553    #[tokio::test]
554    async fn test_fail_clears_sender() {
555        let tracker = AuthTracker::new();
556
557        // Auth fails
558        let rx1 = tracker.begin();
559        tracker.fail("Bad credentials");
560        let result1: Result<(), TestError> =
561            tracker.wait_for_result(Duration::from_secs(1), rx1).await;
562        assert!(result1.is_err());
563
564        // Verify sender was cleared - new auth should work
565        let rx2 = tracker.begin();
566        tracker.succeed();
567        let result2: Result<(), TestError> =
568            tracker.wait_for_result(Duration::from_secs(1), rx2).await;
569        assert!(result2.is_ok());
570    }
571
572    #[rstest]
573    #[tokio::test]
574    async fn test_succeed_clears_sender() {
575        let tracker = AuthTracker::new();
576
577        // Auth succeeds
578        let rx1 = tracker.begin();
579        tracker.succeed();
580        let result1: Result<(), TestError> =
581            tracker.wait_for_result(Duration::from_secs(1), rx1).await;
582        assert!(result1.is_ok());
583
584        // Verify sender was cleared - new auth should work
585        let rx2 = tracker.begin();
586        tracker.succeed();
587        let result2: Result<(), TestError> =
588            tracker.wait_for_result(Duration::from_secs(1), rx2).await;
589        assert!(result2.is_ok());
590    }
591
592    #[rstest]
593    #[tokio::test]
594    async fn test_rapid_begin_succeed_cycles() {
595        let tracker = AuthTracker::new();
596
597        // Rapidly cycle through auth attempts
598        for _ in 0..100 {
599            let rx = tracker.begin();
600            tracker.succeed();
601            let result: Result<(), TestError> =
602                tracker.wait_for_result(Duration::from_secs(1), rx).await;
603            assert!(result.is_ok());
604        }
605    }
606
607    #[rstest]
608    #[tokio::test]
609    async fn test_double_succeed_is_safe() {
610        let tracker = AuthTracker::new();
611        let rx = tracker.begin();
612
613        // Call succeed twice
614        tracker.succeed();
615        tracker.succeed(); // Second call should be no-op
616
617        let result: Result<(), TestError> =
618            tracker.wait_for_result(Duration::from_secs(1), rx).await;
619        assert!(result.is_ok());
620    }
621
622    #[rstest]
623    #[tokio::test]
624    async fn test_double_fail_is_safe() {
625        let tracker = AuthTracker::new();
626        let rx = tracker.begin();
627
628        // Call fail twice
629        tracker.fail("Error 1");
630        tracker.fail("Error 2"); // Second call should be no-op
631
632        let result: Result<(), TestError> =
633            tracker.wait_for_result(Duration::from_secs(1), rx).await;
634        assert_eq!(
635            result.unwrap_err(),
636            TestError("Error 1".to_string()) // Should be first error
637        );
638    }
639
640    #[rstest]
641    #[tokio::test]
642    async fn test_succeed_after_fail_is_ignored() {
643        let tracker = AuthTracker::new();
644        let rx = tracker.begin();
645
646        tracker.fail("Auth failed");
647        tracker.succeed(); // This should be no-op
648
649        let result: Result<(), TestError> =
650            tracker.wait_for_result(Duration::from_secs(1), rx).await;
651        assert!(result.is_err()); // Should still be error
652    }
653
654    #[rstest]
655    #[tokio::test]
656    async fn test_fail_after_succeed_is_ignored() {
657        let tracker = AuthTracker::new();
658        let rx = tracker.begin();
659
660        tracker.succeed();
661        tracker.fail("Auth failed"); // This should be no-op
662
663        let result: Result<(), TestError> =
664            tracker.wait_for_result(Duration::from_secs(1), rx).await;
665        assert!(result.is_ok()); // Should still be success
666    }
667
668    /// Simulates a reconnect flow where authentication must complete before resubscription.
669    ///
670    /// This is an integration-style test that verifies:
671    /// 1. On reconnect, authentication starts first
672    /// 2. Subscription logic waits for auth to complete
673    /// 3. Subscriptions only proceed after successful auth
674    #[rstest]
675    #[tokio::test]
676    async fn test_reconnect_flow_waits_for_auth() {
677        let tracker = Arc::new(AuthTracker::new());
678        let subscribed = Arc::new(tokio::sync::Notify::new());
679        let auth_completed = Arc::new(tokio::sync::Notify::new());
680
681        // Simulate reconnect handler
682        let tracker_reconnect = Arc::clone(&tracker);
683        let subscribed_reconnect = Arc::clone(&subscribed);
684        let auth_completed_reconnect = Arc::clone(&auth_completed);
685
686        let reconnect_task = tokio::spawn(async move {
687            // Step 1: Begin authentication
688            let rx = tracker_reconnect.begin();
689
690            // Step 2: Spawn resubscription task that waits for auth
691            let tracker_resub = Arc::clone(&tracker_reconnect);
692            let subscribed_resub = Arc::clone(&subscribed_reconnect);
693            let auth_completed_resub = Arc::clone(&auth_completed_reconnect);
694
695            let resub_task = tokio::spawn(async move {
696                // Wait for auth to complete
697                let result: Result<(), TestError> = tracker_resub
698                    .wait_for_result(Duration::from_secs(5), rx)
699                    .await;
700
701                if result.is_ok() {
702                    auth_completed_resub.notify_one();
703                    // Simulate resubscription
704                    tokio::time::sleep(Duration::from_millis(10)).await;
705                    subscribed_resub.notify_one();
706                }
707            });
708
709            resub_task.await.unwrap();
710        });
711
712        // Simulate server auth response after delay
713        tokio::time::sleep(Duration::from_millis(100)).await;
714        tracker.succeed();
715
716        // Wait for reconnect flow to complete
717        reconnect_task.await.unwrap();
718
719        // Verify auth completed before subscription
720        tokio::select! {
721            () = auth_completed.notified() => {
722                // Good - auth completed
723            }
724            () = tokio::time::sleep(Duration::from_secs(1)) => {
725                panic!("Auth never completed");
726            }
727        }
728
729        // Verify subscription completed
730        tokio::select! {
731            () = subscribed.notified() => {
732                // Good - subscribed
733            }
734            () = tokio::time::sleep(Duration::from_secs(1)) => {
735                panic!("Subscription never completed");
736            }
737        }
738    }
739
740    /// Verifies that failed authentication prevents resubscription in reconnect flow.
741    #[rstest]
742    #[tokio::test]
743    async fn test_reconnect_flow_blocks_on_auth_failure() {
744        let tracker = Arc::new(AuthTracker::new());
745        let subscribed = Arc::new(AtomicBool::new(false));
746
747        let tracker_reconnect = Arc::clone(&tracker);
748        let subscribed_reconnect = Arc::clone(&subscribed);
749
750        let reconnect_task = tokio::spawn(async move {
751            let rx = tracker_reconnect.begin();
752
753            // Spawn resubscription task that waits for auth
754            let tracker_resub = Arc::clone(&tracker_reconnect);
755            let subscribed_resub = Arc::clone(&subscribed_reconnect);
756
757            let resub_task = tokio::spawn(async move {
758                let result: Result<(), TestError> = tracker_resub
759                    .wait_for_result(Duration::from_secs(5), rx)
760                    .await;
761
762                // Only subscribe if auth succeeds
763                if result.is_ok() {
764                    subscribed_resub.store(true, Ordering::Relaxed);
765                }
766            });
767
768            resub_task.await.unwrap();
769        });
770
771        // Simulate server auth failure
772        tokio::time::sleep(Duration::from_millis(50)).await;
773        tracker.fail("Invalid credentials");
774
775        // Wait for reconnect flow to complete
776        reconnect_task.await.unwrap();
777
778        // Verify subscription never happened
779        tokio::time::sleep(Duration::from_millis(100)).await;
780        assert!(!subscribed.load(Ordering::Relaxed));
781    }
782
783    /// Tests state machine transitions exhaustively.
784    #[rstest]
785    #[tokio::test]
786    async fn test_state_machine_transitions() {
787        let tracker = AuthTracker::new();
788
789        // Transition 1: Initial -> Pending (begin)
790        let rx1 = tracker.begin();
791
792        // Transition 2: Pending -> Success (succeed)
793        tracker.succeed();
794        let result1: Result<(), TestError> =
795            tracker.wait_for_result(Duration::from_secs(1), rx1).await;
796        assert!(result1.is_ok());
797
798        // Transition 3: Success -> Pending (begin again)
799        let rx2 = tracker.begin();
800
801        // Transition 4: Pending -> Failure (fail)
802        tracker.fail("Error");
803        let result2: Result<(), TestError> =
804            tracker.wait_for_result(Duration::from_secs(1), rx2).await;
805        assert!(result2.is_err());
806
807        // Transition 5: Failure -> Pending (begin again)
808        let rx3 = tracker.begin();
809
810        // Transition 6: Pending -> Timeout
811        let result3: Result<(), TestError> = tracker
812            .wait_for_result(Duration::from_millis(50), rx3)
813            .await;
814        assert_eq!(
815            result3.unwrap_err(),
816            TestError("Authentication timed out".to_string())
817        );
818
819        // Transition 7: Timeout -> Pending (begin again)
820        let rx4 = tracker.begin();
821
822        // Transition 8: Pending -> Superseded (begin interrupts)
823        let rx5 = tracker.begin();
824        let result4: Result<(), TestError> =
825            tracker.wait_for_result(Duration::from_secs(1), rx4).await;
826        assert_eq!(
827            result4.unwrap_err(),
828            TestError("Authentication attempt superseded".to_string())
829        );
830
831        // Final success to clean up
832        tracker.succeed();
833        let result5: Result<(), TestError> =
834            tracker.wait_for_result(Duration::from_secs(1), rx5).await;
835        assert!(result5.is_ok());
836    }
837
838    /// Verifies no memory leaks from orphaned senders.
839    #[rstest]
840    #[tokio::test]
841    async fn test_no_sender_leaks() {
842        let tracker = AuthTracker::new();
843
844        for _ in 0..100 {
845            let rx = tracker.begin();
846            let _result: Result<(), TestError> =
847                tracker.wait_for_result(Duration::from_millis(1), rx).await;
848        }
849
850        let rx = tracker.begin();
851        tracker.succeed();
852        let result: Result<(), TestError> =
853            tracker.wait_for_result(Duration::from_secs(1), rx).await;
854        assert!(result.is_ok());
855    }
856
857    /// Tests concurrent success/fail calls don't cause panics.
858    #[rstest]
859    #[tokio::test]
860    async fn test_concurrent_succeed_fail_calls() {
861        let tracker = Arc::new(AuthTracker::new());
862        let rx = tracker.begin();
863
864        let mut handles = vec![];
865
866        // Spawn many tasks trying to succeed
867        for _ in 0..50 {
868            let tracker_clone = Arc::clone(&tracker);
869            handles.push(tokio::spawn(async move {
870                tracker_clone.succeed();
871            }));
872        }
873
874        // Spawn many tasks trying to fail
875        for _ in 0..50 {
876            let tracker_clone = Arc::clone(&tracker);
877            handles.push(tokio::spawn(async move {
878                tracker_clone.fail("Error");
879            }));
880        }
881
882        // Wait for all tasks
883        for handle in handles {
884            handle.await.unwrap();
885        }
886
887        // Should get either success or failure, but not panic
888        let result: Result<(), TestError> =
889            tracker.wait_for_result(Duration::from_secs(1), rx).await;
890        // Don't care which outcome, just that it doesn't panic
891        let _ = result;
892    }
893
894    #[rstest]
895    fn test_is_authenticated_initial_state() {
896        let tracker = AuthTracker::new();
897        assert!(!tracker.is_authenticated());
898    }
899
900    #[rstest]
901    #[tokio::test]
902    async fn test_is_authenticated_after_succeed() {
903        let tracker = AuthTracker::new();
904        assert!(!tracker.is_authenticated());
905
906        let _rx = tracker.begin();
907        assert!(!tracker.is_authenticated());
908
909        tracker.succeed();
910        assert!(tracker.is_authenticated());
911    }
912
913    #[rstest]
914    #[tokio::test]
915    async fn test_is_authenticated_after_fail() {
916        let tracker = AuthTracker::new();
917        let _rx = tracker.begin();
918        tracker.fail("error");
919        assert!(!tracker.is_authenticated());
920    }
921
922    #[rstest]
923    #[tokio::test]
924    async fn test_invalidate_clears_auth_state() {
925        let tracker = AuthTracker::new();
926        let _rx = tracker.begin();
927        tracker.succeed();
928        assert!(tracker.is_authenticated());
929
930        tracker.invalidate();
931        assert!(!tracker.is_authenticated());
932    }
933
934    #[rstest]
935    #[tokio::test]
936    async fn test_begin_clears_auth_state() {
937        let tracker = AuthTracker::new();
938        let _rx1 = tracker.begin();
939        tracker.succeed();
940        assert!(tracker.is_authenticated());
941
942        let _rx2 = tracker.begin();
943        assert!(!tracker.is_authenticated());
944    }
945
946    #[rstest]
947    fn test_is_authenticated_shared_across_clones() {
948        let tracker = AuthTracker::new();
949        let cloned = tracker.clone();
950
951        let _rx = tracker.begin();
952        tracker.succeed();
953
954        assert!(cloned.is_authenticated());
955    }
956
957    #[rstest]
958    fn test_invalidate_shared_across_clones() {
959        let tracker = AuthTracker::new();
960        let cloned = tracker.clone();
961
962        let _rx = tracker.begin();
963        tracker.succeed();
964        assert!(tracker.is_authenticated());
965
966        cloned.invalidate();
967        assert!(!tracker.is_authenticated());
968    }
969
970    #[rstest]
971    fn test_succeed_without_begin_still_updates_auth_state() {
972        let tracker = AuthTracker::new();
973        assert!(!tracker.is_authenticated());
974
975        // State updates even without begin() to handle late responses after timeout
976        tracker.succeed();
977        assert!(tracker.is_authenticated());
978    }
979
980    #[rstest]
981    fn test_fail_without_begin_still_updates_auth_state() {
982        let tracker = AuthTracker::new();
983        tracker.succeed();
984        assert!(tracker.is_authenticated());
985
986        // State updates even without begin() to handle late responses
987        tracker.fail("error");
988        assert!(!tracker.is_authenticated());
989    }
990
991    #[rstest]
992    #[tokio::test]
993    async fn test_auth_state_false_after_timeout_until_late_response() {
994        let tracker = AuthTracker::new();
995        let rx = tracker.begin();
996        assert!(!tracker.is_authenticated());
997
998        let result: Result<(), TestError> =
999            tracker.wait_for_result(Duration::from_millis(10), rx).await;
1000
1001        assert!(result.is_err());
1002        assert!(!tracker.is_authenticated());
1003
1004        // Late response after timeout still updates state
1005        tracker.succeed();
1006        assert!(tracker.is_authenticated());
1007    }
1008
1009    #[rstest]
1010    #[tokio::test]
1011    async fn test_wait_for_authenticated_already_authenticated() {
1012        let tracker = AuthTracker::new();
1013        let _rx = tracker.begin();
1014        tracker.succeed();
1015
1016        assert!(
1017            tracker
1018                .wait_for_authenticated(Duration::from_millis(50))
1019                .await
1020        );
1021    }
1022
1023    #[rstest]
1024    #[tokio::test]
1025    async fn test_wait_for_authenticated_succeeds_after_delay() {
1026        let tracker = AuthTracker::new();
1027        let _rx = tracker.begin();
1028
1029        let tracker_clone = tracker.clone();
1030
1031        tokio::spawn(async move {
1032            tokio::time::sleep(Duration::from_millis(50)).await;
1033            tracker_clone.succeed();
1034        });
1035
1036        assert!(tracker.wait_for_authenticated(Duration::from_secs(1)).await);
1037    }
1038
1039    #[rstest]
1040    #[tokio::test]
1041    async fn test_wait_for_authenticated_returns_false_on_failure() {
1042        let tracker = AuthTracker::new();
1043        let _rx = tracker.begin();
1044
1045        let tracker_clone = tracker.clone();
1046
1047        tokio::spawn(async move {
1048            tokio::time::sleep(Duration::from_millis(50)).await;
1049            tracker_clone.fail("rejected");
1050        });
1051
1052        let start = tokio::time::Instant::now();
1053        let result = tracker.wait_for_authenticated(Duration::from_secs(5)).await;
1054        let elapsed = start.elapsed();
1055
1056        assert!(!result);
1057        assert!(elapsed < Duration::from_secs(1));
1058    }
1059
1060    #[rstest]
1061    #[tokio::test]
1062    async fn test_wait_for_authenticated_times_out() {
1063        let tracker = AuthTracker::new();
1064        let _rx = tracker.begin();
1065
1066        assert!(
1067            !tracker
1068                .wait_for_authenticated(Duration::from_millis(50))
1069                .await
1070        );
1071    }
1072
1073    #[rstest]
1074    #[tokio::test]
1075    async fn test_wait_for_authenticated_begin_clears_failed() {
1076        let tracker = AuthTracker::new();
1077        let _rx = tracker.begin();
1078        tracker.fail("first attempt");
1079
1080        assert!(
1081            !tracker
1082                .wait_for_authenticated(Duration::from_millis(10))
1083                .await
1084        );
1085
1086        // begin() clears the failed flag, allowing a fresh wait
1087        let _rx = tracker.begin();
1088
1089        let tracker_clone = tracker.clone();
1090
1091        tokio::spawn(async move {
1092            tokio::time::sleep(Duration::from_millis(50)).await;
1093            tracker_clone.succeed();
1094        });
1095
1096        assert!(tracker.wait_for_authenticated(Duration::from_secs(1)).await);
1097    }
1098
1099    #[rstest]
1100    #[tokio::test]
1101    async fn test_wait_for_authenticated_invalidate_does_not_return_false() {
1102        let tracker = AuthTracker::new();
1103        let _rx = tracker.begin();
1104
1105        let tracker_clone = tracker.clone();
1106
1107        tokio::spawn(async move {
1108            // invalidate wakes the loop but should not cause early false return
1109            tokio::time::sleep(Duration::from_millis(20)).await;
1110            tracker_clone.invalidate();
1111            // then succeed shortly after
1112            tokio::time::sleep(Duration::from_millis(20)).await;
1113            tracker_clone.succeed();
1114        });
1115
1116        assert!(tracker.wait_for_authenticated(Duration::from_secs(1)).await);
1117    }
1118
1119    #[rstest]
1120    #[tokio::test]
1121    async fn test_wait_for_authenticated_concurrent_waiters() {
1122        let tracker = Arc::new(AuthTracker::new());
1123        let _rx = tracker.begin();
1124
1125        let mut handles = vec![];
1126
1127        for _ in 0..10 {
1128            let t = Arc::clone(&tracker);
1129            handles.push(tokio::spawn(async move {
1130                t.wait_for_authenticated(Duration::from_secs(1)).await
1131            }));
1132        }
1133
1134        tokio::time::sleep(Duration::from_millis(50)).await;
1135        tracker.succeed();
1136
1137        for handle in handles {
1138            assert!(handle.await.unwrap());
1139        }
1140    }
1141
1142    #[rstest]
1143    #[tokio::test]
1144    async fn test_wait_for_authenticated_not_authenticated_initially() {
1145        let tracker = AuthTracker::new();
1146
1147        // Not authenticated, no begin() called, no failed flag set
1148        // Should time out
1149        assert!(
1150            !tracker
1151                .wait_for_authenticated(Duration::from_millis(50))
1152                .await
1153        );
1154    }
1155}
1156
1157#[cfg(test)]
1158mod proptest_tests {
1159    use std::{sync::Arc, time::Duration};
1160
1161    use proptest::prelude::*;
1162    use rstest::rstest;
1163
1164    use super::*;
1165
1166    proptest! {
1167        /// Verifies that any sequence of begin/succeed/fail/invalidate calls
1168        /// leaves the tracker in a consistent state where `is_authenticated`
1169        /// agrees with the last state-setting call.
1170        #[rstest]
1171        fn test_state_consistency_after_random_operations(
1172            ops in proptest::collection::vec(0u8..4, 1..50)
1173        ) {
1174            let tracker = AuthTracker::new();
1175            let mut expected_auth = false;
1176
1177            for op in &ops {
1178                match op {
1179                    0 => {
1180                        let _rx = tracker.begin();
1181                        expected_auth = false;
1182                    }
1183                    1 => {
1184                        tracker.succeed();
1185                        expected_auth = true;
1186                    }
1187                    2 => {
1188                        tracker.fail("test");
1189                        expected_auth = false;
1190                    }
1191                    3 => {
1192                        tracker.invalidate();
1193                        expected_auth = false;
1194                    }
1195                    _ => unreachable!(),
1196                }
1197            }
1198
1199            prop_assert_eq!(tracker.is_authenticated(), expected_auth);
1200        }
1201
1202        /// Verifies that begin() always clears the failed flag regardless of
1203        /// prior state, so a new auth attempt starts clean.
1204        #[rstest]
1205        fn test_begin_always_clears_failed(
1206            prior_ops in proptest::collection::vec(0u8..4, 0..20)
1207        ) {
1208            let tracker = AuthTracker::new();
1209
1210            for op in &prior_ops {
1211                match op {
1212                    0 => { let _rx = tracker.begin(); }
1213                    1 => tracker.succeed(),
1214                    2 => tracker.fail("test"),
1215                    3 => tracker.invalidate(),
1216                    _ => unreachable!(),
1217                }
1218            }
1219
1220            let _rx = tracker.begin();
1221            // After begin(), state is Unauthenticated
1222            prop_assert_eq!(tracker.auth_state(), AuthState::Unauthenticated);
1223        }
1224
1225        /// Verifies that succeed() always transitions to Authenticated,
1226        /// regardless of prior state.
1227        #[rstest]
1228        fn test_succeed_always_sets_authenticated(
1229            prior_ops in proptest::collection::vec(0u8..4, 0..20)
1230        ) {
1231            let tracker = AuthTracker::new();
1232
1233            for op in &prior_ops {
1234                match op {
1235                    0 => { let _rx = tracker.begin(); }
1236                    1 => tracker.succeed(),
1237                    2 => tracker.fail("test"),
1238                    3 => tracker.invalidate(),
1239                    _ => unreachable!(),
1240                }
1241            }
1242
1243            tracker.succeed();
1244            prop_assert_eq!(tracker.auth_state(), AuthState::Authenticated);
1245        }
1246    }
1247
1248    /// Verifies that `wait_for_authenticated` returns within a bounded time
1249    /// when `succeed()` or `fail()` is called, regardless of the timeout value.
1250    #[rstest]
1251    #[tokio::test]
1252    async fn test_wait_responds_within_bounded_time() {
1253        for auth_result in [true, false] {
1254            let tracker = Arc::new(AuthTracker::new());
1255            let _rx = tracker.begin();
1256
1257            let tracker_clone = Arc::clone(&tracker);
1258
1259            tokio::spawn(async move {
1260                tokio::time::sleep(Duration::from_millis(30)).await;
1261
1262                if auth_result {
1263                    tracker_clone.succeed();
1264                } else {
1265                    tracker_clone.fail("rejected");
1266                }
1267            });
1268
1269            let start = tokio::time::Instant::now();
1270            let result = tracker
1271                .wait_for_authenticated(Duration::from_secs(10))
1272                .await;
1273            let elapsed = start.elapsed();
1274
1275            assert_eq!(result, auth_result);
1276            assert!(
1277                elapsed < Duration::from_millis(500),
1278                "wait_for_authenticated took {elapsed:?} for auth_result={auth_result}"
1279            );
1280        }
1281    }
1282}