Skip to main content

nautilus_persistence/backend/
session.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{sync::Arc, vec::IntoIter};
17
18use ahash::{AHashMap, AHashSet};
19use datafusion::{
20    arrow::record_batch::RecordBatch, error::Result, logical_expr::expr::Sort,
21    physical_plan::SendableRecordBatchStream, prelude::*,
22};
23use futures::StreamExt;
24use nautilus_core::{UnixNanos, ffi::cvec::CVec};
25use nautilus_model::data::{Data, HasTsInit};
26use nautilus_serialization::arrow::{
27    DataStreamingError, DecodeDataFromRecordBatch, EncodeToRecordBatch, WriteStream,
28};
29use object_store::ObjectStore;
30use url::Url;
31
32use super::{
33    compare::Compare,
34    kmerge_batch::{EagerStream, ElementBatchIter, KMerge},
35};
36
37#[derive(Debug, Default)]
38pub struct TsInitComparator;
39
40impl<I> Compare<ElementBatchIter<I, Data>> for TsInitComparator
41where
42    I: Iterator<Item = IntoIter<Data>>,
43{
44    fn compare(
45        &self,
46        l: &ElementBatchIter<I, Data>,
47        r: &ElementBatchIter<I, Data>,
48    ) -> std::cmp::Ordering {
49        // Max heap ordering must be reversed
50        l.item.ts_init().cmp(&r.item.ts_init()).reverse()
51    }
52}
53
54pub type QueryResult = KMerge<EagerStream<std::vec::IntoIter<Data>>, Data, TsInitComparator>;
55
56/// Provides a DataFusion session and registers DataFusion queries.
57///
58/// The session is used to register data sources and make queries on them. A
59/// query returns a Chunk of Arrow records. It is decoded and converted into
60/// a Vec of data by types that implement [`DecodeDataFromRecordBatch`].
61#[cfg_attr(
62    feature = "python",
63    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
64)]
65#[cfg_attr(
66    feature = "python",
67    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.persistence")
68)]
69pub struct DataBackendSession {
70    pub chunk_size: usize,
71    pub runtime: Arc<tokio::runtime::Runtime>,
72    session_ctx: SessionContext,
73    batch_streams: Vec<EagerStream<IntoIter<Data>>>,
74    registered_tables: AHashSet<String>,
75}
76
77impl DataBackendSession {
78    /// Creates a new [`DataBackendSession`] instance.
79    #[must_use]
80    pub fn new(chunk_size: usize) -> Self {
81        let runtime = tokio::runtime::Builder::new_multi_thread()
82            .enable_all()
83            .build()
84            .unwrap();
85        let session_cfg = SessionConfig::new()
86            .set_str("datafusion.optimizer.repartition_file_scans", "false")
87            .set_str("datafusion.optimizer.prefer_existing_sort", "true");
88        let session_ctx = SessionContext::new_with_config(session_cfg);
89        Self {
90            session_ctx,
91            batch_streams: Vec::default(),
92            chunk_size,
93            runtime: Arc::new(runtime),
94            registered_tables: AHashSet::new(),
95        }
96    }
97
98    /// Register an object store with the session context
99    pub fn register_object_store(&mut self, url: &Url, object_store: Arc<dyn ObjectStore>) {
100        self.session_ctx.register_object_store(url, object_store);
101    }
102
103    /// Register an object store with the session context from a URI with optional storage options
104    pub fn register_object_store_from_uri(
105        &mut self,
106        uri: &str,
107        storage_options: Option<AHashMap<String, String>>,
108    ) -> anyhow::Result<()> {
109        let location =
110            crate::parquet::create_object_store_location_from_path(uri, storage_options)?;
111
112        if let Some(root_url) = location.store_root_url().cloned() {
113            self.register_object_store(&root_url, location.object_store);
114        }
115
116        Ok(())
117    }
118
119    pub fn write_data<T: EncodeToRecordBatch>(
120        data: &[T],
121        metadata: &AHashMap<String, String>,
122        stream: &mut dyn WriteStream,
123    ) -> Result<(), DataStreamingError> {
124        // Convert AHashMap to HashMap for Arrow compatibility
125        let metadata: std::collections::HashMap<String, String> = metadata
126            .iter()
127            .map(|(k, v)| (k.clone(), v.clone()))
128            .collect();
129        let record_batch = T::encode_batch(&metadata, data)?;
130        stream.write(&record_batch)?;
131        Ok(())
132    }
133
134    /// Registers a Parquet file and adds a batch stream for decoding.
135    ///
136    /// The caller must specify `T` to indicate the kind of data expected. `table_name` is
137    /// the logical name for queries; `file_path` is the Parquet path; `sql_query` defaults
138    /// to `SELECT * FROM {table_name} ORDER BY ts_init` if `None`.
139    ///
140    /// When `custom_type_name` is `Some`, it is merged into each batch's schema metadata
141    /// before decoding (as `type_name`). Use this for custom data when Parquet/DataFusion
142    /// does not preserve schema metadata so the decoder can look up the type in the registry.
143    ///
144    /// The file data must be ordered by the `ts_init` in ascending order for this
145    /// to work correctly.
146    pub fn add_file<T>(
147        &mut self,
148        table_name: &str,
149        file_path: &str,
150        sql_query: Option<&str>,
151        custom_type_name: Option<&str>,
152    ) -> Result<()>
153    where
154        T: DecodeDataFromRecordBatch,
155    {
156        // Check if table is already registered to avoid duplicates
157        let is_new_table = !self.registered_tables.contains(table_name);
158
159        if is_new_table {
160            // Register the table only if it doesn't exist
161            let parquet_options = ParquetReadOptions::<'_> {
162                skip_metadata: Some(false),
163                file_sort_order: vec![vec![Sort {
164                    expr: col("ts_init"),
165                    asc: true,
166                    nulls_first: false,
167                }]],
168                ..Default::default()
169            };
170            self.runtime.block_on(self.session_ctx.register_parquet(
171                table_name,
172                file_path,
173                parquet_options,
174            ))?;
175
176            self.registered_tables.insert(table_name.to_string());
177
178            // Only add batch stream for newly registered tables to avoid duplicates
179            let default_query = format!("SELECT * FROM {} ORDER BY ts_init", &table_name);
180            let sql_query = sql_query.unwrap_or(&default_query);
181            let query = self.runtime.block_on(self.session_ctx.sql(sql_query))?;
182            let batch_stream = self.runtime.block_on(query.execute_stream())?;
183            self.add_batch_stream::<T>(batch_stream, custom_type_name.map(String::from));
184        }
185
186        Ok(())
187    }
188
189    /// Registers a Parquet file and executes a query, returning the raw record batches.
190    pub fn collect_query_batches(
191        &mut self,
192        table_name: &str,
193        file_path: &str,
194        sql_query: Option<&str>,
195    ) -> Result<Vec<RecordBatch>> {
196        if !self.registered_tables.contains(table_name) {
197            let parquet_options = ParquetReadOptions::<'_> {
198                skip_metadata: Some(false),
199                file_sort_order: vec![vec![Sort {
200                    expr: col("ts_init"),
201                    asc: true,
202                    nulls_first: false,
203                }]],
204                ..Default::default()
205            };
206            self.runtime.block_on(self.session_ctx.register_parquet(
207                table_name,
208                file_path,
209                parquet_options,
210            ))?;
211
212            self.registered_tables.insert(table_name.to_string());
213        }
214
215        let default_query = format!("SELECT * FROM {table_name} ORDER BY ts_init");
216        let sql_query = sql_query.unwrap_or(&default_query);
217        let query = self.runtime.block_on(self.session_ctx.sql(sql_query))?;
218        let mut batch_stream = self.runtime.block_on(query.execute_stream())?;
219
220        self.runtime.block_on(async {
221            let mut batches = Vec::new();
222            while let Some(batch) = batch_stream.next().await {
223                batches.push(batch?);
224            }
225            Ok::<_, datafusion::error::DataFusionError>(batches)
226        })
227    }
228
229    fn add_batch_stream<T>(
230        &mut self,
231        stream: SendableRecordBatchStream,
232        custom_type_name: Option<String>,
233    ) where
234        T: DecodeDataFromRecordBatch,
235    {
236        let transform = stream.map(move |result| match result {
237            Ok(batch) => {
238                let mut metadata: std::collections::HashMap<String, String> =
239                    batch.schema().metadata().clone();
240
241                if let Some(ref tn) = custom_type_name {
242                    metadata.insert("type_name".to_string(), tn.clone());
243                }
244                T::decode_data_batch(&metadata, batch).unwrap().into_iter()
245            }
246            Err(e) => panic!("Error getting next batch from RecordBatchStream: {e}"),
247        });
248
249        self.batch_streams
250            .push(EagerStream::from_stream_with_runtime(
251                transform,
252                self.runtime.clone(),
253            ));
254    }
255
256    // Consumes the registered queries and returns a [`QueryResult].
257    // Passes the output of the query though the a KMerge which sorts the
258    // queries in ascending order of `ts_init`.
259    // QueryResult is an iterator that return Vec<Data>.
260    pub fn get_query_result(&mut self) -> QueryResult {
261        let mut kmerge: KMerge<_, _, _> = KMerge::new(TsInitComparator);
262
263        self.batch_streams
264            .drain(..)
265            .for_each(|eager_stream| kmerge.push_iter(eager_stream));
266
267        kmerge
268    }
269
270    /// Clears all registered tables and batch streams.
271    ///
272    /// This is useful when the underlying files have changed and we need to
273    /// re-register tables with updated data.
274    pub fn clear_registered_tables(&mut self) {
275        self.registered_tables.clear();
276        self.batch_streams.clear();
277
278        // Create a new session context to completely reset the DataFusion state
279        let session_cfg = SessionConfig::new()
280            .set_str("datafusion.optimizer.repartition_file_scans", "false")
281            .set_str("datafusion.optimizer.prefer_existing_sort", "true");
282        self.session_ctx = SessionContext::new_with_config(session_cfg);
283    }
284}
285
286#[must_use]
287pub fn build_query(
288    table: &str,
289    start: Option<UnixNanos>,
290    end: Option<UnixNanos>,
291    where_clause: Option<&str>,
292) -> String {
293    let mut conditions = Vec::new();
294
295    // Add where clause if provided
296    if let Some(clause) = where_clause {
297        conditions.push(clause.to_string());
298    }
299
300    // Add start condition if provided
301    if let Some(start_ts) = start {
302        conditions.push(format!("ts_init >= {start_ts}"));
303    }
304
305    // Add end condition if provided
306    if let Some(end_ts) = end {
307        conditions.push(format!("ts_init <= {end_ts}"));
308    }
309
310    // Build base query
311    let mut query = format!("SELECT * FROM {table}");
312
313    // Add WHERE clause if there are conditions
314    if !conditions.is_empty() {
315        query.push_str(" WHERE ");
316        query.push_str(&conditions.join(" AND "));
317    }
318
319    // Add ORDER BY clause
320    query.push_str(" ORDER BY ts_init");
321
322    query
323}
324
325#[cfg_attr(
326    feature = "python",
327    pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
328)]
329#[cfg_attr(
330    feature = "python",
331    pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.persistence")
332)]
333pub struct DataQueryResult {
334    pub chunk: Option<CVec>,
335    pub result: QueryResult,
336    pub acc: Vec<Data>,
337    pub size: usize,
338}
339
340impl DataQueryResult {
341    /// Creates a new [`DataQueryResult`] instance.
342    #[must_use]
343    pub const fn new(result: QueryResult, size: usize) -> Self {
344        Self {
345            chunk: None,
346            result,
347            acc: Vec::new(),
348            size,
349        }
350    }
351
352    /// Set new `CVec` backed chunk from data
353    ///
354    /// It also drops previously allocated chunk
355    pub fn set_chunk(&mut self, data: Vec<Data>) -> CVec {
356        self.drop_chunk();
357
358        let chunk: CVec = data.into();
359        self.chunk = Some(chunk);
360        chunk
361    }
362
363    /// Chunks generated by iteration must be dropped after use, otherwise
364    /// it will leak memory. Current chunk is held by the reader,
365    /// drop if exists and reset the field.
366    pub fn drop_chunk(&mut self) {
367        if let Some(CVec { ptr, len, cap }) = self.chunk.take() {
368            assert!(
369                len <= cap,
370                "drop_chunk: len ({len}) > cap ({cap}) - memory corruption or wrong chunk type"
371            );
372            assert!(
373                len == 0 || !ptr.is_null(),
374                "drop_chunk: null ptr with non-zero len ({len}) - memory corruption"
375            );
376
377            // SAFETY: `ptr`, `len`, and `cap` originate from a valid `CVec` and the
378            // assertions above verify the invariants required by `Vec::from_raw_parts`.
379            let data: Vec<Data> = unsafe { Vec::from_raw_parts(ptr.cast::<Data>(), len, cap) };
380            drop(data);
381        }
382    }
383}
384
385impl Iterator for DataQueryResult {
386    type Item = Vec<Data>;
387
388    fn next(&mut self) -> Option<Self::Item> {
389        for _ in 0..self.size {
390            match self.result.next() {
391                Some(item) => self.acc.push(item),
392                None => break,
393            }
394        }
395
396        // TODO: consider using drain here if perf is unchanged
397        // Some(self.acc.drain(0..).collect())
398        let mut acc: Vec<Data> = Vec::new();
399        std::mem::swap(&mut acc, &mut self.acc);
400        Some(acc)
401    }
402}
403
404impl Drop for DataQueryResult {
405    fn drop(&mut self) {
406        self.drop_chunk();
407        self.result.clear();
408    }
409}