Skip to main content

nautilus_indicators/momentum/
cmo.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::Display;
17
18use nautilus_model::data::{Bar, QuoteTick, TradeTick};
19
20use crate::{
21    average::{MovingAverageFactory, MovingAverageType},
22    indicator::{Indicator, MovingAverage},
23};
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28    feature = "python",
29    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators", unsendable)
30)]
31#[cfg_attr(
32    feature = "python",
33    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
34)]
35pub struct ChandeMomentumOscillator {
36    pub period: usize,
37    pub ma_type: MovingAverageType,
38    pub value: f64,
39    pub count: usize,
40    pub initialized: bool,
41    previous_close: f64,
42    average_gain: Box<dyn MovingAverage + Send + 'static>,
43    average_loss: Box<dyn MovingAverage + Send + 'static>,
44    has_inputs: bool,
45}
46
47impl Display for ChandeMomentumOscillator {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(f, "{}({})", self.name(), self.period)
50    }
51}
52
53impl Indicator for ChandeMomentumOscillator {
54    fn name(&self) -> String {
55        stringify!(ChandeMomentumOscillator).to_string()
56    }
57
58    fn has_inputs(&self) -> bool {
59        self.has_inputs
60    }
61
62    fn initialized(&self) -> bool {
63        self.initialized
64    }
65
66    fn handle_quote(&mut self, _quote: &QuoteTick) {}
67
68    fn handle_trade(&mut self, _trade: &TradeTick) {}
69
70    fn handle_bar(&mut self, bar: &Bar) {
71        self.update_raw((&bar.close).into());
72    }
73
74    fn reset(&mut self) {
75        self.value = 0.0;
76        self.count = 0;
77        self.has_inputs = false;
78        self.initialized = false;
79        self.previous_close = 0.0;
80        self.average_gain.reset();
81        self.average_loss.reset();
82    }
83}
84
85impl ChandeMomentumOscillator {
86    /// Creates a new [`ChandeMomentumOscillator`] instance.
87    ///
88    /// # Panics
89    ///
90    /// Panics if `period` is not positive (> 0).
91    #[must_use]
92    pub fn new(period: usize, ma_type: Option<MovingAverageType>) -> Self {
93        assert!(period > 0, "ChandeMomentumOscillator: period must be > 0");
94        let ma_type = ma_type.unwrap_or(MovingAverageType::Wilder);
95        Self {
96            period,
97            ma_type,
98            average_gain: MovingAverageFactory::create(ma_type, period),
99            average_loss: MovingAverageFactory::create(ma_type, period),
100            previous_close: 0.0,
101            value: 0.0,
102            count: 0,
103            initialized: false,
104            has_inputs: false,
105        }
106    }
107
108    pub fn update_raw(&mut self, close: f64) {
109        self.count += 1;
110
111        if !self.has_inputs {
112            self.previous_close = close;
113            self.has_inputs = true;
114        }
115
116        let gain: f64 = close - self.previous_close;
117        if gain > 0.0 {
118            self.average_gain.update_raw(gain);
119            self.average_loss.update_raw(0.0);
120        } else if gain < 0.0 {
121            self.average_gain.update_raw(0.0);
122            self.average_loss.update_raw(-gain);
123        } else {
124            self.average_gain.update_raw(0.0);
125            self.average_loss.update_raw(0.0);
126        }
127
128        if !self.initialized && self.average_gain.initialized() && self.average_loss.initialized() {
129            self.initialized = true;
130        }
131
132        if self.initialized {
133            let divisor = self.average_gain.value() + self.average_loss.value();
134            if divisor == 0.0 {
135                self.value = 0.0;
136            } else {
137                self.value =
138                    100.0 * (self.average_gain.value() - self.average_loss.value()) / divisor;
139            }
140        }
141        self.previous_close = close;
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use nautilus_model::data::{Bar, QuoteTick};
148    use rstest::rstest;
149
150    use crate::{
151        average::MovingAverageType, indicator::Indicator, momentum::cmo::ChandeMomentumOscillator,
152        stubs::*,
153    };
154
155    #[rstest]
156    fn test_cmo_initialized(cmo_10: ChandeMomentumOscillator) {
157        let display_str = format!("{cmo_10}");
158        assert_eq!(display_str, "ChandeMomentumOscillator(10)");
159        assert_eq!(cmo_10.period, 10);
160        assert!(!cmo_10.initialized);
161    }
162
163    #[rstest]
164    fn test_initialized_with_required_inputs_returns_true(mut cmo_10: ChandeMomentumOscillator) {
165        for i in 0..12 {
166            cmo_10.update_raw(f64::from(i));
167        }
168        assert!(cmo_10.initialized);
169    }
170
171    #[rstest]
172    fn test_value_all_higher_inputs_returns_expected_value(mut cmo_10: ChandeMomentumOscillator) {
173        cmo_10.update_raw(109.93);
174        cmo_10.update_raw(110.0);
175        cmo_10.update_raw(109.77);
176        cmo_10.update_raw(109.96);
177        cmo_10.update_raw(110.29);
178        cmo_10.update_raw(110.53);
179        cmo_10.update_raw(110.27);
180        cmo_10.update_raw(110.21);
181        cmo_10.update_raw(110.06);
182        cmo_10.update_raw(110.19);
183        cmo_10.update_raw(109.83);
184        cmo_10.update_raw(109.9);
185        cmo_10.update_raw(110.0);
186        cmo_10.update_raw(110.03);
187        cmo_10.update_raw(110.13);
188        cmo_10.update_raw(109.95);
189        cmo_10.update_raw(109.75);
190        cmo_10.update_raw(110.15);
191        cmo_10.update_raw(109.9);
192        cmo_10.update_raw(110.04);
193        assert_eq!(cmo_10.value, 2.089_629_456_238_705_4);
194    }
195
196    #[rstest]
197    fn test_value_with_one_input_returns_expected_value(mut cmo_10: ChandeMomentumOscillator) {
198        cmo_10.update_raw(1.00000);
199        assert_eq!(cmo_10.value, 0.0);
200    }
201
202    #[rstest]
203    fn test_reset(mut cmo_10: ChandeMomentumOscillator) {
204        cmo_10.update_raw(1.00020);
205        cmo_10.update_raw(1.00030);
206        cmo_10.update_raw(1.00050);
207        cmo_10.reset();
208        assert!(!cmo_10.initialized());
209        assert_eq!(cmo_10.count, 0);
210        assert_eq!(cmo_10.value, 0.0);
211        assert_eq!(cmo_10.previous_close, 0.0);
212    }
213
214    #[rstest]
215    fn test_handle_quote_tick(mut cmo_10: ChandeMomentumOscillator, stub_quote: QuoteTick) {
216        cmo_10.handle_quote(&stub_quote);
217        assert_eq!(cmo_10.count, 0);
218        assert_eq!(cmo_10.value, 0.0);
219    }
220
221    #[rstest]
222    fn test_handle_bar(mut cmo_10: ChandeMomentumOscillator, bar_ethusdt_binance_minute_bid: Bar) {
223        cmo_10.handle_bar(&bar_ethusdt_binance_minute_bid);
224        assert_eq!(cmo_10.count, 1);
225        assert_eq!(cmo_10.value, 0.0);
226    }
227
228    #[rstest]
229    fn test_ma_type_affects_value() {
230        let mut cmo_sma = ChandeMomentumOscillator::new(3, Some(MovingAverageType::Simple));
231        let mut cmo_wilder = ChandeMomentumOscillator::new(3, Some(MovingAverageType::Wilder));
232        let prices = [1.0, 2.0, 3.0, 2.5, 3.5];
233        for price in prices {
234            cmo_sma.update_raw(price);
235            cmo_wilder.update_raw(price);
236        }
237        assert_ne!(cmo_sma.value, cmo_wilder.value);
238    }
239
240    #[rstest]
241    fn test_count_increments(mut cmo_10: ChandeMomentumOscillator) {
242        for i in 0..5 {
243            cmo_10.update_raw(f64::from(i));
244        }
245        assert_eq!(cmo_10.count, 5);
246    }
247
248    #[rstest]
249    fn test_reset_resets_inner_mas() {
250        let mut cmo = ChandeMomentumOscillator::new(3, None);
251        for price in [1.0, 2.0, 3.0] {
252            cmo.update_raw(price);
253        }
254        assert!(cmo.average_gain.initialized());
255        assert!(cmo.average_loss.initialized());
256        assert_ne!(cmo.average_gain.value(), 0.0);
257        cmo.reset();
258        assert!(!cmo.average_gain.initialized());
259        assert!(!cmo.average_loss.initialized());
260        assert_eq!(cmo.average_gain.value(), 0.0);
261        assert_eq!(cmo.average_loss.value(), 0.0);
262    }
263
264    #[rstest]
265    #[should_panic]
266    fn test_invalid_period_panics() {
267        let _ = ChandeMomentumOscillator::new(0, None);
268    }
269
270    #[rstest]
271    fn test_ma_type_propagation() {
272        let cmo = ChandeMomentumOscillator::new(5, Some(MovingAverageType::Simple));
273        assert_eq!(cmo.ma_type, MovingAverageType::Simple);
274    }
275
276    #[rstest]
277    fn test_zero_divisor_returns_zero() {
278        let mut cmo = ChandeMomentumOscillator::new(3, None);
279        for _ in 0..5 {
280            cmo.update_raw(100.0);
281        }
282        assert!(cmo.initialized);
283        assert_eq!(cmo.value, 0.0);
284    }
285
286    #[rstest]
287    fn test_random_walk_values_within_bounds() {
288        let prices = [
289            100.0, 100.5, 99.8, 100.3, 101.0, 100.7, 101.5, 101.2, 100.6, 101.1, 100.9, 101.4,
290            100.8, 101.2, 100.6, 100.9, 101.3, 101.0, 100.5, 101.1, 100.7, 101.4, 100.9, 100.8,
291            101.2, 100.6, 100.9, 101.3, 101.0, 100.5,
292        ];
293        let mut cmo = ChandeMomentumOscillator::new(10, None);
294        for price in prices {
295            cmo.update_raw(price);
296        }
297        assert!(cmo.initialized);
298        assert!(cmo.value <= 100.0 && cmo.value >= -100.0);
299    }
300}