1use std::time::Duration;
25
26use nautilus_core::correctness::{check_in_range_inclusive_f64, check_predicate_true};
27use rand::RngExt;
28
29#[derive(Clone, Debug)]
30pub struct ExponentialBackoff {
31 delay_initial: Duration,
33 delay_max: Duration,
35 delay_current: Duration,
37 factor: f64,
39 jitter_ms: u64,
41 immediate_reconnect: bool,
43 immediate_reconnect_original: bool,
45}
46
47impl ExponentialBackoff {
55 pub fn new(
65 delay_initial: Duration,
66 delay_max: Duration,
67 factor: f64,
68 jitter_ms: u64,
69 immediate_first: bool,
70 ) -> anyhow::Result<Self> {
71 check_predicate_true(!delay_initial.is_zero(), "delay_initial must be non-zero")?;
72 check_predicate_true(
73 delay_max >= delay_initial,
74 "delay_max must be >= delay_initial",
75 )?;
76 check_predicate_true(
77 delay_max.as_nanos() <= u128::from(u64::MAX),
78 "delay_max exceeds maximum representable duration (≈584 years)",
79 )?;
80 check_in_range_inclusive_f64(factor, 1.0, 100.0, "factor")?;
81
82 Ok(Self {
83 delay_initial,
84 delay_max,
85 delay_current: delay_initial,
86 factor,
87 jitter_ms,
88 immediate_reconnect: immediate_first,
89 immediate_reconnect_original: immediate_first,
90 })
91 }
92
93 pub fn next_duration(&mut self) -> Duration {
99 if self.immediate_reconnect && self.delay_current == self.delay_initial {
100 self.immediate_reconnect = false;
101 return Duration::ZERO;
102 }
103
104 let jitter = rand::rng().random_range(0..=self.jitter_ms); let delay_with_jitter = self.delay_current + Duration::from_millis(jitter);
107
108 let clamped_delay = std::cmp::min(delay_with_jitter, self.delay_max);
110
111 let current_nanos = self.delay_current.as_nanos();
114 let max_nanos = self.delay_max.as_nanos();
115
116 let next_nanos_u128 = if current_nanos > u128::from(u64::MAX) {
118 max_nanos
120 } else {
121 let current_u64 = current_nanos as u64;
122 let next_f64 = current_u64 as f64 * self.factor;
123
124 if next_f64 > u64::MAX as f64 {
126 u128::from(u64::MAX)
127 } else {
128 u128::from(next_f64 as u64)
129 }
130 };
131
132 let clamped = std::cmp::min(next_nanos_u128, max_nanos);
133 let final_nanos = if clamped > u128::from(u64::MAX) {
134 u64::MAX
135 } else {
136 clamped as u64
137 };
138
139 self.delay_current = Duration::from_nanos(final_nanos);
140
141 clamped_delay
142 }
143
144 pub const fn reset(&mut self) {
146 self.delay_current = self.delay_initial;
147 self.immediate_reconnect = self.immediate_reconnect_original;
148 }
149
150 #[must_use]
154 pub const fn current_delay(&self) -> Duration {
155 self.delay_current
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use std::time::Duration;
162
163 use rstest::rstest;
164
165 use super::*;
166
167 #[rstest]
168 fn test_no_jitter_exponential_growth() {
169 let initial = Duration::from_millis(100);
170 let max = Duration::from_millis(1600);
171 let factor = 2.0;
172 let jitter = 0;
173 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
174
175 let d1 = backoff.next_duration();
177 assert_eq!(d1, Duration::from_millis(100));
178
179 let d2 = backoff.next_duration();
181 assert_eq!(d2, Duration::from_millis(200));
182
183 let d3 = backoff.next_duration();
185 assert_eq!(d3, Duration::from_millis(400));
186
187 let d4 = backoff.next_duration();
189 assert_eq!(d4, Duration::from_millis(800));
190
191 let d5 = backoff.next_duration();
193 assert_eq!(d5, Duration::from_millis(1600));
194
195 let d6 = backoff.next_duration();
197 assert_eq!(d6, Duration::from_millis(1600));
198 }
199
200 #[rstest]
201 fn test_reset() {
202 let initial = Duration::from_millis(100);
203 let max = Duration::from_millis(1600);
204 let factor = 2.0;
205 let jitter = 0;
206 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
207
208 let _ = backoff.next_duration(); backoff.reset();
211 let d = backoff.next_duration();
212 assert_eq!(d, Duration::from_millis(100));
214 }
215
216 #[rstest]
217 fn test_jitter_within_bounds() {
218 let initial = Duration::from_millis(100);
219 let max = Duration::from_secs(1);
220 let factor = 2.0;
221 let jitter = 50;
222 for _ in 0..10 {
224 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
225 let base = backoff.delay_current;
227 let delay = backoff.next_duration();
228 let min_expected = base;
230 let max_expected = base + Duration::from_millis(jitter);
231 assert!(
232 delay >= min_expected,
233 "Delay {delay:?} is less than expected minimum {min_expected:?}"
234 );
235 assert!(
236 delay <= max_expected,
237 "Delay {delay:?} exceeds expected maximum {max_expected:?}"
238 );
239 }
240 }
241
242 #[rstest]
243 fn test_factor_less_than_two() {
244 let initial = Duration::from_millis(100);
245 let max = Duration::from_millis(200);
246 let factor = 1.5;
247 let jitter = 0;
248 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
249
250 let d1 = backoff.next_duration();
252 assert_eq!(d1, Duration::from_millis(100));
253
254 let d2 = backoff.next_duration();
256 assert_eq!(d2, Duration::from_millis(150));
257
258 let d3 = backoff.next_duration();
260 assert_eq!(d3, Duration::from_millis(200));
261
262 let d4 = backoff.next_duration();
264 assert_eq!(d4, Duration::from_millis(200));
265 }
266
267 #[rstest]
268 fn test_max_delay_is_respected() {
269 let initial = Duration::from_millis(500);
270 let max = Duration::from_secs(1);
271 let factor = 3.0;
272 let jitter = 0;
273 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
274
275 let d1 = backoff.next_duration();
277 assert_eq!(d1, Duration::from_millis(500));
278
279 let d2 = backoff.next_duration();
281 assert_eq!(d2, Duration::from_secs(1));
282
283 let d3 = backoff.next_duration();
285 assert_eq!(d3, Duration::from_secs(1));
286 }
287
288 #[rstest]
289 fn test_current_delay_getter() {
290 let initial = Duration::from_millis(100);
291 let max = Duration::from_millis(1600);
292 let factor = 2.0;
293 let jitter = 0;
294 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
295
296 assert_eq!(backoff.current_delay(), initial);
297
298 let _ = backoff.next_duration();
299 assert_eq!(backoff.current_delay(), Duration::from_millis(200));
300
301 let _ = backoff.next_duration();
302 assert_eq!(backoff.current_delay(), Duration::from_millis(400));
303
304 backoff.reset();
305 assert_eq!(backoff.current_delay(), initial);
306 }
307
308 #[rstest]
309 fn test_validation_zero_initial_delay() {
310 let result = ExponentialBackoff::new(Duration::ZERO, Duration::from_secs(1), 2.0, 0, false);
311 assert!(result.is_err());
312 assert!(
313 result
314 .unwrap_err()
315 .to_string()
316 .contains("delay_initial must be non-zero")
317 );
318 }
319
320 #[rstest]
321 fn test_validation_max_less_than_initial() {
322 let result = ExponentialBackoff::new(
323 Duration::from_secs(1),
324 Duration::from_millis(500),
325 2.0,
326 0,
327 false,
328 );
329 assert!(result.is_err());
330 assert!(
331 result
332 .unwrap_err()
333 .to_string()
334 .contains("delay_max must be >= delay_initial")
335 );
336 }
337
338 #[rstest]
339 fn test_validation_factor_too_small() {
340 let result = ExponentialBackoff::new(
341 Duration::from_millis(100),
342 Duration::from_secs(1),
343 0.5,
344 0,
345 false,
346 );
347 assert!(result.is_err());
348 assert!(result.unwrap_err().to_string().contains("factor"));
349 }
350
351 #[rstest]
352 fn test_validation_factor_too_large() {
353 let result = ExponentialBackoff::new(
354 Duration::from_millis(100),
355 Duration::from_secs(1),
356 150.0,
357 0,
358 false,
359 );
360 assert!(result.is_err());
361 assert!(result.unwrap_err().to_string().contains("factor"));
362 }
363
364 #[rstest]
365 fn test_validation_delay_max_exceeds_u64_max_nanos() {
366 let max_valid = Duration::from_nanos(u64::MAX);
369 let too_large = max_valid + Duration::from_nanos(1);
370
371 let result = ExponentialBackoff::new(Duration::from_millis(100), too_large, 2.0, 0, false);
372 assert!(result.is_err());
373 assert!(
374 result
375 .unwrap_err()
376 .to_string()
377 .contains("delay_max exceeds maximum representable duration")
378 );
379 }
380
381 #[rstest]
382 fn test_immediate_first() {
383 let initial = Duration::from_millis(100);
384 let max = Duration::from_millis(1600);
385 let factor = 2.0;
386 let jitter = 0;
387 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, true).unwrap();
388
389 let d1 = backoff.next_duration();
391 assert_eq!(
392 d1,
393 Duration::ZERO,
394 "Expected immediate reconnect (zero delay) on first call"
395 );
396
397 let d2 = backoff.next_duration();
399 assert_eq!(
400 d2, initial,
401 "Expected the delay to be the initial delay after immediate reconnect"
402 );
403
404 let d3 = backoff.next_duration();
406 let expected = initial * 2; assert_eq!(
408 d3, expected,
409 "Expected exponential growth from the initial delay"
410 );
411 }
412
413 #[rstest]
414 fn test_reset_restores_immediate_first() {
415 let initial = Duration::from_millis(100);
416 let max = Duration::from_millis(1600);
417 let factor = 2.0;
418 let jitter = 0;
419 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, true).unwrap();
420
421 let d1 = backoff.next_duration();
423 assert_eq!(d1, Duration::ZERO);
424
425 let d2 = backoff.next_duration();
427 assert_eq!(d2, initial);
428
429 backoff.reset();
431 let d3 = backoff.next_duration();
432 assert_eq!(
433 d3,
434 Duration::ZERO,
435 "Reset should restore immediate_first behavior"
436 );
437 }
438
439 #[rstest]
440 fn test_jitter_never_exceeds_max_delay() {
441 let initial = Duration::from_millis(100);
442 let max = Duration::from_secs(1);
443 let factor = 2.0;
444 let jitter = 500;
445
446 let mut backoff = ExponentialBackoff::new(initial, max, factor, jitter, false).unwrap();
447
448 while backoff.current_delay() < max {
450 backoff.next_duration();
451 }
452
453 for _ in 0..100 {
455 let delay = backoff.next_duration();
456 assert!(
457 delay <= max,
458 "Delay with jitter {delay:?} exceeded max {max:?}"
459 );
460 }
461 }
462}