nautilus_indicators/momentum/
rsi.rs1use std::fmt::{Debug, Display};
17
18use nautilus_model::{
19 data::{Bar, QuoteTick, TradeTick},
20 enums::PriceType,
21};
22
23use crate::{
24 average::{MovingAverageFactory, MovingAverageType},
25 indicator::{Indicator, MovingAverage},
26};
27
28#[repr(C)]
30#[derive(Debug)]
31#[cfg_attr(
32 feature = "python",
33 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.indicators", unsendable)
34)]
35#[cfg_attr(
36 feature = "python",
37 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.indicators")
38)]
39pub struct RelativeStrengthIndex {
40 pub period: usize,
41 pub ma_type: MovingAverageType,
42 pub value: f64,
43 pub count: usize,
44 pub initialized: bool,
45 has_inputs: bool,
46 last_value: f64,
47 average_gain: Box<dyn MovingAverage + Send + 'static>,
48 average_loss: Box<dyn MovingAverage + Send + 'static>,
49 rsi_max: f64,
50}
51
52impl Display for RelativeStrengthIndex {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 write!(f, "{}({},{})", self.name(), self.period, self.ma_type)
55 }
56}
57
58impl Indicator for RelativeStrengthIndex {
59 fn name(&self) -> String {
60 stringify!(RelativeStrengthIndex).to_string()
61 }
62
63 fn has_inputs(&self) -> bool {
64 self.has_inputs
65 }
66
67 fn initialized(&self) -> bool {
68 self.initialized
69 }
70
71 fn handle_quote(&mut self, quote: &QuoteTick) {
72 self.update_raw(quote.extract_price(PriceType::Mid).into());
73 }
74
75 fn handle_trade(&mut self, trade: &TradeTick) {
76 self.update_raw((trade.price).into());
77 }
78
79 fn handle_bar(&mut self, bar: &Bar) {
80 self.update_raw((&bar.close).into());
81 }
82
83 fn reset(&mut self) {
84 self.value = 0.0;
85 self.last_value = 0.0;
86 self.count = 0;
87 self.has_inputs = false;
88 self.initialized = false;
89 self.average_gain.reset();
90 self.average_loss.reset();
91 }
92}
93
94impl RelativeStrengthIndex {
95 #[must_use]
97 pub fn new(period: usize, ma_type: Option<MovingAverageType>) -> Self {
98 Self {
99 period,
100 ma_type: ma_type.unwrap_or(MovingAverageType::Exponential),
101 value: 0.0,
102 last_value: 0.0,
103 count: 0,
104 has_inputs: false,
105 average_gain: MovingAverageFactory::create(MovingAverageType::Exponential, period),
106 average_loss: MovingAverageFactory::create(MovingAverageType::Exponential, period),
107 rsi_max: 1.0,
108 initialized: false,
109 }
110 }
111
112 pub fn update_raw(&mut self, value: f64) {
113 if !self.has_inputs {
114 self.last_value = value;
115 self.has_inputs = true;
116 }
117 let gain = value - self.last_value;
118 if gain > 0.0 {
119 self.average_gain.update_raw(gain);
120 self.average_loss.update_raw(0.0);
121 } else if gain < 0.0 {
122 self.average_loss.update_raw(-gain);
123 self.average_gain.update_raw(0.0);
124 } else {
125 self.average_loss.update_raw(0.0);
126 self.average_gain.update_raw(0.0);
127 }
128 self.count = self.average_gain.count();
129 if !self.initialized && self.average_loss.initialized() && self.average_gain.initialized() {
130 self.initialized = true;
131 }
132
133 if self.average_loss.value() == 0.0 {
134 self.value = self.rsi_max;
135 return;
136 }
137
138 let rs = self.average_gain.value() / self.average_loss.value();
139 self.value = self.rsi_max - (self.rsi_max / (1.0 + rs));
140 self.last_value = value;
141
142 if !self.initialized && self.count >= self.period {
143 self.initialized = true;
144 }
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use nautilus_model::data::{Bar, QuoteTick, TradeTick};
151 use rstest::rstest;
152
153 use crate::{indicator::Indicator, momentum::rsi::RelativeStrengthIndex, stubs::*};
154
155 #[rstest]
156 fn test_rsi_initialized(rsi_10: RelativeStrengthIndex) {
157 let display_str = format!("{rsi_10}");
158 assert_eq!(display_str, "RelativeStrengthIndex(10,EXPONENTIAL)");
159 assert_eq!(rsi_10.period, 10);
160 assert!(!rsi_10.initialized);
161 }
162
163 #[rstest]
164 fn test_initialized_with_required_inputs_returns_true(mut rsi_10: RelativeStrengthIndex) {
165 for i in 0..12 {
166 rsi_10.update_raw(f64::from(i));
167 }
168 assert!(rsi_10.initialized);
169 }
170
171 #[rstest]
172 fn test_value_with_one_input_returns_expected_value(mut rsi_10: RelativeStrengthIndex) {
173 rsi_10.update_raw(1.0);
174 assert_eq!(rsi_10.value, 1.0);
175 }
176
177 #[rstest]
178 fn test_value_all_higher_inputs_returns_expected_value(mut rsi_10: RelativeStrengthIndex) {
179 for i in 1..4 {
180 rsi_10.update_raw(f64::from(i));
181 }
182 assert_eq!(rsi_10.value, 1.0);
183 }
184
185 #[rstest]
186 fn test_value_with_all_lower_inputs_returns_expected_value(mut rsi_10: RelativeStrengthIndex) {
187 for i in (1..4).rev() {
188 rsi_10.update_raw(f64::from(i));
189 }
190 assert_eq!(rsi_10.value, 0.0);
191 }
192
193 #[rstest]
194 fn test_value_with_various_input_returns_expected_value(mut rsi_10: RelativeStrengthIndex) {
195 rsi_10.update_raw(3.0);
196 rsi_10.update_raw(2.0);
197 rsi_10.update_raw(5.0);
198 rsi_10.update_raw(6.0);
199 rsi_10.update_raw(7.0);
200 rsi_10.update_raw(6.0);
201
202 assert_eq!(rsi_10.value, 0.683_736_332_582_526_5);
203 }
204
205 #[rstest]
206 fn test_value_at_returns_expected_value(mut rsi_10: RelativeStrengthIndex) {
207 rsi_10.update_raw(3.0);
208 rsi_10.update_raw(2.0);
209 rsi_10.update_raw(5.0);
210 rsi_10.update_raw(6.0);
211 rsi_10.update_raw(7.0);
212 rsi_10.update_raw(6.0);
213 rsi_10.update_raw(6.0);
214 rsi_10.update_raw(7.0);
215
216 assert_eq!(rsi_10.value, 0.761_534_466_766_272_5);
217 }
218
219 #[rstest]
220 fn test_reset(mut rsi_10: RelativeStrengthIndex) {
221 rsi_10.update_raw(1.0);
222 rsi_10.update_raw(2.0);
223 rsi_10.reset();
224 assert!(!rsi_10.initialized());
225 assert_eq!(rsi_10.count, 0);
226 }
227
228 #[rstest]
229 fn test_reset_resets_inner_mas(mut rsi_10: RelativeStrengthIndex) {
230 rsi_10.update_raw(1.0);
231 rsi_10.update_raw(2.0);
232 rsi_10.reset();
233 assert_eq!(rsi_10.average_gain.count(), 0);
234 assert_eq!(rsi_10.average_loss.count(), 0);
235 }
236
237 #[rstest]
238 fn test_handle_quote_tick(mut rsi_10: RelativeStrengthIndex, stub_quote: QuoteTick) {
239 rsi_10.handle_quote(&stub_quote);
240 assert_eq!(rsi_10.count, 1);
241 assert_eq!(rsi_10.value, 1.0);
242 }
243
244 #[rstest]
245 fn test_handle_trade_tick(mut rsi_10: RelativeStrengthIndex, stub_trade: TradeTick) {
246 rsi_10.handle_trade(&stub_trade);
247 assert_eq!(rsi_10.count, 1);
248 assert_eq!(rsi_10.value, 1.0);
249 }
250
251 #[rstest]
252 fn test_handle_bar(mut rsi_10: RelativeStrengthIndex, bar_ethusdt_binance_minute_bid: Bar) {
253 rsi_10.handle_bar(&bar_ethusdt_binance_minute_bid);
254 assert_eq!(rsi_10.count, 1);
255 assert_eq!(rsi_10.value, 1.0);
256 }
257
258 #[rstest]
259 fn test_constant_inputs_initializes_and_value_max(mut rsi_10: RelativeStrengthIndex) {
260 for _ in 0..12 {
261 rsi_10.update_raw(5.0);
262 }
263 assert!(rsi_10.initialized);
264 assert_eq!(rsi_10.value, 1.0);
265 }
266
267 #[rstest]
268 fn test_reset_resets_has_inputs_and_value(mut rsi_10: RelativeStrengthIndex) {
269 rsi_10.update_raw(1.0);
270 rsi_10.reset();
271 assert!(!rsi_10.has_inputs());
272 assert_eq!(rsi_10.value, 0.0);
273 }
274}