nautilus_network/ratelimiter/
clock.rs1use std::{
26 fmt::Debug,
27 future::Future,
28 ops::Add,
29 prelude::v1::*,
30 sync::{
31 Arc,
32 atomic::{AtomicU64, Ordering},
33 },
34 time::Duration,
35};
36
37use super::nanos::Nanos;
38use crate::dst::time::Instant;
39
40pub trait Reference:
42 Sized + Add<Nanos, Output = Self> + PartialEq + Eq + Ord + Copy + Clone + Send + Sync + Debug
43{
44 fn duration_since(&self, earlier: Self) -> Nanos;
49
50 #[must_use]
54 fn saturating_sub(&self, duration: Nanos) -> Self;
55}
56
57pub trait Clock: Clone {
59 type Instant: Reference;
61
62 fn now(&self) -> Self::Instant;
64
65 fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + '_;
71}
72
73impl Reference for Duration {
74 fn duration_since(&self, earlier: Self) -> Nanos {
76 self.checked_sub(earlier)
77 .unwrap_or_else(|| Self::new(0, 0))
78 .into()
79 }
80
81 fn saturating_sub(&self, duration: Nanos) -> Self {
83 self.checked_sub(duration.into()).unwrap_or(*self)
84 }
85}
86
87impl Add<Nanos> for Duration {
88 type Output = Self;
89
90 fn add(self, other: Nanos) -> Self {
91 let other: Self = other.into();
92 self + other
93 }
94}
95
96#[derive(Debug, Clone, Default)]
105pub struct FakeRelativeClock {
106 now: Arc<AtomicU64>,
107}
108
109impl FakeRelativeClock {
110 pub fn advance(&self, by: Duration) {
116 let by: u64 = by
117 .as_nanos()
118 .try_into()
119 .expect("Cannot represent durations greater than 584 years");
120
121 let mut prev = self.now.load(Ordering::Acquire);
122 let mut next = prev + by;
123
124 while let Err(e) =
125 self.now
126 .compare_exchange_weak(prev, next, Ordering::Release, Ordering::Relaxed)
127 {
128 prev = e;
129 next = prev + by;
130 }
131 }
132}
133
134impl PartialEq for FakeRelativeClock {
135 fn eq(&self, other: &Self) -> bool {
136 self.now.load(Ordering::Relaxed) == other.now.load(Ordering::Relaxed)
137 }
138}
139
140impl Clock for FakeRelativeClock {
141 type Instant = Nanos;
142
143 fn now(&self) -> Self::Instant {
144 self.now.load(Ordering::Relaxed).into()
145 }
146
147 fn sleep(&self, duration: Duration) -> impl Future<Output = ()> + Send + '_ {
148 self.advance(duration);
149 std::future::ready(())
150 }
151}
152
153#[derive(Clone, Debug, Default)]
155pub struct MonotonicClock;
156
157impl Add<Nanos> for Instant {
158 type Output = Self;
159
160 fn add(self, other: Nanos) -> Self {
161 let other: Duration = other.into();
162 self + other
163 }
164}
165
166impl Reference for Instant {
167 fn duration_since(&self, earlier: Self) -> Nanos {
168 if earlier < *self {
169 (*self - earlier).into()
170 } else {
171 Nanos::from(Duration::new(0, 0))
172 }
173 }
174
175 fn saturating_sub(&self, duration: Nanos) -> Self {
176 self.checked_sub(duration.into()).unwrap_or(*self)
177 }
178}
179
180impl Clock for MonotonicClock {
181 type Instant = Instant;
182
183 fn now(&self) -> Self::Instant {
184 Instant::now()
185 }
186
187 async fn sleep(&self, duration: Duration) {
188 #[cfg(not(all(feature = "simulation", madsim)))]
189 tokio::time::sleep(duration).await;
190 #[cfg(all(feature = "simulation", madsim))]
191 madsim::time::sleep(duration).await;
192 }
193}
194
195#[cfg(test)]
196mod test {
197 use std::{sync::Arc, thread, time::Duration};
198
199 use rstest::rstest;
200
201 use super::*;
202
203 #[rstest]
204 fn fake_clock_parallel_advances() {
205 let clock = Arc::new(FakeRelativeClock::default());
206 let threads = std::iter::repeat_n((), 10)
207 .map(move |()| {
208 let clock = Arc::clone(&clock);
209
210 thread::spawn(move || {
211 for _ in 0..1_000_000 {
212 let now = clock.now();
213 clock.advance(Duration::from_nanos(1));
214 assert!(clock.now() > now);
215 }
216 })
217 })
218 .collect::<Vec<_>>();
219
220 for t in threads {
221 t.join().unwrap();
222 }
223 }
224
225 #[rstest]
226 fn duration_addition_coverage() {
227 let d = Duration::from_secs(1);
228 let one_ns = Nanos::from(1);
229 assert!(d + one_ns > d);
230 }
231
232 #[cfg(all(feature = "simulation", madsim))]
237 #[madsim::test]
238 async fn test_monotonic_clock_sleep_uses_virtual_time() {
239 let clock = MonotonicClock;
240 let start = Instant::now();
241 clock.sleep(Duration::from_millis(100)).await;
242 let elapsed = start.elapsed();
243 assert!(elapsed >= Duration::from_millis(100));
244 assert!(
245 elapsed < Duration::from_millis(101),
246 "virtual sleep showed real-tokio jitter: {elapsed:?}"
247 );
248 }
249}