Skip to main content

nautilus_indicators/average/
lr.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::{Debug, Display};
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::data::Bar;
20
21use crate::indicator::Indicator;
22
23const MAX_PERIOD: usize = 16_384;
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28    feature = "python",
29    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
30)]
31#[cfg_attr(
32    feature = "python",
33    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
34)]
35pub struct LinearRegression {
36    pub period: usize,
37    pub slope: f64,
38    pub intercept: f64,
39    pub degree: f64,
40    pub cfo: f64,
41    pub r2: f64,
42    pub value: f64,
43    pub initialized: bool,
44    has_inputs: bool,
45    inputs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
46    x_sum: f64,
47    x_mul_sum: f64,
48    divisor: f64,
49}
50
51impl Display for LinearRegression {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        write!(f, "{}({})", self.name(), self.period)
54    }
55}
56
57impl Indicator for LinearRegression {
58    fn name(&self) -> String {
59        stringify!(LinearRegression).into()
60    }
61
62    fn has_inputs(&self) -> bool {
63        self.has_inputs
64    }
65
66    fn initialized(&self) -> bool {
67        self.initialized
68    }
69
70    fn handle_bar(&mut self, bar: &Bar) {
71        self.update_raw(bar.close.into());
72    }
73
74    fn reset(&mut self) {
75        self.slope = 0.0;
76        self.intercept = 0.0;
77        self.degree = 0.0;
78        self.cfo = 0.0;
79        self.r2 = 0.0;
80        self.value = 0.0;
81        self.inputs.clear();
82        self.has_inputs = false;
83        self.initialized = false;
84    }
85}
86
87impl LinearRegression {
88    /// Creates a new [`LinearRegression`] instance.
89    ///
90    /// # Panics
91    ///
92    /// This function panics if:
93    /// `period` is zero.
94    /// `period` exceeds `MAX_PERIOD` (16,384).
95    #[must_use]
96    pub fn new(period: usize) -> Self {
97        assert!(
98            period > 0,
99            "LinearRegression: period must be > 0 (received {period})"
100        );
101        assert!(
102            period <= MAX_PERIOD,
103            "LinearRegression: period {period} exceeds MAX_PERIOD ({MAX_PERIOD})"
104        );
105
106        let n = period as f64;
107        let x_sum = 0.5 * n * (n + 1.0);
108        let x_mul_sum = x_sum * 2.0f64.mul_add(n, 1.0) / 3.0;
109        let divisor = n.mul_add(x_mul_sum, -(x_sum * x_sum));
110
111        Self {
112            period,
113            slope: 0.0,
114            intercept: 0.0,
115            degree: 0.0,
116            cfo: 0.0,
117            r2: 0.0,
118            value: 0.0,
119            initialized: false,
120            has_inputs: false,
121            inputs: ArrayDeque::new(),
122            x_sum,
123            x_mul_sum,
124            divisor,
125        }
126    }
127
128    /// Updates the linear regression with a new data point.
129    ///
130    /// # Panics
131    ///
132    /// Panics if called with an empty window – this is protected against by the logic
133    /// that returns early until enough samples have been collected.
134    pub fn update_raw(&mut self, close: f64) {
135        if self.inputs.len() == self.period {
136            let _ = self.inputs.pop_front();
137        }
138        let _ = self.inputs.push_back(close);
139
140        self.has_inputs = true;
141
142        if self.inputs.len() < self.period {
143            return;
144        }
145        self.initialized = true;
146
147        let n = self.period as f64;
148        let x_sum = self.x_sum;
149        let x_mul_sum = self.x_mul_sum;
150        let divisor = self.divisor;
151
152        let (mut y_sum, mut xy_sum) = (0.0, 0.0);
153
154        for (i, &y) in self.inputs.iter().enumerate() {
155            let x = (i + 1) as f64;
156            y_sum += y;
157            xy_sum += x * y;
158        }
159
160        self.slope = n.mul_add(xy_sum, -(x_sum * y_sum)) / divisor;
161        self.intercept = y_sum.mul_add(x_mul_sum, -(x_sum * xy_sum)) / divisor;
162
163        let (mut sse, mut y_last, mut e_last) = (0.0, 0.0, 0.0);
164
165        for (i, &y) in self.inputs.iter().enumerate() {
166            let x = (i + 1) as f64;
167            let y_hat = self.slope.mul_add(x, self.intercept);
168            let resid = y_hat - y;
169            sse += resid * resid;
170            y_last = y;
171            e_last = resid;
172        }
173
174        self.value = y_last + e_last;
175        self.degree = self.slope.atan().to_degrees();
176        self.cfo = if y_last == 0.0 {
177            f64::NAN
178        } else {
179            100.0 * e_last / y_last
180        };
181
182        let mean = y_sum / n;
183        let sst: f64 = self
184            .inputs
185            .iter()
186            .map(|&y| {
187                let d = y - mean;
188                d * d
189            })
190            .sum();
191
192        self.r2 = if sst.abs() < f64::EPSILON {
193            f64::NAN
194        } else {
195            1.0 - sse / sst
196        };
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use nautilus_model::data::Bar;
203    use rstest::rstest;
204
205    use super::*;
206    use crate::{
207        average::lr::LinearRegression,
208        indicator::Indicator,
209        stubs::{bar_ethusdt_binance_minute_bid, indicator_lr_10},
210    };
211
212    #[rstest]
213    fn test_psl_initialized(indicator_lr_10: LinearRegression) {
214        let display_str = format!("{indicator_lr_10}");
215        assert_eq!(display_str, "LinearRegression(10)");
216        assert_eq!(indicator_lr_10.period, 10);
217        assert!(!indicator_lr_10.initialized);
218        assert!(!indicator_lr_10.has_inputs);
219    }
220
221    #[rstest]
222    #[should_panic(expected = "LinearRegression: period must be > 0")]
223    fn test_new_with_zero_period_panics() {
224        let _ = LinearRegression::new(0);
225    }
226
227    #[rstest]
228    fn test_value_with_one_input(mut indicator_lr_10: LinearRegression) {
229        indicator_lr_10.update_raw(1.0);
230        assert_eq!(indicator_lr_10.value, 0.0);
231    }
232
233    #[rstest]
234    fn test_value_with_three_inputs(mut indicator_lr_10: LinearRegression) {
235        indicator_lr_10.update_raw(1.0);
236        indicator_lr_10.update_raw(2.0);
237        indicator_lr_10.update_raw(3.0);
238        assert_eq!(indicator_lr_10.value, 0.0);
239    }
240
241    #[rstest]
242    fn test_initialized_with_required_input(mut indicator_lr_10: LinearRegression) {
243        for i in 1..10 {
244            indicator_lr_10.update_raw(f64::from(i));
245        }
246        assert!(!indicator_lr_10.initialized);
247        indicator_lr_10.update_raw(10.0);
248        assert!(indicator_lr_10.initialized);
249    }
250
251    #[rstest]
252    fn test_handle_bar(mut indicator_lr_10: LinearRegression, bar_ethusdt_binance_minute_bid: Bar) {
253        indicator_lr_10.handle_bar(&bar_ethusdt_binance_minute_bid);
254        assert_eq!(indicator_lr_10.value, 0.0);
255        assert!(indicator_lr_10.has_inputs);
256        assert!(!indicator_lr_10.initialized);
257    }
258
259    #[rstest]
260    fn test_reset(mut indicator_lr_10: LinearRegression) {
261        indicator_lr_10.update_raw(1.0);
262        indicator_lr_10.reset();
263        assert_eq!(indicator_lr_10.value, 0.0);
264        assert_eq!(indicator_lr_10.inputs.len(), 0);
265        assert_eq!(indicator_lr_10.slope, 0.0);
266        assert_eq!(indicator_lr_10.intercept, 0.0);
267        assert_eq!(indicator_lr_10.degree, 0.0);
268        assert_eq!(indicator_lr_10.cfo, 0.0);
269        assert_eq!(indicator_lr_10.r2, 0.0);
270        assert!(!indicator_lr_10.has_inputs);
271        assert!(!indicator_lr_10.initialized);
272    }
273
274    #[rstest]
275    fn test_inputs_len_never_exceeds_period() {
276        let mut lr = LinearRegression::new(3);
277        for i in 0..10 {
278            lr.update_raw(f64::from(i));
279        }
280        assert_eq!(lr.inputs.len(), lr.period);
281    }
282
283    #[rstest]
284    fn test_oldest_element_evicted() {
285        let mut lr = LinearRegression::new(4);
286        for v in 1..=5 {
287            lr.update_raw(f64::from(v));
288        }
289        assert!(!lr.inputs.contains(&1.0));
290        assert_eq!(lr.inputs.front(), Some(&2.0));
291    }
292
293    #[rstest]
294    fn test_recent_elements_preserved() {
295        let mut lr = LinearRegression::new(5);
296        for v in 0..5 {
297            lr.update_raw(f64::from(v));
298        }
299        lr.update_raw(99.0);
300        let expected = vec![1.0, 2.0, 3.0, 4.0, 99.0];
301        assert_eq!(lr.inputs.iter().copied().collect::<Vec<_>>(), expected);
302    }
303
304    #[rstest]
305    fn test_multiple_evictions() {
306        let mut lr = LinearRegression::new(2);
307        lr.update_raw(10.0);
308        lr.update_raw(20.0);
309        lr.update_raw(30.0);
310        lr.update_raw(40.0);
311        assert_eq!(
312            lr.inputs.iter().copied().collect::<Vec<_>>(),
313            vec![30.0, 40.0]
314        );
315    }
316
317    #[rstest]
318    fn test_value_stable_after_eviction() {
319        let mut lr = LinearRegression::new(3);
320        lr.update_raw(1.0);
321        lr.update_raw(2.0);
322        lr.update_raw(3.0);
323        let before = lr.value;
324        lr.update_raw(4.0);
325        let after = lr.value;
326        assert!(after.is_finite());
327        assert_ne!(before, after);
328    }
329
330    #[rstest]
331    fn test_value_with_ten_inputs(mut indicator_lr_10: LinearRegression) {
332        indicator_lr_10.update_raw(1.00000);
333        indicator_lr_10.update_raw(1.00010);
334        indicator_lr_10.update_raw(1.00030);
335        indicator_lr_10.update_raw(1.00040);
336        indicator_lr_10.update_raw(1.00050);
337        indicator_lr_10.update_raw(1.00060);
338        indicator_lr_10.update_raw(1.00050);
339        indicator_lr_10.update_raw(1.00040);
340        indicator_lr_10.update_raw(1.00030);
341        indicator_lr_10.update_raw(1.00010);
342        indicator_lr_10.update_raw(1.00000);
343
344        assert!((indicator_lr_10.value - 1.000_232_727_272_727_6).abs() < 1e-12);
345    }
346
347    #[rstest]
348    fn r2_nan_for_constant_series() {
349        let mut lr = LinearRegression::new(5);
350        for _ in 0..5 {
351            lr.update_raw(42.0);
352        }
353        assert!(lr.initialized);
354        assert!(
355            lr.r2.is_nan(),
356            "R² should be NaN for a constant-value input series"
357        );
358    }
359
360    #[rstest]
361    fn cfo_nan_when_last_price_zero() {
362        let mut lr = LinearRegression::new(3);
363        lr.update_raw(1.0);
364        lr.update_raw(2.0);
365        lr.update_raw(0.0);
366        assert!(lr.initialized);
367        assert!(
368            lr.cfo.is_nan(),
369            "CFO should be NaN when the most-recent price equals zero"
370        );
371    }
372
373    #[rstest]
374    fn positive_slope_and_degree_for_uptrend() {
375        let mut lr = LinearRegression::new(4);
376        for v in 1..=4 {
377            lr.update_raw(f64::from(v));
378        }
379        assert!(lr.slope > 0.0, "slope expected positive for up-trend");
380        assert!(lr.degree > 0.0, "degree expected positive for up-trend");
381    }
382
383    #[rstest]
384    fn negative_slope_and_degree_for_downtrend() {
385        let mut lr = LinearRegression::new(4);
386        for v in (1..=4).rev() {
387            lr.update_raw(f64::from(v));
388        }
389        assert!(lr.slope < 0.0, "slope expected negative for down-trend");
390        assert!(lr.degree < 0.0, "degree expected negative for down-trend");
391    }
392
393    #[rstest]
394    fn not_initialized_until_enough_samples() {
395        let mut lr = LinearRegression::new(6);
396        for v in 0..5 {
397            lr.update_raw(f64::from(v));
398        }
399        assert!(
400            !lr.initialized,
401            "indicator should remain uninitialised with fewer than `period` inputs"
402        );
403    }
404
405    #[rstest]
406    #[case(128)]
407    #[case(1_024)]
408    #[case(16_384)]
409    fn large_period_initialisation_and_window_size(#[case] period: usize) {
410        let mut lr = LinearRegression::new(period);
411        for v in 0..period {
412            lr.update_raw(v as f64);
413        }
414        assert!(
415            lr.initialized,
416            "indicator should initialise after exactly `period` samples"
417        );
418        assert_eq!(
419            lr.inputs.len(),
420            period,
421            "internal window length must equal the configured period"
422        );
423    }
424
425    #[rstest]
426    fn cached_constants_correct() {
427        let period = 10;
428        let lr = LinearRegression::new(period);
429
430        let n = period as f64;
431        let expected_x_sum = 0.5 * n * (n + 1.0);
432        let expected_x_mul_sum = expected_x_sum * 2.0f64.mul_add(n, 1.0) / 3.0;
433        let expected_divisor = n.mul_add(expected_x_mul_sum, -(expected_x_sum * expected_x_sum));
434
435        assert!((lr.x_sum - expected_x_sum).abs() < 1e-12, "x_sum mismatch");
436        assert!(
437            (lr.x_mul_sum - expected_x_mul_sum).abs() < 1e-12,
438            "x_mul_sum mismatch"
439        );
440        assert!(
441            (lr.divisor - expected_divisor).abs() < 1e-12,
442            "divisor mismatch"
443        );
444    }
445
446    #[rstest]
447    fn cached_constants_immutable_through_updates() {
448        let mut lr = LinearRegression::new(5);
449
450        let (x_sum, x_mul_sum, divisor) = (lr.x_sum, lr.x_mul_sum, lr.divisor);
451
452        for v in 0..20 {
453            lr.update_raw(f64::from(v));
454        }
455
456        assert_eq!(lr.x_sum, x_sum, "x_sum must remain unchanged after updates");
457        assert_eq!(
458            lr.x_mul_sum, x_mul_sum,
459            "x_mul_sum must remain unchanged after updates"
460        );
461        assert_eq!(
462            lr.divisor, divisor,
463            "divisor must remain unchanged after updates"
464        );
465    }
466
467    #[rstest]
468    fn cached_constants_immutable_after_reset() {
469        let mut lr = LinearRegression::new(8);
470
471        let (x_sum, x_mul_sum, divisor) = (lr.x_sum, lr.x_mul_sum, lr.divisor);
472
473        for v in 0..8 {
474            lr.update_raw(f64::from(v));
475        }
476        lr.reset();
477
478        assert_eq!(lr.x_sum, x_sum, "x_sum must survive reset()");
479        assert_eq!(lr.x_mul_sum, x_mul_sum, "x_mul_sum must survive reset()");
480        assert_eq!(lr.divisor, divisor, "divisor must survive reset()");
481    }
482
483    const EPS: f64 = 1e-12;
484
485    #[rstest]
486    #[should_panic]
487    fn new_zero_period_panics() {
488        let _ = LinearRegression::new(0);
489    }
490
491    #[rstest]
492    #[should_panic]
493    fn new_period_exceeds_max_panics() {
494        let _ = LinearRegression::new(MAX_PERIOD + 1);
495    }
496
497    #[rstest(
498        period, value,
499        case(8, 5.0),
500        case(16, -std::f64::consts::PI)
501    )]
502    fn constant_non_zero_series(period: usize, value: f64) {
503        let mut lr = LinearRegression::new(period);
504
505        for _ in 0..period {
506            lr.update_raw(value);
507        }
508
509        assert!(lr.initialized());
510        assert!(lr.slope.abs() < EPS);
511        assert!((lr.intercept - value).abs() < EPS);
512        assert!(lr.degree.abs() < EPS);
513        assert!(lr.r2.is_nan());
514        assert!((lr.cfo).abs() < EPS);
515        assert!((lr.value - value).abs() < EPS);
516    }
517
518    #[rstest(period, case(4), case(32))]
519    fn constant_zero_series_cfo_nan(period: usize) {
520        let mut lr = LinearRegression::new(period);
521
522        for _ in 0..period {
523            lr.update_raw(0.0);
524        }
525
526        assert!(lr.initialized());
527        assert!(lr.cfo.is_nan());
528    }
529
530    #[rstest(period, case(6), case(13))]
531    fn reset_clears_state_but_keeps_constants(period: usize) {
532        let mut lr = LinearRegression::new(period);
533
534        for i in 1..=period {
535            lr.update_raw(i as f64);
536        }
537
538        let x_sum_before = lr.x_sum;
539        let x_mul_sum_before = lr.x_mul_sum;
540        let divisor_before = lr.divisor;
541
542        lr.reset();
543
544        assert!(!lr.initialized());
545        assert!(!lr.has_inputs());
546
547        assert!(lr.slope.abs() < EPS);
548        assert!(lr.intercept.abs() < EPS);
549        assert!(lr.degree.abs() < EPS);
550        assert!(lr.cfo.abs() < EPS);
551        assert!(lr.r2.abs() < EPS);
552        assert!(lr.value.abs() < EPS);
553
554        assert_eq!(lr.x_sum, x_sum_before);
555        assert_eq!(lr.x_mul_sum, x_mul_sum_before);
556        assert_eq!(lr.divisor, divisor_before);
557    }
558
559    #[rstest(period, case(5), case(31))]
560    fn perfect_linear_series(period: usize) {
561        const A: f64 = 2.0;
562        const B: f64 = -3.0;
563        let mut lr = LinearRegression::new(period);
564
565        for x in 1..=period {
566            lr.update_raw(A.mul_add(x as f64, B));
567        }
568
569        assert!(lr.initialized());
570        assert!((lr.slope - A).abs() < EPS);
571        assert!((lr.intercept - B).abs() < EPS);
572        assert!((lr.r2 - 1.0).abs() < EPS);
573        assert!((lr.degree.to_radians().tan() - A).abs() < EPS);
574    }
575
576    #[rstest]
577    fn sliding_window_keeps_last_period() {
578        const P: usize = 4;
579        let mut lr = LinearRegression::new(P);
580        for i in 1..=P {
581            lr.update_raw(i as f64);
582        }
583        let slope_first_window = lr.slope;
584
585        lr.update_raw(-100.0);
586        assert!(lr.slope < slope_first_window);
587        assert_eq!(lr.inputs.len(), P);
588        assert_eq!(lr.inputs.front(), Some(&2.0));
589    }
590
591    #[rstest]
592    fn r2_between_zero_and_one() {
593        const P: usize = 32;
594        let mut lr = LinearRegression::new(P);
595        for x in 1..=P {
596            let noise = if x.is_multiple_of(2) { 0.5 } else { -0.5 };
597            lr.update_raw(3.0f64.mul_add(x as f64, noise));
598        }
599        assert!(lr.r2 > 0.0 && lr.r2 < 1.0);
600    }
601
602    #[rstest]
603    fn reset_before_initialized() {
604        let mut lr = LinearRegression::new(10);
605        lr.update_raw(1.0);
606        lr.reset();
607
608        assert!(!lr.initialized());
609        assert!(!lr.has_inputs());
610        assert_eq!(lr.inputs.len(), 0);
611    }
612}