Skip to main content

nautilus_infrastructure/python/sql/
cache.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use bytes::Bytes;
17use nautilus_common::{cache::database::CacheDatabaseAdapter, live::get_runtime, signal::Signal};
18use nautilus_core::python::to_pyruntime_err;
19use nautilus_model::{
20    data::{Bar, CustomData, DataType, QuoteTick, TradeTick},
21    events::{OrderSnapshot, PositionSnapshot},
22    identifiers::{AccountId, ClientId, ClientOrderId, InstrumentId, PositionId},
23    python::{
24        account::{account_any_to_pyobject, pyobject_to_account_any},
25        events::order::pyobject_to_order_event,
26        instruments::{instrument_any_to_pyobject, pyobject_to_instrument_any},
27        orders::{order_any_to_pyobject, pyobject_to_order_any},
28    },
29    types::Currency,
30};
31use pyo3::{IntoPyObjectExt, prelude::*};
32
33use crate::sql::{cache::PostgresCacheDatabase, queries::DatabaseQueries};
34
35#[pymethods]
36impl PostgresCacheDatabase {
37    /// Connects to the Postgres cache database using the provided connection parameters.
38    #[staticmethod]
39    #[pyo3(name = "connect")]
40    #[pyo3(signature = (host=None, port=None, username=None, password=None, database=None))]
41    fn py_connect(
42        host: Option<String>,
43        port: Option<u16>,
44        username: Option<String>,
45        password: Option<String>,
46        database: Option<String>,
47    ) -> PyResult<Self> {
48        let result = get_runtime()
49            .block_on(async { Self::connect(host, port, username, password, database).await });
50        result.map_err(to_pyruntime_err)
51    }
52
53    #[pyo3(name = "close")]
54    fn py_close(&mut self) -> PyResult<()> {
55        self.close().map_err(to_pyruntime_err)
56    }
57
58    #[pyo3(name = "flush_db")]
59    fn py_flush_db(&mut self) -> PyResult<()> {
60        self.flush().map_err(to_pyruntime_err)
61    }
62
63    #[pyo3(name = "load")]
64    fn py_load(&self) -> PyResult<std::collections::HashMap<String, Vec<u8>>> {
65        get_runtime()
66            .block_on(async { DatabaseQueries::load(&self.pool).await })
67            .map(|m| m.into_iter().collect())
68            .map_err(to_pyruntime_err)
69    }
70
71    #[pyo3(name = "load_currency")]
72    fn py_load_currency(&self, code: &str) -> PyResult<Option<Currency>> {
73        let result = get_runtime()
74            .block_on(async { DatabaseQueries::load_currency(&self.pool, code).await });
75        result.map_err(to_pyruntime_err)
76    }
77
78    #[pyo3(name = "load_currencies")]
79    fn py_load_currencies(&self) -> PyResult<Vec<Currency>> {
80        let result =
81            get_runtime().block_on(async { DatabaseQueries::load_currencies(&self.pool).await });
82        result.map_err(to_pyruntime_err)
83    }
84
85    #[pyo3(name = "load_instrument")]
86    fn py_load_instrument(
87        &self,
88        py: Python,
89        instrument_id: InstrumentId,
90    ) -> PyResult<Option<Py<PyAny>>> {
91        get_runtime().block_on(async {
92            let result = DatabaseQueries::load_instrument(&self.pool, &instrument_id)
93                .await
94                .unwrap();
95
96            match result {
97                Some(instrument) => {
98                    let py_object = instrument_any_to_pyobject(py, instrument)?;
99                    Ok(Some(py_object))
100                }
101                None => Ok(None),
102            }
103        })
104    }
105
106    #[pyo3(name = "load_instruments")]
107    fn py_load_instruments(&self, py: Python) -> PyResult<Vec<Py<PyAny>>> {
108        get_runtime().block_on(async {
109            let result = DatabaseQueries::load_instruments(&self.pool).await.unwrap();
110            let mut instruments = Vec::new();
111
112            for instrument in result {
113                let py_object = instrument_any_to_pyobject(py, instrument)?;
114                instruments.push(py_object);
115            }
116            Ok(instruments)
117        })
118    }
119
120    #[pyo3(name = "load_order")]
121    fn py_load_order(
122        &self,
123        py: Python,
124        client_order_id: ClientOrderId,
125    ) -> PyResult<Option<Py<PyAny>>> {
126        get_runtime().block_on(async {
127            let result = DatabaseQueries::load_order(&self.pool, &client_order_id)
128                .await
129                .unwrap();
130
131            match result {
132                Some(order) => {
133                    let py_object = order_any_to_pyobject(py, order)?;
134                    Ok(Some(py_object))
135                }
136                None => Ok(None),
137            }
138        })
139    }
140
141    #[pyo3(name = "load_account")]
142    fn py_load_account(&self, py: Python, account_id: AccountId) -> PyResult<Option<Py<PyAny>>> {
143        get_runtime().block_on(async {
144            let result = DatabaseQueries::load_account(&self.pool, &account_id)
145                .await
146                .unwrap();
147
148            match result {
149                Some(account) => {
150                    let py_object = account_any_to_pyobject(py, account)?;
151                    Ok(Some(py_object))
152                }
153                None => Ok(None),
154            }
155        })
156    }
157
158    #[pyo3(name = "load_quotes")]
159    fn py_load_quotes(&self, py: Python, instrument_id: InstrumentId) -> PyResult<Vec<Py<PyAny>>> {
160        get_runtime().block_on(async {
161            let result = DatabaseQueries::load_quotes(&self.pool, &instrument_id)
162                .await
163                .unwrap();
164            let mut quotes = Vec::new();
165
166            for quote in result {
167                let py_object = quote.into_py_any(py)?;
168                quotes.push(py_object);
169            }
170            Ok(quotes)
171        })
172    }
173
174    #[pyo3(name = "load_trades")]
175    fn py_load_trades(&self, py: Python, instrument_id: InstrumentId) -> PyResult<Vec<Py<PyAny>>> {
176        get_runtime().block_on(async {
177            let result = DatabaseQueries::load_trades(&self.pool, &instrument_id)
178                .await
179                .unwrap();
180            let mut trades = Vec::new();
181
182            for trade in result {
183                let py_object = trade.into_py_any(py)?;
184                trades.push(py_object);
185            }
186            Ok(trades)
187        })
188    }
189
190    #[pyo3(name = "load_bars")]
191    fn py_load_bars(&self, py: Python, instrument_id: InstrumentId) -> PyResult<Vec<Py<PyAny>>> {
192        get_runtime().block_on(async {
193            let result = DatabaseQueries::load_bars(&self.pool, &instrument_id)
194                .await
195                .unwrap();
196            let mut bars = Vec::new();
197
198            for bar in result {
199                let py_object = bar.into_py_any(py)?;
200                bars.push(py_object);
201            }
202            Ok(bars)
203        })
204    }
205
206    #[pyo3(name = "load_signals")]
207    fn py_load_signals(&self, name: &str) -> PyResult<Vec<Signal>> {
208        get_runtime().block_on(async {
209            DatabaseQueries::load_signals(&self.pool, name)
210                .await
211                .map_err(to_pyruntime_err)
212        })
213    }
214
215    #[pyo3(name = "load_custom_data")]
216    #[expect(clippy::needless_pass_by_value)]
217    fn py_load_custom_data(&self, data_type: DataType) -> PyResult<Vec<CustomData>> {
218        get_runtime()
219            .block_on(async { DatabaseQueries::load_custom_data(&self.pool, &data_type).await })
220            .map_err(to_pyruntime_err)
221    }
222
223    #[pyo3(name = "load_order_snapshot")]
224    fn py_load_order_snapshot(
225        &self,
226        client_order_id: ClientOrderId,
227    ) -> PyResult<Option<OrderSnapshot>> {
228        get_runtime().block_on(async {
229            DatabaseQueries::load_order_snapshot(&self.pool, &client_order_id)
230                .await
231                .map_err(to_pyruntime_err)
232        })
233    }
234
235    #[pyo3(name = "load_position_snapshot")]
236    fn py_load_position_snapshot(
237        &self,
238        position_id: PositionId,
239    ) -> PyResult<Option<PositionSnapshot>> {
240        get_runtime().block_on(async {
241            DatabaseQueries::load_position_snapshot(&self.pool, &position_id)
242                .await
243                .map_err(to_pyruntime_err)
244        })
245    }
246
247    #[pyo3(name = "add")]
248    fn py_add(&self, key: String, value: Vec<u8>) -> PyResult<()> {
249        self.add(key, Bytes::from(value)).map_err(to_pyruntime_err)
250    }
251
252    #[pyo3(name = "add_currency")]
253    fn py_add_currency(&self, currency: Currency) -> PyResult<()> {
254        self.add_currency(&currency).map_err(to_pyruntime_err)
255    }
256
257    #[pyo3(name = "add_instrument")]
258    fn py_add_instrument(&self, py: Python, instrument: Py<PyAny>) -> PyResult<()> {
259        let instrument_any = pyobject_to_instrument_any(py, instrument)?;
260        self.add_instrument(&instrument_any)
261            .map_err(to_pyruntime_err)
262    }
263
264    #[pyo3(name = "add_order")]
265    #[pyo3(signature = (order, client_id=None))]
266    fn py_add_order(
267        &self,
268        py: Python,
269        order: Py<PyAny>,
270        client_id: Option<ClientId>,
271    ) -> PyResult<()> {
272        let order_any = pyobject_to_order_any(py, order)?;
273        self.add_order(&order_any, client_id)
274            .map_err(to_pyruntime_err)
275    }
276
277    #[pyo3(name = "add_order_snapshot")]
278    #[expect(clippy::needless_pass_by_value)]
279    fn py_add_order_snapshot(&self, snapshot: OrderSnapshot) -> PyResult<()> {
280        self.add_order_snapshot(&snapshot).map_err(to_pyruntime_err)
281    }
282
283    #[pyo3(name = "add_position_snapshot")]
284    #[expect(clippy::needless_pass_by_value)]
285    fn py_add_position_snapshot(&self, snapshot: PositionSnapshot) -> PyResult<()> {
286        self.add_position_snapshot(&snapshot)
287            .map_err(to_pyruntime_err)
288    }
289
290    #[pyo3(name = "add_account")]
291    fn py_add_account(&self, py: Python, account: Py<PyAny>) -> PyResult<()> {
292        let account_any = pyobject_to_account_any(py, account)?;
293        self.add_account(&account_any).map_err(to_pyruntime_err)
294    }
295
296    #[pyo3(name = "add_quote")]
297    fn py_add_quote(&self, quote: QuoteTick) -> PyResult<()> {
298        self.add_quote(&quote).map_err(to_pyruntime_err)
299    }
300
301    #[pyo3(name = "add_trade")]
302    fn py_add_trade(&self, trade: TradeTick) -> PyResult<()> {
303        self.add_trade(&trade).map_err(to_pyruntime_err)
304    }
305
306    #[pyo3(name = "add_bar")]
307    fn py_add_bar(&self, bar: Bar) -> PyResult<()> {
308        self.add_bar(&bar).map_err(to_pyruntime_err)
309    }
310
311    #[pyo3(name = "add_signal")]
312    #[expect(clippy::needless_pass_by_value)]
313    fn py_add_signal(&self, signal: Signal) -> PyResult<()> {
314        self.add_signal(&signal).map_err(to_pyruntime_err)
315    }
316
317    #[pyo3(name = "add_custom_data")]
318    #[expect(clippy::needless_pass_by_value)]
319    fn py_add_custom_data(&self, data: CustomData) -> PyResult<()> {
320        self.add_custom_data(&data).map_err(to_pyruntime_err)
321    }
322
323    #[pyo3(name = "update_order")]
324    fn py_update_order(&self, py: Python, order_event: Py<PyAny>) -> PyResult<()> {
325        let event = pyobject_to_order_event(py, order_event)?;
326        self.update_order(&event).map_err(to_pyruntime_err)
327    }
328
329    #[pyo3(name = "update_account")]
330    fn py_update_account(&self, py: Python, order: Py<PyAny>) -> PyResult<()> {
331        let order_any = pyobject_to_account_any(py, order)?;
332        self.update_account(&order_any).map_err(to_pyruntime_err)
333    }
334}