nautilus_indicators/volatility/
atr.rs1use std::fmt::{Debug, Display};
17
18use nautilus_model::data::Bar;
19
20use crate::{
21 average::{MovingAverageFactory, MovingAverageType},
22 indicator::{Indicator, MovingAverage},
23};
24
25#[repr(C)]
27#[derive(Debug)]
28#[cfg_attr(
29 feature = "python",
30 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators", unsendable)
31)]
32#[cfg_attr(
33 feature = "python",
34 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
35)]
36pub struct AverageTrueRange {
37 pub period: usize,
38 pub ma_type: MovingAverageType,
39 pub use_previous: bool,
40 pub value_floor: f64,
41 pub value: f64,
42 pub count: usize,
43 pub initialized: bool,
44 ma: Box<dyn MovingAverage + Send + 'static>,
45 has_inputs: bool,
46 previous_close: f64,
47}
48
49impl Display for AverageTrueRange {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 write!(
52 f,
53 "{}({},{},{},{})",
54 self.name(),
55 self.period,
56 self.ma_type,
57 self.use_previous,
58 self.value_floor,
59 )
60 }
61}
62
63impl Indicator for AverageTrueRange {
64 fn name(&self) -> String {
65 stringify!(AverageTrueRange).to_string()
66 }
67
68 fn has_inputs(&self) -> bool {
69 self.has_inputs
70 }
71
72 fn initialized(&self) -> bool {
73 self.initialized
74 }
75
76 fn handle_bar(&mut self, bar: &Bar) {
77 self.update_raw((&bar.high).into(), (&bar.low).into(), (&bar.close).into());
78 }
79
80 fn reset(&mut self) {
81 self.previous_close = 0.0;
82 self.value = 0.0;
83 self.count = 0;
84 self.has_inputs = false;
85 self.initialized = false;
86 }
87}
88
89impl AverageTrueRange {
90 #[must_use]
92 pub fn new(
93 period: usize,
94 ma_type: Option<MovingAverageType>,
95 use_previous: Option<bool>,
96 value_floor: Option<f64>,
97 ) -> Self {
98 Self {
99 period,
100 ma_type: ma_type.unwrap_or(MovingAverageType::Simple),
101 use_previous: use_previous.unwrap_or(true),
102 value_floor: value_floor.unwrap_or(0.0),
103 value: 0.0,
104 count: 0,
105 previous_close: 0.0,
106 ma: MovingAverageFactory::create(MovingAverageType::Simple, period),
107 has_inputs: false,
108 initialized: false,
109 }
110 }
111
112 pub fn update_raw(&mut self, high: f64, low: f64, close: f64) {
113 if self.use_previous {
114 if !self.has_inputs {
115 self.previous_close = close;
116 }
117 self.ma.update_raw(
118 f64::max(self.previous_close, high) - f64::min(low, self.previous_close),
119 );
120 self.previous_close = close;
121 } else {
122 self.ma.update_raw(high - low);
123 }
124
125 self._floor_value();
126 self.increment_count();
127 }
128
129 fn _floor_value(&mut self) {
130 if self.value_floor == 0.0 || self.value_floor < self.ma.value() {
131 self.value = self.ma.value();
132 } else {
133 self.value = self.value_floor;
135 }
136 }
137
138 const fn increment_count(&mut self) {
139 self.count += 1;
140
141 if !self.initialized {
142 self.has_inputs = true;
143
144 if self.count >= self.period {
145 self.initialized = true;
146 }
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use rstest::rstest;
154
155 use super::*;
156 use crate::testing::approx_equal;
157
158 #[rstest]
159 fn test_name_returns_expected_string() {
160 let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
161 assert_eq!(atr.name(), "AverageTrueRange");
162 }
163
164 #[rstest]
165 fn test_str_repr_returns_expected_string() {
166 let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), Some(true), Some(0.0));
167 assert_eq!(format!("{atr}"), "AverageTrueRange(10,SIMPLE,true,0)");
168 }
169
170 #[rstest]
171 fn test_period() {
172 let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
173 assert_eq!(atr.period, 10);
174 }
175
176 #[rstest]
177 fn test_initialized_without_inputs_returns_false() {
178 let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
179 assert!(!atr.initialized());
180 }
181
182 #[rstest]
183 fn test_initialized_with_required_inputs_returns_true() {
184 let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
185 for _ in 0..10 {
186 atr.update_raw(1.0, 1.0, 1.0);
187 }
188 assert!(atr.initialized());
189 }
190
191 #[rstest]
192 fn test_value_with_no_inputs_returns_zero() {
193 let atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
194 assert_eq!(atr.value, 0.0);
195 }
196
197 #[rstest]
198 fn test_value_with_epsilon_input() {
199 let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
200 let epsilon = f64::EPSILON;
201 atr.update_raw(epsilon, epsilon, epsilon);
202 assert_eq!(atr.value, 0.0);
203 }
204
205 #[rstest]
206 fn test_value_with_one_ones_input() {
207 let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
208 atr.update_raw(1.0, 1.0, 1.0);
209 assert_eq!(atr.value, 0.0);
210 }
211
212 #[rstest]
213 fn test_value_with_one_input() {
214 let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
215 atr.update_raw(1.00020, 1.0, 1.00010);
216 assert!(approx_equal(atr.value, 0.0002));
217 }
218
219 #[rstest]
220 fn test_value_with_three_inputs() {
221 let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
222 atr.update_raw(1.00020, 1.0, 1.00010);
223 atr.update_raw(1.00020, 1.0, 1.00010);
224 atr.update_raw(1.00020, 1.0, 1.00010);
225 assert!(approx_equal(atr.value, 0.0002));
226 }
227
228 #[rstest]
229 fn test_value_with_close_on_high() {
230 let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
231 let mut high = 1.00010;
232 let mut low = 1.0;
233
234 for _ in 0..1000 {
235 high += 0.00010;
236 low += 0.00010;
237 let close = high;
238 atr.update_raw(high, low, close);
239 }
240 assert!(approx_equal(atr.value, 0.000_099_999_999_999_988_99));
241 }
242
243 #[rstest]
244 fn test_value_with_close_on_low() {
245 let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
246 let mut high = 1.00010;
247 let mut low = 1.0;
248
249 for _ in 0..1000 {
250 high -= 0.00010;
251 low -= 0.00010;
252 let close = low;
253 atr.update_raw(high, low, close);
254 }
255 assert!(approx_equal(atr.value, 0.000_099_999_999_999_988_99));
256 }
257
258 #[rstest]
259 fn test_floor_with_ten_ones_inputs() {
260 let floor = 0.00005;
261 let mut floored_atr =
262 AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, Some(floor));
263
264 for _ in 0..20 {
265 floored_atr.update_raw(1.0, 1.0, 1.0);
266 }
267 assert_eq!(floored_atr.value, 5e-05);
268 }
269
270 #[rstest]
271 fn test_floor_with_exponentially_decreasing_high_inputs() {
272 let floor = 0.00005;
273 let mut floored_atr =
274 AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, Some(floor));
275 let mut high = 1.00020;
276 let low = 1.0;
277 let close = 1.0;
278
279 for _ in 0..20 {
280 high -= (high - low) / 2.0;
281 floored_atr.update_raw(high, low, close);
282 }
283 assert_eq!(floored_atr.value, floor);
284 }
285
286 #[rstest]
287 fn test_reset_successfully_returns_indicator_to_fresh_state() {
288 let mut atr = AverageTrueRange::new(10, Some(MovingAverageType::Simple), None, None);
289 for _ in 0..1000 {
290 atr.update_raw(1.00010, 1.0, 1.00005);
291 }
292 atr.reset();
293 assert!(!atr.initialized);
294 assert_eq!(atr.value, 0.0);
295 }
296}