1use std::fmt::{Debug, Display};
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::data::Bar;
20use strum::{AsRefStr, Display as StrumDisplay, EnumIter, EnumString, FromRepr};
21
22use crate::{
23 average::{MovingAverageFactory, MovingAverageType},
24 indicator::{Indicator, MovingAverage},
25};
26
27const MAX_PERIOD: usize = 1_024;
28
29#[repr(C)]
39#[derive(
40 Copy,
41 Clone,
42 Debug,
43 Default,
44 Hash,
45 PartialEq,
46 Eq,
47 PartialOrd,
48 Ord,
49 AsRefStr,
50 FromRepr,
51 EnumIter,
52 EnumString,
53 StrumDisplay,
54)]
55#[strum(ascii_case_insensitive)]
56#[strum(serialize_all = "SCREAMING_SNAKE_CASE")]
57#[cfg_attr(
58 feature = "python",
59 pyo3::pyclass(
60 frozen,
61 eq,
62 eq_int,
63 hash,
64 module = "nautilus_trader.core.nautilus_pyo3.indicators",
65 from_py_object,
66 )
67)]
68#[cfg_attr(
69 feature = "python",
70 pyo3_stub_gen::derive::gen_stub_pyclass_enum(module = "nautilus_trader.indicators")
71)]
72pub enum StochasticsDMethod {
73 #[default]
76 Ratio,
77 MovingAverage,
80}
81
82#[repr(C)]
83#[cfg_attr(
84 feature = "python",
85 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
86)]
87#[cfg_attr(
88 feature = "python",
89 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
90)]
91pub struct Stochastics {
92 pub period_k: usize,
94 pub period_d: usize,
96 pub slowing: usize,
98 pub ma_type: MovingAverageType,
100 pub d_method: StochasticsDMethod,
102 pub value_k: f64,
104 pub value_d: f64,
106 pub initialized: bool,
108 has_inputs: bool,
109 highs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
110 lows: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
111 c_sub_1: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
112 h_sub_l: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
113 slowing_ma: Option<Box<dyn MovingAverage + Send + Sync>>,
115 d_ma: Option<Box<dyn MovingAverage + Send + Sync>>,
117}
118
119impl Debug for Stochastics {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 f.debug_struct(stringify!(Stochastics))
122 .field("period_k", &self.period_k)
123 .field("period_d", &self.period_d)
124 .field("slowing", &self.slowing)
125 .field("ma_type", &self.ma_type)
126 .field("d_method", &self.d_method)
127 .field("value_k", &self.value_k)
128 .field("value_d", &self.value_d)
129 .field("initialized", &self.initialized)
130 .field("has_inputs", &self.has_inputs)
131 .field(
132 "slowing_ma",
133 &self.slowing_ma.as_ref().map(|_| "MovingAverage"),
134 )
135 .field("d_ma", &self.d_ma.as_ref().map(|_| "MovingAverage"))
136 .finish()
137 }
138}
139
140impl Display for Stochastics {
141 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
142 write!(f, "{}({},{})", self.name(), self.period_k, self.period_d,)
143 }
144}
145
146impl Indicator for Stochastics {
147 fn name(&self) -> String {
148 stringify!(Stochastics).to_string()
149 }
150
151 fn has_inputs(&self) -> bool {
152 self.has_inputs
153 }
154
155 fn initialized(&self) -> bool {
156 self.initialized
157 }
158
159 fn handle_bar(&mut self, bar: &Bar) {
160 self.update_raw((&bar.high).into(), (&bar.low).into(), (&bar.close).into());
161 }
162
163 fn reset(&mut self) {
164 self.highs.clear();
165 self.lows.clear();
166 self.c_sub_1.clear();
167 self.h_sub_l.clear();
168 self.value_k = 0.0;
169 self.value_d = 0.0;
170 self.has_inputs = false;
171 self.initialized = false;
172
173 if let Some(ref mut ma) = self.slowing_ma {
175 ma.reset();
176 }
177
178 if let Some(ref mut ma) = self.d_ma {
180 ma.reset();
181 }
182 }
183}
184
185impl Stochastics {
186 #[must_use]
199 pub fn new(period_k: usize, period_d: usize) -> Self {
200 Self::new_with_params(
201 period_k,
202 period_d,
203 1, MovingAverageType::Exponential, StochasticsDMethod::Ratio, )
207 }
208
209 #[must_use]
224 pub fn new_with_params(
225 period_k: usize,
226 period_d: usize,
227 slowing: usize,
228 ma_type: MovingAverageType,
229 d_method: StochasticsDMethod,
230 ) -> Self {
231 assert!(
232 period_k > 0 && period_k <= MAX_PERIOD,
233 "Stochastics: period_k {period_k} exceeds bounds (1..={MAX_PERIOD})"
234 );
235 assert!(
236 period_d > 0 && period_d <= MAX_PERIOD,
237 "Stochastics: period_d {period_d} exceeds bounds (1..={MAX_PERIOD})"
238 );
239 assert!(
240 slowing > 0 && slowing <= MAX_PERIOD,
241 "Stochastics: slowing {slowing} exceeds bounds (1..={MAX_PERIOD})"
242 );
243
244 let slowing_ma = if slowing > 1 {
246 Some(MovingAverageFactory::create(ma_type, slowing))
247 } else {
248 None
249 };
250
251 let d_ma = match d_method {
253 StochasticsDMethod::MovingAverage => {
254 Some(MovingAverageFactory::create(ma_type, period_d))
255 }
256 StochasticsDMethod::Ratio => None,
257 };
258
259 Self {
260 period_k,
261 period_d,
262 slowing,
263 ma_type,
264 d_method,
265 has_inputs: false,
266 initialized: false,
267 value_k: 0.0,
268 value_d: 0.0,
269 highs: ArrayDeque::new(),
270 lows: ArrayDeque::new(),
271 h_sub_l: ArrayDeque::new(),
272 c_sub_1: ArrayDeque::new(),
273 slowing_ma,
274 d_ma,
275 }
276 }
277
278 pub fn update_raw(&mut self, high: f64, low: f64, close: f64) {
286 if !self.has_inputs {
287 self.has_inputs = true;
288 }
289
290 if self.highs.len() == self.period_k {
292 self.highs.pop_front();
293 self.lows.pop_front();
294 }
295 let _ = self.highs.push_back(high);
296 let _ = self.lows.push_back(low);
297
298 if !self.initialized
300 && self.highs.len() == self.period_k
301 && self.lows.len() == self.period_k
302 {
303 if self.slowing_ma.is_none() && self.d_method == StochasticsDMethod::Ratio {
306 self.initialized = true;
307 }
308 }
309
310 let k_max_high = self.highs.iter().copied().fold(f64::NEG_INFINITY, f64::max);
312 let k_min_low = self.lows.iter().copied().fold(f64::INFINITY, f64::min);
313
314 if self.d_method == StochasticsDMethod::Ratio {
316 if self.c_sub_1.len() == self.period_d {
317 self.c_sub_1.pop_front();
318 self.h_sub_l.pop_front();
319 }
320 let _ = self.c_sub_1.push_back(close - k_min_low);
321 let _ = self.h_sub_l.push_back(k_max_high - k_min_low);
322 }
323
324 if k_max_high == k_min_low {
326 return;
327 }
328
329 let raw_k = 100.0 * ((close - k_min_low) / (k_max_high - k_min_low));
331
332 let slowed_k = match &mut self.slowing_ma {
334 Some(ma) => {
335 ma.update_raw(raw_k);
336 ma.value()
337 }
338 None => raw_k, };
340 self.value_k = slowed_k;
341
342 self.value_d = match self.d_method {
344 StochasticsDMethod::Ratio => {
345 let sum_h_sub_l: f64 = self.h_sub_l.iter().sum();
348 if sum_h_sub_l == 0.0 {
349 0.0
350 } else {
351 100.0 * (self.c_sub_1.iter().sum::<f64>() / sum_h_sub_l)
352 }
353 }
354 StochasticsDMethod::MovingAverage => {
355 if let Some(ref mut ma) = self.d_ma {
357 ma.update_raw(slowed_k);
358 ma.value()
359 } else {
360 50.0 }
362 }
363 };
364
365 if !self.initialized {
369 let base_ready = self.highs.len() == self.period_k;
370 let slowing_ready = match &self.slowing_ma {
371 Some(ma) => ma.initialized(),
372 None => true,
373 };
374 let d_ready = match self.d_method {
375 StochasticsDMethod::Ratio => true, StochasticsDMethod::MovingAverage => match &self.d_ma {
377 Some(ma) => ma.initialized(),
378 None => true,
379 },
380 };
381
382 if base_ready && slowing_ready && d_ready {
383 self.initialized = true;
384 }
385 }
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use nautilus_model::data::Bar;
392 use rstest::rstest;
393
394 use crate::{
395 average::MovingAverageType,
396 indicator::Indicator,
397 momentum::stochastics::{Stochastics, StochasticsDMethod},
398 stubs::{bar_ethusdt_binance_minute_bid, stochastics_10},
399 };
400
401 #[rstest]
402 fn test_stochastics_initialized(stochastics_10: Stochastics) {
403 let display_str = format!("{stochastics_10}");
404 assert_eq!(display_str, "Stochastics(10,10)");
405 assert_eq!(stochastics_10.period_d, 10);
406 assert_eq!(stochastics_10.period_k, 10);
407 assert!(!stochastics_10.initialized);
408 assert!(!stochastics_10.has_inputs);
409 }
410
411 #[rstest]
412 fn test_value_with_one_input(mut stochastics_10: Stochastics) {
413 stochastics_10.update_raw(1.0, 1.0, 1.0);
414 assert_eq!(stochastics_10.value_d, 0.0);
415 assert_eq!(stochastics_10.value_k, 0.0);
416 }
417
418 #[rstest]
419 fn test_value_with_three_inputs(mut stochastics_10: Stochastics) {
420 stochastics_10.update_raw(1.0, 1.0, 1.0);
421 stochastics_10.update_raw(2.0, 2.0, 2.0);
422 stochastics_10.update_raw(3.0, 3.0, 3.0);
423 assert_eq!(stochastics_10.value_d, 100.0);
424 assert_eq!(stochastics_10.value_k, 100.0);
425 }
426
427 #[rstest]
428 fn test_value_with_ten_inputs(mut stochastics_10: Stochastics) {
429 let high_values = [
430 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
431 ];
432 let low_values = [
433 0.9, 1.9, 2.9, 3.9, 4.9, 5.9, 6.9, 7.9, 8.9, 9.9, 10.1, 10.2, 10.3, 11.1, 11.4,
434 ];
435 let close_values = [
436 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0,
437 ];
438
439 for i in 0..15 {
440 stochastics_10.update_raw(high_values[i], low_values[i], close_values[i]);
441 }
442
443 assert!(stochastics_10.initialized());
444 assert_eq!(stochastics_10.value_d, 100.0);
445 assert_eq!(stochastics_10.value_k, 100.0);
446 }
447
448 #[rstest]
449 fn test_initialized_with_required_input(mut stochastics_10: Stochastics) {
450 for i in 1..10 {
451 stochastics_10.update_raw(f64::from(i), f64::from(i), f64::from(i));
452 }
453 assert!(!stochastics_10.initialized);
454 stochastics_10.update_raw(10.0, 12.0, 14.0);
455 assert!(stochastics_10.initialized);
456 }
457
458 #[rstest]
459 fn test_handle_bar(mut stochastics_10: Stochastics, bar_ethusdt_binance_minute_bid: Bar) {
460 stochastics_10.handle_bar(&bar_ethusdt_binance_minute_bid);
461 assert_eq!(stochastics_10.value_d, 49.090_909_090_909_09);
462 assert_eq!(stochastics_10.value_k, 49.090_909_090_909_09);
463 assert!(stochastics_10.has_inputs);
464 assert!(!stochastics_10.initialized);
465 }
466
467 #[rstest]
468 fn test_reset(mut stochastics_10: Stochastics) {
469 stochastics_10.update_raw(1.0, 1.0, 1.0);
470 assert_eq!(stochastics_10.c_sub_1.len(), 1);
471 assert_eq!(stochastics_10.h_sub_l.len(), 1);
472
473 stochastics_10.reset();
474 assert_eq!(stochastics_10.value_d, 0.0);
475 assert_eq!(stochastics_10.value_k, 0.0);
476 assert_eq!(stochastics_10.h_sub_l.len(), 0);
477 assert_eq!(stochastics_10.c_sub_1.len(), 0);
478 assert!(!stochastics_10.has_inputs);
479 assert!(!stochastics_10.initialized);
480 }
481
482 #[rstest]
483 fn test_new_defaults_slowing_1_ratio() {
484 let stoch = Stochastics::new(10, 3);
485 assert_eq!(stoch.period_k, 10);
486 assert_eq!(stoch.period_d, 3);
487 assert_eq!(stoch.slowing, 1);
488 assert_eq!(stoch.ma_type, MovingAverageType::Exponential);
489 assert_eq!(stoch.d_method, StochasticsDMethod::Ratio);
490 assert!(
491 stoch.slowing_ma.is_none(),
492 "slowing_ma should be None when slowing == 1"
493 );
494 assert!(
495 stoch.d_ma.is_none(),
496 "d_ma should be None when d_method == Ratio"
497 );
498 }
499
500 #[rstest]
501 fn test_new_with_params_accepts_all_params() {
502 let stoch = Stochastics::new_with_params(
503 11,
504 3,
505 3,
506 MovingAverageType::Exponential,
507 StochasticsDMethod::MovingAverage,
508 );
509 assert_eq!(stoch.period_k, 11);
510 assert_eq!(stoch.period_d, 3);
511 assert_eq!(stoch.slowing, 3);
512 assert_eq!(stoch.ma_type, MovingAverageType::Exponential);
513 assert_eq!(stoch.d_method, StochasticsDMethod::MovingAverage);
514 assert!(
515 stoch.slowing_ma.is_some(),
516 "slowing_ma should exist when slowing > 1"
517 );
518 assert!(
519 stoch.d_ma.is_some(),
520 "d_ma should exist when d_method == MovingAverage"
521 );
522 }
523
524 #[rstest]
525 fn test_backward_compatibility_identical_output() {
526 let mut stoch_old = Stochastics::new(10, 10);
528 let mut stoch_new = Stochastics::new_with_params(
529 10,
530 10,
531 1,
532 MovingAverageType::Exponential,
533 StochasticsDMethod::Ratio,
534 );
535
536 let high_values = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
538 let low_values = [0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5];
539 let close_values = [0.8, 1.8, 2.8, 3.8, 4.8, 5.8, 6.8, 7.8, 8.8, 9.8];
540
541 for i in 0..10 {
542 stoch_old.update_raw(high_values[i], low_values[i], close_values[i]);
543 stoch_new.update_raw(high_values[i], low_values[i], close_values[i]);
544 }
545
546 assert_eq!(stoch_old.value_k, stoch_new.value_k, "value_k mismatch");
548 assert_eq!(stoch_old.value_d, stoch_new.value_d, "value_d mismatch");
549 assert_eq!(stoch_old.initialized, stoch_new.initialized);
550 }
551
552 #[rstest]
553 fn test_slowing_3_smoothes_k() {
554 let mut stoch_no_slowing = Stochastics::new(5, 3);
555 let mut stoch_with_slowing = Stochastics::new_with_params(
556 5,
557 3,
558 3,
559 MovingAverageType::Exponential,
560 StochasticsDMethod::Ratio,
561 );
562
563 let data = [
565 (10.0, 5.0, 8.0),
566 (12.0, 6.0, 7.0),
567 (11.0, 4.0, 9.0),
568 (13.0, 7.0, 8.0),
569 (14.0, 8.0, 10.0),
570 (12.0, 6.0, 7.0),
571 (15.0, 9.0, 14.0),
572 (16.0, 10.0, 11.0),
573 ];
574
575 for (high, low, close) in data {
576 stoch_no_slowing.update_raw(high, low, close);
577 stoch_with_slowing.update_raw(high, low, close);
578 }
579
580 assert!(
584 (stoch_no_slowing.value_k - stoch_with_slowing.value_k).abs() > 0.01,
585 "Slowing should produce different %K values"
586 );
587 }
588
589 #[rstest]
590 #[case(MovingAverageType::Simple)]
591 #[case(MovingAverageType::Exponential)]
592 #[case(MovingAverageType::Wilder)]
593 #[case(MovingAverageType::Hull)]
594 fn test_slowing_with_different_ma_types(#[case] ma_type: MovingAverageType) {
595 let mut stoch = Stochastics::new_with_params(5, 3, 3, ma_type, StochasticsDMethod::Ratio);
596
597 for i in 1..=10 {
599 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
600 }
601
602 assert!(
603 stoch.value_k.is_finite(),
604 "value_k should be finite with {ma_type:?}"
605 );
606 assert!(
607 stoch.value_d.is_finite(),
608 "value_d should be finite with {ma_type:?}"
609 );
610 assert!(
611 stoch.value_k >= 0.0 && stoch.value_k <= 100.0,
612 "value_k out of range with {ma_type:?}"
613 );
614 }
615
616 #[rstest]
617 fn test_d_method_ratio_preserves_nautilus_behavior() {
618 let mut stoch = Stochastics::new_with_params(
619 10,
620 3,
621 1, MovingAverageType::Exponential,
623 StochasticsDMethod::Ratio,
624 );
625
626 for i in 1..=15 {
628 stoch.update_raw(f64::from(i), f64::from(i) - 0.1, f64::from(i));
629 }
630
631 assert!(stoch.initialized);
633 assert!(stoch.value_d > 0.0);
634 }
635
636 #[rstest]
637 fn test_d_method_ma_produces_smoothed_k() {
638 let mut stoch = Stochastics::new_with_params(
639 5,
640 3,
641 3, MovingAverageType::Exponential,
643 StochasticsDMethod::MovingAverage, );
645
646 let data = [
647 (10.0, 5.0, 8.0),
648 (12.0, 6.0, 7.0),
649 (11.0, 4.0, 9.0),
650 (13.0, 7.0, 8.0),
651 (14.0, 8.0, 10.0),
652 (12.0, 6.0, 7.0),
653 (15.0, 9.0, 14.0),
654 (16.0, 10.0, 11.0),
655 (14.0, 8.0, 12.0),
656 (13.0, 7.0, 10.0),
657 ];
658
659 for (high, low, close) in data {
660 stoch.update_raw(high, low, close);
661 }
662
663 assert!(stoch.value_d.is_finite());
665 assert!(stoch.value_d >= 0.0 && stoch.value_d <= 100.0);
666 }
667
668 #[rstest]
669 fn test_warmup_period_with_slowing() {
670 let mut stoch = Stochastics::new_with_params(
671 5,
672 3,
673 3, MovingAverageType::Exponential,
675 StochasticsDMethod::Ratio,
676 );
677
678 for i in 1..=4 {
685 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
686 assert!(!stoch.initialized, "Should not be initialized at bar {i}");
687 }
688
689 for i in 5..=15 {
691 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
692 }
693
694 assert!(
695 stoch.initialized,
696 "Should be initialized after sufficient bars"
697 );
698 }
699
700 #[rstest]
701 fn test_warmup_period_with_ma_d_method() {
702 let mut stoch = Stochastics::new_with_params(
703 5,
704 3,
705 3,
706 MovingAverageType::Exponential,
707 StochasticsDMethod::MovingAverage, );
709
710 for i in 1..=4 {
711 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
712 }
713 assert!(!stoch.initialized);
714
715 for i in 5..=20 {
717 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
718 }
719
720 assert!(
721 stoch.initialized,
722 "Should be initialized after sufficient bars"
723 );
724 }
725
726 #[rstest]
727 fn test_reset_clears_slowing_ma_state() {
728 let mut stoch = Stochastics::new_with_params(
729 5,
730 3,
731 3,
732 MovingAverageType::Exponential,
733 StochasticsDMethod::MovingAverage,
734 );
735
736 for i in 1..=10 {
738 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
739 }
740
741 assert!(stoch.has_inputs);
742
743 stoch.reset();
745
746 assert!(!stoch.has_inputs);
747 assert!(!stoch.initialized);
748 assert_eq!(stoch.value_k, 0.0);
749 assert_eq!(stoch.value_d, 0.0);
750 assert_eq!(stoch.highs.len(), 0);
751 assert_eq!(stoch.lows.len(), 0);
752
753 for i in 1..=10 {
755 stoch.update_raw(f64::from(i) + 5.0, f64::from(i), f64::from(i) + 2.0);
756 }
757 assert!(stoch.value_k > 0.0);
758 }
759
760 #[rstest]
761 fn test_slowing_1_bypasses_ma() {
762 let stoch = Stochastics::new_with_params(
763 10,
764 3,
765 1, MovingAverageType::Exponential,
767 StochasticsDMethod::Ratio,
768 );
769
770 assert!(
771 stoch.slowing_ma.is_none(),
772 "slowing = 1 should not create MA"
773 );
774 }
775
776 #[rstest]
777 #[should_panic(expected = "slowing")]
778 fn test_slowing_0_panics() {
779 let _ = Stochastics::new_with_params(
780 10,
781 3,
782 0, MovingAverageType::Exponential,
784 StochasticsDMethod::Ratio,
785 );
786 }
787
788 #[rstest]
789 fn test_division_by_zero_protection() {
790 let mut stoch = Stochastics::new_with_params(
791 5,
792 3,
793 3,
794 MovingAverageType::Exponential,
795 StochasticsDMethod::MovingAverage,
796 );
797
798 for _ in 0..10 {
800 stoch.update_raw(100.0, 100.0, 100.0);
801 }
802
803 assert!(stoch.value_k.is_finite());
805 assert!(stoch.value_d.is_finite());
806 }
807}