nautilus_network/ratelimiter/
mod.rs1pub mod clock;
20mod gcra;
21mod nanos;
22pub mod quota;
23
24use std::{
25 fmt::Debug,
26 hash::Hash,
27 num::NonZeroU64,
28 sync::atomic::{AtomicU64, Ordering},
29 time::Duration,
30};
31
32use dashmap::DashMap;
33use futures_util::StreamExt;
34
35use self::{
36 clock::{Clock, FakeRelativeClock, MonotonicClock},
37 gcra::{Gcra, NotUntil},
38 nanos::Nanos,
39 quota::Quota,
40};
41
42#[derive(Debug, Default)]
51pub struct InMemoryState(AtomicU64);
52
53impl InMemoryState {
54 pub(crate) fn measure_and_replace_one<T, F, E>(&self, mut f: F) -> Result<T, E>
60 where
61 F: FnMut(Option<Nanos>) -> Result<(T, Nanos), E>,
62 {
63 let mut prev = self.0.load(Ordering::Acquire);
64 let mut decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
65 while let Ok((result, new_data)) = decision {
66 match self.0.compare_exchange_weak(
69 prev,
70 new_data.into(),
71 Ordering::Release,
72 Ordering::Relaxed,
73 ) {
74 Ok(_) => return Ok(result),
75 Err(e) => prev = e, }
77 decision = f(NonZeroU64::new(prev).map(|n| n.get().into()));
78 }
79 decision.map(|(result, _)| result)
82 }
83}
84
85pub type DashMapStateStore<K> = DashMap<K, InMemoryState>;
87
88pub trait StateStore {
99 type Key;
101
102 fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
119 where
120 F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>;
121}
122
123impl<K: Hash + Eq + Clone> StateStore for DashMapStateStore<K> {
124 type Key = K;
125
126 fn measure_and_replace<T, F, E>(&self, key: &Self::Key, f: F) -> Result<T, E>
127 where
128 F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>,
129 {
130 if let Some(v) = self.get(key) {
131 return v.measure_and_replace_one(f);
133 }
134 let entry = self.entry(key.clone()).or_default();
136 (*entry).measure_and_replace_one(f)
137 }
138}
139
140pub struct RateLimiter<K, C>
145where
146 C: Clock,
147{
148 default_gcra: Option<Gcra>,
149 state: DashMapStateStore<K>,
150 gcra: DashMap<K, Gcra>,
151 clock: C,
152 start: C::Instant,
153}
154
155impl<K, C> Debug for RateLimiter<K, C>
156where
157 K: Debug,
158 C: Clock,
159{
160 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161 f.debug_struct(stringify!(RateLimiter)).finish()
162 }
163}
164
165impl<K> RateLimiter<K, MonotonicClock>
166where
167 K: Eq + Hash,
168{
169 #[must_use]
174 pub fn new_with_quota(base_quota: Option<Quota>, keyed_quotas: Vec<(K, Quota)>) -> Self {
175 let clock = MonotonicClock {};
176 let start = MonotonicClock::now(&clock);
177 let gcra: DashMap<_, _> = keyed_quotas
178 .into_iter()
179 .map(|(k, q)| (k, Gcra::new(q)))
180 .collect();
181 Self {
182 default_gcra: base_quota.map(Gcra::new),
183 state: DashMapStateStore::new(),
184 gcra,
185 clock,
186 start,
187 }
188 }
189}
190
191impl<K> RateLimiter<K, FakeRelativeClock>
192where
193 K: Hash + Eq + Clone,
194{
195 pub fn advance_clock(&self, by: Duration) {
199 self.clock.advance(by);
200 }
201}
202
203impl<K, C> RateLimiter<K, C>
204where
205 K: Hash + Eq + Clone,
206 C: Clock,
207{
208 pub fn add_quota_for_key(&self, key: K, value: Quota) {
210 self.gcra.insert(key, Gcra::new(value));
211 }
212
213 pub fn check_key(&self, key: &K) -> Result<(), NotUntil<C::Instant>> {
219 match self.gcra.get(key) {
220 Some(quota) => quota.test_and_update(self.start, key, &self.state, self.clock.now()),
221 None => self.default_gcra.as_ref().map_or(Ok(()), |gcra| {
222 gcra.test_and_update(self.start, key, &self.state, self.clock.now())
223 }),
224 }
225 }
226
227 pub async fn until_key_ready(&self, key: &K) {
229 loop {
230 match self.check_key(key) {
231 Ok(()) => {
232 break;
233 }
234 Err(e) => {
235 self.clock.sleep(e.wait_time_from(self.clock.now())).await;
236 }
237 }
238 }
239 }
240
241 pub async fn await_keys_ready(&self, keys: Option<&[K]>) {
246 let Some(keys) = keys else {
247 return;
248 };
249
250 match keys.len() {
251 0 => {}
252 1 => self.until_key_ready(&keys[0]).await,
253 2 => {
254 tokio::join!(
255 self.until_key_ready(&keys[0]),
256 self.until_key_ready(&keys[1]),
257 );
258 }
259 _ => {
260 let tasks = keys.iter().map(|key| self.until_key_ready(key));
261 futures::stream::iter(tasks)
262 .for_each_concurrent(None, |key_future| async move {
263 key_future.await;
264 })
265 .await;
266 }
267 }
268 }
269}
270
271#[cfg(test)]
272mod tests {
273 use std::{
274 num::NonZeroU32,
275 sync::atomic::{AtomicU32, Ordering},
276 time::Duration,
277 };
278
279 use dashmap::DashMap;
280 use rstest::rstest;
281
282 use super::{
283 DashMapStateStore, RateLimiter,
284 clock::{Clock, FakeRelativeClock},
285 gcra::{Gcra, StateSnapshot},
286 nanos::Nanos,
287 quota::Quota,
288 };
289
290 fn initialize_mock_rate_limiter() -> RateLimiter<String, FakeRelativeClock> {
291 let clock = FakeRelativeClock::default();
292 let start = clock.now();
293 let gcra = DashMap::new();
294 let base_quota = Quota::per_second(NonZeroU32::new(2).unwrap()).unwrap();
295 RateLimiter {
296 default_gcra: Some(Gcra::new(base_quota)),
297 state: DashMapStateStore::new(),
298 gcra,
299 clock,
300 start,
301 }
302 }
303
304 #[rstest]
305 fn test_default_quota() {
306 let mock_limiter = initialize_mock_rate_limiter();
307
308 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
310 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
311
312 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
314
315 mock_limiter.advance_clock(Duration::from_secs(1));
317 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
318 }
319
320 #[rstest]
321 fn test_custom_key_quota() {
322 let mock_limiter = initialize_mock_rate_limiter();
323
324 mock_limiter.add_quota_for_key(
326 "custom".to_string(),
327 Quota::per_second(NonZeroU32::new(1).unwrap()).unwrap(),
328 );
329
330 assert!(mock_limiter.check_key(&"custom".to_string()).is_ok());
332 assert!(mock_limiter.check_key(&"custom".to_string()).is_err());
333
334 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
336 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
337 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
338 }
339
340 #[rstest]
341 fn test_multiple_keys() {
342 let mock_limiter = initialize_mock_rate_limiter();
343
344 mock_limiter.add_quota_for_key(
345 "key1".to_string(),
346 Quota::per_second(NonZeroU32::new(1).unwrap()).unwrap(),
347 );
348 mock_limiter.add_quota_for_key(
349 "key2".to_string(),
350 Quota::per_second(NonZeroU32::new(3).unwrap()).unwrap(),
351 );
352
353 assert!(mock_limiter.check_key(&"key1".to_string()).is_ok());
355 assert!(mock_limiter.check_key(&"key1".to_string()).is_err());
356
357 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
359 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
360 assert!(mock_limiter.check_key(&"key2".to_string()).is_ok());
361 assert!(mock_limiter.check_key(&"key2".to_string()).is_err());
362 }
363
364 #[rstest]
365 fn test_quota_reset() {
366 let mock_limiter = initialize_mock_rate_limiter();
367
368 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
370 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
371 assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
372
373 mock_limiter.advance_clock(Duration::from_millis(499));
375 assert!(mock_limiter.check_key(&"reset".to_string()).is_err());
376
377 mock_limiter.advance_clock(Duration::from_millis(501));
379 assert!(mock_limiter.check_key(&"reset".to_string()).is_ok());
380 }
381
382 #[rstest]
383 fn test_different_quotas() {
384 let mock_limiter = initialize_mock_rate_limiter();
385
386 mock_limiter.add_quota_for_key(
387 "per_second".to_string(),
388 Quota::per_second(NonZeroU32::new(2).unwrap()).unwrap(),
389 );
390 mock_limiter.add_quota_for_key(
391 "per_minute".to_string(),
392 Quota::per_minute(NonZeroU32::new(3).unwrap()),
393 );
394
395 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
397 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
398 assert!(mock_limiter.check_key(&"per_second".to_string()).is_err());
399
400 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
402 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
403 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_ok());
404 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
405
406 mock_limiter.advance_clock(Duration::from_secs(1));
408 assert!(mock_limiter.check_key(&"per_second".to_string()).is_ok());
409 assert!(mock_limiter.check_key(&"per_minute".to_string()).is_err());
410 }
411
412 #[tokio::test]
413 async fn test_await_keys_ready() {
414 let mock_limiter = initialize_mock_rate_limiter();
415
416 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
418 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
419
420 assert!(mock_limiter.check_key(&"default".to_string()).is_err());
422
423 mock_limiter.advance_clock(Duration::from_secs(1));
425 let keys = ["default".to_string()];
426 mock_limiter.await_keys_ready(Some(keys.as_slice())).await;
427 assert!(mock_limiter.check_key(&"default".to_string()).is_ok());
428 }
429
430 #[rstest]
431 fn test_remaining_burst_capacity_zero_t() {
432 let snapshot = StateSnapshot::new(
433 Nanos::from(0u64),
434 Nanos::from(1_000_000u64),
435 Nanos::from(0u64),
436 Nanos::from(0u64),
437 );
438 assert_eq!(snapshot.remaining_burst_capacity(), 0);
439 }
440
441 #[rstest]
442 fn test_per_second_returns_none_on_zero_replenish_interval() {
443 assert!(Quota::per_second(NonZeroU32::new(u32::MAX).unwrap()).is_none());
444 }
445
446 #[rstest]
447 fn test_per_minute_accepts_max_burst() {
448 let quota = Quota::per_minute(NonZeroU32::new(u32::MAX).unwrap());
449 assert!(quota.replenish_interval().as_nanos() > 0);
450 }
451
452 #[rstest]
453 fn test_per_hour_accepts_max_burst() {
454 let quota = Quota::per_hour(NonZeroU32::new(u32::MAX).unwrap());
455 assert!(quota.replenish_interval().as_nanos() > 0);
456 }
457
458 mod property_tests {
459 use proptest::prelude::*;
460 use rstest::rstest;
461
462 use crate::ratelimiter::{gcra::StateSnapshot, nanos::Nanos};
463
464 const MAX_NANOS: u64 = 3_600_000_000_000;
466
467 proptest! {
468 #![proptest_config(ProptestConfig {
469 failure_persistence: Some(Box::new(
470 proptest::test_runner::FileFailurePersistence::WithSource("ratelimiter")
471 )),
472 ..ProptestConfig::default()
473 })]
474
475 #[rstest]
476 fn remaining_burst_capacity_never_panics(
477 t in 0u64..=MAX_NANOS,
478 tau in 0u64..=MAX_NANOS,
479 time_of_measurement in 0u64..=MAX_NANOS,
480 tat in 0u64..=MAX_NANOS,
481 ) {
482 let snapshot = StateSnapshot::new(
483 Nanos::from(t),
484 Nanos::from(tau),
485 Nanos::from(time_of_measurement),
486 Nanos::from(tat),
487 );
488
489 let _ = snapshot.remaining_burst_capacity();
490 }
491 }
492 }
493
494 #[rstest]
495 fn test_gcra_boundary_exact_replenishment() {
496 let mock_limiter = initialize_mock_rate_limiter();
499 let key = "boundary_test".to_string();
500
501 assert!(mock_limiter.check_key(&key).is_ok());
502 assert!(mock_limiter.check_key(&key).is_ok());
503 assert!(mock_limiter.check_key(&key).is_err());
504
505 let quota = Quota::per_second(NonZeroU32::new(2).unwrap()).unwrap();
507 let replenish_interval = quota.replenish_interval();
508 mock_limiter.advance_clock(replenish_interval);
509
510 assert!(
511 mock_limiter.check_key(&key).is_ok(),
512 "Request at exact replenish boundary should be allowed"
513 );
514 assert!(
515 mock_limiter.check_key(&key).is_err(),
516 "Immediate follow-up should be rate-limited"
517 );
518 }
519
520 #[rstest]
521 fn test_per_second_boundary_exact_limit() {
522 let quota = Quota::per_second(NonZeroU32::new(1_000_000_000).unwrap()).unwrap();
524 assert_eq!(quota.replenish_interval().as_nanos(), 1);
525 }
526
527 #[rstest]
528 fn test_per_second_returns_none_above_one_billion() {
529 assert!(Quota::per_second(NonZeroU32::new(1_000_000_001).unwrap()).is_none());
531 }
532
533 #[rstest]
534 fn test_burst_size_replenished_in_truncation() {
535 let quota = Quota::with_period(Duration::from_secs(100))
537 .unwrap()
538 .allow_burst(NonZeroU32::new(u32::MAX).unwrap());
539
540 let replenished_in = quota.burst_size_replenished_in();
541 let full: u128 = 100_000_000_000u128 * u128::from(u32::MAX);
542 let truncated = full as u64;
543
544 assert_eq!(replenished_in, Duration::from_nanos(truncated));
545 assert_ne!(
546 full,
547 u128::from(truncated),
548 "Truncation should have occurred"
549 );
550 }
551
552 #[rstest]
553 #[should_panic(expected = "t cannot be zero")]
554 fn test_from_gcra_parameters_panics_on_zero_t() {
555 let _ = Quota::from_gcra_parameters(Nanos::from(0u64), Nanos::from(100u64));
556 }
557
558 #[rstest]
559 #[should_panic(expected = "tau/t results in zero burst capacity")]
560 fn test_from_gcra_parameters_panics_on_zero_division() {
561 let _ = Quota::from_gcra_parameters(Nanos::from(2u64), Nanos::from(1u64));
563 }
564
565 #[rstest]
566 #[should_panic(expected = "tau/t exceeds u32::MAX")]
567 fn test_from_gcra_parameters_panics_on_overflow() {
568 let _ = Quota::from_gcra_parameters(Nanos::from(1u64), Nanos::from(u64::MAX));
569 }
570
571 #[rstest]
572 fn test_concurrent_check_key_respects_burst() {
573 let rate = 10u32;
574 let clock = FakeRelativeClock::default();
575 let start = clock.now();
576 let limiter = RateLimiter {
577 default_gcra: Some(Gcra::new(
578 Quota::per_second(NonZeroU32::new(rate).unwrap()).unwrap(),
579 )),
580 state: DashMapStateStore::new(),
581 gcra: DashMap::new(),
582 clock,
583 start,
584 };
585
586 let accepted = AtomicU32::new(0);
587 let num_threads = 50;
588
589 std::thread::scope(|s| {
591 for _ in 0..num_threads {
592 s.spawn(|| {
593 if limiter.check_key(&"hot_key".to_string()).is_ok() {
594 accepted.fetch_add(1, Ordering::Relaxed);
595 }
596 });
597 }
598 });
599
600 let total = accepted.load(Ordering::Relaxed);
601 assert!(total >= 1, "At least one request should be accepted");
602 assert!(
603 total <= rate,
604 "Accepted {total} but burst capacity is {rate}"
605 );
606 }
607}