nautilus_indicators/average/
lr.rs1use std::fmt::{Debug, Display};
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::data::Bar;
20
21use crate::indicator::Indicator;
22
23const MAX_PERIOD: usize = 16_384;
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28 feature = "python",
29 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
30)]
31#[cfg_attr(
32 feature = "python",
33 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
34)]
35pub struct LinearRegression {
36 pub period: usize,
37 pub slope: f64,
38 pub intercept: f64,
39 pub degree: f64,
40 pub cfo: f64,
41 pub r2: f64,
42 pub value: f64,
43 pub initialized: bool,
44 has_inputs: bool,
45 inputs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
46 x_sum: f64,
47 x_mul_sum: f64,
48 divisor: f64,
49}
50
51impl Display for LinearRegression {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 write!(f, "{}({})", self.name(), self.period)
54 }
55}
56
57impl Indicator for LinearRegression {
58 fn name(&self) -> String {
59 stringify!(LinearRegression).into()
60 }
61
62 fn has_inputs(&self) -> bool {
63 self.has_inputs
64 }
65
66 fn initialized(&self) -> bool {
67 self.initialized
68 }
69
70 fn handle_bar(&mut self, bar: &Bar) {
71 self.update_raw(bar.close.into());
72 }
73
74 fn reset(&mut self) {
75 self.slope = 0.0;
76 self.intercept = 0.0;
77 self.degree = 0.0;
78 self.cfo = 0.0;
79 self.r2 = 0.0;
80 self.value = 0.0;
81 self.inputs.clear();
82 self.has_inputs = false;
83 self.initialized = false;
84 }
85}
86
87impl LinearRegression {
88 #[must_use]
96 pub fn new(period: usize) -> Self {
97 assert!(
98 period > 0,
99 "LinearRegression: period must be > 0 (received {period})"
100 );
101 assert!(
102 period <= MAX_PERIOD,
103 "LinearRegression: period {period} exceeds MAX_PERIOD ({MAX_PERIOD})"
104 );
105
106 let n = period as f64;
107 let x_sum = 0.5 * n * (n + 1.0);
108 let x_mul_sum = x_sum * 2.0f64.mul_add(n, 1.0) / 3.0;
109 let divisor = n.mul_add(x_mul_sum, -(x_sum * x_sum));
110
111 Self {
112 period,
113 slope: 0.0,
114 intercept: 0.0,
115 degree: 0.0,
116 cfo: 0.0,
117 r2: 0.0,
118 value: 0.0,
119 initialized: false,
120 has_inputs: false,
121 inputs: ArrayDeque::new(),
122 x_sum,
123 x_mul_sum,
124 divisor,
125 }
126 }
127
128 pub fn update_raw(&mut self, close: f64) {
135 if self.inputs.len() == self.period {
136 let _ = self.inputs.pop_front();
137 }
138 let _ = self.inputs.push_back(close);
139
140 self.has_inputs = true;
141
142 if self.inputs.len() < self.period {
143 return;
144 }
145 self.initialized = true;
146
147 let n = self.period as f64;
148 let x_sum = self.x_sum;
149 let x_mul_sum = self.x_mul_sum;
150 let divisor = self.divisor;
151
152 let (mut y_sum, mut xy_sum) = (0.0, 0.0);
153
154 for (i, &y) in self.inputs.iter().enumerate() {
155 let x = (i + 1) as f64;
156 y_sum += y;
157 xy_sum += x * y;
158 }
159
160 self.slope = n.mul_add(xy_sum, -(x_sum * y_sum)) / divisor;
161 self.intercept = y_sum.mul_add(x_mul_sum, -(x_sum * xy_sum)) / divisor;
162
163 let (mut sse, mut y_last, mut e_last) = (0.0, 0.0, 0.0);
164
165 for (i, &y) in self.inputs.iter().enumerate() {
166 let x = (i + 1) as f64;
167 let y_hat = self.slope.mul_add(x, self.intercept);
168 let resid = y_hat - y;
169 sse += resid * resid;
170 y_last = y;
171 e_last = resid;
172 }
173
174 self.value = y_last + e_last;
175 self.degree = self.slope.atan().to_degrees();
176 self.cfo = if y_last == 0.0 {
177 f64::NAN
178 } else {
179 100.0 * e_last / y_last
180 };
181
182 let mean = y_sum / n;
183 let sst: f64 = self
184 .inputs
185 .iter()
186 .map(|&y| {
187 let d = y - mean;
188 d * d
189 })
190 .sum();
191
192 self.r2 = if sst.abs() < f64::EPSILON {
193 f64::NAN
194 } else {
195 1.0 - sse / sst
196 };
197 }
198}
199
200#[cfg(test)]
201mod tests {
202 use nautilus_model::data::Bar;
203 use rstest::rstest;
204
205 use super::*;
206 use crate::{
207 average::lr::LinearRegression,
208 indicator::Indicator,
209 stubs::{bar_ethusdt_binance_minute_bid, indicator_lr_10},
210 };
211
212 #[rstest]
213 fn test_psl_initialized(indicator_lr_10: LinearRegression) {
214 let display_str = format!("{indicator_lr_10}");
215 assert_eq!(display_str, "LinearRegression(10)");
216 assert_eq!(indicator_lr_10.period, 10);
217 assert!(!indicator_lr_10.initialized);
218 assert!(!indicator_lr_10.has_inputs);
219 }
220
221 #[rstest]
222 #[should_panic(expected = "LinearRegression: period must be > 0")]
223 fn test_new_with_zero_period_panics() {
224 let _ = LinearRegression::new(0);
225 }
226
227 #[rstest]
228 fn test_value_with_one_input(mut indicator_lr_10: LinearRegression) {
229 indicator_lr_10.update_raw(1.0);
230 assert_eq!(indicator_lr_10.value, 0.0);
231 }
232
233 #[rstest]
234 fn test_value_with_three_inputs(mut indicator_lr_10: LinearRegression) {
235 indicator_lr_10.update_raw(1.0);
236 indicator_lr_10.update_raw(2.0);
237 indicator_lr_10.update_raw(3.0);
238 assert_eq!(indicator_lr_10.value, 0.0);
239 }
240
241 #[rstest]
242 fn test_initialized_with_required_input(mut indicator_lr_10: LinearRegression) {
243 for i in 1..10 {
244 indicator_lr_10.update_raw(f64::from(i));
245 }
246 assert!(!indicator_lr_10.initialized);
247 indicator_lr_10.update_raw(10.0);
248 assert!(indicator_lr_10.initialized);
249 }
250
251 #[rstest]
252 fn test_handle_bar(mut indicator_lr_10: LinearRegression, bar_ethusdt_binance_minute_bid: Bar) {
253 indicator_lr_10.handle_bar(&bar_ethusdt_binance_minute_bid);
254 assert_eq!(indicator_lr_10.value, 0.0);
255 assert!(indicator_lr_10.has_inputs);
256 assert!(!indicator_lr_10.initialized);
257 }
258
259 #[rstest]
260 fn test_reset(mut indicator_lr_10: LinearRegression) {
261 indicator_lr_10.update_raw(1.0);
262 indicator_lr_10.reset();
263 assert_eq!(indicator_lr_10.value, 0.0);
264 assert_eq!(indicator_lr_10.inputs.len(), 0);
265 assert_eq!(indicator_lr_10.slope, 0.0);
266 assert_eq!(indicator_lr_10.intercept, 0.0);
267 assert_eq!(indicator_lr_10.degree, 0.0);
268 assert_eq!(indicator_lr_10.cfo, 0.0);
269 assert_eq!(indicator_lr_10.r2, 0.0);
270 assert!(!indicator_lr_10.has_inputs);
271 assert!(!indicator_lr_10.initialized);
272 }
273
274 #[rstest]
275 fn test_inputs_len_never_exceeds_period() {
276 let mut lr = LinearRegression::new(3);
277 for i in 0..10 {
278 lr.update_raw(f64::from(i));
279 }
280 assert_eq!(lr.inputs.len(), lr.period);
281 }
282
283 #[rstest]
284 fn test_oldest_element_evicted() {
285 let mut lr = LinearRegression::new(4);
286 for v in 1..=5 {
287 lr.update_raw(f64::from(v));
288 }
289 assert!(!lr.inputs.contains(&1.0));
290 assert_eq!(lr.inputs.front(), Some(&2.0));
291 }
292
293 #[rstest]
294 fn test_recent_elements_preserved() {
295 let mut lr = LinearRegression::new(5);
296 for v in 0..5 {
297 lr.update_raw(f64::from(v));
298 }
299 lr.update_raw(99.0);
300 let expected = vec![1.0, 2.0, 3.0, 4.0, 99.0];
301 assert_eq!(lr.inputs.iter().copied().collect::<Vec<_>>(), expected);
302 }
303
304 #[rstest]
305 fn test_multiple_evictions() {
306 let mut lr = LinearRegression::new(2);
307 lr.update_raw(10.0);
308 lr.update_raw(20.0);
309 lr.update_raw(30.0);
310 lr.update_raw(40.0);
311 assert_eq!(
312 lr.inputs.iter().copied().collect::<Vec<_>>(),
313 vec![30.0, 40.0]
314 );
315 }
316
317 #[rstest]
318 fn test_value_stable_after_eviction() {
319 let mut lr = LinearRegression::new(3);
320 lr.update_raw(1.0);
321 lr.update_raw(2.0);
322 lr.update_raw(3.0);
323 let before = lr.value;
324 lr.update_raw(4.0);
325 let after = lr.value;
326 assert!(after.is_finite());
327 assert_ne!(before, after);
328 }
329
330 #[rstest]
331 fn test_value_with_ten_inputs(mut indicator_lr_10: LinearRegression) {
332 indicator_lr_10.update_raw(1.00000);
333 indicator_lr_10.update_raw(1.00010);
334 indicator_lr_10.update_raw(1.00030);
335 indicator_lr_10.update_raw(1.00040);
336 indicator_lr_10.update_raw(1.00050);
337 indicator_lr_10.update_raw(1.00060);
338 indicator_lr_10.update_raw(1.00050);
339 indicator_lr_10.update_raw(1.00040);
340 indicator_lr_10.update_raw(1.00030);
341 indicator_lr_10.update_raw(1.00010);
342 indicator_lr_10.update_raw(1.00000);
343
344 assert!((indicator_lr_10.value - 1.000_232_727_272_727_6).abs() < 1e-12);
345 }
346
347 #[rstest]
348 fn r2_nan_for_constant_series() {
349 let mut lr = LinearRegression::new(5);
350 for _ in 0..5 {
351 lr.update_raw(42.0);
352 }
353 assert!(lr.initialized);
354 assert!(
355 lr.r2.is_nan(),
356 "R² should be NaN for a constant-value input series"
357 );
358 }
359
360 #[rstest]
361 fn cfo_nan_when_last_price_zero() {
362 let mut lr = LinearRegression::new(3);
363 lr.update_raw(1.0);
364 lr.update_raw(2.0);
365 lr.update_raw(0.0);
366 assert!(lr.initialized);
367 assert!(
368 lr.cfo.is_nan(),
369 "CFO should be NaN when the most-recent price equals zero"
370 );
371 }
372
373 #[rstest]
374 fn positive_slope_and_degree_for_uptrend() {
375 let mut lr = LinearRegression::new(4);
376 for v in 1..=4 {
377 lr.update_raw(f64::from(v));
378 }
379 assert!(lr.slope > 0.0, "slope expected positive for up-trend");
380 assert!(lr.degree > 0.0, "degree expected positive for up-trend");
381 }
382
383 #[rstest]
384 fn negative_slope_and_degree_for_downtrend() {
385 let mut lr = LinearRegression::new(4);
386 for v in (1..=4).rev() {
387 lr.update_raw(f64::from(v));
388 }
389 assert!(lr.slope < 0.0, "slope expected negative for down-trend");
390 assert!(lr.degree < 0.0, "degree expected negative for down-trend");
391 }
392
393 #[rstest]
394 fn not_initialized_until_enough_samples() {
395 let mut lr = LinearRegression::new(6);
396 for v in 0..5 {
397 lr.update_raw(f64::from(v));
398 }
399 assert!(
400 !lr.initialized,
401 "indicator should remain uninitialised with fewer than `period` inputs"
402 );
403 }
404
405 #[rstest]
406 #[case(128)]
407 #[case(1_024)]
408 #[case(16_384)]
409 fn large_period_initialisation_and_window_size(#[case] period: usize) {
410 let mut lr = LinearRegression::new(period);
411 for v in 0..period {
412 lr.update_raw(v as f64);
413 }
414 assert!(
415 lr.initialized,
416 "indicator should initialise after exactly `period` samples"
417 );
418 assert_eq!(
419 lr.inputs.len(),
420 period,
421 "internal window length must equal the configured period"
422 );
423 }
424
425 #[rstest]
426 fn cached_constants_correct() {
427 let period = 10;
428 let lr = LinearRegression::new(period);
429
430 let n = period as f64;
431 let expected_x_sum = 0.5 * n * (n + 1.0);
432 let expected_x_mul_sum = expected_x_sum * 2.0f64.mul_add(n, 1.0) / 3.0;
433 let expected_divisor = n.mul_add(expected_x_mul_sum, -(expected_x_sum * expected_x_sum));
434
435 assert!((lr.x_sum - expected_x_sum).abs() < 1e-12, "x_sum mismatch");
436 assert!(
437 (lr.x_mul_sum - expected_x_mul_sum).abs() < 1e-12,
438 "x_mul_sum mismatch"
439 );
440 assert!(
441 (lr.divisor - expected_divisor).abs() < 1e-12,
442 "divisor mismatch"
443 );
444 }
445
446 #[rstest]
447 fn cached_constants_immutable_through_updates() {
448 let mut lr = LinearRegression::new(5);
449
450 let (x_sum, x_mul_sum, divisor) = (lr.x_sum, lr.x_mul_sum, lr.divisor);
451
452 for v in 0..20 {
453 lr.update_raw(f64::from(v));
454 }
455
456 assert_eq!(lr.x_sum, x_sum, "x_sum must remain unchanged after updates");
457 assert_eq!(
458 lr.x_mul_sum, x_mul_sum,
459 "x_mul_sum must remain unchanged after updates"
460 );
461 assert_eq!(
462 lr.divisor, divisor,
463 "divisor must remain unchanged after updates"
464 );
465 }
466
467 #[rstest]
468 fn cached_constants_immutable_after_reset() {
469 let mut lr = LinearRegression::new(8);
470
471 let (x_sum, x_mul_sum, divisor) = (lr.x_sum, lr.x_mul_sum, lr.divisor);
472
473 for v in 0..8 {
474 lr.update_raw(f64::from(v));
475 }
476 lr.reset();
477
478 assert_eq!(lr.x_sum, x_sum, "x_sum must survive reset()");
479 assert_eq!(lr.x_mul_sum, x_mul_sum, "x_mul_sum must survive reset()");
480 assert_eq!(lr.divisor, divisor, "divisor must survive reset()");
481 }
482
483 const EPS: f64 = 1e-12;
484
485 #[rstest]
486 #[should_panic]
487 fn new_zero_period_panics() {
488 let _ = LinearRegression::new(0);
489 }
490
491 #[rstest]
492 #[should_panic]
493 fn new_period_exceeds_max_panics() {
494 let _ = LinearRegression::new(MAX_PERIOD + 1);
495 }
496
497 #[rstest(
498 period, value,
499 case(8, 5.0),
500 case(16, -std::f64::consts::PI)
501 )]
502 fn constant_non_zero_series(period: usize, value: f64) {
503 let mut lr = LinearRegression::new(period);
504
505 for _ in 0..period {
506 lr.update_raw(value);
507 }
508
509 assert!(lr.initialized());
510 assert!(lr.slope.abs() < EPS);
511 assert!((lr.intercept - value).abs() < EPS);
512 assert!(lr.degree.abs() < EPS);
513 assert!(lr.r2.is_nan());
514 assert!((lr.cfo).abs() < EPS);
515 assert!((lr.value - value).abs() < EPS);
516 }
517
518 #[rstest(period, case(4), case(32))]
519 fn constant_zero_series_cfo_nan(period: usize) {
520 let mut lr = LinearRegression::new(period);
521
522 for _ in 0..period {
523 lr.update_raw(0.0);
524 }
525
526 assert!(lr.initialized());
527 assert!(lr.cfo.is_nan());
528 }
529
530 #[rstest(period, case(6), case(13))]
531 fn reset_clears_state_but_keeps_constants(period: usize) {
532 let mut lr = LinearRegression::new(period);
533
534 for i in 1..=period {
535 lr.update_raw(i as f64);
536 }
537
538 let x_sum_before = lr.x_sum;
539 let x_mul_sum_before = lr.x_mul_sum;
540 let divisor_before = lr.divisor;
541
542 lr.reset();
543
544 assert!(!lr.initialized());
545 assert!(!lr.has_inputs());
546
547 assert!(lr.slope.abs() < EPS);
548 assert!(lr.intercept.abs() < EPS);
549 assert!(lr.degree.abs() < EPS);
550 assert!(lr.cfo.abs() < EPS);
551 assert!(lr.r2.abs() < EPS);
552 assert!(lr.value.abs() < EPS);
553
554 assert_eq!(lr.x_sum, x_sum_before);
555 assert_eq!(lr.x_mul_sum, x_mul_sum_before);
556 assert_eq!(lr.divisor, divisor_before);
557 }
558
559 #[rstest(period, case(5), case(31))]
560 fn perfect_linear_series(period: usize) {
561 const A: f64 = 2.0;
562 const B: f64 = -3.0;
563 let mut lr = LinearRegression::new(period);
564
565 for x in 1..=period {
566 lr.update_raw(A.mul_add(x as f64, B));
567 }
568
569 assert!(lr.initialized());
570 assert!((lr.slope - A).abs() < EPS);
571 assert!((lr.intercept - B).abs() < EPS);
572 assert!((lr.r2 - 1.0).abs() < EPS);
573 assert!((lr.degree.to_radians().tan() - A).abs() < EPS);
574 }
575
576 #[rstest]
577 fn sliding_window_keeps_last_period() {
578 const P: usize = 4;
579 let mut lr = LinearRegression::new(P);
580 for i in 1..=P {
581 lr.update_raw(i as f64);
582 }
583 let slope_first_window = lr.slope;
584
585 lr.update_raw(-100.0);
586 assert!(lr.slope < slope_first_window);
587 assert_eq!(lr.inputs.len(), P);
588 assert_eq!(lr.inputs.front(), Some(&2.0));
589 }
590
591 #[rstest]
592 fn r2_between_zero_and_one() {
593 const P: usize = 32;
594 let mut lr = LinearRegression::new(P);
595 for x in 1..=P {
596 let noise = if x.is_multiple_of(2) { 0.5 } else { -0.5 };
597 lr.update_raw(3.0f64.mul_add(x as f64, noise));
598 }
599 assert!(lr.r2 > 0.0 && lr.r2 < 1.0);
600 }
601
602 #[rstest]
603 fn reset_before_initialized() {
604 let mut lr = LinearRegression::new(10);
605 lr.update_raw(1.0);
606 lr.reset();
607
608 assert!(!lr.initialized());
609 assert!(!lr.has_inputs());
610 assert_eq!(lr.inputs.len(), 0);
611 }
612}