Skip to main content

nautilus_risk/python/
config.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16//! Python bindings for risk engine configuration.
17
18use std::{collections::HashMap, str::FromStr};
19
20use ahash::AHashMap;
21use nautilus_common::throttler::RateLimit;
22use nautilus_core::{datetime::NANOSECONDS_IN_SECOND, python::to_pyvalue_err};
23use nautilus_model::identifiers::InstrumentId;
24use pyo3::{Py, PyAny, PyResult, Python, prelude::PyAnyMethods, pymethods};
25use rust_decimal::Decimal;
26
27use crate::engine::config::RiskEngineConfig;
28
29fn format_rate_limit(rate: &RateLimit) -> String {
30    let total_secs = rate.interval_ns / NANOSECONDS_IN_SECOND;
31    let hours = total_secs / 3_600;
32    let minutes = (total_secs % 3_600) / 60;
33    let seconds = total_secs % 60;
34    format!("{}/{:02}:{:02}:{:02}", rate.limit, hours, minutes, seconds)
35}
36
37fn parse_rate_limit(name: &str, value: &str) -> PyResult<RateLimit> {
38    let (limit, interval) = value
39        .split_once('/')
40        .ok_or_else(|| to_pyvalue_err(format!("invalid `{name}`: expected 'limit/HH:MM:SS'")))?;
41
42    let limit = limit
43        .parse::<usize>()
44        .map_err(|e| to_pyvalue_err(format!("invalid `{name}` limit: {e}")))?;
45
46    if limit == 0 {
47        return Err(to_pyvalue_err(format!(
48            "invalid `{name}`: limit must be greater than zero"
49        )));
50    }
51
52    let mut total_secs: u64 = 0;
53    let mut parts = interval.split(':');
54    for label in ["hours", "minutes", "seconds"] {
55        let component = parts
56            .next()
57            .ok_or_else(|| {
58                to_pyvalue_err(format!(
59                    "invalid `{name}`: expected 'limit/HH:MM:SS' interval"
60                ))
61            })?
62            .parse::<u64>()
63            .map_err(|e| to_pyvalue_err(format!("invalid `{name}` {label}: {e}")))?;
64
65        let multiplier: u64 = match label {
66            "hours" => 3_600,
67            "minutes" => 60,
68            "seconds" => 1,
69            _ => unreachable!(),
70        };
71        total_secs = total_secs.saturating_add(component.saturating_mul(multiplier));
72    }
73
74    if parts.next().is_some() {
75        return Err(to_pyvalue_err(format!(
76            "invalid `{name}`: expected 'limit/HH:MM:SS'"
77        )));
78    }
79
80    if total_secs == 0 {
81        return Err(to_pyvalue_err(format!(
82            "invalid `{name}`: interval must be greater than zero"
83        )));
84    }
85
86    Ok(RateLimit::new(
87        limit,
88        total_secs.saturating_mul(NANOSECONDS_IN_SECOND),
89    ))
90}
91
92fn coerce_max_notional_per_order(
93    raw: HashMap<String, Py<PyAny>>,
94) -> PyResult<AHashMap<InstrumentId, Decimal>> {
95    Python::attach(|py| -> PyResult<AHashMap<InstrumentId, Decimal>> {
96        let mut result = AHashMap::with_capacity(raw.len());
97        for (instrument_id, value) in raw {
98            let parsed_id = InstrumentId::from_str(&instrument_id).map_err(|e| {
99                to_pyvalue_err(format!(
100                    "invalid `max_notional_per_order` instrument ID {instrument_id:?}: {e}"
101                ))
102            })?;
103            let value_str: String = value.bind(py).str()?.extract()?;
104            let notional = Decimal::from_str(&value_str).map_err(|e| {
105                to_pyvalue_err(format!(
106                    "invalid `max_notional_per_order` notional {value_str:?}: {e}"
107                ))
108            })?;
109            result.insert(parsed_id, notional);
110        }
111        Ok(result)
112    })
113}
114
115#[pymethods]
116#[pyo3_stub_gen::derive::gen_stub_pymethods]
117impl RiskEngineConfig {
118    /// Configuration for `RiskEngine` instances.
119    #[new]
120    #[pyo3(signature = (
121        bypass = None,
122        max_order_submit_rate = None,
123        max_order_modify_rate = None,
124        max_notional_per_order = None,
125        debug = None,
126    ))]
127    fn py_new(
128        bypass: Option<bool>,
129        max_order_submit_rate: Option<String>,
130        max_order_modify_rate: Option<String>,
131        max_notional_per_order: Option<HashMap<String, Py<PyAny>>>,
132        debug: Option<bool>,
133    ) -> PyResult<Self> {
134        let default = Self::default();
135
136        let max_order_submit = match max_order_submit_rate {
137            Some(value) => parse_rate_limit("max_order_submit_rate", &value)?,
138            None => default.max_order_submit,
139        };
140        let max_order_modify = match max_order_modify_rate {
141            Some(value) => parse_rate_limit("max_order_modify_rate", &value)?,
142            None => default.max_order_modify,
143        };
144        let max_notional_per_order = match max_notional_per_order {
145            Some(raw) => coerce_max_notional_per_order(raw)?,
146            None => default.max_notional_per_order,
147        };
148
149        Ok(Self {
150            bypass: bypass.unwrap_or(default.bypass),
151            max_order_submit,
152            max_order_modify,
153            max_notional_per_order,
154            debug: debug.unwrap_or(default.debug),
155        })
156    }
157
158    #[getter]
159    #[pyo3(name = "bypass")]
160    const fn py_bypass(&self) -> bool {
161        self.bypass
162    }
163
164    #[getter]
165    #[pyo3(name = "max_order_submit_rate")]
166    fn py_max_order_submit_rate(&self) -> String {
167        format_rate_limit(&self.max_order_submit)
168    }
169
170    #[getter]
171    #[pyo3(name = "max_order_modify_rate")]
172    fn py_max_order_modify_rate(&self) -> String {
173        format_rate_limit(&self.max_order_modify)
174    }
175
176    #[getter]
177    #[pyo3(name = "max_notional_per_order")]
178    fn py_max_notional_per_order(&self) -> HashMap<String, String> {
179        self.max_notional_per_order
180            .iter()
181            .map(|(id, notional)| (id.to_string(), notional.to_string()))
182            .collect()
183    }
184
185    #[getter]
186    #[pyo3(name = "debug")]
187    const fn py_debug(&self) -> bool {
188        self.debug
189    }
190
191    fn __repr__(&self) -> String {
192        format!("{self:?}")
193    }
194
195    fn __str__(&self) -> String {
196        format!("{self:?}")
197    }
198}