nautilus_indicators/average/
ema.rs1use std::fmt::Display;
17
18use nautilus_model::{
19 data::{Bar, QuoteTick, TradeTick},
20 enums::PriceType,
21};
22
23use crate::indicator::{Indicator, MovingAverage};
24
25#[repr(C)]
26#[derive(Debug)]
27#[cfg_attr(
28 feature = "python",
29 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators")
30)]
31#[cfg_attr(
32 feature = "python",
33 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
34)]
35pub struct ExponentialMovingAverage {
36 pub period: usize,
37 pub price_type: PriceType,
38 pub alpha: f64,
39 pub value: f64,
40 pub count: usize,
41 pub initialized: bool,
42 has_inputs: bool,
43}
44
45impl Display for ExponentialMovingAverage {
46 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47 write!(f, "{}({})", self.name(), self.period)
48 }
49}
50
51impl Indicator for ExponentialMovingAverage {
52 fn name(&self) -> String {
53 stringify!(ExponentialMovingAverage).to_string()
54 }
55
56 fn has_inputs(&self) -> bool {
57 self.has_inputs
58 }
59
60 fn initialized(&self) -> bool {
61 self.initialized
62 }
63
64 fn handle_quote(&mut self, quote: &QuoteTick) {
65 self.update_raw(quote.extract_price(self.price_type).into());
66 }
67
68 fn handle_trade(&mut self, trade: &TradeTick) {
69 self.update_raw((&trade.price).into());
70 }
71
72 fn handle_bar(&mut self, bar: &Bar) {
73 self.update_raw((&bar.close).into());
74 }
75
76 fn reset(&mut self) {
77 self.value = 0.0;
78 self.count = 0;
79 self.has_inputs = false;
80 self.initialized = false;
81 }
82}
83
84impl ExponentialMovingAverage {
85 #[must_use]
91 pub fn new(period: usize, price_type: Option<PriceType>) -> Self {
92 assert!(
93 period > 0,
94 "ExponentialMovingAverage::new → `period` must be positive (> 0); got {period}"
95 );
96 Self {
97 period,
98 price_type: price_type.unwrap_or(PriceType::Last),
99 alpha: 2.0 / (period as f64 + 1.0),
100 value: 0.0,
101 count: 0,
102 has_inputs: false,
103 initialized: false,
104 }
105 }
106}
107
108impl MovingAverage for ExponentialMovingAverage {
109 fn value(&self) -> f64 {
110 self.value
111 }
112
113 fn count(&self) -> usize {
114 self.count
115 }
116
117 fn update_raw(&mut self, value: f64) {
118 if !self.has_inputs {
119 self.has_inputs = true;
120 self.value = value;
121 self.count = 1;
122
123 if self.period == 1 {
124 self.initialized = true;
125 }
126 return;
127 }
128
129 self.value = self.alpha.mul_add(value, (1.0 - self.alpha) * self.value);
130 self.count += 1;
131
132 if !self.initialized && self.count >= self.period {
134 self.initialized = true;
135 }
136 }
137}
138
139#[cfg(test)]
140mod tests {
141 use nautilus_model::{
142 data::{Bar, QuoteTick, TradeTick},
143 enums::PriceType,
144 };
145 use rstest::rstest;
146
147 use crate::{
148 average::ema::ExponentialMovingAverage,
149 indicator::{Indicator, MovingAverage},
150 stubs::*,
151 };
152
153 #[rstest]
154 fn test_ema_initialized(indicator_ema_10: ExponentialMovingAverage) {
155 let ema = indicator_ema_10;
156 let display_str = format!("{ema}");
157 assert_eq!(display_str, "ExponentialMovingAverage(10)");
158 assert_eq!(ema.period, 10);
159 assert_eq!(ema.price_type, PriceType::Mid);
160 assert_eq!(ema.alpha, 0.181_818_181_818_181_82);
161 assert!(!ema.initialized);
162 }
163
164 #[rstest]
165 fn test_one_value_input(indicator_ema_10: ExponentialMovingAverage) {
166 let mut ema = indicator_ema_10;
167 ema.update_raw(1.0);
168 assert_eq!(ema.count, 1);
169 assert_eq!(ema.value, 1.0);
170 }
171
172 #[rstest]
173 fn test_ema_update_raw(indicator_ema_10: ExponentialMovingAverage) {
174 let mut ema = indicator_ema_10;
175 ema.update_raw(1.0);
176 ema.update_raw(2.0);
177 ema.update_raw(3.0);
178 ema.update_raw(4.0);
179 ema.update_raw(5.0);
180 ema.update_raw(6.0);
181 ema.update_raw(7.0);
182 ema.update_raw(8.0);
183 ema.update_raw(9.0);
184 ema.update_raw(10.0);
185
186 assert!(ema.has_inputs());
187 assert!(ema.initialized());
188 assert_eq!(ema.count, 10);
189 assert_eq!(ema.value, 6.239_368_480_121_215_5);
190 }
191
192 #[rstest]
193 fn test_reset(indicator_ema_10: ExponentialMovingAverage) {
194 let mut ema = indicator_ema_10;
195 ema.update_raw(1.0);
196 assert_eq!(ema.count, 1);
197 ema.reset();
198 assert_eq!(ema.count, 0);
199 assert_eq!(ema.value, 0.0);
200 assert!(!ema.initialized);
201 }
202
203 #[rstest]
204 fn test_handle_quote_tick_single(
205 indicator_ema_10: ExponentialMovingAverage,
206 stub_quote: QuoteTick,
207 ) {
208 let mut ema = indicator_ema_10;
209 ema.handle_quote(&stub_quote);
210 assert!(ema.has_inputs());
211 assert_eq!(ema.value, 1501.0);
212 }
213
214 #[rstest]
215 fn test_handle_quote_tick_multi(mut indicator_ema_10: ExponentialMovingAverage) {
216 let tick1 = stub_quote("1500.0", "1502.0");
217 let tick2 = stub_quote("1502.0", "1504.0");
218
219 indicator_ema_10.handle_quote(&tick1);
220 indicator_ema_10.handle_quote(&tick2);
221 assert_eq!(indicator_ema_10.count, 2);
222 assert_eq!(indicator_ema_10.value, 1_501.363_636_363_636_3);
223 }
224
225 #[rstest]
226 fn test_handle_trade_tick(indicator_ema_10: ExponentialMovingAverage, stub_trade: TradeTick) {
227 let mut ema = indicator_ema_10;
228 ema.handle_trade(&stub_trade);
229 assert!(ema.has_inputs());
230 assert_eq!(ema.value, 1500.0);
231 }
232
233 #[rstest]
234 fn handle_handle_bar(
235 mut indicator_ema_10: ExponentialMovingAverage,
236 bar_ethusdt_binance_minute_bid: Bar,
237 ) {
238 indicator_ema_10.handle_bar(&bar_ethusdt_binance_minute_bid);
239 assert!(indicator_ema_10.has_inputs);
240 assert!(!indicator_ema_10.initialized);
241 assert_eq!(indicator_ema_10.value, 1522.0);
242 }
243
244 #[rstest]
245 fn test_period_one_behaviour() {
246 let mut ema = ExponentialMovingAverage::new(1, None);
247 assert_eq!(ema.alpha, 1.0, "α must be 1 when period = 1");
248
249 ema.update_raw(10.0);
250 assert!(ema.initialized());
251 assert_eq!(ema.value(), 10.0);
252
253 ema.update_raw(42.0);
254 assert_eq!(
255 ema.value(),
256 42.0,
257 "With α = 1, the EMA must track the latest sample exactly"
258 );
259 }
260
261 #[rstest]
262 fn test_default_price_type_is_last() {
263 let ema = ExponentialMovingAverage::new(3, None);
264 assert_eq!(
265 ema.price_type,
266 PriceType::Last,
267 "`price_type` default mismatch"
268 );
269 }
270
271 #[rstest]
272 fn test_nan_poisoning_and_reset_recovery() {
273 let mut ema = ExponentialMovingAverage::new(4, None);
274 for x in 0..3 {
275 ema.update_raw(f64::from(x));
276 assert!(ema.value().is_finite());
277 }
278
279 ema.update_raw(f64::NAN);
280 assert!(ema.value().is_nan());
281
282 ema.update_raw(123.456);
283 assert!(ema.value().is_nan());
284
285 ema.reset();
286 assert!(!ema.has_inputs());
287 ema.update_raw(7.0);
288 assert_eq!(ema.value(), 7.0);
289 assert!(ema.value().is_finite());
290 }
291
292 #[rstest]
293 fn test_reset_without_inputs_is_safe() {
294 let mut ema = ExponentialMovingAverage::new(8, None);
295 ema.reset();
296 assert!(!ema.has_inputs());
297 assert_eq!(ema.count(), 0);
298 assert!(!ema.initialized());
299 }
300
301 #[rstest]
302 fn test_has_inputs_lifecycle() {
303 let mut ema = ExponentialMovingAverage::new(5, None);
304 assert!(!ema.has_inputs());
305
306 ema.update_raw(1.23);
307 assert!(ema.has_inputs());
308
309 ema.reset();
310 assert!(!ema.has_inputs());
311 }
312
313 #[rstest]
314 fn test_subnormal_inputs_do_not_underflow() {
315 let mut ema = ExponentialMovingAverage::new(2, None);
316 let tiny = f64::MIN_POSITIVE / 2.0;
317 ema.update_raw(tiny);
318 ema.update_raw(tiny);
319 assert!(
320 ema.value() > 0.0,
321 "Underflow: EMA value collapsed to zero for sub-normal inputs"
322 );
323 }
324}