nautilus_indicators/momentum/
amat.rs1use std::fmt::{Debug, Display};
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::data::Bar;
20
21use crate::{
22 average::{MovingAverageFactory, MovingAverageType},
23 indicator::{Indicator, MovingAverage},
24};
25
26const DEFAULT_MA_TYPE: MovingAverageType = MovingAverageType::Exponential;
27const MAX_SIGNAL: usize = 1_024;
28
29type SignalBuf = ArrayDeque<f64, { MAX_SIGNAL + 1 }, Wrapping>;
30
31#[repr(C)]
32#[derive(Debug)]
33#[cfg_attr(
34 feature = "python",
35 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators", unsendable)
36)]
37#[cfg_attr(
38 feature = "python",
39 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
40)]
41pub struct ArcherMovingAveragesTrends {
42 pub fast_period: usize,
43 pub slow_period: usize,
44 pub signal_period: usize,
45 pub ma_type: MovingAverageType,
46 pub long_run: bool,
47 pub short_run: bool,
48 pub initialized: bool,
49 fast_ma: Box<dyn MovingAverage + Send + 'static>,
50 slow_ma: Box<dyn MovingAverage + Send + 'static>,
51 fast_ma_price: SignalBuf,
52 slow_ma_price: SignalBuf,
53 has_inputs: bool,
54}
55
56impl Display for ArcherMovingAveragesTrends {
57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
58 write!(
59 f,
60 "{}({},{},{},{})",
61 self.name(),
62 self.fast_period,
63 self.slow_period,
64 self.signal_period,
65 self.ma_type,
66 )
67 }
68}
69
70impl Indicator for ArcherMovingAveragesTrends {
71 fn name(&self) -> String {
72 stringify!(ArcherMovingAveragesTrends).into()
73 }
74
75 fn has_inputs(&self) -> bool {
76 self.has_inputs
77 }
78
79 fn initialized(&self) -> bool {
80 self.initialized
81 }
82
83 fn handle_bar(&mut self, bar: &Bar) {
84 self.update_raw(bar.close.into());
85 }
86
87 fn reset(&mut self) {
88 self.fast_ma.reset();
89 self.slow_ma.reset();
90 self.long_run = false;
91 self.short_run = false;
92 self.fast_ma_price.clear();
93 self.slow_ma_price.clear();
94 self.has_inputs = false;
95 self.initialized = false;
96 }
97}
98
99impl ArcherMovingAveragesTrends {
100 #[must_use]
109 pub fn new(
110 fast_period: usize,
111 slow_period: usize,
112 signal_period: usize,
113 ma_type: Option<MovingAverageType>,
114 ) -> Self {
115 assert!(
116 fast_period > 0,
117 "fast_period must be positive (received {fast_period})"
118 );
119 assert!(
120 slow_period > 0,
121 "slow_period must be positive (received {slow_period})"
122 );
123 assert!(
124 signal_period > 0,
125 "signal_period must be positive (received {signal_period})"
126 );
127 assert!(
128 slow_period > fast_period,
129 "slow_period ({slow_period}) must be greater than fast_period ({fast_period})"
130 );
131 assert!(
132 signal_period <= MAX_SIGNAL,
133 "signal_period ({signal_period}) must not exceed MAX_SIGNAL ({MAX_SIGNAL})"
134 );
135
136 let ma_type = ma_type.unwrap_or(DEFAULT_MA_TYPE);
137
138 Self {
139 fast_period,
140 slow_period,
141 signal_period,
142 ma_type,
143 long_run: false,
144 short_run: false,
145 fast_ma: MovingAverageFactory::create(ma_type, fast_period),
146 slow_ma: MovingAverageFactory::create(ma_type, slow_period),
147 fast_ma_price: SignalBuf::new(),
148 slow_ma_price: SignalBuf::new(),
149 has_inputs: false,
150 initialized: false,
151 }
152 }
153
154 pub fn update_raw(&mut self, close: f64) {
159 self.fast_ma.update_raw(close);
160 self.slow_ma.update_raw(close);
161
162 if self.slow_ma.initialized() {
163 self.fast_ma_price.push_back(self.fast_ma.value());
164 self.slow_ma_price.push_back(self.slow_ma.value());
165
166 let max_len = self.signal_period + 1;
167 if self.fast_ma_price.len() > max_len {
168 self.fast_ma_price.pop_front();
169 self.slow_ma_price.pop_front();
170 }
171
172 let fast_back = self.fast_ma.value();
173 let fast_front = *self
174 .fast_ma_price
175 .front()
176 .expect("buffer has at least one element");
177
178 let fast_diff = fast_back - fast_front;
179 self.long_run = fast_diff > 0.0 || self.long_run;
180 self.short_run = fast_diff < 0.0 || self.short_run;
181 }
182
183 if !self.initialized {
184 self.has_inputs = true;
185 let max_len = self.signal_period + 1;
186 if self.slow_ma_price.len() == max_len && self.slow_ma.initialized() {
187 self.initialized = true;
188 }
189 }
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use rstest::rstest;
196
197 use super::*;
198 use crate::stubs::amat_345;
199
200 fn make(fast: usize, slow: usize, signal: usize) {
201 let _ = ArcherMovingAveragesTrends::new(fast, slow, signal, None);
202 }
203
204 #[rstest]
205 fn default_ma_type_is_exponential() {
206 let ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
207 assert_eq!(ind.ma_type, MovingAverageType::Exponential);
208 }
209
210 #[rstest]
211 fn test_name_returns_expected_string(amat_345: ArcherMovingAveragesTrends) {
212 assert_eq!(amat_345.name(), "ArcherMovingAveragesTrends");
213 }
214
215 #[rstest]
216 fn test_str_repr_returns_expected_string(amat_345: ArcherMovingAveragesTrends) {
217 assert_eq!(
218 format!("{amat_345}"),
219 "ArcherMovingAveragesTrends(3,4,5,SIMPLE)"
220 );
221 }
222
223 #[rstest]
224 fn test_period_returns_expected_value(amat_345: ArcherMovingAveragesTrends) {
225 assert_eq!(amat_345.fast_period, 3);
226 assert_eq!(amat_345.slow_period, 4);
227 assert_eq!(amat_345.signal_period, 5);
228 }
229
230 #[rstest]
231 fn test_initialized_without_inputs_returns_false(amat_345: ArcherMovingAveragesTrends) {
232 assert!(!amat_345.initialized());
233 }
234
235 #[rstest]
236 #[should_panic(expected = "fast_period must be positive")]
237 fn new_panics_on_zero_fast_period() {
238 make(0, 4, 5);
239 }
240
241 #[rstest]
242 #[should_panic(expected = "slow_period must be positive")]
243 fn new_panics_on_zero_slow_period() {
244 make(3, 0, 5);
245 }
246
247 #[rstest]
248 #[should_panic(expected = "signal_period must be positive")]
249 fn new_panics_on_zero_signal_period() {
250 make(3, 5, 0);
251 }
252
253 #[rstest]
254 #[should_panic(expected = "slow_period (3) must be greater than fast_period (3)")]
255 fn new_panics_when_slow_not_greater_than_fast() {
256 make(3, 3, 5);
257 }
258
259 #[rstest]
260 #[should_panic(expected = "slow_period (2) must be greater than fast_period (3)")]
261 fn new_panics_when_slow_less_than_fast() {
262 make(3, 2, 5);
263 }
264
265 fn feed_sequence(ind: &mut ArcherMovingAveragesTrends, start: i64, count: usize, step: i64) {
266 (0..count).for_each(|i| ind.update_raw((start + i as i64 * step) as f64));
267 }
268
269 #[rstest]
270 fn buffer_len_never_exceeds_signal_plus_one() {
271 let mut ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
272 feed_sequence(&mut ind, 0, 100, 1);
273 assert_eq!(ind.fast_ma_price.len(), ind.signal_period + 1);
274 assert_eq!(ind.slow_ma_price.len(), ind.signal_period + 1);
275 }
276
277 #[rstest]
278 fn initialized_becomes_true_after_slow_ready_and_buffer_full() {
279 let mut ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
280 feed_sequence(&mut ind, 0, 11, 1); assert!(ind.initialized());
282 }
283
284 #[rstest]
285 fn long_run_flag_sets_on_bullish_trend() {
286 let mut ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
287 feed_sequence(&mut ind, 0, 60, 1);
288 assert!(ind.long_run, "Expected long_run=TRUE on up-trend");
289 assert!(!ind.short_run, "short_run should remain FALSE here");
290 }
291
292 #[rstest]
293 fn short_run_flag_sets_on_bearish_trend() {
294 let mut ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
295 feed_sequence(&mut ind, 100, 60, -1);
296 assert!(ind.short_run, "Expected short_run=TRUE on down-trend");
297 assert!(!ind.long_run, "long_run should remain FALSE here");
298 }
299
300 #[rstest]
301 fn reset_clears_internal_state() {
302 let mut ind = ArcherMovingAveragesTrends::new(3, 4, 5, None);
303 feed_sequence(&mut ind, 0, 50, 1);
304 assert!(ind.long_run || ind.short_run);
305 assert!(!ind.fast_ma_price.is_empty());
306
307 ind.reset();
308
309 assert!(!ind.long_run && !ind.short_run);
310 assert_eq!(ind.fast_ma_price.len(), 0);
311 assert_eq!(ind.slow_ma_price.len(), 0);
312 assert!(!ind.initialized());
313 assert!(!ind.has_inputs());
314 }
315
316 #[rstest]
317 #[should_panic(expected = "signal_period (1025) must not exceed MAX_SIGNAL (1024)")]
318 fn new_panics_when_signal_exceeds_max() {
319 let _ = ArcherMovingAveragesTrends::new(3, 4, MAX_SIGNAL + 1, None);
320 }
321
322 #[rstest]
323 fn ma_type_override_is_respected() {
324 let ind = ArcherMovingAveragesTrends::new(3, 4, 5, Some(MovingAverageType::Simple));
325 assert_eq!(ind.ma_type, MovingAverageType::Simple);
326 }
327}