Skip to main content

nautilus_analysis/python/
analyzer.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{collections::HashMap, sync::Arc};
17
18use nautilus_core::{UnixNanos, python::to_pyvalue_err};
19use nautilus_model::{
20    identifiers::PositionId,
21    position::Position,
22    types::{Currency, Money},
23};
24use pyo3::prelude::*;
25
26use crate::{
27    analyzer::{PortfolioAnalyzer, Statistic},
28    statistics::{
29        expectancy::Expectancy, long_ratio::LongRatio, loser_avg::AvgLoser, loser_max::MaxLoser,
30        loser_min::MinLoser, profit_factor::ProfitFactor, returns_avg::ReturnsAverage,
31        returns_avg_loss::ReturnsAverageLoss, returns_avg_win::ReturnsAverageWin,
32        returns_volatility::ReturnsVolatility, risk_return_ratio::RiskReturnRatio,
33        sharpe_ratio::SharpeRatio, sortino_ratio::SortinoRatio, win_rate::WinRate,
34        winner_avg::AvgWinner, winner_max::MaxWinner, winner_min::MinWinner,
35    },
36};
37
38#[pymethods]
39#[pyo3_stub_gen::derive::gen_stub_pymethods]
40impl PortfolioAnalyzer {
41    /// Analyzes portfolio performance and calculates various statistics.
42    ///
43    /// The `PortfolioAnalyzer` tracks account balances, positions, and realized PnLs
44    /// to provide portfolio analysis including returns, PnL calculations,
45    /// and customizable statistics.
46    #[new]
47    #[must_use]
48    pub fn py_new() -> Self {
49        Self::new()
50    }
51
52    fn __repr__(&self) -> String {
53        format!("PortfolioAnalyzer(currencies={})", self.currencies().len())
54    }
55
56    /// Returns all tracked currencies.
57    #[pyo3(name = "currencies")]
58    fn py_currencies(&self) -> Vec<Currency> {
59        self.currencies().into_iter().copied().collect()
60    }
61
62    /// Calculates total PnL including unrealized PnL if provided.
63    #[pyo3(name = "get_performance_stats_returns")]
64    fn py_get_performance_stats_returns(&self) -> HashMap<String, f64> {
65        self.get_performance_stats_returns().into_iter().collect()
66    }
67
68    /// Gets all position-return-based performance statistics.
69    #[pyo3(name = "get_performance_stats_position_returns")]
70    fn py_get_performance_stats_position_returns(&self) -> HashMap<String, f64> {
71        self.get_performance_stats_position_returns()
72            .into_iter()
73            .collect()
74    }
75
76    /// Gets all portfolio-return-based performance statistics.
77    #[pyo3(name = "get_performance_stats_portfolio_returns")]
78    fn py_get_performance_stats_portfolio_returns(&self) -> HashMap<String, f64> {
79        self.get_performance_stats_portfolio_returns()
80            .into_iter()
81            .collect()
82    }
83
84    #[pyo3(name = "get_performance_stats_pnls")]
85    fn py_get_performance_stats_pnls(
86        &self,
87        currency: Option<&Currency>,
88        unrealized_pnl: Option<&Money>,
89    ) -> PyResult<HashMap<String, f64>> {
90        self.get_performance_stats_pnls(currency, unrealized_pnl)
91            .map(|m| m.into_iter().collect())
92            .map_err(to_pyvalue_err)
93    }
94
95    /// Gets general portfolio statistics.
96    #[pyo3(name = "get_performance_stats_general")]
97    fn py_get_performance_stats_general(&self) -> HashMap<String, f64> {
98        self.get_performance_stats_general().into_iter().collect()
99    }
100
101    /// Records a position return at a specific timestamp.
102    #[pyo3(name = "add_position_return")]
103    fn py_add_position_return(&mut self, timestamp: u64, value: f64) {
104        self.add_position_return(UnixNanos::from(timestamp), value);
105    }
106
107    /// Records a return at a specific timestamp.
108    ///
109    /// This is a backward-compatible alias for `Self.add_position_return`.
110    #[pyo3(name = "add_return")]
111    fn py_add_return(&mut self, timestamp: u64, value: f64) {
112        self.add_return(UnixNanos::from(timestamp), value);
113    }
114
115    /// Resets all analysis data to initial state.
116    #[pyo3(name = "reset")]
117    fn py_reset(&mut self) {
118        self.reset();
119    }
120
121    /// Registers a new portfolio statistic for calculation.
122    #[pyo3(name = "register_statistic")]
123    #[expect(clippy::needless_pass_by_value)]
124    fn py_register_statistic(&mut self, py: Python, statistic: Py<PyAny>) -> PyResult<()> {
125        let type_name = statistic
126            .getattr(py, "__class__")?
127            .getattr(py, "__name__")?
128            .extract::<String>(py)?;
129
130        match type_name.as_str() {
131            "MaxWinner" => {
132                let stat = statistic.extract::<MaxWinner>(py)?;
133                self.register_statistic(Arc::new(stat));
134            }
135            "MinWinner" => {
136                let stat = statistic.extract::<MinWinner>(py)?;
137                self.register_statistic(Arc::new(stat));
138            }
139            "AvgWinner" => {
140                let stat = statistic.extract::<AvgWinner>(py)?;
141                self.register_statistic(Arc::new(stat));
142            }
143            "MaxLoser" => {
144                let stat = statistic.extract::<MaxLoser>(py)?;
145                self.register_statistic(Arc::new(stat));
146            }
147            "MinLoser" => {
148                let stat = statistic.extract::<MinLoser>(py)?;
149                self.register_statistic(Arc::new(stat));
150            }
151            "AvgLoser" => {
152                let stat = statistic.extract::<AvgLoser>(py)?;
153                self.register_statistic(Arc::new(stat));
154            }
155            "Expectancy" => {
156                let stat = statistic.extract::<Expectancy>(py)?;
157                self.register_statistic(Arc::new(stat));
158            }
159            "WinRate" => {
160                let stat = statistic.extract::<WinRate>(py)?;
161                self.register_statistic(Arc::new(stat));
162            }
163            "ReturnsVolatility" => {
164                let stat = statistic.extract::<ReturnsVolatility>(py)?;
165                self.register_statistic(Arc::new(stat));
166            }
167            "ReturnsAverage" => {
168                let stat = statistic.extract::<ReturnsAverage>(py)?;
169                self.register_statistic(Arc::new(stat));
170            }
171            "ReturnsAverageLoss" => {
172                let stat = statistic.extract::<ReturnsAverageLoss>(py)?;
173                self.register_statistic(Arc::new(stat));
174            }
175            "ReturnsAverageWin" => {
176                let stat = statistic.extract::<ReturnsAverageWin>(py)?;
177                self.register_statistic(Arc::new(stat));
178            }
179            "SharpeRatio" => {
180                let stat = statistic.extract::<SharpeRatio>(py)?;
181                self.register_statistic(Arc::new(stat));
182            }
183            "SortinoRatio" => {
184                let stat = statistic.extract::<SortinoRatio>(py)?;
185                self.register_statistic(Arc::new(stat));
186            }
187            "ProfitFactor" => {
188                let stat = statistic.extract::<ProfitFactor>(py)?;
189                self.register_statistic(Arc::new(stat));
190            }
191            "RiskReturnRatio" => {
192                let stat = statistic.extract::<RiskReturnRatio>(py)?;
193                self.register_statistic(Arc::new(stat));
194            }
195            "LongRatio" => {
196                let stat = statistic.extract::<LongRatio>(py)?;
197                self.register_statistic(Arc::new(stat));
198            }
199            _ => {
200                return Err(to_pyvalue_err(format!(
201                    "Unknown statistic type: {type_name}"
202                )));
203            }
204        }
205
206        Ok(())
207    }
208
209    /// Removes a specific statistic from calculation.
210    #[pyo3(name = "deregister_statistic")]
211    #[expect(clippy::needless_pass_by_value)]
212    fn py_deregister_statistic(&mut self, py: Python, statistic: Py<PyAny>) -> PyResult<()> {
213        let type_name = statistic
214            .getattr(py, "__class__")?
215            .getattr(py, "__name__")?
216            .extract::<String>(py)?;
217
218        match type_name.as_str() {
219            "MaxWinner" => {
220                let stat = statistic.extract::<MaxWinner>(py)?;
221                self.deregister_statistic(&(Arc::new(stat) as Statistic));
222            }
223            "MinWinner" => {
224                let stat = statistic.extract::<MinWinner>(py)?;
225                self.deregister_statistic(&(Arc::new(stat) as Statistic));
226            }
227            "AvgWinner" => {
228                let stat = statistic.extract::<AvgWinner>(py)?;
229                self.deregister_statistic(&(Arc::new(stat) as Statistic));
230            }
231            "MaxLoser" => {
232                let stat = statistic.extract::<MaxLoser>(py)?;
233                self.deregister_statistic(&(Arc::new(stat) as Statistic));
234            }
235            "MinLoser" => {
236                let stat = statistic.extract::<MinLoser>(py)?;
237                self.deregister_statistic(&(Arc::new(stat) as Statistic));
238            }
239            "AvgLoser" => {
240                let stat = statistic.extract::<AvgLoser>(py)?;
241                self.deregister_statistic(&(Arc::new(stat) as Statistic));
242            }
243            "Expectancy" => {
244                let stat = statistic.extract::<Expectancy>(py)?;
245                self.deregister_statistic(&(Arc::new(stat) as Statistic));
246            }
247            "WinRate" => {
248                let stat = statistic.extract::<WinRate>(py)?;
249                self.deregister_statistic(&(Arc::new(stat) as Statistic));
250            }
251            "ReturnsVolatility" => {
252                let stat = statistic.extract::<ReturnsVolatility>(py)?;
253                self.deregister_statistic(&(Arc::new(stat) as Statistic));
254            }
255            "ReturnsAverage" => {
256                let stat = statistic.extract::<ReturnsAverage>(py)?;
257                self.deregister_statistic(&(Arc::new(stat) as Statistic));
258            }
259            "ReturnsAverageLoss" => {
260                let stat = statistic.extract::<ReturnsAverageLoss>(py)?;
261                self.deregister_statistic(&(Arc::new(stat) as Statistic));
262            }
263            "ReturnsAverageWin" => {
264                let stat = statistic.extract::<ReturnsAverageWin>(py)?;
265                self.deregister_statistic(&(Arc::new(stat) as Statistic));
266            }
267            "SharpeRatio" => {
268                let stat = statistic.extract::<SharpeRatio>(py)?;
269                self.deregister_statistic(&(Arc::new(stat) as Statistic));
270            }
271            "SortinoRatio" => {
272                let stat = statistic.extract::<SortinoRatio>(py)?;
273                self.deregister_statistic(&(Arc::new(stat) as Statistic));
274            }
275            "ProfitFactor" => {
276                let stat = statistic.extract::<ProfitFactor>(py)?;
277                self.deregister_statistic(&(Arc::new(stat) as Statistic));
278            }
279            "RiskReturnRatio" => {
280                let stat = statistic.extract::<RiskReturnRatio>(py)?;
281                self.deregister_statistic(&(Arc::new(stat) as Statistic));
282            }
283            "LongRatio" => {
284                let stat = statistic.extract::<LongRatio>(py)?;
285                self.deregister_statistic(&(Arc::new(stat) as Statistic));
286            }
287            _ => {
288                return Err(to_pyvalue_err(format!(
289                    "Unknown statistic type: {type_name}"
290                )));
291            }
292        }
293
294        Ok(())
295    }
296
297    /// Removes all registered statistics.
298    #[pyo3(name = "deregister_statistics")]
299    fn py_deregister_statistics(&mut self) {
300        self.deregister_statistics();
301    }
302
303    /// Adds new positions for analysis.
304    #[pyo3(name = "add_positions")]
305    #[expect(clippy::needless_pass_by_value)]
306    fn py_add_positions(&mut self, py: Python, positions: Vec<Py<PyAny>>) -> PyResult<()> {
307        // Extract Position objects from Cython wrappers
308        let positions: Vec<Position> = positions
309            .iter()
310            .map(|p| {
311                // Try to get the underlying Rust Position
312                // For now, we'll need to handle Cython Position by accessing its _mem field
313                p.getattr(py, "_mem")?
314                    .extract::<Position>(py)
315                    .map_err(Into::into)
316            })
317            .collect::<PyResult<Vec<Position>>>()?;
318
319        self.add_positions(&positions);
320        Ok(())
321    }
322
323    /// Records a trade's PnL.
324    #[pyo3(name = "add_trade")]
325    fn py_add_trade(&mut self, position_id: &PositionId, realized_pnl: &Money) {
326        self.add_trade(position_id, realized_pnl);
327    }
328
329    // Note: calculate_statistics is not exposed to Python because it requires
330    // complex conversions of Account and dict types. Use the Python analyzer.py wrapper instead.
331
332    /// Retrieves a specific statistic by name.
333    #[pyo3(name = "statistic")]
334    fn py_statistic(&self, name: &str) -> Option<String> {
335        self.statistic(name).map(|s| s.name())
336    }
337
338    /// Returns the primary calculated returns.
339    ///
340    /// This returns portfolio returns when available, otherwise it falls back
341    /// to position returns for backward compatibility.
342    #[pyo3(name = "returns")]
343    fn py_returns(&self, py: Python) -> PyResult<Py<PyAny>> {
344        // Convert BTreeMap<UnixNanos, f64> to Python dict
345        let dict = pyo3::types::PyDict::new(py);
346        for (timestamp, value) in self.returns() {
347            dict.set_item(timestamp.as_u64(), value)?;
348        }
349        Ok(dict.into())
350    }
351
352    /// Returns the per-position calculated returns.
353    #[pyo3(name = "position_returns")]
354    fn py_position_returns(&self, py: Python) -> PyResult<Py<PyAny>> {
355        let dict = pyo3::types::PyDict::new(py);
356        for (timestamp, value) in self.position_returns() {
357            dict.set_item(timestamp.as_u64(), value)?;
358        }
359        Ok(dict.into())
360    }
361
362    /// Returns the portfolio calculated returns.
363    #[pyo3(name = "portfolio_returns")]
364    fn py_portfolio_returns(&self, py: Python) -> PyResult<Py<PyAny>> {
365        let dict = pyo3::types::PyDict::new(py);
366        for (timestamp, value) in self.portfolio_returns() {
367            dict.set_item(timestamp.as_u64(), value)?;
368        }
369        Ok(dict.into())
370    }
371
372    /// Retrieves realized PnLs for a specific currency.
373    ///
374    /// Returns `None` if no PnLs exist, or if multiple currencies exist
375    /// without an explicit currency specified.
376    #[pyo3(name = "realized_pnls")]
377    fn py_realized_pnls(&self, py: Python, currency: Option<&Currency>) -> PyResult<Py<PyAny>> {
378        match self.realized_pnls(currency) {
379            Some(pnls) => {
380                // Convert Vec<(PositionId, f64)> to Python list of tuples or dict
381                let dict = pyo3::types::PyDict::new(py);
382                for (position_id, pnl) in pnls {
383                    dict.set_item(position_id.to_string(), pnl)?;
384                }
385                Ok(dict.into())
386            }
387            None => Ok(py.None()),
388        }
389    }
390
391    #[pyo3(name = "total_pnl")]
392    fn py_total_pnl(
393        &self,
394        currency: Option<&Currency>,
395        unrealized_pnl: Option<&Money>,
396    ) -> PyResult<f64> {
397        self.total_pnl(currency, unrealized_pnl)
398            .map_err(to_pyvalue_err)
399    }
400
401    #[pyo3(name = "total_pnl_percentage")]
402    fn py_total_pnl_percentage(
403        &self,
404        currency: Option<&Currency>,
405        unrealized_pnl: Option<&Money>,
406    ) -> PyResult<f64> {
407        self.total_pnl_percentage(currency, unrealized_pnl)
408            .map_err(to_pyvalue_err)
409    }
410
411    /// Gets formatted PnL statistics as strings.
412    #[pyo3(name = "get_stats_pnls_formatted")]
413    fn py_get_stats_pnls_formatted(
414        &self,
415        currency: Option<&Currency>,
416        unrealized_pnl: Option<&Money>,
417    ) -> PyResult<Vec<String>> {
418        self.get_stats_pnls_formatted(currency, unrealized_pnl)
419            .map_err(to_pyvalue_err)
420    }
421
422    /// Gets formatted return statistics as strings.
423    #[pyo3(name = "get_stats_returns_formatted")]
424    fn py_get_stats_returns_formatted(&self) -> Vec<String> {
425        self.get_stats_returns_formatted()
426    }
427
428    /// Gets formatted position-return statistics as strings.
429    #[pyo3(name = "get_stats_position_returns_formatted")]
430    fn py_get_stats_position_returns_formatted(&self) -> Vec<String> {
431        self.get_stats_position_returns_formatted()
432    }
433
434    /// Gets formatted portfolio-return statistics as strings.
435    #[pyo3(name = "get_stats_portfolio_returns_formatted")]
436    fn py_get_stats_portfolio_returns_formatted(&self) -> Vec<String> {
437        self.get_stats_portfolio_returns_formatted()
438    }
439
440    /// Gets formatted general statistics as strings.
441    #[pyo3(name = "get_stats_general_formatted")]
442    fn py_get_stats_general_formatted(&self) -> Vec<String> {
443        self.get_stats_general_formatted()
444    }
445}