nautilus_indicators/momentum/
cmo.rs1use std::fmt::Display;
17
18use nautilus_model::data::{Bar, QuoteTick, TradeTick};
19
20use crate::{
21 average::{MovingAverageFactory, MovingAverageType},
22 indicator::{Indicator, MovingAverage},
23};
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28 feature = "python",
29 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators", unsendable)
30)]
31#[cfg_attr(
32 feature = "python",
33 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
34)]
35pub struct ChandeMomentumOscillator {
36 pub period: usize,
37 pub ma_type: MovingAverageType,
38 pub value: f64,
39 pub count: usize,
40 pub initialized: bool,
41 previous_close: f64,
42 average_gain: Box<dyn MovingAverage + Send + 'static>,
43 average_loss: Box<dyn MovingAverage + Send + 'static>,
44 has_inputs: bool,
45}
46
47impl Display for ChandeMomentumOscillator {
48 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49 write!(f, "{}({})", self.name(), self.period)
50 }
51}
52
53impl Indicator for ChandeMomentumOscillator {
54 fn name(&self) -> String {
55 stringify!(ChandeMomentumOscillator).to_string()
56 }
57
58 fn has_inputs(&self) -> bool {
59 self.has_inputs
60 }
61
62 fn initialized(&self) -> bool {
63 self.initialized
64 }
65
66 fn handle_quote(&mut self, _quote: &QuoteTick) {}
67
68 fn handle_trade(&mut self, _trade: &TradeTick) {}
69
70 fn handle_bar(&mut self, bar: &Bar) {
71 self.update_raw((&bar.close).into());
72 }
73
74 fn reset(&mut self) {
75 self.value = 0.0;
76 self.count = 0;
77 self.has_inputs = false;
78 self.initialized = false;
79 self.previous_close = 0.0;
80 self.average_gain.reset();
81 self.average_loss.reset();
82 }
83}
84
85impl ChandeMomentumOscillator {
86 #[must_use]
92 pub fn new(period: usize, ma_type: Option<MovingAverageType>) -> Self {
93 assert!(period > 0, "ChandeMomentumOscillator: period must be > 0");
94 let ma_type = ma_type.unwrap_or(MovingAverageType::Wilder);
95 Self {
96 period,
97 ma_type,
98 average_gain: MovingAverageFactory::create(ma_type, period),
99 average_loss: MovingAverageFactory::create(ma_type, period),
100 previous_close: 0.0,
101 value: 0.0,
102 count: 0,
103 initialized: false,
104 has_inputs: false,
105 }
106 }
107
108 pub fn update_raw(&mut self, close: f64) {
109 self.count += 1;
110
111 if !self.has_inputs {
112 self.previous_close = close;
113 self.has_inputs = true;
114 }
115
116 let gain: f64 = close - self.previous_close;
117 if gain > 0.0 {
118 self.average_gain.update_raw(gain);
119 self.average_loss.update_raw(0.0);
120 } else if gain < 0.0 {
121 self.average_gain.update_raw(0.0);
122 self.average_loss.update_raw(-gain);
123 } else {
124 self.average_gain.update_raw(0.0);
125 self.average_loss.update_raw(0.0);
126 }
127
128 if !self.initialized && self.average_gain.initialized() && self.average_loss.initialized() {
129 self.initialized = true;
130 }
131
132 if self.initialized {
133 let divisor = self.average_gain.value() + self.average_loss.value();
134 if divisor == 0.0 {
135 self.value = 0.0;
136 } else {
137 self.value =
138 100.0 * (self.average_gain.value() - self.average_loss.value()) / divisor;
139 }
140 }
141 self.previous_close = close;
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use nautilus_model::data::{Bar, QuoteTick};
148 use rstest::rstest;
149
150 use crate::{
151 average::MovingAverageType, indicator::Indicator, momentum::cmo::ChandeMomentumOscillator,
152 stubs::*,
153 };
154
155 #[rstest]
156 fn test_cmo_initialized(cmo_10: ChandeMomentumOscillator) {
157 let display_str = format!("{cmo_10}");
158 assert_eq!(display_str, "ChandeMomentumOscillator(10)");
159 assert_eq!(cmo_10.period, 10);
160 assert!(!cmo_10.initialized);
161 }
162
163 #[rstest]
164 fn test_initialized_with_required_inputs_returns_true(mut cmo_10: ChandeMomentumOscillator) {
165 for i in 0..12 {
166 cmo_10.update_raw(f64::from(i));
167 }
168 assert!(cmo_10.initialized);
169 }
170
171 #[rstest]
172 fn test_value_all_higher_inputs_returns_expected_value(mut cmo_10: ChandeMomentumOscillator) {
173 cmo_10.update_raw(109.93);
174 cmo_10.update_raw(110.0);
175 cmo_10.update_raw(109.77);
176 cmo_10.update_raw(109.96);
177 cmo_10.update_raw(110.29);
178 cmo_10.update_raw(110.53);
179 cmo_10.update_raw(110.27);
180 cmo_10.update_raw(110.21);
181 cmo_10.update_raw(110.06);
182 cmo_10.update_raw(110.19);
183 cmo_10.update_raw(109.83);
184 cmo_10.update_raw(109.9);
185 cmo_10.update_raw(110.0);
186 cmo_10.update_raw(110.03);
187 cmo_10.update_raw(110.13);
188 cmo_10.update_raw(109.95);
189 cmo_10.update_raw(109.75);
190 cmo_10.update_raw(110.15);
191 cmo_10.update_raw(109.9);
192 cmo_10.update_raw(110.04);
193 assert_eq!(cmo_10.value, 2.089_629_456_238_705_4);
194 }
195
196 #[rstest]
197 fn test_value_with_one_input_returns_expected_value(mut cmo_10: ChandeMomentumOscillator) {
198 cmo_10.update_raw(1.00000);
199 assert_eq!(cmo_10.value, 0.0);
200 }
201
202 #[rstest]
203 fn test_reset(mut cmo_10: ChandeMomentumOscillator) {
204 cmo_10.update_raw(1.00020);
205 cmo_10.update_raw(1.00030);
206 cmo_10.update_raw(1.00050);
207 cmo_10.reset();
208 assert!(!cmo_10.initialized());
209 assert_eq!(cmo_10.count, 0);
210 assert_eq!(cmo_10.value, 0.0);
211 assert_eq!(cmo_10.previous_close, 0.0);
212 }
213
214 #[rstest]
215 fn test_handle_quote_tick(mut cmo_10: ChandeMomentumOscillator, stub_quote: QuoteTick) {
216 cmo_10.handle_quote(&stub_quote);
217 assert_eq!(cmo_10.count, 0);
218 assert_eq!(cmo_10.value, 0.0);
219 }
220
221 #[rstest]
222 fn test_handle_bar(mut cmo_10: ChandeMomentumOscillator, bar_ethusdt_binance_minute_bid: Bar) {
223 cmo_10.handle_bar(&bar_ethusdt_binance_minute_bid);
224 assert_eq!(cmo_10.count, 1);
225 assert_eq!(cmo_10.value, 0.0);
226 }
227
228 #[rstest]
229 fn test_ma_type_affects_value() {
230 let mut cmo_sma = ChandeMomentumOscillator::new(3, Some(MovingAverageType::Simple));
231 let mut cmo_wilder = ChandeMomentumOscillator::new(3, Some(MovingAverageType::Wilder));
232 let prices = [1.0, 2.0, 3.0, 2.5, 3.5];
233 for price in prices {
234 cmo_sma.update_raw(price);
235 cmo_wilder.update_raw(price);
236 }
237 assert_ne!(cmo_sma.value, cmo_wilder.value);
238 }
239
240 #[rstest]
241 fn test_count_increments(mut cmo_10: ChandeMomentumOscillator) {
242 for i in 0..5 {
243 cmo_10.update_raw(f64::from(i));
244 }
245 assert_eq!(cmo_10.count, 5);
246 }
247
248 #[rstest]
249 fn test_reset_resets_inner_mas() {
250 let mut cmo = ChandeMomentumOscillator::new(3, None);
251 for price in [1.0, 2.0, 3.0] {
252 cmo.update_raw(price);
253 }
254 assert!(cmo.average_gain.initialized());
255 assert!(cmo.average_loss.initialized());
256 assert_ne!(cmo.average_gain.value(), 0.0);
257 cmo.reset();
258 assert!(!cmo.average_gain.initialized());
259 assert!(!cmo.average_loss.initialized());
260 assert_eq!(cmo.average_gain.value(), 0.0);
261 assert_eq!(cmo.average_loss.value(), 0.0);
262 }
263
264 #[rstest]
265 #[should_panic]
266 fn test_invalid_period_panics() {
267 let _ = ChandeMomentumOscillator::new(0, None);
268 }
269
270 #[rstest]
271 fn test_ma_type_propagation() {
272 let cmo = ChandeMomentumOscillator::new(5, Some(MovingAverageType::Simple));
273 assert_eq!(cmo.ma_type, MovingAverageType::Simple);
274 }
275
276 #[rstest]
277 fn test_zero_divisor_returns_zero() {
278 let mut cmo = ChandeMomentumOscillator::new(3, None);
279 for _ in 0..5 {
280 cmo.update_raw(100.0);
281 }
282 assert!(cmo.initialized);
283 assert_eq!(cmo.value, 0.0);
284 }
285
286 #[rstest]
287 fn test_random_walk_values_within_bounds() {
288 let prices = [
289 100.0, 100.5, 99.8, 100.3, 101.0, 100.7, 101.5, 101.2, 100.6, 101.1, 100.9, 101.4,
290 100.8, 101.2, 100.6, 100.9, 101.3, 101.0, 100.5, 101.1, 100.7, 101.4, 100.9, 100.8,
291 101.2, 100.6, 100.9, 101.3, 101.0, 100.5,
292 ];
293 let mut cmo = ChandeMomentumOscillator::new(10, None);
294 for price in prices {
295 cmo.update_raw(price);
296 }
297 assert!(cmo.initialized);
298 assert!(cmo.value <= 100.0 && cmo.value >= -100.0);
299 }
300}