nautilus_persistence/backend/
session.rs1use std::{sync::Arc, vec::IntoIter};
17
18use ahash::{AHashMap, AHashSet};
19use datafusion::{
20 arrow::record_batch::RecordBatch, error::Result, logical_expr::expr::Sort,
21 physical_plan::SendableRecordBatchStream, prelude::*,
22};
23use futures::StreamExt;
24use nautilus_core::{UnixNanos, ffi::cvec::CVec};
25use nautilus_model::data::{Data, HasTsInit};
26use nautilus_serialization::arrow::{
27 DataStreamingError, DecodeDataFromRecordBatch, EncodeToRecordBatch, WriteStream,
28};
29use object_store::ObjectStore;
30use url::Url;
31
32use super::{
33 compare::Compare,
34 kmerge_batch::{EagerStream, ElementBatchIter, KMerge},
35};
36
37#[derive(Debug, Default)]
38pub struct TsInitComparator;
39
40impl<I> Compare<ElementBatchIter<I, Data>> for TsInitComparator
41where
42 I: Iterator<Item = IntoIter<Data>>,
43{
44 fn compare(
45 &self,
46 l: &ElementBatchIter<I, Data>,
47 r: &ElementBatchIter<I, Data>,
48 ) -> std::cmp::Ordering {
49 l.item.ts_init().cmp(&r.item.ts_init()).reverse()
51 }
52}
53
54pub type QueryResult = KMerge<EagerStream<std::vec::IntoIter<Data>>, Data, TsInitComparator>;
55
56#[cfg_attr(
62 feature = "python",
63 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
64)]
65#[cfg_attr(
66 feature = "python",
67 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.persistence")
68)]
69pub struct DataBackendSession {
70 pub chunk_size: usize,
71 pub runtime: Arc<tokio::runtime::Runtime>,
72 session_ctx: SessionContext,
73 batch_streams: Vec<EagerStream<IntoIter<Data>>>,
74 registered_tables: AHashSet<String>,
75}
76
77impl DataBackendSession {
78 #[must_use]
80 pub fn new(chunk_size: usize) -> Self {
81 let runtime = tokio::runtime::Builder::new_multi_thread()
82 .enable_all()
83 .build()
84 .unwrap();
85 let session_cfg = SessionConfig::new()
86 .set_str("datafusion.optimizer.repartition_file_scans", "false")
87 .set_str("datafusion.optimizer.prefer_existing_sort", "true");
88 let session_ctx = SessionContext::new_with_config(session_cfg);
89 Self {
90 session_ctx,
91 batch_streams: Vec::default(),
92 chunk_size,
93 runtime: Arc::new(runtime),
94 registered_tables: AHashSet::new(),
95 }
96 }
97
98 pub fn register_object_store(&mut self, url: &Url, object_store: Arc<dyn ObjectStore>) {
100 self.session_ctx.register_object_store(url, object_store);
101 }
102
103 pub fn register_object_store_from_uri(
105 &mut self,
106 uri: &str,
107 storage_options: Option<AHashMap<String, String>>,
108 ) -> anyhow::Result<()> {
109 let location =
110 crate::parquet::create_object_store_location_from_path(uri, storage_options)?;
111
112 if let Some(root_url) = location.store_root_url().cloned() {
113 self.register_object_store(&root_url, location.object_store);
114 }
115
116 Ok(())
117 }
118
119 pub fn write_data<T: EncodeToRecordBatch>(
120 data: &[T],
121 metadata: &AHashMap<String, String>,
122 stream: &mut dyn WriteStream,
123 ) -> Result<(), DataStreamingError> {
124 let metadata: std::collections::HashMap<String, String> = metadata
126 .iter()
127 .map(|(k, v)| (k.clone(), v.clone()))
128 .collect();
129 let record_batch = T::encode_batch(&metadata, data)?;
130 stream.write(&record_batch)?;
131 Ok(())
132 }
133
134 pub fn add_file<T>(
147 &mut self,
148 table_name: &str,
149 file_path: &str,
150 sql_query: Option<&str>,
151 custom_type_name: Option<&str>,
152 ) -> Result<()>
153 where
154 T: DecodeDataFromRecordBatch,
155 {
156 let is_new_table = !self.registered_tables.contains(table_name);
158
159 if is_new_table {
160 let parquet_options = ParquetReadOptions::<'_> {
162 skip_metadata: Some(false),
163 file_sort_order: vec![vec![Sort {
164 expr: col("ts_init"),
165 asc: true,
166 nulls_first: false,
167 }]],
168 ..Default::default()
169 };
170 self.runtime.block_on(self.session_ctx.register_parquet(
171 table_name,
172 file_path,
173 parquet_options,
174 ))?;
175
176 self.registered_tables.insert(table_name.to_string());
177
178 let default_query = format!("SELECT * FROM {} ORDER BY ts_init", &table_name);
180 let sql_query = sql_query.unwrap_or(&default_query);
181 let query = self.runtime.block_on(self.session_ctx.sql(sql_query))?;
182 let batch_stream = self.runtime.block_on(query.execute_stream())?;
183 self.add_batch_stream::<T>(batch_stream, custom_type_name.map(String::from));
184 }
185
186 Ok(())
187 }
188
189 pub fn collect_query_batches(
191 &mut self,
192 table_name: &str,
193 file_path: &str,
194 sql_query: Option<&str>,
195 ) -> Result<Vec<RecordBatch>> {
196 if !self.registered_tables.contains(table_name) {
197 let parquet_options = ParquetReadOptions::<'_> {
198 skip_metadata: Some(false),
199 file_sort_order: vec![vec![Sort {
200 expr: col("ts_init"),
201 asc: true,
202 nulls_first: false,
203 }]],
204 ..Default::default()
205 };
206 self.runtime.block_on(self.session_ctx.register_parquet(
207 table_name,
208 file_path,
209 parquet_options,
210 ))?;
211
212 self.registered_tables.insert(table_name.to_string());
213 }
214
215 let default_query = format!("SELECT * FROM {table_name} ORDER BY ts_init");
216 let sql_query = sql_query.unwrap_or(&default_query);
217 let query = self.runtime.block_on(self.session_ctx.sql(sql_query))?;
218 let mut batch_stream = self.runtime.block_on(query.execute_stream())?;
219
220 self.runtime.block_on(async {
221 let mut batches = Vec::new();
222 while let Some(batch) = batch_stream.next().await {
223 batches.push(batch?);
224 }
225 Ok::<_, datafusion::error::DataFusionError>(batches)
226 })
227 }
228
229 fn add_batch_stream<T>(
230 &mut self,
231 stream: SendableRecordBatchStream,
232 custom_type_name: Option<String>,
233 ) where
234 T: DecodeDataFromRecordBatch,
235 {
236 let transform = stream.map(move |result| match result {
237 Ok(batch) => {
238 let mut metadata: std::collections::HashMap<String, String> =
239 batch.schema().metadata().clone();
240
241 if let Some(ref tn) = custom_type_name {
242 metadata.insert("type_name".to_string(), tn.clone());
243 }
244 T::decode_data_batch(&metadata, batch).unwrap().into_iter()
245 }
246 Err(e) => panic!("Error getting next batch from RecordBatchStream: {e}"),
247 });
248
249 self.batch_streams
250 .push(EagerStream::from_stream_with_runtime(
251 transform,
252 self.runtime.clone(),
253 ));
254 }
255
256 pub fn get_query_result(&mut self) -> QueryResult {
261 let mut kmerge: KMerge<_, _, _> = KMerge::new(TsInitComparator);
262
263 self.batch_streams
264 .drain(..)
265 .for_each(|eager_stream| kmerge.push_iter(eager_stream));
266
267 kmerge
268 }
269
270 pub fn clear_registered_tables(&mut self) {
275 self.registered_tables.clear();
276 self.batch_streams.clear();
277
278 let session_cfg = SessionConfig::new()
280 .set_str("datafusion.optimizer.repartition_file_scans", "false")
281 .set_str("datafusion.optimizer.prefer_existing_sort", "true");
282 self.session_ctx = SessionContext::new_with_config(session_cfg);
283 }
284}
285
286#[must_use]
287pub fn build_query(
288 table: &str,
289 start: Option<UnixNanos>,
290 end: Option<UnixNanos>,
291 where_clause: Option<&str>,
292) -> String {
293 let mut conditions = Vec::new();
294
295 if let Some(clause) = where_clause {
297 conditions.push(clause.to_string());
298 }
299
300 if let Some(start_ts) = start {
302 conditions.push(format!("ts_init >= {start_ts}"));
303 }
304
305 if let Some(end_ts) = end {
307 conditions.push(format!("ts_init <= {end_ts}"));
308 }
309
310 let mut query = format!("SELECT * FROM {table}");
312
313 if !conditions.is_empty() {
315 query.push_str(" WHERE ");
316 query.push_str(&conditions.join(" AND "));
317 }
318
319 query.push_str(" ORDER BY ts_init");
321
322 query
323}
324
325#[cfg_attr(
326 feature = "python",
327 pyo3::pyclass(module = "nautilus_trader.core.nautilus_pyo3.persistence", unsendable)
328)]
329#[cfg_attr(
330 feature = "python",
331 pyo3_stub_gen::derive::gen_stub_pyclass(module = "nautilus_trader.persistence")
332)]
333pub struct DataQueryResult {
334 pub chunk: Option<CVec>,
335 pub result: QueryResult,
336 pub acc: Vec<Data>,
337 pub size: usize,
338}
339
340impl DataQueryResult {
341 #[must_use]
343 pub const fn new(result: QueryResult, size: usize) -> Self {
344 Self {
345 chunk: None,
346 result,
347 acc: Vec::new(),
348 size,
349 }
350 }
351
352 pub fn set_chunk(&mut self, data: Vec<Data>) -> CVec {
356 self.drop_chunk();
357
358 let chunk: CVec = data.into();
359 self.chunk = Some(chunk);
360 chunk
361 }
362
363 pub fn drop_chunk(&mut self) {
367 if let Some(CVec { ptr, len, cap }) = self.chunk.take() {
368 assert!(
369 len <= cap,
370 "drop_chunk: len ({len}) > cap ({cap}) - memory corruption or wrong chunk type"
371 );
372 assert!(
373 len == 0 || !ptr.is_null(),
374 "drop_chunk: null ptr with non-zero len ({len}) - memory corruption"
375 );
376
377 let data: Vec<Data> = unsafe { Vec::from_raw_parts(ptr.cast::<Data>(), len, cap) };
380 drop(data);
381 }
382 }
383}
384
385impl Iterator for DataQueryResult {
386 type Item = Vec<Data>;
387
388 fn next(&mut self) -> Option<Self::Item> {
389 for _ in 0..self.size {
390 match self.result.next() {
391 Some(item) => self.acc.push(item),
392 None => break,
393 }
394 }
395
396 let mut acc: Vec<Data> = Vec::new();
399 std::mem::swap(&mut acc, &mut self.acc);
400 Some(acc)
401 }
402}
403
404impl Drop for DataQueryResult {
405 fn drop(&mut self) {
406 self.drop_chunk();
407 self.result.clear();
408 }
409}