1use std::fmt::Display;
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_core::correctness::{FAILED, check_predicate_true};
20use nautilus_model::{
21 data::{Bar, QuoteTick, TradeTick},
22 enums::PriceType,
23};
24
25use crate::indicator::{Indicator, MovingAverage};
26
27const MAX_PERIOD: usize = 8_192;
28
29#[repr(C)]
31#[derive(Debug)]
32#[cfg_attr(
33 feature = "python",
34 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
35)]
36#[cfg_attr(
37 feature = "python",
38 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
39)]
40pub struct WeightedMovingAverage {
41 pub period: usize,
43 pub weights: Vec<f64>,
45 pub price_type: PriceType,
47 pub value: f64,
49 pub initialized: bool,
51 pub inputs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
53}
54
55impl Display for WeightedMovingAverage {
56 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57 write!(f, "{}({},{:?})", self.name(), self.period, self.weights)
58 }
59}
60
61impl WeightedMovingAverage {
62 #[must_use]
71 pub fn new(period: usize, weights: Vec<f64>, price_type: Option<PriceType>) -> Self {
72 Self::new_checked(period, weights, price_type).expect(FAILED)
73 }
74
75 pub fn new_checked(
84 period: usize,
85 weights: Vec<f64>,
86 price_type: Option<PriceType>,
87 ) -> anyhow::Result<Self> {
88 const EPS: f64 = f64::EPSILON;
89
90 check_predicate_true(period > 0, "`period` must be positive")?;
91
92 check_predicate_true(
93 period == weights.len(),
94 "`period` must equal `weights.len()`",
95 )?;
96
97 let weight_sum: f64 = weights.iter().copied().sum();
98 check_predicate_true(
99 weight_sum > EPS,
100 "`weights` sum must be positive and > f64::EPSILON",
101 )?;
102
103 Ok(Self {
104 period,
105 weights,
106 price_type: price_type.unwrap_or(PriceType::Last),
107 value: 0.0,
108 inputs: ArrayDeque::new(),
109 initialized: false,
110 })
111 }
112
113 fn weighted_average(&self) -> f64 {
114 let n = self.inputs.len();
115 let weights_slice = &self.weights[self.period - n..];
116
117 let mut sum = 0.0;
118 let mut weight_sum = 0.0;
119
120 for (input, weight) in self.inputs.iter().rev().zip(weights_slice.iter().rev()) {
121 sum += input * weight;
122 weight_sum += weight;
123 }
124 sum / weight_sum
125 }
126}
127
128impl Indicator for WeightedMovingAverage {
129 fn name(&self) -> String {
130 stringify!(WeightedMovingAverage).to_string()
131 }
132
133 fn has_inputs(&self) -> bool {
134 !self.inputs.is_empty()
135 }
136
137 fn initialized(&self) -> bool {
138 self.initialized
139 }
140
141 fn handle_quote(&mut self, quote: &QuoteTick) {
142 self.update_raw(quote.extract_price(self.price_type).into());
143 }
144
145 fn handle_trade(&mut self, trade: &TradeTick) {
146 self.update_raw((&trade.price).into());
147 }
148
149 fn handle_bar(&mut self, bar: &Bar) {
150 self.update_raw((&bar.close).into());
151 }
152
153 fn reset(&mut self) {
154 self.value = 0.0;
155 self.initialized = false;
156 self.inputs.clear();
157 }
158}
159
160impl MovingAverage for WeightedMovingAverage {
161 fn value(&self) -> f64 {
162 self.value
163 }
164
165 fn count(&self) -> usize {
166 self.inputs.len()
167 }
168
169 fn update_raw(&mut self, value: f64) {
170 if self.inputs.len() == self.period.min(MAX_PERIOD) {
171 self.inputs.pop_front();
172 }
173 let _ = self.inputs.push_back(value);
174
175 self.value = self.weighted_average();
176 self.initialized = self.count() >= self.period;
177 }
178}
179
180#[cfg(test)]
181mod tests {
182
183 use arraydeque::{ArrayDeque, Wrapping};
184 use rstest::rstest;
185
186 use crate::{
187 average::wma::WeightedMovingAverage,
188 indicator::{Indicator, MovingAverage},
189 stubs::*,
190 };
191
192 #[rstest]
193 fn test_wma_initialized(indicator_wma_10: WeightedMovingAverage) {
194 let display_str = format!("{indicator_wma_10}");
195 assert_eq!(
196 display_str,
197 "WeightedMovingAverage(10,[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0])"
198 );
199 assert_eq!(indicator_wma_10.name(), "WeightedMovingAverage");
200 assert!(!indicator_wma_10.has_inputs());
201 assert!(!indicator_wma_10.initialized());
202 }
203
204 #[rstest]
205 #[should_panic]
206 fn test_different_weights_len_and_period_error() {
207 let _ = WeightedMovingAverage::new(10, vec![0.5, 0.5, 0.5], None);
208 }
209
210 #[rstest]
211 fn test_value_with_one_input(mut indicator_wma_10: WeightedMovingAverage) {
212 indicator_wma_10.update_raw(1.0);
213 assert_eq!(indicator_wma_10.value, 1.0);
214 }
215
216 #[rstest]
217 fn test_value_with_two_inputs_equal_weights() {
218 let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
219 wma.update_raw(1.0);
220 wma.update_raw(2.0);
221 assert_eq!(wma.value, 1.5);
222 }
223
224 #[rstest]
225 fn test_value_with_four_inputs_equal_weights() {
226 let mut wma = WeightedMovingAverage::new(4, vec![0.25, 0.25, 0.25, 0.25], None);
227 wma.update_raw(1.0);
228 wma.update_raw(2.0);
229 wma.update_raw(3.0);
230 wma.update_raw(4.0);
231 assert_eq!(wma.value, 2.5);
232 }
233
234 #[rstest]
235 fn test_value_with_two_inputs(mut indicator_wma_10: WeightedMovingAverage) {
236 indicator_wma_10.update_raw(1.0);
237 indicator_wma_10.update_raw(2.0);
238 let result = 2.0f64.mul_add(1.0, 1.0 * 0.9) / 1.9;
239 assert_eq!(indicator_wma_10.value, result);
240 }
241
242 #[rstest]
243 fn test_value_with_three_inputs(mut indicator_wma_10: WeightedMovingAverage) {
244 indicator_wma_10.update_raw(1.0);
245 indicator_wma_10.update_raw(2.0);
246 indicator_wma_10.update_raw(3.0);
247 let result = 1.0f64.mul_add(0.8, 3.0f64.mul_add(1.0, 2.0 * 0.9)) / (1.0 + 0.9 + 0.8);
248 assert_eq!(indicator_wma_10.value, result);
249 }
250
251 #[rstest]
252 fn test_value_expected_with_exact_period(mut indicator_wma_10: WeightedMovingAverage) {
253 for i in 1..11 {
254 indicator_wma_10.update_raw(f64::from(i));
255 }
256 assert_eq!(indicator_wma_10.value, 7.0);
257 }
258
259 #[rstest]
260 fn test_value_expected_with_more_inputs(mut indicator_wma_10: WeightedMovingAverage) {
261 for i in 1..=11 {
262 indicator_wma_10.update_raw(f64::from(i));
263 }
264 assert_eq!(indicator_wma_10.value(), 8.000_000_000_000_002);
265 }
266
267 #[rstest]
268 fn test_reset(mut indicator_wma_10: WeightedMovingAverage) {
269 indicator_wma_10.update_raw(1.0);
270 indicator_wma_10.update_raw(2.0);
271 indicator_wma_10.reset();
272 assert_eq!(indicator_wma_10.value, 0.0);
273 assert_eq!(indicator_wma_10.count(), 0);
274 assert!(!indicator_wma_10.initialized);
275 }
276
277 #[rstest]
278 #[should_panic]
279 fn new_panics_on_zero_period() {
280 let _ = WeightedMovingAverage::new(0, vec![1.0], None);
281 }
282
283 #[rstest]
284 fn new_checked_err_on_zero_period() {
285 let res = WeightedMovingAverage::new_checked(0, vec![1.0], None);
286 assert!(res.is_err());
287 }
288
289 #[rstest]
290 #[should_panic]
291 fn new_panics_on_zero_weight_sum() {
292 let _ = WeightedMovingAverage::new(3, vec![0.0, 0.0, 0.0], None);
293 }
294
295 #[rstest]
296 fn new_checked_err_on_zero_weight_sum() {
297 let res = WeightedMovingAverage::new_checked(3, vec![0.0, 0.0, 0.0], None);
298 assert!(res.is_err());
299 }
300
301 #[rstest]
302 #[should_panic]
303 fn new_panics_when_weight_sum_below_epsilon() {
304 let tiny = f64::EPSILON / 10.0;
305 let _ = WeightedMovingAverage::new(3, vec![tiny; 3], None);
306 }
307
308 #[rstest]
309 fn initialized_flag_transitions() {
310 let period = 3;
311 let weights = vec![1.0, 2.0, 3.0];
312 let mut wma = WeightedMovingAverage::new(period, weights, None);
313
314 assert!(!wma.initialized());
315
316 for i in 0..period {
317 wma.update_raw(i as f64);
318 let expected = (i + 1) >= period;
319 assert_eq!(wma.initialized(), expected);
320 }
321 assert!(wma.initialized());
322 }
323
324 #[rstest]
325 fn count_matches_inputs_and_has_inputs() {
326 let mut wma = WeightedMovingAverage::new(4, vec![0.25; 4], None);
327
328 assert_eq!(wma.count(), 0);
329 assert!(!wma.has_inputs());
330
331 wma.update_raw(1.0);
332 wma.update_raw(2.0);
333 assert_eq!(wma.count(), 2);
334 assert!(wma.has_inputs());
335 }
336
337 #[rstest]
338 fn reset_restores_pristine_state() {
339 let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
340 wma.update_raw(1.0);
341 wma.update_raw(2.0);
342 assert!(wma.initialized());
343
344 wma.reset();
345
346 assert_eq!(wma.count(), 0);
347 assert_eq!(wma.value(), 0.0);
348 assert!(!wma.initialized());
349 assert!(!wma.has_inputs());
350 }
351
352 #[rstest]
353 fn weighted_average_with_non_uniform_weights() {
354 let mut wma = WeightedMovingAverage::new(3, vec![1.0, 2.0, 3.0], None);
355 wma.update_raw(10.0);
356 wma.update_raw(20.0);
357 wma.update_raw(30.0);
358 let expected = 23.333_333_333_333_332;
359 let tol = f64::EPSILON.sqrt();
360 assert!(
361 (wma.value() - expected).abs() < tol,
362 "value = {}, expected ≈ {}",
363 wma.value(),
364 expected
365 );
366 }
367
368 #[rstest]
369 fn test_window_never_exceeds_period(mut indicator_wma_10: WeightedMovingAverage) {
370 for i in 0..100 {
371 indicator_wma_10.update_raw(f64::from(i));
372 assert!(indicator_wma_10.count() <= indicator_wma_10.period);
373 }
374 }
375
376 #[rstest]
377 fn test_negative_weights_positive_sum() {
378 let period = 3;
379 let weights = vec![-1.0, 2.0, 2.0];
380 let mut wma = WeightedMovingAverage::new(period, weights, None);
381 wma.update_raw(1.0);
382 wma.update_raw(2.0);
383 wma.update_raw(3.0);
384
385 let expected = 2.0f64.mul_add(3.0, 2.0f64.mul_add(2.0, -1.0)) / 3.0;
386 let tol = f64::EPSILON.sqrt();
387 assert!((wma.value() - expected).abs() < tol);
388 }
389
390 #[rstest]
391 fn test_nan_input_propagates() {
392 let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
393 wma.update_raw(1.0);
394 wma.update_raw(f64::NAN);
395
396 assert!(wma.value().is_nan());
397 }
398
399 #[rstest]
400 #[should_panic]
401 fn new_panics_when_weight_sum_equals_epsilon() {
402 let eps_third = f64::EPSILON / 3.0;
403 let _ = WeightedMovingAverage::new(3, vec![eps_third; 3], None);
404 }
405
406 #[rstest]
407 fn new_checked_err_when_weight_sum_equals_epsilon() {
408 let eps_third = f64::EPSILON / 3.0;
409 let res = WeightedMovingAverage::new_checked(3, vec![eps_third; 3], None);
410 assert!(res.is_err());
411 }
412
413 #[rstest]
414 fn new_checked_err_when_weight_sum_below_epsilon() {
415 let w = f64::EPSILON * 0.9;
416 let res = WeightedMovingAverage::new_checked(1, vec![w], None);
417 assert!(res.is_err());
418 }
419
420 #[rstest]
421 fn new_ok_when_weight_sum_above_epsilon() {
422 let w = f64::EPSILON * 1.1;
423 let res = WeightedMovingAverage::new_checked(1, vec![w], None);
424 assert!(res.is_ok());
425 }
426
427 #[rstest]
428 #[should_panic]
429 fn new_panics_on_cancelled_weights_sum() {
430 let _ = WeightedMovingAverage::new(3, vec![1.0, -1.0, 0.0], None);
431 }
432
433 #[rstest]
434 fn new_checked_err_on_cancelled_weights_sum() {
435 let res = WeightedMovingAverage::new_checked(3, vec![1.0, -1.0, 0.0], None);
436 assert!(res.is_err());
437 }
438
439 #[rstest]
440 fn single_period_returns_latest_input() {
441 let mut wma = WeightedMovingAverage::new(1, vec![1.0], None);
442
443 for i in 0..5 {
444 let v = f64::from(i);
445 wma.update_raw(v);
446 assert_eq!(wma.value(), v);
447 }
448 }
449
450 #[rstest]
451 fn value_with_sparse_weights() {
452 let mut wma = WeightedMovingAverage::new(3, vec![0.0, 1.0, 0.0], None);
453 wma.update_raw(10.0);
454 wma.update_raw(20.0);
455 wma.update_raw(30.0);
456 assert_eq!(wma.value(), 20.0);
457 }
458
459 #[rstest]
460 fn warm_up_len1() {
461 let mut wma = WeightedMovingAverage::new(4, vec![1.0, 2.0, 3.0, 4.0], None);
462 wma.update_raw(42.0);
463 assert_eq!(wma.value(), 42.0);
464 }
465
466 #[rstest]
467 fn warm_up_len2() {
468 let mut wma = WeightedMovingAverage::new(4, vec![1.0, 2.0, 3.0, 4.0], None);
469 wma.update_raw(10.0);
470 wma.update_raw(20.0);
471 let expected = 20.0f64.mul_add(4.0, 10.0 * 3.0) / (4.0 + 3.0);
472 assert_eq!(wma.value(), expected);
473 }
474
475 #[rstest]
476 fn warm_up_len3() {
477 let mut wma = WeightedMovingAverage::new(4, vec![1.0, 2.0, 3.0, 4.0], None);
478 wma.update_raw(1.0);
479 wma.update_raw(2.0);
480 wma.update_raw(3.0);
481 let expected = 1.0f64.mul_add(2.0, 3.0f64.mul_add(4.0, 2.0 * 3.0)) / (4.0 + 3.0 + 2.0);
482 assert_eq!(wma.value(), expected);
483 }
484
485 #[rstest]
486 fn input_window_contains_latest_period() {
487 let period = 3;
488 let mut wma = WeightedMovingAverage::new(period, vec![1.0; period], None);
489 let vals = [1.0, 2.0, 3.0, 4.0];
490 for v in vals {
491 wma.update_raw(v);
492 }
493 let expected: Vec<f64> = vals[vals.len() - period..].to_vec();
494 assert_eq!(wma.inputs.iter().copied().collect::<Vec<_>>(), expected);
495 }
496
497 #[rstest]
498 fn window_slides_correctly() {
499 let mut wma = WeightedMovingAverage::new(2, vec![1.0; 2], None);
500 wma.update_raw(1.0);
501 assert_eq!(wma.inputs.iter().copied().collect::<Vec<_>>(), vec![1.0]);
502 wma.update_raw(2.0);
503 assert_eq!(
504 wma.inputs.iter().copied().collect::<Vec<_>>(),
505 vec![1.0, 2.0]
506 );
507 wma.update_raw(3.0);
508 assert_eq!(
509 wma.inputs.iter().copied().collect::<Vec<_>>(),
510 vec![2.0, 3.0]
511 );
512 }
513
514 #[rstest]
515 fn window_len_constant_after_many_updates() {
516 let period = 5;
517 let mut wma = WeightedMovingAverage::new(period, vec![1.0; period], None);
518 for i in 0..100 {
519 wma.update_raw(i as f64);
520 assert_eq!(wma.inputs.len(), period.min(i + 1));
521 }
522 }
523
524 #[rstest]
525 fn arraydeque_wraps_when_full() {
526 const CAP: usize = 3;
527 let mut buf: ArrayDeque<usize, CAP, Wrapping> = ArrayDeque::new();
528 for i in 0..=CAP {
529 let _ = buf.push_back(i);
530 }
531 assert_eq!(buf.len(), CAP);
532 assert_eq!(buf.front().copied(), Some(1));
533 assert_eq!(buf.back().copied(), Some(3));
534 }
535
536 #[rstest]
537 fn arraydeque_sliding_window_with_pop() {
538 const CAP: usize = 3;
539 let mut buf: ArrayDeque<usize, CAP, Wrapping> = ArrayDeque::new();
540 for i in 0..10 {
541 if buf.len() == CAP {
542 buf.pop_front();
543 }
544 let _ = buf.push_back(i);
545 assert!(buf.len() <= CAP);
546 }
547 assert_eq!(buf.len(), CAP);
548 }
549
550 #[rstest]
551 fn new_ok_with_infinite_weight() {
552 let res = WeightedMovingAverage::new_checked(2, vec![f64::INFINITY, 1.0], None);
553 assert!(res.is_ok());
554 }
555
556 #[rstest]
557 #[should_panic]
558 fn new_panics_on_nan_weight() {
559 let _ = WeightedMovingAverage::new(2, vec![f64::NAN, 1.0], None);
560 }
561
562 #[rstest]
563 #[should_panic]
564 fn new_panics_on_empty_weights() {
565 let _ = WeightedMovingAverage::new(1, Vec::new(), None);
566 }
567
568 #[rstest]
569 fn inf_input_propagates() {
570 let mut wma = WeightedMovingAverage::new(2, vec![0.5, 0.5], None);
571 wma.update_raw(1.0);
572 wma.update_raw(f64::INFINITY);
573 assert!(wma.value().is_infinite());
574 }
575
576 #[rstest]
577 fn warm_up_with_front_zero_weights() {
578 let mut wma = WeightedMovingAverage::new(4, vec![0.0, 0.0, 1.0, 1.0], None);
579 wma.update_raw(10.0);
580 wma.update_raw(20.0);
581 let expected = 20.0f64.mul_add(1.0, 10.0 * 1.0) / 2.0;
582 assert_eq!(wma.value(), expected);
583 }
584}