nautilus_analysis/statistics/
max_drawdown.rs1use std::collections::BTreeMap;
19
20use nautilus_core::UnixNanos;
21
22use crate::statistic::PortfolioStatistic;
23
24#[repr(C)]
32#[derive(Debug, Clone, Default)]
33#[cfg_attr(
34 feature = "python",
35 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.analysis", from_py_object)
36)]
37#[cfg_attr(
38 feature = "python",
39 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.analysis")
40)]
41pub struct MaxDrawdown {}
42
43impl MaxDrawdown {
44 #[must_use]
46 pub fn new() -> Self {
47 Self {}
48 }
49}
50
51impl PortfolioStatistic for MaxDrawdown {
52 type Item = f64;
53
54 fn name(&self) -> String {
55 "Max Drawdown".to_string()
56 }
57
58 fn calculate_from_returns(&self, returns: &BTreeMap<UnixNanos, f64>) -> Option<Self::Item> {
59 if returns.is_empty() {
60 return Some(0.0);
61 }
62
63 let mut cumulative = 1.0;
65 let mut running_max = 1.0;
66 let mut max_drawdown = 0.0;
67
68 for &ret in returns.values() {
69 cumulative *= 1.0 + ret;
70
71 if cumulative > running_max {
73 running_max = cumulative;
74 }
75
76 let drawdown = (running_max - cumulative) / running_max;
78
79 if drawdown > max_drawdown {
81 max_drawdown = drawdown;
82 }
83 }
84
85 Some(-max_drawdown)
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use rstest::rstest;
93
94 use super::*;
95
96 fn create_returns(values: &[f64]) -> BTreeMap<UnixNanos, f64> {
97 values
98 .iter()
99 .copied()
100 .enumerate()
101 .map(|(i, v)| (UnixNanos::from(i as u64), v))
102 .collect()
103 }
104
105 #[rstest]
106 fn test_name() {
107 let stat = MaxDrawdown::new();
108 assert_eq!(stat.name(), "Max Drawdown");
109 }
110
111 #[rstest]
112 fn test_empty_returns() {
113 let stat = MaxDrawdown::new();
114 let returns = BTreeMap::new();
115 let result = stat.calculate_from_returns(&returns);
116 assert_eq!(result, Some(0.0));
117 }
118
119 #[rstest]
120 fn test_no_drawdown() {
121 let stat = MaxDrawdown::new();
122 let returns = create_returns(&[0.01, 0.02, 0.01, 0.015]);
124 let result = stat.calculate_from_returns(&returns).unwrap();
125 assert_eq!(result, 0.0);
126 }
127
128 #[rstest]
129 fn test_simple_drawdown() {
130 let stat = MaxDrawdown::new();
131 let returns = create_returns(&[0.10, -0.10]);
134 let result = stat.calculate_from_returns(&returns).unwrap();
135
136 assert!((result + 0.10).abs() < 0.01);
138 }
139
140 #[rstest]
141 fn test_multiple_drawdowns() {
142 let stat = MaxDrawdown::new();
143 let returns = create_returns(&[0.10, -0.10, 0.50, -0.20, 0.10]);
147 let result = stat.calculate_from_returns(&returns).unwrap();
148
149 assert!((result + 0.20).abs() < 0.01);
151 }
152
153 #[rstest]
154 fn test_initial_loss() {
155 let stat = MaxDrawdown::new();
156 let returns = create_returns(&[-0.40, -0.10]);
158 let result = stat.calculate_from_returns(&returns).unwrap();
159
160 assert!((result + 0.46).abs() < 0.01);
163 }
164}