Skip to main content

nautilus_network/
backoff.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Provides an implementation of an exponential backoff mechanism with jitter support.
17//! It is used for managing reconnection delays in the socket clients.
18//!
19//! The backoff mechanism allows the delay to grow exponentially up to a configurable
20//! maximum, optionally applying random jitter to avoid synchronized reconnection storms.
21//! An "immediate first" flag is available so that the very first reconnect attempt
22//! can occur without any delay.
23
24use std::time::Duration;
25
26use nautilus_core::correctness::{check_in_range_inclusive_f64, check_predicate_true};
27use rand::RngExt;
28
29#[derive(Clone, Debug)]
30pub struct ExponentialBackoff {
31    /// The initial backoff delay.
32    delay_initial: Duration,
33    /// The maximum delay to cap the backoff.
34    delay_max: Duration,
35    /// The current backoff delay.
36    delay_current: Duration,
37    /// The factor to multiply the delay on each iteration.
38    factor: f64,
39    /// The maximum random jitter to add (in milliseconds).
40    jitter_ms: u64,
41    /// If true, the first call to `next()` returns zero delay (immediate reconnect).
42    immediate_reconnect: bool,
43    /// The original value of `immediate_reconnect` for reset purposes.
44    immediate_reconnect_original: bool,
45}
46
47/// An exponential backoff mechanism with optional jitter and immediate-first behavior.
48///
49/// This struct computes successive delays for reconnect attempts.
50/// It starts from an initial delay and multiplies it by a factor on each iteration,
51/// capping the delay at a maximum value. Random jitter is added (up to a configured
52/// maximum) to the delay. When `immediate_first` is true, the first call to `next_duration`
53/// returns zero delay, triggering an immediate reconnect, after which the immediate flag is disabled.
54impl ExponentialBackoff {
55    /// Creates a new [`ExponentialBackoff]` instance.
56    ///
57    /// # Errors
58    ///
59    /// Returns an error if:
60    /// - `delay_initial` is zero.
61    /// - `delay_max` is less than `delay_initial`.
62    /// - `delay_max` exceeds `Duration::from_nanos(u64::MAX)` (≈584 years).
63    /// - `factor` is not in the range [1.0, 100.0] (to prevent reconnect spam).
64    pub fn new(
65        delay_initial: Duration,
66        delay_max: Duration,
67        factor: f64,
68        jitter_ms: u64,
69        immediate_first: bool,
70    ) -> anyhow::Result<Self> {
71        check_predicate_true(!delay_initial.is_zero(), "delay_initial must be non-zero")?;
72        check_predicate_true(
73            delay_max >= delay_initial,
74            "delay_max must be >= delay_initial",
75        )?;
76        check_predicate_true(
77            delay_max.as_nanos() <= u128::from(u64::MAX),
78            "delay_max exceeds maximum representable duration (≈584 years)",
79        )?;
80        check_in_range_inclusive_f64(factor, 1.0, 100.0, "factor")?;
81
82        Ok(Self {
83            delay_initial,
84            delay_max,
85            delay_current: delay_initial,
86            factor,
87            jitter_ms,
88            immediate_reconnect: immediate_first,
89            immediate_reconnect_original: immediate_first,
90        })
91    }
92
93    /// Return the next backoff delay with jitter and update the internal state.
94    ///
95    /// If the `immediate_first` flag is set and this is the first call (i.e. the current
96    /// delay equals the initial delay), it returns `Duration::ZERO` to trigger an immediate
97    /// reconnect and disables the immediate behavior for subsequent calls.
98    pub fn next_duration(&mut self) -> Duration {
99        if self.immediate_reconnect && self.delay_current == self.delay_initial {
100            self.immediate_reconnect = false;
101            return Duration::ZERO;
102        }
103
104        // Generate random jitter
105        let jitter = rand::rng().random_range(0..=self.jitter_ms); // dst-ok: transport-layer reconnect jitter, out of DST scope
106        let delay_with_jitter = self.delay_current + Duration::from_millis(jitter);
107
108        // Clamp the returned delay to never exceed delay_max
109        let clamped_delay = std::cmp::min(delay_with_jitter, self.delay_max);
110
111        // Prepare the next delay with overflow protection
112        // Keep all math in u128 to avoid silent truncation
113        let current_nanos = self.delay_current.as_nanos();
114        let max_nanos = self.delay_max.as_nanos();
115
116        // Use checked floating point multiplication to prevent overflow
117        let next_nanos_u128 = if current_nanos > u128::from(u64::MAX) {
118            // Current is already at max representable value, cap to max
119            max_nanos
120        } else {
121            let current_u64 = current_nanos as u64;
122            let next_f64 = current_u64 as f64 * self.factor;
123
124            // Check for overflow in the float result
125            if next_f64 > u64::MAX as f64 {
126                u128::from(u64::MAX)
127            } else {
128                u128::from(next_f64 as u64)
129            }
130        };
131
132        let clamped = std::cmp::min(next_nanos_u128, max_nanos);
133        let final_nanos = if clamped > u128::from(u64::MAX) {
134            u64::MAX
135        } else {
136            clamped as u64
137        };
138
139        self.delay_current = Duration::from_nanos(final_nanos);
140
141        clamped_delay
142    }
143
144    /// Reset the backoff to its initial state.
145    pub const fn reset(&mut self) {
146        self.delay_current = self.delay_initial;
147        self.immediate_reconnect = self.immediate_reconnect_original;
148    }
149
150    /// Returns the current base delay without jitter.
151    /// This represents the delay that would be used as the base for the next call to `next()`,
152    /// before any jitter is applied.
153    #[must_use]
154    pub const fn current_delay(&self) -> Duration {
155        self.delay_current
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use std::time::Duration;
162
163    use rstest::rstest;
164
165    use super::*;
166
167    #[rstest]
168    fn test_no_jitter_exponential_growth() {
169        let initial = Duration::from_millis(100);
170        let max = Duration::from_millis(1600);
171        let factor = 2.0;
172        let jitter = 0;
173        let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
174
175        // 1st call returns the initial delay
176        let d1 = backoff.next_duration();
177        assert_eq!(d1, Duration::from_millis(100));
178
179        // 2nd call: current becomes 200ms
180        let d2 = backoff.next_duration();
181        assert_eq!(d2, Duration::from_millis(200));
182
183        // 3rd call: current becomes 400ms
184        let d3 = backoff.next_duration();
185        assert_eq!(d3, Duration::from_millis(400));
186
187        // 4th call: current becomes 800ms
188        let d4 = backoff.next_duration();
189        assert_eq!(d4, Duration::from_millis(800));
190
191        // 5th call: current would be 1600ms (800 * 2) which is within the cap
192        let d5 = backoff.next_duration();
193        assert_eq!(d5, Duration::from_millis(1600));
194
195        // 6th call: should still be capped at 1600ms
196        let d6 = backoff.next_duration();
197        assert_eq!(d6, Duration::from_millis(1600));
198    }
199
200    #[rstest]
201    fn test_reset() {
202        let initial = Duration::from_millis(100);
203        let max = Duration::from_millis(1600);
204        let factor = 2.0;
205        let jitter = 0;
206        let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
207
208        // Call next() once so that the internal state updates
209        let _ = backoff.next_duration(); // current_delay becomes 200ms
210        backoff.reset();
211        let d = backoff.next_duration();
212        // After reset, the next delay should be the initial delay (100ms)
213        assert_eq!(d, Duration::from_millis(100));
214    }
215
216    #[rstest]
217    fn test_jitter_within_bounds() {
218        let initial = Duration::from_millis(100);
219        let max = Duration::from_secs(1);
220        let factor = 2.0;
221        let jitter = 50;
222        // Run several iterations to ensure that jitter stays within bounds
223        for _ in 0..10 {
224            let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
225            // Capture the expected base delay before jitter is applied
226            let base = backoff.delay_current;
227            let delay = backoff.next_duration();
228            // The returned delay must be at least the base delay and at most base + jitter
229            let min_expected = base;
230            let max_expected = base + Duration::from_millis(jitter);
231            assert!(
232                delay >= min_expected,
233                "Delay {delay:?} is less than expected minimum {min_expected:?}"
234            );
235            assert!(
236                delay <= max_expected,
237                "Delay {delay:?} exceeds expected maximum {max_expected:?}"
238            );
239        }
240    }
241
242    #[rstest]
243    fn test_factor_less_than_two() {
244        let initial = Duration::from_millis(100);
245        let max = Duration::from_millis(200);
246        let factor = 1.5;
247        let jitter = 0;
248        let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
249
250        // First call returns 100ms
251        let d1 = backoff.next_duration();
252        assert_eq!(d1, Duration::from_millis(100));
253
254        // Second call: current_delay becomes 100 * 1.5 = 150ms
255        let d2 = backoff.next_duration();
256        assert_eq!(d2, Duration::from_millis(150));
257
258        // Third call: current_delay becomes 150 * 1.5 = 225ms, but capped to 200ms
259        let d3 = backoff.next_duration();
260        assert_eq!(d3, Duration::from_millis(200));
261
262        // Fourth call: remains at the max of 200ms
263        let d4 = backoff.next_duration();
264        assert_eq!(d4, Duration::from_millis(200));
265    }
266
267    #[rstest]
268    fn test_max_delay_is_respected() {
269        let initial = Duration::from_millis(500);
270        let max = Duration::from_secs(1);
271        let factor = 3.0;
272        let jitter = 0;
273        let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
274
275        // 1st call returns 500ms
276        let d1 = backoff.next_duration();
277        assert_eq!(d1, Duration::from_millis(500));
278
279        // 2nd call: would be 500 * 3 = 1500ms but is capped to 1000ms
280        let d2 = backoff.next_duration();
281        assert_eq!(d2, Duration::from_secs(1));
282
283        // Subsequent calls should continue to return the max delay
284        let d3 = backoff.next_duration();
285        assert_eq!(d3, Duration::from_secs(1));
286    }
287
288    #[rstest]
289    fn test_current_delay_getter() {
290        let initial = Duration::from_millis(100);
291        let max = Duration::from_millis(1600);
292        let factor = 2.0;
293        let jitter = 0;
294        let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
295
296        assert_eq!(backoff.current_delay(), initial);
297
298        let _ = backoff.next_duration();
299        assert_eq!(backoff.current_delay(), Duration::from_millis(200));
300
301        let _ = backoff.next_duration();
302        assert_eq!(backoff.current_delay(), Duration::from_millis(400));
303
304        backoff.reset();
305        assert_eq!(backoff.current_delay(), initial);
306    }
307
308    #[rstest]
309    fn test_validation_zero_initial_delay() {
310        let result = ExponentialBackoff::new(Duration::ZERO, Duration::from_secs(1), 2.0, 0, false);
311        assert!(result.is_err());
312        assert!(
313            result
314                .unwrap_err()
315                .to_string()
316                .contains("delay_initial must be non-zero")
317        );
318    }
319
320    #[rstest]
321    fn test_validation_max_less_than_initial() {
322        let result = ExponentialBackoff::new(
323            Duration::from_secs(1),
324            Duration::from_millis(500),
325            2.0,
326            0,
327            false,
328        );
329        assert!(result.is_err());
330        assert!(
331            result
332                .unwrap_err()
333                .to_string()
334                .contains("delay_max must be >= delay_initial")
335        );
336    }
337
338    #[rstest]
339    fn test_validation_factor_too_small() {
340        let result = ExponentialBackoff::new(
341            Duration::from_millis(100),
342            Duration::from_secs(1),
343            0.5,
344            0,
345            false,
346        );
347        assert!(result.is_err());
348        assert!(result.unwrap_err().to_string().contains("factor"));
349    }
350
351    #[rstest]
352    fn test_validation_factor_too_large() {
353        let result = ExponentialBackoff::new(
354            Duration::from_millis(100),
355            Duration::from_secs(1),
356            150.0,
357            0,
358            false,
359        );
360        assert!(result.is_err());
361        assert!(result.unwrap_err().to_string().contains("factor"));
362    }
363
364    #[rstest]
365    fn test_validation_delay_max_exceeds_u64_max_nanos() {
366        // Duration::from_nanos(u64::MAX) is approximately 584 years
367        // Try to create a backoff with delay_max exceeding this
368        let max_valid = Duration::from_nanos(u64::MAX);
369        let too_large = max_valid + Duration::from_nanos(1);
370
371        let result = ExponentialBackoff::new(Duration::from_millis(100), too_large, 2.0, 0, false);
372        assert!(result.is_err());
373        assert!(
374            result
375                .unwrap_err()
376                .to_string()
377                .contains("delay_max exceeds maximum representable duration")
378        );
379    }
380
381    #[rstest]
382    fn test_immediate_first() {
383        let initial = Duration::from_millis(100);
384        let max = Duration::from_millis(1600);
385        let factor = 2.0;
386        let jitter = 0;
387        let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, true).unwrap();
388
389        // The first call should yield an immediate (zero) delay
390        let d1 = backoff.next_duration();
391        assert_eq!(
392            d1,
393            Duration::ZERO,
394            "Expected immediate reconnect (zero delay) on first call"
395        );
396
397        // The next call should return the current delay (i.e. the base initial delay)
398        let d2 = backoff.next_duration();
399        assert_eq!(
400            d2, initial,
401            "Expected the delay to be the initial delay after immediate reconnect"
402        );
403
404        // Subsequent calls should continue with the exponential growth
405        let d3 = backoff.next_duration();
406        let expected = initial * 2; // 100ms * 2 = 200ms
407        assert_eq!(
408            d3, expected,
409            "Expected exponential growth from the initial delay"
410        );
411    }
412
413    #[rstest]
414    fn test_reset_restores_immediate_first() {
415        let initial = Duration::from_millis(100);
416        let max = Duration::from_millis(1600);
417        let factor = 2.0;
418        let jitter = 0;
419        let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, true).unwrap();
420
421        // Use immediate first
422        let d1 = backoff.next_duration();
423        assert_eq!(d1, Duration::ZERO);
424
425        // Now immediate_first should be disabled
426        let d2 = backoff.next_duration();
427        assert_eq!(d2, initial);
428
429        // Reset should restore immediate_first
430        backoff.reset();
431        let d3 = backoff.next_duration();
432        assert_eq!(
433            d3,
434            Duration::ZERO,
435            "Reset should restore immediate_first behavior"
436        );
437    }
438
439    #[rstest]
440    fn test_jitter_never_exceeds_max_delay() {
441        let initial = Duration::from_millis(100);
442        let max = Duration::from_secs(1);
443        let factor = 2.0;
444        let jitter = 500;
445
446        let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
447
448        // Run backoff until it reaches the cap
449        while backoff.current_delay() < max {
450            backoff.next_duration();
451        }
452
453        // Now that we're at the cap, verify jitter doesn't push us over delay_max
454        for _ in 0..100 {
455            let delay = backoff.next_duration();
456            assert!(
457                delay <= max,
458                "Delay with jitter {delay:?} exceeded max {max:?}"
459            );
460        }
461    }
462}