Skip to main content

nautilus_indicators/volatility/
atr.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::fmt::{Debug, Display};
17
18use nautilus_model::data::Bar;
19
20use crate::{
21    average::{MovingAverageFactory, MovingAverageType},
22    indicator::{Indicator, MovingAverage},
23};
24
25/// An indicator which calculates an Average True Range (ATR) across a rolling window.
26#[repr(C)]
27#[derive(Debug)]
28#[cfg_attr(
29    feature = "python",
30    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators", unsendable)
31)]
32#[cfg_attr(
33    feature = "python",
34    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
35)]
36pub struct AverageTrueRange {
37    pub period: usize,
38    pub ma_type: MovingAverageType,
39    pub use_previous: bool,
40    pub value_floor: f64,
41    pub value: f64,
42    pub count: usize,
43    pub initialized: bool,
44    ma: Box<dyn MovingAverage + Send + 'static>,
45    has_inputs: bool,
46    previous_close: f64,
47}
48
49impl Display for AverageTrueRange {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        write!(
52            f,
53            "{}({},{},{},{})",
54            self.name(),
55            self.period,
56            self.ma_type,
57            self.use_previous,
58            self.value_floor,
59        )
60    }
61}
62
63impl Indicator for AverageTrueRange {
64    fn name(&self) -> String {
65        stringify!(AverageTrueRange).to_string()
66    }
67
68    fn has_inputs(&self) -> bool {
69        self.has_inputs
70    }
71
72    fn initialized(&self) -> bool {
73        self.initialized
74    }
75
76    fn handle_bar(&mut self, bar: &Bar) {
77        self.update_raw((&bar.high).into(), (&bar.low).into(), (&bar.close).into());
78    }
79
80    fn reset(&mut self) {
81        self.previous_close = 0.0;
82        self.value = 0.0;
83        self.count = 0;
84        self.has_inputs = false;
85        self.initialized = false;
86    }
87}
88
89impl AverageTrueRange {
90    /// Creates a new [`AverageTrueRange`] instance.
91    #[must_use]
92    pub fn new(
93        period: usize,
94        ma_type: Option<MovingAverageType>,
95        use_previous: Option<bool>,
96        value_floor: Option<f64>,
97    ) -> Self {
98        Self {
99            period,
100            ma_type: ma_type.unwrap_or(MovingAverageType::Simple),
101            use_previous: use_previous.unwrap_or(true),
102            value_floor: value_floor.unwrap_or(0.0),
103            value: 0.0,
104            count: 0,
105            previous_close: 0.0,
106            ma: MovingAverageFactory::create(MovingAverageType::Simple, period),
107            has_inputs: false,
108            initialized: false,
109        }
110    }
111
112    pub fn update_raw(&mut self, high: f64, low: f64, close: f64) {
113        if self.use_previous {
114            if !self.has_inputs {
115                self.previous_close = close;
116            }
117            self.ma.update_raw(
118                f64::max(self.previous_close, high) - f64::min(low, self.previous_close),
119            );
120            self.previous_close = close;
121        } else {
122            self.ma.update_raw(high - low);
123        }
124
125        self._floor_value();
126        self.increment_count();
127    }
128
129    fn _floor_value(&mut self) {
130        if self.value_floor == 0.0 || self.value_floor < self.ma.value() {
131            self.value = self.ma.value();
132        } else {
133            // Floor the value
134            self.value = self.value_floor;
135        }
136    }
137
138    const fn increment_count(&mut self) {
139        self.count += 1;
140
141        if !self.initialized {
142            self.has_inputs = true;
143
144            if self.count >= self.period {
145                self.initialized = true;
146            }
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use rstest::rstest;
154
155    use super::*;
156    use crate::testing::approx_equal;
157
158    #[rstest]
159    fn test_name_returns_expected_string() {
160        let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
161        assert_eq!(atr.name(), "AverageTrueRange");
162    }
163
164    #[rstest]
165    fn test_str_repr_returns_expected_string() {
166        let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), Some(true), Some(0.0));
167        assert_eq!(format!("{atr}"), "AverageTrueRange(10,SIMPLE,true,0)");
168    }
169
170    #[rstest]
171    fn test_period() {
172        let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
173        assert_eq!(atr.period, 10);
174    }
175
176    #[rstest]
177    fn test_initialized_without_inputs_returns_false() {
178        let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
179        assert!(!atr.initialized());
180    }
181
182    #[rstest]
183    fn test_initialized_with_required_inputs_returns_true() {
184        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
185        for _ in 0..10 {
186            atr.update_raw(1.0, 1.0, 1.0);
187        }
188        assert!(atr.initialized());
189    }
190
191    #[rstest]
192    fn test_value_with_no_inputs_returns_zero() {
193        let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
194        assert_eq!(atr.value, 0.0);
195    }
196
197    #[rstest]
198    fn test_value_with_epsilon_input() {
199        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
200        let epsilon = f64::EPSILON;
201        atr.update_raw(epsilon, epsilon, epsilon);
202        assert_eq!(atr.value, 0.0);
203    }
204
205    #[rstest]
206    fn test_value_with_one_ones_input() {
207        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
208        atr.update_raw(1.0, 1.0, 1.0);
209        assert_eq!(atr.value, 0.0);
210    }
211
212    #[rstest]
213    fn test_value_with_one_input() {
214        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
215        atr.update_raw(1.00020, 1.0, 1.00010);
216        assert!(approx_equal(atr.value, 0.0002));
217    }
218
219    #[rstest]
220    fn test_value_with_three_inputs() {
221        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
222        atr.update_raw(1.00020, 1.0, 1.00010);
223        atr.update_raw(1.00020, 1.0, 1.00010);
224        atr.update_raw(1.00020, 1.0, 1.00010);
225        assert!(approx_equal(atr.value, 0.0002));
226    }
227
228    #[rstest]
229    fn test_value_with_close_on_high() {
230        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
231        let mut high = 1.00010;
232        let mut low = 1.0;
233
234        for _ in 0..1000 {
235            high += 0.00010;
236            low += 0.00010;
237            let close = high;
238            atr.update_raw(high, low, close);
239        }
240        assert!(approx_equal(atr.value, 0.000_099_999_999_999_988_99));
241    }
242
243    #[rstest]
244    fn test_value_with_close_on_low() {
245        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
246        let mut high = 1.00010;
247        let mut low = 1.0;
248
249        for _ in 0..1000 {
250            high -= 0.00010;
251            low -= 0.00010;
252            let close = low;
253            atr.update_raw(high, low, close);
254        }
255        assert!(approx_equal(atr.value, 0.000_099_999_999_999_988_99));
256    }
257
258    #[rstest]
259    fn test_floor_with_ten_ones_inputs() {
260        let floor = 0.00005;
261        let mut floored_atr =
262            AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, Some(floor));
263
264        for _ in 0..20 {
265            floored_atr.update_raw(1.0, 1.0, 1.0);
266        }
267        assert_eq!(floored_atr.value, 5e-05);
268    }
269
270    #[rstest]
271    fn test_floor_with_exponentially_decreasing_high_inputs() {
272        let floor = 0.00005;
273        let mut floored_atr =
274            AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, Some(floor));
275        let mut high = 1.00020;
276        let low = 1.0;
277        let close = 1.0;
278
279        for _ in 0..20 {
280            high -= (high - low) / 2.0;
281            floored_atr.update_raw(high, low, close);
282        }
283        assert_eq!(floored_atr.value, floor);
284    }
285
286    #[rstest]
287    fn test_reset_successfully_returns_indicator_to_fresh_state() {
288        let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
289        for _ in 0..1000 {
290            atr.update_raw(1.00010, 1.0, 1.00005);
291        }
292        atr.reset();
293        assert!(!atr.initialized);
294        assert_eq!(atr.value, 0.0);
295    }
296}