Skip to main content

nautilus_indicators/average/
sma.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::Display;
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::{
20    data::{Bar, QuoteTick, TradeTick},
21    enums::PriceType,
22};
23
24use crate::indicator::{Indicator, MovingAverage};
25
26const MAX_PERIOD: usize = 1_024;
27
28#[repr(C)]
29#[derive(Debug)]
30#[cfg_attr(
31    feature = "python",
32    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
33)]
34#[cfg_attr(
35    feature = "python",
36    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
37)]
38pub struct SimpleMovingAverage {
39    pub period: usize,
40    pub price_type: PriceType,
41    pub value: f64,
42    sum: f64,
43    pub count: usize,
44    buf: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
45    pub initialized: bool,
46}
47
48impl Display for SimpleMovingAverage {
49    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
50        write!(f, "{}({})", self.name(), self.period)
51    }
52}
53
54impl Indicator for SimpleMovingAverage {
55    fn name(&self) -> String {
56        stringify!(SimpleMovingAverage).into()
57    }
58
59    fn has_inputs(&self) -> bool {
60        self.count > 0
61    }
62
63    fn initialized(&self) -> bool {
64        self.initialized
65    }
66
67    fn handle_quote(&mut self, quote: &QuoteTick) {
68        self.process_raw(quote.extract_price(self.price_type).into());
69    }
70
71    fn handle_trade(&mut self, trade: &TradeTick) {
72        self.process_raw(trade.price.into());
73    }
74
75    fn handle_bar(&mut self, bar: &Bar) {
76        self.process_raw(bar.close.into());
77    }
78
79    fn reset(&mut self) {
80        self.value = 0.0;
81        self.sum = 0.0;
82        self.count = 0;
83        self.buf.clear();
84        self.initialized = false;
85    }
86}
87
88impl MovingAverage for SimpleMovingAverage {
89    fn value(&self) -> f64 {
90        self.value
91    }
92
93    fn count(&self) -> usize {
94        self.count
95    }
96
97    fn update_raw(&mut self, value: f64) {
98        self.process_raw(value);
99    }
100}
101
102impl SimpleMovingAverage {
103    /// Creates a new [`SimpleMovingAverage`] instance.
104    ///
105    /// # Panics
106    ///
107    /// Panics if `period` is not positive (> 0).
108    #[must_use]
109    pub fn new(period: usize, price_type: Option<PriceType>) -> Self {
110        assert!(period > 0, "SimpleMovingAverage: period must be > 0");
111        assert!(
112            period <= MAX_PERIOD,
113            "SimpleMovingAverage: period {period} exceeds MAX_PERIOD ({MAX_PERIOD})"
114        );
115
116        Self {
117            period,
118            price_type: price_type.unwrap_or(PriceType::Last),
119            value: 0.0,
120            sum: 0.0,
121            count: 0,
122            buf: ArrayDeque::new(),
123            initialized: false,
124        }
125    }
126
127    fn process_raw(&mut self, price: f64) {
128        if self.count == self.period {
129            if let Some(oldest) = self.buf.pop_front() {
130                self.sum -= oldest;
131            }
132        } else {
133            self.count += 1;
134        }
135
136        let _ = self.buf.push_back(price);
137        self.sum += price;
138
139        self.value = self.sum / self.count as f64;
140        self.initialized = self.count >= self.period;
141    }
142}
143
144#[cfg(test)]
145mod tests {
146    use arraydeque::{ArrayDeque, Wrapping};
147    use nautilus_model::{
148        data::{QuoteTick, TradeTick},
149        enums::PriceType,
150    };
151    use rstest::rstest;
152
153    use super::MAX_PERIOD;
154    use crate::{
155        average::sma::SimpleMovingAverage,
156        indicator::{Indicator, MovingAverage},
157        stubs::*,
158    };
159
160    #[rstest]
161    fn sma_initialized_state(indicator_sma_10: SimpleMovingAverage) {
162        let display_str = format!("{indicator_sma_10}");
163        assert_eq!(display_str, "SimpleMovingAverage(10)");
164        assert_eq!(indicator_sma_10.period, 10);
165        assert_eq!(indicator_sma_10.price_type, PriceType::Mid);
166        assert_eq!(indicator_sma_10.value, 0.0);
167        assert_eq!(indicator_sma_10.count, 0);
168        assert!(!indicator_sma_10.initialized());
169        assert!(!indicator_sma_10.has_inputs());
170    }
171
172    #[rstest]
173    fn sma_update_raw_exact_period(indicator_sma_10: SimpleMovingAverage) {
174        let mut sma = indicator_sma_10;
175        for i in 1..=10 {
176            sma.update_raw(f64::from(i));
177        }
178        assert!(sma.has_inputs());
179        assert!(sma.initialized());
180        assert_eq!(sma.count, 10);
181        assert_eq!(sma.value, 5.5);
182    }
183
184    #[rstest]
185    fn sma_reset_smoke(indicator_sma_10: SimpleMovingAverage) {
186        let mut sma = indicator_sma_10;
187        sma.update_raw(1.0);
188        assert_eq!(sma.count, 1);
189        sma.reset();
190        assert_eq!(sma.count, 0);
191        assert_eq!(sma.value, 0.0);
192        assert!(!sma.initialized());
193    }
194
195    #[rstest]
196    fn sma_handle_single_quote(indicator_sma_10: SimpleMovingAverage, stub_quote: QuoteTick) {
197        let mut sma = indicator_sma_10;
198        sma.handle_quote(&stub_quote);
199        assert_eq!(sma.count, 1);
200        assert_eq!(sma.value, 1501.0);
201    }
202
203    #[rstest]
204    fn sma_handle_multiple_quotes(indicator_sma_10: SimpleMovingAverage) {
205        let mut sma = indicator_sma_10;
206        let q1 = stub_quote("1500.0", "1502.0");
207        let q2 = stub_quote("1502.0", "1504.0");
208
209        sma.handle_quote(&q1);
210        sma.handle_quote(&q2);
211        assert_eq!(sma.count, 2);
212        assert_eq!(sma.value, 1502.0);
213    }
214
215    #[rstest]
216    fn sma_handle_trade(indicator_sma_10: SimpleMovingAverage, stub_trade: TradeTick) {
217        let mut sma = indicator_sma_10;
218        sma.handle_trade(&stub_trade);
219        assert_eq!(sma.count, 1);
220        assert_eq!(sma.value, 1500.0);
221    }
222
223    #[rstest]
224    #[case(1)]
225    #[case(3)]
226    #[case(5)]
227    #[case(16)]
228    fn count_progression_respects_period(#[case] period: usize) {
229        let mut sma = SimpleMovingAverage::new(period, None);
230
231        for i in 0..(period * 3) {
232            sma.update_raw(i as f64);
233
234            assert!(
235                sma.count() <= period,
236                "period={period}, step={i}, count={}",
237                sma.count()
238            );
239
240            let expected = usize::min(i + 1, period);
241            assert_eq!(
242                sma.count(),
243                expected,
244                "period={period}, step={i}, expected={expected}, was={}",
245                sma.count()
246            );
247        }
248    }
249
250    #[rstest]
251    #[case(1)]
252    #[case(4)]
253    #[case(10)]
254    fn count_after_reset_is_zero(#[case] period: usize) {
255        let mut sma = SimpleMovingAverage::new(period, None);
256
257        for i in 0..(period + 2) {
258            sma.update_raw(i as f64);
259        }
260        assert_eq!(sma.count(), period, "pre-reset saturation failed");
261
262        sma.reset();
263        assert_eq!(sma.count(), 0, "count not reset to zero");
264        assert_eq!(sma.value(), 0.0, "value not reset to zero");
265        assert!(!sma.initialized(), "initialized flag not cleared");
266    }
267
268    #[rstest]
269    fn count_edge_case_period_one() {
270        let mut sma = SimpleMovingAverage::new(1, None);
271
272        sma.update_raw(10.0);
273        assert_eq!(sma.count(), 1);
274        assert_eq!(sma.value(), 10.0);
275
276        sma.update_raw(20.0);
277        assert_eq!(sma.count(), 1, "count exceeded 1 with period==1");
278        assert_eq!(sma.value(), 20.0, "value not equal to latest price");
279    }
280
281    #[rstest]
282    fn sliding_window_correctness() {
283        let mut sma = SimpleMovingAverage::new(3, None);
284
285        let prices = [1.0, 2.0, 3.0, 4.0, 5.0];
286        let expect_avg = [1.0, 1.5, 2.0, 3.0, 4.0];
287
288        for (i, &p) in prices.iter().enumerate() {
289            sma.update_raw(p);
290            assert!(
291                (sma.value() - expect_avg[i]).abs() < 1e-9,
292                "step {i}: expected {}, was {}",
293                expect_avg[i],
294                sma.value()
295            );
296        }
297    }
298
299    #[rstest]
300    #[case(2)]
301    #[case(6)]
302    fn initialized_transitions_with_count(#[case] period: usize) {
303        let mut sma = SimpleMovingAverage::new(period, None);
304
305        for i in 0..(period - 1) {
306            sma.update_raw(i as f64);
307            assert!(
308                !sma.initialized(),
309                "initialized early at i={i} (period={period})"
310            );
311        }
312
313        sma.update_raw(42.0);
314        assert_eq!(sma.count(), period);
315        assert!(sma.initialized(), "initialized flag not set at period");
316    }
317
318    #[rstest]
319    #[should_panic(expected = "period must be > 0")]
320    fn sma_new_with_zero_period_panics() {
321        let _ = SimpleMovingAverage::new(0, None);
322    }
323
324    #[rstest]
325    fn sma_rolling_mean_exact_values() {
326        let mut sma = SimpleMovingAverage::new(3, None);
327        let inputs = [1.0, 2.0, 3.0, 4.0, 5.0];
328        let expected = [1.0, 1.5, 2.0, 3.0, 4.0];
329
330        for (&price, &exp_mean) in inputs.iter().zip(expected.iter()) {
331            sma.update_raw(price);
332            assert!(
333                (sma.value() - exp_mean).abs() < 1e-12,
334                "input={price}, expected={exp_mean}, was={}",
335                sma.value()
336            );
337        }
338    }
339
340    #[rstest]
341    fn sma_matches_reference_implementation() {
342        const PERIOD: usize = 5;
343        let mut sma = SimpleMovingAverage::new(PERIOD, None);
344        let mut window: ArrayDeque<f64, PERIOD, Wrapping> = ArrayDeque::new();
345
346        for step in 0..20 {
347            let price = f64::from(step) * 10.0;
348            sma.update_raw(price);
349
350            if window.len() == PERIOD {
351                window.pop_front();
352            }
353            let _ = window.push_back(price);
354
355            let ref_mean: f64 = window.iter().sum::<f64>() / window.len() as f64;
356            assert!(
357                (sma.value() - ref_mean).abs() < 1e-12,
358                "step={step}, expected={ref_mean}, was={}",
359                sma.value()
360            );
361        }
362    }
363
364    #[rstest]
365    #[case(f64::NAN)]
366    #[case(f64::INFINITY)]
367    #[case(f64::NEG_INFINITY)]
368    fn sma_handles_bad_floats(#[case] bad: f64) {
369        let mut sma = SimpleMovingAverage::new(3, None);
370        sma.update_raw(1.0);
371        sma.update_raw(bad);
372        sma.update_raw(3.0);
373        assert!(
374            sma.value().is_nan() || !sma.value().is_finite(),
375            "bad float not propagated"
376        );
377    }
378
379    #[rstest]
380    fn deque_and_count_always_match() {
381        const PERIOD: usize = 8;
382        let mut sma = SimpleMovingAverage::new(PERIOD, None);
383        for i in 0..50 {
384            sma.update_raw(f64::from(i));
385            assert!(
386                sma.buf.len() == sma.count,
387                "buf.len() != count at step {i}: {} != {}",
388                sma.buf.len(),
389                sma.count
390            );
391        }
392    }
393
394    #[rstest]
395    fn sma_multiple_resets() {
396        let mut sma = SimpleMovingAverage::new(4, None);
397
398        for cycle in 0..5 {
399            for x in 0..4 {
400                sma.update_raw(f64::from(x));
401            }
402            assert!(sma.initialized(), "cycle {cycle}: not initialized");
403            sma.reset();
404            assert_eq!(sma.count(), 0);
405            assert_eq!(sma.value(), 0.0);
406            assert!(!sma.initialized());
407        }
408    }
409
410    #[rstest]
411    fn sma_buffer_never_exceeds_capacity() {
412        const PERIOD: usize = MAX_PERIOD;
413        let mut sma = super::SimpleMovingAverage::new(PERIOD, None);
414
415        for i in 0..(PERIOD * 2) {
416            sma.update_raw(i as f64);
417
418            assert!(
419                sma.buf.len() <= PERIOD,
420                "step {i}: buf.len()={}, exceeds PERIOD={PERIOD}",
421                sma.buf.len(),
422            );
423        }
424        assert!(
425            sma.buf.is_full(),
426            "buffer not reported as full after saturation"
427        );
428        assert_eq!(
429            sma.count(),
430            PERIOD,
431            "count diverged from logical window length"
432        );
433    }
434
435    #[rstest]
436    fn sma_deque_eviction_order() {
437        let mut sma = super::SimpleMovingAverage::new(3, None);
438
439        sma.update_raw(1.0);
440        sma.update_raw(2.0);
441        sma.update_raw(3.0);
442        sma.update_raw(4.0);
443
444        assert_eq!(sma.buf.front().copied(), Some(2.0), "oldest element wrong");
445        assert_eq!(sma.buf.back().copied(), Some(4.0), "newest element wrong");
446
447        assert!(
448            (sma.value() - 3.0).abs() < 1e-12,
449            "unexpected mean after eviction: {}",
450            sma.value()
451        );
452    }
453
454    #[rstest]
455    fn sma_sum_consistent_with_buffer() {
456        const PERIOD: usize = 7;
457        let mut sma = super::SimpleMovingAverage::new(PERIOD, None);
458
459        for i in 0..40 {
460            sma.update_raw(f64::from(i));
461
462            let deque_sum: f64 = sma.buf.iter().copied().sum();
463            assert!(
464                (sma.sum - deque_sum).abs() < 1e-12,
465                "step {i}: internal sum={} differs from buf sum={}",
466                sma.sum,
467                deque_sum
468            );
469        }
470    }
471}