Skip to main content

nautilus_indicators/momentum/
swings.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::Display;
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::data::Bar;
20
21use crate::indicator::Indicator;
22
23const MAX_PERIOD: usize = 1_024;
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28    feature = "python",
29    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
30)]
31#[cfg_attr(
32    feature = "python",
33    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
34)]
35pub struct Swings {
36    pub period: usize,
37    pub direction: i64,
38    pub changed: bool,
39    pub high_datetime: f64,
40    pub low_datetime: f64,
41    pub high_price: f64,
42    pub low_price: f64,
43    pub length: usize,
44    pub duration: usize,
45    pub since_high: usize,
46    pub since_low: usize,
47    high_inputs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
48    low_inputs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
49    has_inputs: bool,
50    initialized: bool,
51}
52
53impl Display for Swings {
54    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55        write!(f, "{}({})", self.name(), self.period,)
56    }
57}
58
59impl Indicator for Swings {
60    fn name(&self) -> String {
61        stringify!(Swings).to_string()
62    }
63
64    fn has_inputs(&self) -> bool {
65        self.has_inputs
66    }
67
68    fn initialized(&self) -> bool {
69        self.initialized
70    }
71
72    fn handle_bar(&mut self, bar: &Bar) {
73        self.update_raw((&bar.high).into(), (&bar.low).into(), bar.ts_init.as_f64());
74    }
75
76    fn reset(&mut self) {
77        self.high_inputs.clear();
78        self.low_inputs.clear();
79        self.has_inputs = false;
80        self.initialized = false;
81        self.direction = 0;
82        self.changed = false;
83        self.high_datetime = 0.0;
84        self.low_datetime = 0.0;
85        self.high_price = 0.0;
86        self.low_price = 0.0;
87        self.length = 0;
88        self.duration = 0;
89        self.since_high = 0;
90        self.since_low = 0;
91    }
92}
93
94impl Swings {
95    /// Creates a new [`Swings`] instance.
96    ///
97    /// # Panics
98    ///
99    /// This function panics if:
100    /// - `period` is less than or equal to 0.
101    /// - `period` exceeds the maximum allowed value of `MAX_PERIOD`.
102    #[must_use]
103    pub fn new(period: usize) -> Self {
104        assert!(
105            period > 0 && period <= MAX_PERIOD,
106            "Swings: period {period} exceeds MAX_PERIOD ({MAX_PERIOD})"
107        );
108
109        Self {
110            period,
111            high_inputs: ArrayDeque::new(),
112            low_inputs: ArrayDeque::new(),
113            has_inputs: false,
114            initialized: false,
115            direction: 0,
116            changed: false,
117            high_datetime: 0.0,
118            low_datetime: 0.0,
119            high_price: 0.0,
120            low_price: 0.0,
121            length: 0,
122            duration: 0,
123            since_high: 0,
124            since_low: 0,
125        }
126    }
127
128    pub fn update_raw(&mut self, high: f64, low: f64, timestamp: f64) {
129        self.changed = false;
130
131        if self.high_inputs.len() == self.period {
132            self.high_inputs.pop_front();
133        }
134
135        if self.low_inputs.len() == self.period {
136            self.low_inputs.pop_front();
137        }
138        let _ = self.high_inputs.push_back(high);
139        let _ = self.low_inputs.push_back(low);
140
141        let max_high = self.high_inputs.iter().fold(f64::MIN, |a, &b| a.max(b));
142        let min_low = self.low_inputs.iter().fold(f64::MAX, |a, &b| a.min(b));
143
144        let is_swing_high = high >= max_high && low >= min_low;
145        let is_swing_low = high <= max_high && low <= min_low;
146
147        if is_swing_high && is_swing_low {
148            if self.high_price == 0.0 {
149                self.high_price = high;
150                self.high_datetime = timestamp;
151            }
152            self.since_high += 1;
153            self.since_low += 1;
154        } else if is_swing_high {
155            if self.direction == -1 {
156                self.changed = true;
157            }
158
159            if high > self.high_price {
160                self.high_price = high;
161                self.high_datetime = timestamp;
162            }
163            self.direction = 1;
164            self.since_high = 0;
165            self.since_low += 1;
166        } else if is_swing_low {
167            if self.direction == 1 {
168                self.changed = true;
169            }
170
171            if self.high_price == 0.0 {
172                self.high_price = max_high;
173                self.high_datetime = timestamp;
174            }
175
176            if low < self.low_price || self.low_price == 0.0 {
177                self.low_price = low;
178                self.low_datetime = timestamp;
179            }
180            self.direction = -1;
181            self.since_high += 1;
182            self.since_low = 0;
183        } else {
184            self.since_high += 1;
185            self.since_low += 1;
186        }
187
188        self.has_inputs = true;
189
190        if self.high_price != 0.0 && self.low_price != 0.0 {
191            self.initialized = true;
192            self.length = ((self.high_price - self.low_price).abs().round()) as usize;
193
194            if self.direction == 1 {
195                self.duration = self.since_low;
196            } else if self.direction == -1 {
197                self.duration = self.since_high;
198            } else {
199                self.duration = 0;
200            }
201        }
202    }
203}
204
205#[cfg(test)]
206mod tests {
207    use rstest::rstest;
208
209    use super::*;
210    use crate::stubs::swings_10;
211
212    #[rstest]
213    fn test_name_returns_expected_string(swings_10: Swings) {
214        assert_eq!(swings_10.name(), "Swings");
215    }
216
217    #[rstest]
218    fn test_str_repr_returns_expected_string(swings_10: Swings) {
219        assert_eq!(format!("{swings_10}"), "Swings(10)");
220    }
221
222    #[rstest]
223    fn test_period_returns_expected_value(swings_10: Swings) {
224        assert_eq!(swings_10.period, 10);
225    }
226
227    #[rstest]
228    fn test_initialized_without_inputs_returns_false(swings_10: Swings) {
229        assert!(!swings_10.initialized());
230    }
231
232    #[rstest]
233    fn test_value_with_all_higher_inputs_returns_expected_value(mut swings_10: Swings) {
234        let high = [
235            0.9, 1.9, 2.9, 3.9, 4.9, 3.2, 6.9, 7.9, 8.9, 9.9, 1.1, 3.2, 10.3, 11.1, 11.4,
236        ];
237        let low = [
238            0.8, 1.8, 2.8, 3.8, 4.8, 3.1, 6.8, 7.8, 0.8, 9.8, 1.0, 3.1, 10.2, 11.0, 11.3,
239        ];
240        let time = [
241            1_643_723_400.0,
242            1_643_723_410.0,
243            1_643_723_420.0,
244            1_643_723_430.0,
245            1_643_723_440.0,
246            1_643_723_450.0,
247            1_643_723_460.0,
248            1_643_723_470.0,
249            1_643_723_480.0,
250            1_643_723_490.0,
251            1_643_723_500.0,
252            1_643_723_510.0,
253            1_643_723_520.0,
254            1_643_723_530.0,
255            1_643_723_540.0,
256        ];
257
258        for i in 0..15 {
259            swings_10.update_raw(high[i], low[i], time[i]);
260        }
261
262        assert_eq!(swings_10.direction, 1);
263        assert_eq!(swings_10.high_price, 11.4);
264        assert_eq!(swings_10.low_price, 0.0);
265        assert_eq!(swings_10.high_datetime, time[14]);
266        assert_eq!(swings_10.low_datetime, 0.0);
267        assert_eq!(swings_10.length, 0);
268        assert_eq!(swings_10.duration, 0);
269        assert_eq!(swings_10.since_high, 0);
270        assert_eq!(swings_10.since_low, 15);
271    }
272
273    #[rstest]
274    fn test_reset_successfully_returns_indicator_to_fresh_state(mut swings_10: Swings) {
275        let high = [1.0, 2.0, 3.0, 4.0, 5.0];
276        let low = [0.9, 1.9, 2.9, 3.9, 4.9];
277        let time = [
278            1_643_723_400.0,
279            1_643_723_410.0,
280            1_643_723_420.0,
281            1_643_723_430.0,
282            1_643_723_440.0,
283        ];
284
285        for i in 0..5 {
286            swings_10.update_raw(high[i], low[i], time[i]);
287        }
288
289        swings_10.reset();
290
291        assert!(!swings_10.initialized());
292        assert_eq!(swings_10.direction, 0);
293        assert_eq!(swings_10.high_price, 0.0);
294        assert_eq!(swings_10.low_price, 0.0);
295        assert_eq!(swings_10.high_datetime, 0.0);
296        assert_eq!(swings_10.low_datetime, 0.0);
297        assert_eq!(swings_10.length, 0);
298        assert_eq!(swings_10.duration, 0);
299        assert_eq!(swings_10.since_high, 0);
300        assert_eq!(swings_10.since_low, 0);
301        assert!(swings_10.high_inputs.is_empty());
302        assert!(swings_10.low_inputs.is_empty());
303    }
304
305    #[rstest]
306    fn test_changed_flag_flips() {
307        let mut swings = Swings::new(2);
308
309        swings.update_raw(1.0, 0.5, 1.0);
310        assert!(!swings.changed);
311
312        swings.update_raw(2.0, 1.5, 2.0);
313        assert!(!swings.changed);
314
315        swings.update_raw(0.0, -1.0, 3.0);
316        assert!(swings.changed);
317
318        swings.update_raw(-0.5, -1.5, 4.0);
319        assert!(!swings.changed);
320    }
321
322    #[rstest]
323    fn test_length_computation_after_initialization() {
324        let mut swings = Swings::new(2);
325        swings.update_raw(10.0, 9.0, 1.0);
326        swings.update_raw(8.0, 7.0, 2.0);
327        swings.update_raw(8.0, 7.5, 3.0);
328        assert_eq!(swings.length, 3);
329    }
330
331    #[rstest]
332    fn test_length_rounds_fractional_difference() {
333        let mut swings = Swings::new(2);
334        swings.update_raw(10.9, 10.7, 1.0);
335        swings.update_raw(9.7, 9.4, 2.0);
336        swings.update_raw(9.7, 9.4, 3.0);
337        assert_eq!(swings.length, 2);
338    }
339
340    #[rstest]
341    fn test_queue_eviction_does_not_exceed_capacity() {
342        let period = 3;
343        let mut swings = Swings::new(period);
344
345        let highs = [1.0, 2.0, 3.0, 4.0, 5.0];
346        let lows = [0.5, 1.5, 2.5, 3.5, 4.5];
347
348        for i in 0..highs.len() {
349            swings.update_raw(highs[i], lows[i], (i + 1) as f64);
350
351            assert!(swings.high_inputs.len() <= period);
352            assert!(swings.low_inputs.len() <= period);
353        }
354
355        assert_eq!(swings.high_inputs.len(), period);
356        assert_eq!(swings.low_inputs.len(), period);
357        assert_eq!(swings.high_inputs.front().copied(), Some(3.0));
358        assert_eq!(swings.low_inputs.front().copied(), Some(2.5));
359    }
360
361    #[rstest]
362    fn test_changed_flag_toggles_on_every_direction_flip() {
363        let mut swings = Swings::new(2);
364
365        swings.update_raw(1.0, 0.7, 1.0);
366        assert!(!swings.changed);
367        swings.update_raw(2.0, 1.7, 2.0);
368        assert!(!swings.changed);
369
370        swings.update_raw(0.0, -1.0, 3.0);
371        assert!(swings.changed);
372        swings.update_raw(-0.5, -1.5, 4.0);
373        assert!(!swings.changed);
374
375        swings.update_raw(2.5, 1.5, 5.0);
376        assert!(swings.changed);
377        swings.update_raw(3.0, 2.0, 6.0);
378        assert!(!swings.changed);
379    }
380
381    #[rstest]
382    fn test_length_precision_rounding() {
383        let mut swings = Swings::new(3);
384        swings.update_raw(10.49, 9.9, 1.0);
385        swings.update_raw(9.00, 8.0, 2.0);
386        swings.update_raw(9.00, 8.0, 3.0);
387        assert_eq!(swings.length, 2);
388
389        swings.reset();
390        swings.update_raw(10.5, 10.4, 10.0);
391        swings.update_raw(8.0, 7.5, 20.0);
392        swings.update_raw(8.0, 7.5, 30.0);
393        assert_eq!(swings.length, 3);
394
395        swings.reset();
396        swings.update_raw(10.8, 10.6, 40.0);
397        swings.update_raw(8.2, 7.4, 50.0);
398        swings.update_raw(8.2, 7.4, 60.0);
399        assert_eq!(swings.length, 3);
400    }
401}