Skip to main content

nautilus_indicators/average/
wma.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::Display;
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_core::correctness::{FAILED, check_predicate_true};
20use nautilus_model::{
21    data::{Bar, QuoteTick, TradeTick},
22    enums::PriceType,
23};
24
25use crate::indicator::{Indicator, MovingAverage};
26
27const MAX_PERIOD: usize = 8_192;
28
29/// An indicator which calculates a weighted moving average across a rolling window.
30#[repr(C)]
31#[derive(Debug)]
32#[cfg_attr(
33    feature = "python",
34    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
35)]
36#[cfg_attr(
37    feature = "python",
38    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
39)]
40pub struct WeightedMovingAverage {
41    /// The rolling window period for the indicator (> 0).
42    pub period: usize,
43    /// The weights for the moving average calculation
44    pub weights: Vec<f64>,
45    /// Price type
46    pub price_type: PriceType,
47    /// The last indicator value.
48    pub value: f64,
49    /// Whether the indicator is initialized.
50    pub initialized: bool,
51    /// Inputs
52    pub inputs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
53}
54
55impl Display for WeightedMovingAverage {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        write!(f, "{}({},{:?})", self.name(), self.period, self.weights)
58    }
59}
60
61impl WeightedMovingAverage {
62    /// Creates a new [`WeightedMovingAverage`] instance.
63    ///
64    /// # Panics
65    ///
66    /// This function panics if:
67    /// - `period` is zero.
68    /// - `weights.len()` does not equal `period`.
69    /// - `weights` sum is effectively zero.
70    #[must_use]
71    pub fn new(period: usize, weights: Vec<f64>, price_type: Option<PriceType>) -> Self {
72        Self::new_checked(period, weights, price_type).expect(FAILED)
73    }
74
75    /// Creates a new [`WeightedMovingAverage`] instance with the given period and weights.
76    ///
77    /// # Errors
78    ///
79    /// Returns an error if **any** of the validation rules fails:
80    /// - `period` must be **positive**.
81    /// - `weights` must be **exactly** `period` elements long.
82    /// - `weights` must contain at least one non-zero value (∑wᵢ > ε).
83    pub fn new_checked(
84        period: usize,
85        weights: Vec<f64>,
86        price_type: Option<PriceType>,
87    ) -> anyhow::Result<Self> {
88        const EPS: f64 = f64::EPSILON;
89
90        check_predicate_true(period > 0, "`period` must be positive")?;
91
92        check_predicate_true(
93            period == weights.len(),
94            "`period` must equal `weights.len()`",
95        )?;
96
97        let weight_sum: f64 = weights.iter().copied().sum();
98        check_predicate_true(
99            weight_sum > EPS,
100            "`weights` sum must be positive and > f64::EPSILON",
101        )?;
102
103        Ok(Self {
104            period,
105            weights,
106            price_type: price_type.unwrap_or(PriceType::Last),
107            value: 0.0,
108            inputs: ArrayDeque::new(),
109            initialized: false,
110        })
111    }
112
113    fn weighted_average(&self) -> f64 {
114        let n = self.inputs.len();
115        let weights_slice = &self.weights[self.period - n..];
116
117        let mut sum = 0.0;
118        let mut weight_sum = 0.0;
119
120        for (input, weight) in self.inputs.iter().rev().zip(weights_slice.iter().rev()) {
121            sum += input * weight;
122            weight_sum += weight;
123        }
124        sum / weight_sum
125    }
126}
127
128impl Indicator for WeightedMovingAverage {
129    fn name(&self) -> String {
130        stringify!(WeightedMovingAverage).to_string()
131    }
132
133    fn has_inputs(&self) -> bool {
134        !self.inputs.is_empty()
135    }
136
137    fn initialized(&self) -> bool {
138        self.initialized
139    }
140
141    fn handle_quote(&mut self, quote: &QuoteTick) {
142        self.update_raw(quote.extract_price(self.price_type).into());
143    }
144
145    fn handle_trade(&mut self, trade: &TradeTick) {
146        self.update_raw((&trade.price).into());
147    }
148
149    fn handle_bar(&mut self, bar: &Bar) {
150        self.update_raw((&bar.close).into());
151    }
152
153    fn reset(&mut self) {
154        self.value = 0.0;
155        self.initialized = false;
156        self.inputs.clear();
157    }
158}
159
160impl MovingAverage for WeightedMovingAverage {
161    fn value(&self) -> f64 {
162        self.value
163    }
164
165    fn count(&self) -> usize {
166        self.inputs.len()
167    }
168
169    fn update_raw(&mut self, value: f64) {
170        if self.inputs.len() == self.period.min(MAX_PERIOD) {
171            self.inputs.pop_front();
172        }
173        let _ = self.inputs.push_back(value);
174
175        self.value = self.weighted_average();
176        self.initialized = self.count() >= self.period;
177    }
178}
179
180#[cfg(test)]
181mod tests {
182
183    use arraydeque::{ArrayDeque, Wrapping};
184    use rstest::rstest;
185
186    use crate::{
187        average::wma::WeightedMovingAverage,
188        indicator::{Indicator, MovingAverage},
189        stubs::*,
190    };
191
192    #[rstest]
193    fn test_wma_initialized(indicator_wma_10: WeightedMovingAverage) {
194        let display_str = format!("{indicator_wma_10}");
195        assert_eq!(
196            display_str,
197            "WeightedMovingAverage(10,[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])"
198        );
199        assert_eq!(indicator_wma_10.name(), "WeightedMovingAverage");
200        assert!(!indicator_wma_10.has_inputs());
201        assert!(!indicator_wma_10.initialized());
202    }
203
204    #[rstest]
205    #[should_panic]
206    fn test_different_weights_len_and_period_error() {
207        let _ = WeightedMovingAverage::new(10, vec![0.5, 0.5, 0.5], None);
208    }
209
210    #[rstest]
211    fn test_value_with_one_input(mut indicator_wma_10: WeightedMovingAverage) {
212        indicator_wma_10.update_raw(1.0);
213        assert_eq!(indicator_wma_10.value, 1.0);
214    }
215
216    #[rstest]
217    fn test_value_with_two_inputs_equal_weights() {
218        let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
219        wma.update_raw(1.0);
220        wma.update_raw(2.0);
221        assert_eq!(wma.value, 1.5);
222    }
223
224    #[rstest]
225    fn test_value_with_four_inputs_equal_weights() {
226        let mut wma = WeightedMovingAverage::new(4, vec![0.25, 0.25, 0.25, 0.25], None);
227        wma.update_raw(1.0);
228        wma.update_raw(2.0);
229        wma.update_raw(3.0);
230        wma.update_raw(4.0);
231        assert_eq!(wma.value, 2.5);
232    }
233
234    #[rstest]
235    fn test_value_with_two_inputs(mut indicator_wma_10: WeightedMovingAverage) {
236        indicator_wma_10.update_raw(1.0);
237        indicator_wma_10.update_raw(2.0);
238        let result = 2.0f64.mul_add(1.0, 1.0 * 0.9) / 1.9;
239        assert_eq!(indicator_wma_10.value, result);
240    }
241
242    #[rstest]
243    fn test_value_with_three_inputs(mut indicator_wma_10: WeightedMovingAverage) {
244        indicator_wma_10.update_raw(1.0);
245        indicator_wma_10.update_raw(2.0);
246        indicator_wma_10.update_raw(3.0);
247        let result = 1.0f64.mul_add(0.8, 3.0f64.mul_add(1.0, 2.0 * 0.9)) / (1.0 + 0.9 + 0.8);
248        assert_eq!(indicator_wma_10.value, result);
249    }
250
251    #[rstest]
252    fn test_value_expected_with_exact_period(mut indicator_wma_10: WeightedMovingAverage) {
253        for i in 1..11 {
254            indicator_wma_10.update_raw(f64::from(i));
255        }
256        assert_eq!(indicator_wma_10.value, 7.0);
257    }
258
259    #[rstest]
260    fn test_value_expected_with_more_inputs(mut indicator_wma_10: WeightedMovingAverage) {
261        for i in 1..=11 {
262            indicator_wma_10.update_raw(f64::from(i));
263        }
264        assert_eq!(indicator_wma_10.value(), 8.000_000_000_000_002);
265    }
266
267    #[rstest]
268    fn test_reset(mut indicator_wma_10: WeightedMovingAverage) {
269        indicator_wma_10.update_raw(1.0);
270        indicator_wma_10.update_raw(2.0);
271        indicator_wma_10.reset();
272        assert_eq!(indicator_wma_10.value, 0.0);
273        assert_eq!(indicator_wma_10.count(), 0);
274        assert!(!indicator_wma_10.initialized);
275    }
276
277    #[rstest]
278    #[should_panic]
279    fn new_panics_on_zero_period() {
280        let _ = WeightedMovingAverage::new(0, vec![1.0], None);
281    }
282
283    #[rstest]
284    fn new_checked_err_on_zero_period() {
285        let res = WeightedMovingAverage::new_checked(0, vec![1.0], None);
286        assert!(res.is_err());
287    }
288
289    #[rstest]
290    #[should_panic]
291    fn new_panics_on_zero_weight_sum() {
292        let _ = WeightedMovingAverage::new(3, vec![0.0, 0.0, 0.0], None);
293    }
294
295    #[rstest]
296    fn new_checked_err_on_zero_weight_sum() {
297        let res = WeightedMovingAverage::new_checked(3, vec![0.0, 0.0, 0.0], None);
298        assert!(res.is_err());
299    }
300
301    #[rstest]
302    #[should_panic]
303    fn new_panics_when_weight_sum_below_epsilon() {
304        let tiny = f64::EPSILON / 10.0;
305        let _ = WeightedMovingAverage::new(3, vec![tiny; 3], None);
306    }
307
308    #[rstest]
309    fn initialized_flag_transitions() {
310        let period = 3;
311        let weights = vec![1.0, 2.0, 3.0];
312        let mut wma = WeightedMovingAverage::new(period, weights, None);
313
314        assert!(!wma.initialized());
315
316        for i in 0..period {
317            wma.update_raw(i as f64);
318            let expected = (i + 1) >= period;
319            assert_eq!(wma.initialized(), expected);
320        }
321        assert!(wma.initialized());
322    }
323
324    #[rstest]
325    fn count_matches_inputs_and_has_inputs() {
326        let mut wma = WeightedMovingAverage::new(4, vec![0.25; 4], None);
327
328        assert_eq!(wma.count(), 0);
329        assert!(!wma.has_inputs());
330
331        wma.update_raw(1.0);
332        wma.update_raw(2.0);
333        assert_eq!(wma.count(), 2);
334        assert!(wma.has_inputs());
335    }
336
337    #[rstest]
338    fn reset_restores_pristine_state() {
339        let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
340        wma.update_raw(1.0);
341        wma.update_raw(2.0);
342        assert!(wma.initialized());
343
344        wma.reset();
345
346        assert_eq!(wma.count(), 0);
347        assert_eq!(wma.value(), 0.0);
348        assert!(!wma.initialized());
349        assert!(!wma.has_inputs());
350    }
351
352    #[rstest]
353    fn weighted_average_with_non_uniform_weights() {
354        let mut wma = WeightedMovingAverage::new(3, vec![1.0, 2.0, 3.0], None);
355        wma.update_raw(10.0);
356        wma.update_raw(20.0);
357        wma.update_raw(30.0);
358        let expected = 23.333_333_333_333_332;
359        let tol = f64::EPSILON.sqrt();
360        assert!(
361            (wma.value() - expected).abs() < tol,
362            "value = {}, expected ≈ {}",
363            wma.value(),
364            expected
365        );
366    }
367
368    #[rstest]
369    fn test_window_never_exceeds_period(mut indicator_wma_10: WeightedMovingAverage) {
370        for i in 0..100 {
371            indicator_wma_10.update_raw(f64::from(i));
372            assert!(indicator_wma_10.count() <= indicator_wma_10.period);
373        }
374    }
375
376    #[rstest]
377    fn test_negative_weights_positive_sum() {
378        let period = 3;
379        let weights = vec![-1.0, 2.0, 2.0];
380        let mut wma = WeightedMovingAverage::new(period, weights, None);
381        wma.update_raw(1.0);
382        wma.update_raw(2.0);
383        wma.update_raw(3.0);
384
385        let expected = 2.0f64.mul_add(3.0, 2.0f64.mul_add(2.0, -1.0)) / 3.0;
386        let tol = f64::EPSILON.sqrt();
387        assert!((wma.value() - expected).abs() < tol);
388    }
389
390    #[rstest]
391    fn test_nan_input_propagates() {
392        let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
393        wma.update_raw(1.0);
394        wma.update_raw(f64::NAN);
395
396        assert!(wma.value().is_nan());
397    }
398
399    #[rstest]
400    #[should_panic]
401    fn new_panics_when_weight_sum_equals_epsilon() {
402        let eps_third = f64::EPSILON / 3.0;
403        let _ = WeightedMovingAverage::new(3, vec![eps_third; 3], None);
404    }
405
406    #[rstest]
407    fn new_checked_err_when_weight_sum_equals_epsilon() {
408        let eps_third = f64::EPSILON / 3.0;
409        let res = WeightedMovingAverage::new_checked(3, vec![eps_third; 3], None);
410        assert!(res.is_err());
411    }
412
413    #[rstest]
414    fn new_checked_err_when_weight_sum_below_epsilon() {
415        let w = f64::EPSILON * 0.9;
416        let res = WeightedMovingAverage::new_checked(1, vec![w], None);
417        assert!(res.is_err());
418    }
419
420    #[rstest]
421    fn new_ok_when_weight_sum_above_epsilon() {
422        let w = f64::EPSILON * 1.1;
423        let res = WeightedMovingAverage::new_checked(1, vec![w], None);
424        assert!(res.is_ok());
425    }
426
427    #[rstest]
428    #[should_panic]
429    fn new_panics_on_cancelled_weights_sum() {
430        let _ = WeightedMovingAverage::new(3, vec![1.0, -1.0, 0.0], None);
431    }
432
433    #[rstest]
434    fn new_checked_err_on_cancelled_weights_sum() {
435        let res = WeightedMovingAverage::new_checked(3, vec![1.0, -1.0, 0.0], None);
436        assert!(res.is_err());
437    }
438
439    #[rstest]
440    fn single_period_returns_latest_input() {
441        let mut wma = WeightedMovingAverage::new(1, vec![1.0], None);
442
443        for i in 0..5 {
444            let v = f64::from(i);
445            wma.update_raw(v);
446            assert_eq!(wma.value(), v);
447        }
448    }
449
450    #[rstest]
451    fn value_with_sparse_weights() {
452        let mut wma = WeightedMovingAverage::new(3, vec![0.0, 1.0, 0.0], None);
453        wma.update_raw(10.0);
454        wma.update_raw(20.0);
455        wma.update_raw(30.0);
456        assert_eq!(wma.value(), 20.0);
457    }
458
459    #[rstest]
460    fn warm_up_len1() {
461        let mut wma = WeightedMovingAverage::new(4, vec![1.0, 2.0, 3.0, 4.0], None);
462        wma.update_raw(42.0);
463        assert_eq!(wma.value(), 42.0);
464    }
465
466    #[rstest]
467    fn warm_up_len2() {
468        let mut wma = WeightedMovingAverage::new(4, vec![1.0, 2.0, 3.0, 4.0], None);
469        wma.update_raw(10.0);
470        wma.update_raw(20.0);
471        let expected = 20.0f64.mul_add(4.0, 10.0 * 3.0) / (4.0 + 3.0);
472        assert_eq!(wma.value(), expected);
473    }
474
475    #[rstest]
476    fn warm_up_len3() {
477        let mut wma = WeightedMovingAverage::new(4, vec![1.0, 2.0, 3.0, 4.0], None);
478        wma.update_raw(1.0);
479        wma.update_raw(2.0);
480        wma.update_raw(3.0);
481        let expected = 1.0f64.mul_add(2.0, 3.0f64.mul_add(4.0, 2.0 * 3.0)) / (4.0 + 3.0 + 2.0);
482        assert_eq!(wma.value(), expected);
483    }
484
485    #[rstest]
486    fn input_window_contains_latest_period() {
487        let period = 3;
488        let mut wma = WeightedMovingAverage::new(period, vec![1.0; period], None);
489        let vals = [1.0, 2.0, 3.0, 4.0];
490        for v in vals {
491            wma.update_raw(v);
492        }
493        let expected: Vec<f64> = vals[vals.len() - period..].to_vec();
494        assert_eq!(wma.inputs.iter().copied().collect::<Vec<_>>(), expected);
495    }
496
497    #[rstest]
498    fn window_slides_correctly() {
499        let mut wma = WeightedMovingAverage::new(2, vec![1.0; 2], None);
500        wma.update_raw(1.0);
501        assert_eq!(wma.inputs.iter().copied().collect::<Vec<_>>(), vec![1.0]);
502        wma.update_raw(2.0);
503        assert_eq!(
504            wma.inputs.iter().copied().collect::<Vec<_>>(),
505            vec![1.0, 2.0]
506        );
507        wma.update_raw(3.0);
508        assert_eq!(
509            wma.inputs.iter().copied().collect::<Vec<_>>(),
510            vec![2.0, 3.0]
511        );
512    }
513
514    #[rstest]
515    fn window_len_constant_after_many_updates() {
516        let period = 5;
517        let mut wma = WeightedMovingAverage::new(period, vec![1.0; period], None);
518        for i in 0..100 {
519            wma.update_raw(i as f64);
520            assert_eq!(wma.inputs.len(), period.min(i + 1));
521        }
522    }
523
524    #[rstest]
525    fn arraydeque_wraps_when_full() {
526        const CAP: usize = 3;
527        let mut buf: ArrayDeque<usize, CAP, Wrapping> = ArrayDeque::new();
528        for i in 0..=CAP {
529            let _ = buf.push_back(i);
530        }
531        assert_eq!(buf.len(), CAP);
532        assert_eq!(buf.front().copied(), Some(1));
533        assert_eq!(buf.back().copied(), Some(3));
534    }
535
536    #[rstest]
537    fn arraydeque_sliding_window_with_pop() {
538        const CAP: usize = 3;
539        let mut buf: ArrayDeque<usize, CAP, Wrapping> = ArrayDeque::new();
540        for i in 0..10 {
541            if buf.len() == CAP {
542                buf.pop_front();
543            }
544            let _ = buf.push_back(i);
545            assert!(buf.len() <= CAP);
546        }
547        assert_eq!(buf.len(), CAP);
548    }
549
550    #[rstest]
551    fn new_ok_with_infinite_weight() {
552        let res = WeightedMovingAverage::new_checked(2, vec![f64::INFINITY, 1.0], None);
553        assert!(res.is_ok());
554    }
555
556    #[rstest]
557    #[should_panic]
558    fn new_panics_on_nan_weight() {
559        let _ = WeightedMovingAverage::new(2, vec![f64::NAN, 1.0], None);
560    }
561
562    #[rstest]
563    #[should_panic]
564    fn new_panics_on_empty_weights() {
565        let _ = WeightedMovingAverage::new(1, Vec::new(), None);
566    }
567
568    #[rstest]
569    fn inf_input_propagates() {
570        let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
571        wma.update_raw(1.0);
572        wma.update_raw(f64::INFINITY);
573        assert!(wma.value().is_infinite());
574    }
575
576    #[rstest]
577    fn warm_up_with_front_zero_weights() {
578        let mut wma = WeightedMovingAverage::new(4, vec![0.0, 0.0, 1.0, 1.0], None);
579        wma.update_raw(10.0);
580        wma.update_raw(20.0);
581        let expected = 20.0f64.mul_add(1.0, 10.0 * 1.0) / 2.0;
582        assert_eq!(wma.value(), expected);
583    }
584}