nautilus_indicators/average/
sma.rs1use std::fmt::Display;
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::{
20 data::{Bar, QuoteTick, TradeTick},
21 enums::PriceType,
22};
23
24use crate::indicator::{Indicator, MovingAverage};
25
26const MAX_PERIOD: usize = 1_024;
27
28#[repr(C)]
29#[derive(Debug)]
30#[cfg_attr(
31 feature = "python",
32 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
33)]
34#[cfg_attr(
35 feature = "python",
36 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
37)]
38pub struct SimpleMovingAverage {
39 pub period: usize,
40 pub price_type: PriceType,
41 pub value: f64,
42 sum: f64,
43 pub count: usize,
44 buf: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
45 pub initialized: bool,
46}
47
48impl Display for SimpleMovingAverage {
49 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50 write!(f, "{}({})", self.name(), self.period)
51 }
52}
53
54impl Indicator for SimpleMovingAverage {
55 fn name(&self) -> String {
56 stringify!(SimpleMovingAverage).into()
57 }
58
59 fn has_inputs(&self) -> bool {
60 self.count > 0
61 }
62
63 fn initialized(&self) -> bool {
64 self.initialized
65 }
66
67 fn handle_quote(&mut self, quote: &QuoteTick) {
68 self.process_raw(quote.extract_price(self.price_type).into());
69 }
70
71 fn handle_trade(&mut self, trade: &TradeTick) {
72 self.process_raw(trade.price.into());
73 }
74
75 fn handle_bar(&mut self, bar: &Bar) {
76 self.process_raw(bar.close.into());
77 }
78
79 fn reset(&mut self) {
80 self.value = 0.0;
81 self.sum = 0.0;
82 self.count = 0;
83 self.buf.clear();
84 self.initialized = false;
85 }
86}
87
88impl MovingAverage for SimpleMovingAverage {
89 fn value(&self) -> f64 {
90 self.value
91 }
92
93 fn count(&self) -> usize {
94 self.count
95 }
96
97 fn update_raw(&mut self, value: f64) {
98 self.process_raw(value);
99 }
100}
101
102impl SimpleMovingAverage {
103 #[must_use]
109 pub fn new(period: usize, price_type: Option<PriceType>) -> Self {
110 assert!(period > 0, "SimpleMovingAverage: period must be > 0");
111 assert!(
112 period <= MAX_PERIOD,
113 "SimpleMovingAverage: period {period} exceeds MAX_PERIOD ({MAX_PERIOD})"
114 );
115
116 Self {
117 period,
118 price_type: price_type.unwrap_or(PriceType::Last),
119 value: 0.0,
120 sum: 0.0,
121 count: 0,
122 buf: ArrayDeque::new(),
123 initialized: false,
124 }
125 }
126
127 fn process_raw(&mut self, price: f64) {
128 if self.count == self.period {
129 if let Some(oldest) = self.buf.pop_front() {
130 self.sum -= oldest;
131 }
132 } else {
133 self.count += 1;
134 }
135
136 let _ = self.buf.push_back(price);
137 self.sum += price;
138
139 self.value = self.sum / self.count as f64;
140 self.initialized = self.count >= self.period;
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use arraydeque::{ArrayDeque, Wrapping};
147 use nautilus_model::{
148 data::{QuoteTick, TradeTick},
149 enums::PriceType,
150 };
151 use rstest::rstest;
152
153 use super::MAX_PERIOD;
154 use crate::{
155 average::sma::SimpleMovingAverage,
156 indicator::{Indicator, MovingAverage},
157 stubs::*,
158 };
159
160 #[rstest]
161 fn sma_initialized_state(indicator_sma_10: SimpleMovingAverage) {
162 let display_str = format!("{indicator_sma_10}");
163 assert_eq!(display_str, "SimpleMovingAverage(10)");
164 assert_eq!(indicator_sma_10.period, 10);
165 assert_eq!(indicator_sma_10.price_type, PriceType::Mid);
166 assert_eq!(indicator_sma_10.value, 0.0);
167 assert_eq!(indicator_sma_10.count, 0);
168 assert!(!indicator_sma_10.initialized());
169 assert!(!indicator_sma_10.has_inputs());
170 }
171
172 #[rstest]
173 fn sma_update_raw_exact_period(indicator_sma_10: SimpleMovingAverage) {
174 let mut sma = indicator_sma_10;
175 for i in 1..=10 {
176 sma.update_raw(f64::from(i));
177 }
178 assert!(sma.has_inputs());
179 assert!(sma.initialized());
180 assert_eq!(sma.count, 10);
181 assert_eq!(sma.value, 5.5);
182 }
183
184 #[rstest]
185 fn sma_reset_smoke(indicator_sma_10: SimpleMovingAverage) {
186 let mut sma = indicator_sma_10;
187 sma.update_raw(1.0);
188 assert_eq!(sma.count, 1);
189 sma.reset();
190 assert_eq!(sma.count, 0);
191 assert_eq!(sma.value, 0.0);
192 assert!(!sma.initialized());
193 }
194
195 #[rstest]
196 fn sma_handle_single_quote(indicator_sma_10: SimpleMovingAverage, stub_quote: QuoteTick) {
197 let mut sma = indicator_sma_10;
198 sma.handle_quote(&stub_quote);
199 assert_eq!(sma.count, 1);
200 assert_eq!(sma.value, 1501.0);
201 }
202
203 #[rstest]
204 fn sma_handle_multiple_quotes(indicator_sma_10: SimpleMovingAverage) {
205 let mut sma = indicator_sma_10;
206 let q1 = stub_quote("1500.0", "1502.0");
207 let q2 = stub_quote("1502.0", "1504.0");
208
209 sma.handle_quote(&q1);
210 sma.handle_quote(&q2);
211 assert_eq!(sma.count, 2);
212 assert_eq!(sma.value, 1502.0);
213 }
214
215 #[rstest]
216 fn sma_handle_trade(indicator_sma_10: SimpleMovingAverage, stub_trade: TradeTick) {
217 let mut sma = indicator_sma_10;
218 sma.handle_trade(&stub_trade);
219 assert_eq!(sma.count, 1);
220 assert_eq!(sma.value, 1500.0);
221 }
222
223 #[rstest]
224 #[case(1)]
225 #[case(3)]
226 #[case(5)]
227 #[case(16)]
228 fn count_progression_respects_period(#[case] period: usize) {
229 let mut sma = SimpleMovingAverage::new(period, None);
230
231 for i in 0..(period * 3) {
232 sma.update_raw(i as f64);
233
234 assert!(
235 sma.count() <= period,
236 "period={period}, step={i}, count={}",
237 sma.count()
238 );
239
240 let expected = usize::min(i + 1, period);
241 assert_eq!(
242 sma.count(),
243 expected,
244 "period={period}, step={i}, expected={expected}, was={}",
245 sma.count()
246 );
247 }
248 }
249
250 #[rstest]
251 #[case(1)]
252 #[case(4)]
253 #[case(10)]
254 fn count_after_reset_is_zero(#[case] period: usize) {
255 let mut sma = SimpleMovingAverage::new(period, None);
256
257 for i in 0..(period + 2) {
258 sma.update_raw(i as f64);
259 }
260 assert_eq!(sma.count(), period, "pre-reset saturation failed");
261
262 sma.reset();
263 assert_eq!(sma.count(), 0, "count not reset to zero");
264 assert_eq!(sma.value(), 0.0, "value not reset to zero");
265 assert!(!sma.initialized(), "initialized flag not cleared");
266 }
267
268 #[rstest]
269 fn count_edge_case_period_one() {
270 let mut sma = SimpleMovingAverage::new(1, None);
271
272 sma.update_raw(10.0);
273 assert_eq!(sma.count(), 1);
274 assert_eq!(sma.value(), 10.0);
275
276 sma.update_raw(20.0);
277 assert_eq!(sma.count(), 1, "count exceeded 1 with period==1");
278 assert_eq!(sma.value(), 20.0, "value not equal to latest price");
279 }
280
281 #[rstest]
282 fn sliding_window_correctness() {
283 let mut sma = SimpleMovingAverage::new(3, None);
284
285 let prices = [1.0, 2.0, 3.0, 4.0, 5.0];
286 let expect_avg = [1.0, 1.5, 2.0, 3.0, 4.0];
287
288 for (i, &p) in prices.iter().enumerate() {
289 sma.update_raw(p);
290 assert!(
291 (sma.value() - expect_avg[i]).abs() < 1e-9,
292 "step {i}: expected {}, was {}",
293 expect_avg[i],
294 sma.value()
295 );
296 }
297 }
298
299 #[rstest]
300 #[case(2)]
301 #[case(6)]
302 fn initialized_transitions_with_count(#[case] period: usize) {
303 let mut sma = SimpleMovingAverage::new(period, None);
304
305 for i in 0..(period - 1) {
306 sma.update_raw(i as f64);
307 assert!(
308 !sma.initialized(),
309 "initialized early at i={i} (period={period})"
310 );
311 }
312
313 sma.update_raw(42.0);
314 assert_eq!(sma.count(), period);
315 assert!(sma.initialized(), "initialized flag not set at period");
316 }
317
318 #[rstest]
319 #[should_panic(expected = "period must be > 0")]
320 fn sma_new_with_zero_period_panics() {
321 let _ = SimpleMovingAverage::new(0, None);
322 }
323
324 #[rstest]
325 fn sma_rolling_mean_exact_values() {
326 let mut sma = SimpleMovingAverage::new(3, None);
327 let inputs = [1.0, 2.0, 3.0, 4.0, 5.0];
328 let expected = [1.0, 1.5, 2.0, 3.0, 4.0];
329
330 for (&price, &exp_mean) in inputs.iter().zip(expected.iter()) {
331 sma.update_raw(price);
332 assert!(
333 (sma.value() - exp_mean).abs() < 1e-12,
334 "input={price}, expected={exp_mean}, was={}",
335 sma.value()
336 );
337 }
338 }
339
340 #[rstest]
341 fn sma_matches_reference_implementation() {
342 const PERIOD: usize = 5;
343 let mut sma = SimpleMovingAverage::new(PERIOD, None);
344 let mut window: ArrayDeque<f64, PERIOD, Wrapping> = ArrayDeque::new();
345
346 for step in 0..20 {
347 let price = f64::from(step) * 10.0;
348 sma.update_raw(price);
349
350 if window.len() == PERIOD {
351 window.pop_front();
352 }
353 let _ = window.push_back(price);
354
355 let ref_mean: f64 = window.iter().sum::<f64>() / window.len() as f64;
356 assert!(
357 (sma.value() - ref_mean).abs() < 1e-12,
358 "step={step}, expected={ref_mean}, was={}",
359 sma.value()
360 );
361 }
362 }
363
364 #[rstest]
365 #[case(f64::NAN)]
366 #[case(f64::INFINITY)]
367 #[case(f64::NEG_INFINITY)]
368 fn sma_handles_bad_floats(#[case] bad: f64) {
369 let mut sma = SimpleMovingAverage::new(3, None);
370 sma.update_raw(1.0);
371 sma.update_raw(bad);
372 sma.update_raw(3.0);
373 assert!(
374 sma.value().is_nan() || !sma.value().is_finite(),
375 "bad float not propagated"
376 );
377 }
378
379 #[rstest]
380 fn deque_and_count_always_match() {
381 const PERIOD: usize = 8;
382 let mut sma = SimpleMovingAverage::new(PERIOD, None);
383 for i in 0..50 {
384 sma.update_raw(f64::from(i));
385 assert!(
386 sma.buf.len() == sma.count,
387 "buf.len() != count at step {i}: {} != {}",
388 sma.buf.len(),
389 sma.count
390 );
391 }
392 }
393
394 #[rstest]
395 fn sma_multiple_resets() {
396 let mut sma = SimpleMovingAverage::new(4, None);
397
398 for cycle in 0..5 {
399 for x in 0..4 {
400 sma.update_raw(f64::from(x));
401 }
402 assert!(sma.initialized(), "cycle {cycle}: not initialized");
403 sma.reset();
404 assert_eq!(sma.count(), 0);
405 assert_eq!(sma.value(), 0.0);
406 assert!(!sma.initialized());
407 }
408 }
409
410 #[rstest]
411 fn sma_buffer_never_exceeds_capacity() {
412 const PERIOD: usize = MAX_PERIOD;
413 let mut sma = super::SimpleMovingAverage::new(PERIOD, None);
414
415 for i in 0..(PERIOD * 2) {
416 sma.update_raw(i as f64);
417
418 assert!(
419 sma.buf.len() <= PERIOD,
420 "step {i}: buf.len()={}, exceeds PERIOD={PERIOD}",
421 sma.buf.len(),
422 );
423 }
424 assert!(
425 sma.buf.is_full(),
426 "buffer not reported as full after saturation"
427 );
428 assert_eq!(
429 sma.count(),
430 PERIOD,
431 "count diverged from logical window length"
432 );
433 }
434
435 #[rstest]
436 fn sma_deque_eviction_order() {
437 let mut sma = super::SimpleMovingAverage::new(3, None);
438
439 sma.update_raw(1.0);
440 sma.update_raw(2.0);
441 sma.update_raw(3.0);
442 sma.update_raw(4.0);
443
444 assert_eq!(sma.buf.front().copied(), Some(2.0), "oldest element wrong");
445 assert_eq!(sma.buf.back().copied(), Some(4.0), "newest element wrong");
446
447 assert!(
448 (sma.value() - 3.0).abs() < 1e-12,
449 "unexpected mean after eviction: {}",
450 sma.value()
451 );
452 }
453
454 #[rstest]
455 fn sma_sum_consistent_with_buffer() {
456 const PERIOD: usize = 7;
457 let mut sma = super::SimpleMovingAverage::new(PERIOD, None);
458
459 for i in 0..40 {
460 sma.update_raw(f64::from(i));
461
462 let deque_sum: f64 = sma.buf.iter().copied().sum();
463 assert!(
464 (sma.sum - deque_sum).abs() < 1e-12,
465 "step {i}: internal sum={} differs from buf sum={}",
466 sma.sum,
467 deque_sum
468 );
469 }
470 }
471}