1use std::fmt::Display;
17
18use arraydeque::{ArrayDeque, Wrapping};
19use nautilus_model::data::Bar;
20
21use crate::indicator::Indicator;
22
23const MAX_PERIOD: usize = 1_024;
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28 feature = "python",
29 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
30)]
31#[cfg_attr(
32 feature = "python",
33 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
34)]
35pub struct Swings {
36 pub period: usize,
37 pub direction: i64,
38 pub changed: bool,
39 pub high_datetime: f64,
40 pub low_datetime: f64,
41 pub high_price: f64,
42 pub low_price: f64,
43 pub length: usize,
44 pub duration: usize,
45 pub since_high: usize,
46 pub since_low: usize,
47 high_inputs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
48 low_inputs: ArrayDeque<f64, MAX_PERIOD, Wrapping>,
49 has_inputs: bool,
50 initialized: bool,
51}
52
53impl Display for Swings {
54 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
55 write!(f, "{}({})", self.name(), self.period,)
56 }
57}
58
59impl Indicator for Swings {
60 fn name(&self) -> String {
61 stringify!(Swings).to_string()
62 }
63
64 fn has_inputs(&self) -> bool {
65 self.has_inputs
66 }
67
68 fn initialized(&self) -> bool {
69 self.initialized
70 }
71
72 fn handle_bar(&mut self, bar: &Bar) {
73 self.update_raw((&bar.high).into(), (&bar.low).into(), bar.ts_init.as_f64());
74 }
75
76 fn reset(&mut self) {
77 self.high_inputs.clear();
78 self.low_inputs.clear();
79 self.has_inputs = false;
80 self.initialized = false;
81 self.direction = 0;
82 self.changed = false;
83 self.high_datetime = 0.0;
84 self.low_datetime = 0.0;
85 self.high_price = 0.0;
86 self.low_price = 0.0;
87 self.length = 0;
88 self.duration = 0;
89 self.since_high = 0;
90 self.since_low = 0;
91 }
92}
93
94impl Swings {
95 #[must_use]
103 pub fn new(period: usize) -> Self {
104 assert!(
105 period > 0 && period <= MAX_PERIOD,
106 "Swings: period {period} exceeds MAX_PERIOD ({MAX_PERIOD})"
107 );
108
109 Self {
110 period,
111 high_inputs: ArrayDeque::new(),
112 low_inputs: ArrayDeque::new(),
113 has_inputs: false,
114 initialized: false,
115 direction: 0,
116 changed: false,
117 high_datetime: 0.0,
118 low_datetime: 0.0,
119 high_price: 0.0,
120 low_price: 0.0,
121 length: 0,
122 duration: 0,
123 since_high: 0,
124 since_low: 0,
125 }
126 }
127
128 pub fn update_raw(&mut self, high: f64, low: f64, timestamp: f64) {
129 self.changed = false;
130
131 if self.high_inputs.len() == self.period {
132 self.high_inputs.pop_front();
133 }
134
135 if self.low_inputs.len() == self.period {
136 self.low_inputs.pop_front();
137 }
138 let _ = self.high_inputs.push_back(high);
139 let _ = self.low_inputs.push_back(low);
140
141 let max_high = self.high_inputs.iter().fold(f64::MIN, |a, &b| a.max(b));
142 let min_low = self.low_inputs.iter().fold(f64::MAX, |a, &b| a.min(b));
143
144 let is_swing_high = high >= max_high && low >= min_low;
145 let is_swing_low = high <= max_high && low <= min_low;
146
147 if is_swing_high && is_swing_low {
148 if self.high_price == 0.0 {
149 self.high_price = high;
150 self.high_datetime = timestamp;
151 }
152 self.since_high += 1;
153 self.since_low += 1;
154 } else if is_swing_high {
155 if self.direction == -1 {
156 self.changed = true;
157 }
158
159 if high > self.high_price {
160 self.high_price = high;
161 self.high_datetime = timestamp;
162 }
163 self.direction = 1;
164 self.since_high = 0;
165 self.since_low += 1;
166 } else if is_swing_low {
167 if self.direction == 1 {
168 self.changed = true;
169 }
170
171 if self.high_price == 0.0 {
172 self.high_price = max_high;
173 self.high_datetime = timestamp;
174 }
175
176 if low < self.low_price || self.low_price == 0.0 {
177 self.low_price = low;
178 self.low_datetime = timestamp;
179 }
180 self.direction = -1;
181 self.since_high += 1;
182 self.since_low = 0;
183 } else {
184 self.since_high += 1;
185 self.since_low += 1;
186 }
187
188 self.has_inputs = true;
189
190 if self.high_price != 0.0 && self.low_price != 0.0 {
191 self.initialized = true;
192 self.length = ((self.high_price - self.low_price).abs().round()) as usize;
193
194 if self.direction == 1 {
195 self.duration = self.since_low;
196 } else if self.direction == -1 {
197 self.duration = self.since_high;
198 } else {
199 self.duration = 0;
200 }
201 }
202 }
203}
204
205#[cfg(test)]
206mod tests {
207 use rstest::rstest;
208
209 use super::*;
210 use crate::stubs::swings_10;
211
212 #[rstest]
213 fn test_name_returns_expected_string(swings_10: Swings) {
214 assert_eq!(swings_10.name(), "Swings");
215 }
216
217 #[rstest]
218 fn test_str_repr_returns_expected_string(swings_10: Swings) {
219 assert_eq!(format!("{swings_10}"), "Swings(10)");
220 }
221
222 #[rstest]
223 fn test_period_returns_expected_value(swings_10: Swings) {
224 assert_eq!(swings_10.period, 10);
225 }
226
227 #[rstest]
228 fn test_initialized_without_inputs_returns_false(swings_10: Swings) {
229 assert!(!swings_10.initialized());
230 }
231
232 #[rstest]
233 fn test_value_with_all_higher_inputs_returns_expected_value(mut swings_10: Swings) {
234 let high = [
235 0.9, 1.9, 2.9, 3.9, 4.9, 3.2, 6.9, 7.9, 8.9, 9.9, 1.1, 3.2, 10.3, 11.1, 11.4,
236 ];
237 let low = [
238 0.8, 1.8, 2.8, 3.8, 4.8, 3.1, 6.8, 7.8, 0.8, 9.8, 1.0, 3.1, 10.2, 11.0, 11.3,
239 ];
240 let time = [
241 1_643_723_400.0,
242 1_643_723_410.0,
243 1_643_723_420.0,
244 1_643_723_430.0,
245 1_643_723_440.0,
246 1_643_723_450.0,
247 1_643_723_460.0,
248 1_643_723_470.0,
249 1_643_723_480.0,
250 1_643_723_490.0,
251 1_643_723_500.0,
252 1_643_723_510.0,
253 1_643_723_520.0,
254 1_643_723_530.0,
255 1_643_723_540.0,
256 ];
257
258 for i in 0..15 {
259 swings_10.update_raw(high[i], low[i], time[i]);
260 }
261
262 assert_eq!(swings_10.direction, 1);
263 assert_eq!(swings_10.high_price, 11.4);
264 assert_eq!(swings_10.low_price, 0.0);
265 assert_eq!(swings_10.high_datetime, time[14]);
266 assert_eq!(swings_10.low_datetime, 0.0);
267 assert_eq!(swings_10.length, 0);
268 assert_eq!(swings_10.duration, 0);
269 assert_eq!(swings_10.since_high, 0);
270 assert_eq!(swings_10.since_low, 15);
271 }
272
273 #[rstest]
274 fn test_reset_successfully_returns_indicator_to_fresh_state(mut swings_10: Swings) {
275 let high = [1.0, 2.0, 3.0, 4.0, 5.0];
276 let low = [0.9, 1.9, 2.9, 3.9, 4.9];
277 let time = [
278 1_643_723_400.0,
279 1_643_723_410.0,
280 1_643_723_420.0,
281 1_643_723_430.0,
282 1_643_723_440.0,
283 ];
284
285 for i in 0..5 {
286 swings_10.update_raw(high[i], low[i], time[i]);
287 }
288
289 swings_10.reset();
290
291 assert!(!swings_10.initialized());
292 assert_eq!(swings_10.direction, 0);
293 assert_eq!(swings_10.high_price, 0.0);
294 assert_eq!(swings_10.low_price, 0.0);
295 assert_eq!(swings_10.high_datetime, 0.0);
296 assert_eq!(swings_10.low_datetime, 0.0);
297 assert_eq!(swings_10.length, 0);
298 assert_eq!(swings_10.duration, 0);
299 assert_eq!(swings_10.since_high, 0);
300 assert_eq!(swings_10.since_low, 0);
301 assert!(swings_10.high_inputs.is_empty());
302 assert!(swings_10.low_inputs.is_empty());
303 }
304
305 #[rstest]
306 fn test_changed_flag_flips() {
307 let mut swings = Swings::new(2);
308
309 swings.update_raw(1.0, 0.5, 1.0);
310 assert!(!swings.changed);
311
312 swings.update_raw(2.0, 1.5, 2.0);
313 assert!(!swings.changed);
314
315 swings.update_raw(0.0, -1.0, 3.0);
316 assert!(swings.changed);
317
318 swings.update_raw(-0.5, -1.5, 4.0);
319 assert!(!swings.changed);
320 }
321
322 #[rstest]
323 fn test_length_computation_after_initialization() {
324 let mut swings = Swings::new(2);
325 swings.update_raw(10.0, 9.0, 1.0);
326 swings.update_raw(8.0, 7.0, 2.0);
327 swings.update_raw(8.0, 7.5, 3.0);
328 assert_eq!(swings.length, 3);
329 }
330
331 #[rstest]
332 fn test_length_rounds_fractional_difference() {
333 let mut swings = Swings::new(2);
334 swings.update_raw(10.9, 10.7, 1.0);
335 swings.update_raw(9.7, 9.4, 2.0);
336 swings.update_raw(9.7, 9.4, 3.0);
337 assert_eq!(swings.length, 2);
338 }
339
340 #[rstest]
341 fn test_queue_eviction_does_not_exceed_capacity() {
342 let period = 3;
343 let mut swings = Swings::new(period);
344
345 let highs = [1.0, 2.0, 3.0, 4.0, 5.0];
346 let lows = [0.5, 1.5, 2.5, 3.5, 4.5];
347
348 for i in 0..highs.len() {
349 swings.update_raw(highs[i], lows[i], (i + 1) as f64);
350
351 assert!(swings.high_inputs.len() <= period);
352 assert!(swings.low_inputs.len() <= period);
353 }
354
355 assert_eq!(swings.high_inputs.len(), period);
356 assert_eq!(swings.low_inputs.len(), period);
357 assert_eq!(swings.high_inputs.front().copied(), Some(3.0));
358 assert_eq!(swings.low_inputs.front().copied(), Some(2.5));
359 }
360
361 #[rstest]
362 fn test_changed_flag_toggles_on_every_direction_flip() {
363 let mut swings = Swings::new(2);
364
365 swings.update_raw(1.0, 0.7, 1.0);
366 assert!(!swings.changed);
367 swings.update_raw(2.0, 1.7, 2.0);
368 assert!(!swings.changed);
369
370 swings.update_raw(0.0, -1.0, 3.0);
371 assert!(swings.changed);
372 swings.update_raw(-0.5, -1.5, 4.0);
373 assert!(!swings.changed);
374
375 swings.update_raw(2.5, 1.5, 5.0);
376 assert!(swings.changed);
377 swings.update_raw(3.0, 2.0, 6.0);
378 assert!(!swings.changed);
379 }
380
381 #[rstest]
382 fn test_length_precision_rounding() {
383 let mut swings = Swings::new(3);
384 swings.update_raw(10.49, 9.9, 1.0);
385 swings.update_raw(9.00, 8.0, 2.0);
386 swings.update_raw(9.00, 8.0, 3.0);
387 assert_eq!(swings.length, 2);
388
389 swings.reset();
390 swings.update_raw(10.5, 10.4, 10.0);
391 swings.update_raw(8.0, 7.5, 20.0);
392 swings.update_raw(8.0, 7.5, 30.0);
393 assert_eq!(swings.length, 3);
394
395 swings.reset();
396 swings.update_raw(10.8, 10.6, 40.0);
397 swings.update_raw(8.2, 7.4, 50.0);
398 swings.update_raw(8.2, 7.4, 60.0);
399 assert_eq!(swings.length, 3);
400 }
401}