1use std::{
23 any::Any,
24 cell::{RefCell, UnsafeCell},
25 collections::VecDeque,
26 fmt::Debug,
27 marker::PhantomData,
28 rc::Rc,
29};
30
31use nautilus_core::{UnixNanos, correctness::FAILED};
32use serde::{Deserialize, Serialize};
33use ustr::Ustr;
34
35use crate::{
36 actor::{
37 Actor,
38 registry::{get_actor_unchecked, register_actor, try_get_actor_unchecked},
39 },
40 clock::Clock,
41 msgbus::{self, Endpoint, Handler, MStr, ShareableMessageHandler},
42 timer::{TimeEvent, TimeEventCallback},
43};
44
45#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)]
47#[serde(deny_unknown_fields)]
48pub struct RateLimit {
49 pub limit: usize,
50 pub interval_ns: u64,
51}
52
53impl RateLimit {
54 #[must_use]
56 pub const fn new(limit: usize, interval_ns: u64) -> Self {
57 Self { limit, interval_ns }
58 }
59}
60
61pub struct Throttler<T, F> {
66 pub recv_count: usize,
68 pub sent_count: usize,
70 pub is_limiting: bool,
72 pub limit: usize,
74 pub buffer: VecDeque<T>,
76 pub timestamps: VecDeque<UnixNanos>,
78 pub clock: Rc<RefCell<dyn Clock>>,
80 pub actor_id: Ustr,
82 interval: u64,
84 timer_name: Ustr,
86 output_send: F,
88 output_drop: Option<F>,
90}
91
92impl<T, F> Actor for Throttler<T, F>
93where
94 T: 'static + Debug,
95 F: Fn(T) + 'static,
96{
97 fn id(&self) -> Ustr {
98 self.actor_id
99 }
100
101 fn handle(&mut self, _msg: &dyn Any) {}
102
103 fn as_any(&self) -> &dyn Any {
104 self
105 }
106}
107
108impl<T, F> Debug for Throttler<T, F>
109where
110 T: Debug,
111{
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 f.debug_struct(stringify!(InnerThrottler))
114 .field("recv_count", &self.recv_count)
115 .field("sent_count", &self.sent_count)
116 .field("is_limiting", &self.is_limiting)
117 .field("limit", &self.limit)
118 .field("buffer", &self.buffer)
119 .field("timestamps", &self.timestamps)
120 .field("interval", &self.interval)
121 .field("timer_name", &self.timer_name)
122 .finish()
123 }
124}
125
126impl<T, F> Throttler<T, F>
127where
128 T: Debug,
129{
130 #[inline]
131 pub fn new(
132 limit: usize,
133 interval: u64,
134 clock: Rc<RefCell<dyn Clock>>,
135 timer_name: &str,
136 output_send: F,
137 output_drop: Option<F>,
138 actor_id: Ustr,
139 ) -> Self {
140 Self {
141 recv_count: 0,
142 sent_count: 0,
143 is_limiting: false,
144 limit,
145 buffer: VecDeque::new(),
146 timestamps: VecDeque::with_capacity(limit.min(1024)),
147 clock,
148 interval,
149 timer_name: Ustr::from(timer_name),
150 output_send,
151 output_drop,
152 actor_id,
153 }
154 }
155
156 #[inline]
166 pub fn set_timer(&mut self, callback: Option<TimeEventCallback>) {
167 let delta = self.delta_next();
168 let mut clock = self.clock.borrow_mut();
169 if clock.timer_exists(&self.timer_name) {
170 clock.cancel_timer(&self.timer_name);
171 }
172 let alert_ts = clock.timestamp_ns() + delta;
173
174 clock
175 .set_time_alert_ns(&self.timer_name, alert_ts, callback, None)
176 .expect(FAILED);
177 }
178
179 #[inline]
181 pub fn delta_next(&mut self) -> u64 {
182 match self.timestamps.get(self.limit - 1) {
183 Some(ts) => {
184 let diff = self.clock.borrow().timestamp_ns().as_u64() - ts.as_u64();
185 self.interval.saturating_sub(diff)
186 }
187 None => 0,
188 }
189 }
190
191 #[inline]
193 pub fn reset(&mut self) {
194 self.buffer.clear();
195 self.recv_count = 0;
196 self.sent_count = 0;
197 self.is_limiting = false;
198 self.timestamps.clear();
199 }
200
201 #[inline]
203 pub fn used(&self) -> f64 {
204 if self.timestamps.is_empty() {
205 return 0.0;
206 }
207
208 let now = self.clock.borrow().timestamp_ns().as_i64();
209 let interval_start = now - self.interval as i64;
210
211 let messages_in_current_interval = self
212 .timestamps
213 .iter()
214 .take_while(|&&ts| ts.as_i64() > interval_start)
215 .count();
216
217 (messages_in_current_interval as f64) / (self.limit as f64)
218 }
219
220 #[inline]
222 pub fn qsize(&self) -> usize {
223 self.buffer.len()
224 }
225}
226
227impl<T, F> Throttler<T, F>
228where
229 T: 'static + Debug,
230 F: Fn(T) + 'static,
231{
232 pub fn to_actor(self) -> Rc<UnsafeCell<Self>> {
233 let process_handler = ThrottlerProcess::<T, F>::new(self.actor_id);
235 msgbus::register_any(
236 process_handler.id().as_str().into(),
237 ShareableMessageHandler::from(Rc::new(process_handler) as Rc<dyn Handler<dyn Any>>),
238 );
239
240 register_actor(self)
242 }
243
244 #[inline]
245 pub fn send_msg(&mut self, msg: T) {
246 let now = self.clock.borrow().timestamp_ns();
247
248 if self.timestamps.len() >= self.limit {
249 self.timestamps.pop_back();
250 }
251 self.timestamps.push_front(now);
252
253 self.sent_count += 1;
254 (self.output_send)(msg);
255 }
256
257 #[inline]
258 pub fn limit_msg(&mut self, msg: T) {
259 if self.output_drop.is_none() {
260 self.buffer.push_front(msg);
261 log::debug!("Buffering {}", self.buffer.len());
262
263 if !self.is_limiting {
264 log::debug!("Limiting");
265 let cb = Some(ThrottlerProcess::<T, F>::new(self.actor_id).get_timer_callback());
266 self.set_timer(cb);
267 self.is_limiting = true;
268 }
269 } else {
270 log::debug!("Dropping");
271
272 if let Some(drop) = &self.output_drop {
273 drop(msg);
274 }
275
276 if !self.is_limiting {
277 log::debug!("Limiting");
278 self.set_timer(Some(throttler_resume::<T, F>(self.actor_id)));
279 self.is_limiting = true;
280 }
281 }
282 }
283
284 #[inline]
285 pub fn send(&mut self, msg: T)
286 where
287 T: 'static,
288 F: Fn(T) + 'static,
289 {
290 self.recv_count += 1;
291
292 let delta = self.delta_next();
293
294 if self.is_limiting && delta == 0 && self.buffer.is_empty() {
299 self.is_limiting = false;
300 }
301
302 if self.is_limiting || delta > 0 {
303 self.limit_msg(msg);
304 } else {
305 self.send_msg(msg);
306 }
307 }
308}
309
310struct ThrottlerProcess<T, F> {
315 actor_id: Ustr,
316 endpoint: MStr<Endpoint>,
317 phantom_t: PhantomData<T>,
318 phantom_f: PhantomData<F>,
319}
320
321impl<T, F> ThrottlerProcess<T, F>
322where
323 T: Debug,
324{
325 pub fn new(actor_id: Ustr) -> Self {
326 let endpoint = MStr::endpoint(format!("{actor_id}_process")).expect(FAILED);
327 Self {
328 actor_id,
329 endpoint,
330 phantom_t: PhantomData,
331 phantom_f: PhantomData,
332 }
333 }
334
335 pub fn get_timer_callback(&self) -> TimeEventCallback {
336 let endpoint = self.endpoint;
337 TimeEventCallback::from(move |event: TimeEvent| {
338 msgbus::send_any(endpoint, &(event));
339 })
340 }
341}
342
343impl<T, F> Handler<dyn Any> for ThrottlerProcess<T, F>
344where
345 T: 'static + Debug,
346 F: Fn(T) + 'static,
347{
348 fn id(&self) -> Ustr {
349 *self.endpoint
350 }
351
352 fn handle(&self, _message: &dyn Any) {
353 let mut throttler = get_actor_unchecked::<Throttler<T, F>>(&self.actor_id);
354 while let Some(msg) = throttler.buffer.pop_back() {
355 throttler.send_msg(msg);
356
357 if !throttler.buffer.is_empty() && throttler.delta_next() > 0 {
361 throttler.is_limiting = true;
362
363 let endpoint = self.endpoint;
364
365 throttler.set_timer(Some(TimeEventCallback::from(move |event: TimeEvent| {
367 msgbus::send_any(endpoint, &(event));
368 })));
369 return;
370 }
371 }
372
373 throttler.is_limiting = false;
374 }
375}
376
377pub fn throttler_resume<T, F>(actor_id: Ustr) -> TimeEventCallback
383where
384 T: 'static + Debug,
385 F: Fn(T) + 'static,
386{
387 TimeEventCallback::from(move |_event: TimeEvent| {
388 if let Some(mut throttler) = try_get_actor_unchecked::<Throttler<T, F>>(&actor_id) {
389 throttler.is_limiting = false;
390 }
391 })
392}
393
394#[cfg(test)]
395mod tests {
396 use std::{
397 cell::{RefCell, UnsafeCell},
398 rc::Rc,
399 };
400
401 use nautilus_core::UUID4;
402 use rstest::{fixture, rstest};
403 use ustr::Ustr;
404
405 use super::{RateLimit, Throttler, ThrottlerProcess};
406 use crate::{clock::TestClock, msgbus::Handler};
407 type SharedThrottler = Rc<UnsafeCell<Throttler<u64, Box<dyn Fn(u64)>>>>;
408
409 #[derive(Clone)]
414 struct TestThrottler {
415 throttler: SharedThrottler,
416 clock: Rc<RefCell<TestClock>>,
417 interval: u64,
418 }
419
420 #[allow(unsafe_code)]
421 impl TestThrottler {
422 #[expect(clippy::mut_from_ref)]
423 pub fn get_throttler(&self) -> &mut Throttler<u64, Box<dyn Fn(u64)>> {
424 unsafe { &mut *self.throttler.get() }
425 }
426 }
427
428 #[fixture]
429 pub fn test_throttler_buffered() -> TestThrottler {
430 let output_send: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
431 log::debug!("Sent: {msg}");
432 });
433 let clock = Rc::new(RefCell::new(TestClock::new()));
434 let inner_clock = Rc::clone(&clock);
435 let rate_limit = RateLimit::new(5, 10);
436 let interval = rate_limit.interval_ns;
437 let actor_id = Ustr::from(UUID4::new().as_str());
438
439 TestThrottler {
440 throttler: Throttler::new(
441 rate_limit.limit,
442 rate_limit.interval_ns,
443 clock,
444 "buffer_timer",
445 output_send,
446 None,
447 actor_id,
448 )
449 .to_actor(),
450 clock: inner_clock,
451 interval,
452 }
453 }
454
455 #[fixture]
456 pub fn test_throttler_unbuffered() -> TestThrottler {
457 let output_send: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
458 log::debug!("Sent: {msg}");
459 });
460 let output_drop: Box<dyn Fn(u64)> = Box::new(|msg: u64| {
461 log::debug!("Dropped: {msg}");
462 });
463 let clock = Rc::new(RefCell::new(TestClock::new()));
464 let inner_clock = Rc::clone(&clock);
465 let rate_limit = RateLimit::new(5, 10);
466 let interval = rate_limit.interval_ns;
467 let actor_id = Ustr::from(UUID4::new().as_str());
468
469 TestThrottler {
470 throttler: Throttler::new(
471 rate_limit.limit,
472 rate_limit.interval_ns,
473 clock,
474 "dropper_timer",
475 output_send,
476 Some(output_drop),
477 actor_id,
478 )
479 .to_actor(),
480 clock: inner_clock,
481 interval,
482 }
483 }
484
485 #[rstest]
486 fn test_buffering_send_to_limit_becomes_throttled(test_throttler_buffered: TestThrottler) {
487 let throttler = test_throttler_buffered.get_throttler();
488 for _ in 0..6 {
489 throttler.send(42);
490 }
491 assert_eq!(throttler.qsize(), 1);
492
493 assert!(throttler.is_limiting);
494 assert_eq!(throttler.recv_count, 6);
495 assert_eq!(throttler.sent_count, 5);
496 assert_eq!(throttler.clock.borrow().timer_names(), vec!["buffer_timer"]);
497 }
498
499 #[rstest]
500 fn test_buffering_used_when_sent_to_limit_returns_one(test_throttler_buffered: TestThrottler) {
501 let throttler = test_throttler_buffered.get_throttler();
502
503 for _ in 0..5 {
504 throttler.send(42);
505 }
506
507 assert_eq!(throttler.used(), 1.0);
508 assert_eq!(throttler.recv_count, 5);
509 assert_eq!(throttler.sent_count, 5);
510 }
511
512 #[rstest]
513 fn test_buffering_used_when_half_interval_from_limit_returns_one(
514 test_throttler_buffered: TestThrottler,
515 ) {
516 let throttler = test_throttler_buffered.get_throttler();
517
518 for _ in 0..5 {
519 throttler.send(42);
520 }
521
522 let half_interval = test_throttler_buffered.interval / 2;
523 {
525 let mut clock = test_throttler_buffered.clock.borrow_mut();
526 clock.advance_time(half_interval.into(), true);
527 }
528
529 assert_eq!(throttler.used(), 1.0);
530 assert_eq!(throttler.recv_count, 5);
531 assert_eq!(throttler.sent_count, 5);
532 }
533
534 #[rstest]
535 fn test_buffering_used_before_limit_when_halfway_returns_half(
536 test_throttler_buffered: TestThrottler,
537 ) {
538 let throttler = test_throttler_buffered.get_throttler();
539
540 for _ in 0..3 {
541 throttler.send(42);
542 }
543
544 assert_eq!(throttler.used(), 0.6);
545 assert_eq!(throttler.recv_count, 3);
546 assert_eq!(throttler.sent_count, 3);
547 }
548
549 #[rstest]
550 fn test_buffering_refresh_when_at_limit_sends_remaining_items(
551 test_throttler_buffered: TestThrottler,
552 ) {
553 let throttler = test_throttler_buffered.get_throttler();
554
555 for _ in 0..6 {
556 throttler.send(42);
557 }
558
559 {
561 let mut clock = test_throttler_buffered.clock.borrow_mut();
562 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
563 for each_event in clock.match_handlers(time_events) {
564 drop(clock); each_event.callback.call(each_event.event);
567
568 clock = test_throttler_buffered.clock.borrow_mut();
570 }
571 }
572
573 assert_eq!(throttler.used(), 0.2);
575 assert_eq!(throttler.recv_count, 6);
576 assert_eq!(throttler.sent_count, 6);
577 assert_eq!(throttler.qsize(), 0);
578 }
579
580 #[rstest]
581 fn test_buffering_send_message_after_buffering_message(test_throttler_buffered: TestThrottler) {
582 let throttler = test_throttler_buffered.get_throttler();
583
584 for _ in 0..6 {
585 throttler.send(43);
586 }
587
588 {
590 let mut clock = test_throttler_buffered.clock.borrow_mut();
591 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
592 for each_event in clock.match_handlers(time_events) {
593 drop(clock); each_event.callback.call(each_event.event);
596
597 clock = test_throttler_buffered.clock.borrow_mut();
599 }
600 }
601
602 for _ in 0..6 {
603 throttler.send(42);
604 }
605
606 assert_eq!(throttler.used(), 1.0);
608 assert_eq!(throttler.recv_count, 12);
609 assert_eq!(throttler.sent_count, 10);
610 assert_eq!(throttler.qsize(), 2);
611 }
612
613 #[rstest]
614 fn test_buffering_send_message_after_halfway_after_buffering_message(
615 test_throttler_buffered: TestThrottler,
616 ) {
617 let throttler = test_throttler_buffered.get_throttler();
618
619 for _ in 0..6 {
620 throttler.send(42);
621 }
622
623 {
625 let mut clock = test_throttler_buffered.clock.borrow_mut();
626 let time_events = clock.advance_time(test_throttler_buffered.interval.into(), true);
627 for each_event in clock.match_handlers(time_events) {
628 drop(clock); each_event.callback.call(each_event.event);
631
632 clock = test_throttler_buffered.clock.borrow_mut();
634 }
635 }
636
637 for _ in 0..3 {
638 throttler.send(42);
639 }
640
641 assert_eq!(throttler.used(), 0.8);
643 assert_eq!(throttler.recv_count, 9);
644 assert_eq!(throttler.sent_count, 9);
645 assert_eq!(throttler.qsize(), 0);
646 }
647
648 #[rstest]
649 fn test_dropping_send_sends_message_to_handler(test_throttler_unbuffered: TestThrottler) {
650 let throttler = test_throttler_unbuffered.get_throttler();
651 throttler.send(42);
652
653 assert!(!throttler.is_limiting);
654 assert_eq!(throttler.recv_count, 1);
655 assert_eq!(throttler.sent_count, 1);
656 }
657
658 #[rstest]
659 fn test_dropping_send_to_limit_drops_message(test_throttler_unbuffered: TestThrottler) {
660 let throttler = test_throttler_unbuffered.get_throttler();
661 for _ in 0..6 {
662 throttler.send(42);
663 }
664 assert_eq!(throttler.qsize(), 0);
665
666 assert!(throttler.is_limiting);
667 assert_eq!(throttler.used(), 1.0);
668 assert_eq!(throttler.clock.borrow().timer_count(), 1);
669 assert_eq!(
670 throttler.clock.borrow().timer_names(),
671 vec!["dropper_timer"]
672 );
673 assert_eq!(throttler.recv_count, 6);
674 assert_eq!(throttler.sent_count, 5);
675 }
676
677 #[rstest]
678 fn test_dropping_advance_time_when_at_limit_dropped_message(
679 test_throttler_unbuffered: TestThrottler,
680 ) {
681 let throttler = test_throttler_unbuffered.get_throttler();
682 for _ in 0..6 {
683 throttler.send(42);
684 }
685
686 {
688 let mut clock = test_throttler_unbuffered.clock.borrow_mut();
689 let time_events = clock.advance_time(test_throttler_unbuffered.interval.into(), true);
690 for each_event in clock.match_handlers(time_events) {
691 drop(clock); each_event.callback.call(each_event.event);
694
695 clock = test_throttler_unbuffered.clock.borrow_mut();
697 }
698 }
699
700 assert_eq!(throttler.clock.borrow().timer_count(), 0);
701 assert!(!throttler.is_limiting);
702 assert_eq!(throttler.used(), 0.0);
703 assert_eq!(throttler.recv_count, 6);
704 assert_eq!(throttler.sent_count, 5);
705 }
706
707 #[rstest]
708 fn test_dropping_send_message_after_dropping_message(test_throttler_unbuffered: TestThrottler) {
709 let throttler = test_throttler_unbuffered.get_throttler();
710 for _ in 0..6 {
711 throttler.send(42);
712 }
713
714 {
716 let mut clock = test_throttler_unbuffered.clock.borrow_mut();
717 let time_events = clock.advance_time(test_throttler_unbuffered.interval.into(), true);
718 for each_event in clock.match_handlers(time_events) {
719 drop(clock); each_event.callback.call(each_event.event);
722
723 clock = test_throttler_unbuffered.clock.borrow_mut();
725 }
726 }
727
728 throttler.send(42);
729
730 assert_eq!(throttler.used(), 0.2);
731 assert_eq!(throttler.clock.borrow().timer_count(), 0);
732 assert!(!throttler.is_limiting);
733 assert_eq!(throttler.recv_count, 7);
734 assert_eq!(throttler.sent_count, 6);
735 }
736
737 use proptest::prelude::*;
742
743 #[derive(Clone, Debug)]
744 enum ThrottlerInput {
745 SendMessage(u64),
746 AdvanceClock(u8),
747 }
748
749 fn throttler_input_strategy() -> impl Strategy<Value = ThrottlerInput> {
751 prop_oneof![
752 2 => prop::bool::ANY.prop_map(|_| ThrottlerInput::SendMessage(42)),
753 8 => prop::num::u8::ANY.prop_map(|v| ThrottlerInput::AdvanceClock(v % 5 + 5)),
754 ]
755 }
756
757 fn throttler_test_strategy() -> impl Strategy<Value = Vec<ThrottlerInput>> {
759 prop::collection::vec(throttler_input_strategy(), 10..=150)
760 }
761
762 fn test_throttler_with_inputs(inputs: Vec<ThrottlerInput>, test_throttler: &TestThrottler) {
763 let test_clock = test_throttler.clock.clone();
764 let interval = test_throttler.interval;
765 let throttler = test_throttler.get_throttler();
766 let mut sent_count = 0;
767
768 for input in inputs {
769 match input {
770 ThrottlerInput::SendMessage(msg) => {
771 throttler.send(msg);
772 sent_count += 1;
773 }
774 ThrottlerInput::AdvanceClock(duration) => {
775 let mut clock_ref = test_clock.borrow_mut();
776 let current_time = clock_ref.get_time_ns();
777 let time_events =
778 clock_ref.advance_time(current_time + u64::from(duration), true);
779 for each_event in clock_ref.match_handlers(time_events) {
780 drop(clock_ref);
781 each_event.callback.call(each_event.event);
782 clock_ref = test_clock.borrow_mut();
783 }
784 }
785 }
786
787 let buffered_messages = throttler.qsize() > 0;
792 let now = throttler.clock.borrow().timestamp_ns().as_u64();
793 let limit_filled_within_interval = throttler
794 .timestamps
795 .get(throttler.limit - 1)
796 .is_some_and(|&ts| (now - ts.as_u64()) < interval);
797 let expected_limiting = buffered_messages && limit_filled_within_interval;
798 assert_eq!(throttler.is_limiting, expected_limiting);
799
800 assert_eq!(sent_count, throttler.sent_count + throttler.qsize());
802 }
803
804 for i in 1..=100u64 {
808 if throttler.qsize() == 0 {
809 break;
810 }
811 let advance_to = interval * 100 * i;
812 let time_events = test_clock
813 .borrow_mut()
814 .advance_time(advance_to.into(), true);
815 let mut clock_ref = test_clock.borrow_mut();
816 for each_event in clock_ref.match_handlers(time_events) {
817 drop(clock_ref);
818 each_event.callback.call(each_event.event);
819 clock_ref = test_clock.borrow_mut();
820 }
821 }
822 assert_eq!(throttler.qsize(), 0);
823 }
824
825 #[rstest]
826 fn prop_test() {
827 proptest!(|(inputs in throttler_test_strategy())| {
830 let test_throttler = test_throttler_buffered();
831 test_throttler_with_inputs(inputs, &test_throttler);
832 });
833 }
834
835 #[rstest]
836 fn test_throttler_process_id_returns_ustr() {
837 let actor_id = Ustr::from("test_throttler");
840 let process = ThrottlerProcess::<String, fn(String)>::new(actor_id);
841
842 let handler_id: Ustr = process.id();
844
845 assert!(handler_id.as_str().contains("test_throttler_process"));
847 assert!(!handler_id.is_empty());
848
849 let _type_check: Ustr = handler_id;
851 }
852}