Skip to main content

nautilus_persistence/backend/
kmerge_batch.rs

1// -------------------------------------------------------------------------------------------------
2//  Copyright (C) 2015-2026 Nautech Systems Pty Ltd. All rights reserved.
3//  https://nautechsystems.io
4//
5//  Licensed under the GNU Lesser General Public License Version 3.0 (the "License");
6//  You may not use this file except in compliance with the License.
7//  You may obtain a copy of the License at https://www.gnu.org/licenses/lgpl-3.0.en.html
8//
9//  Unless required by applicable law or agreed to in writing, software
10//  distributed under the License is distributed on an "AS IS" BASIS,
11//  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//  See the License for the specific language governing permissions and
13//  limitations under the License.
14// -------------------------------------------------------------------------------------------------
15
16use std::{sync::Arc, vec::IntoIter};
17
18use futures::{Stream, StreamExt};
19use tokio::{
20    runtime::Runtime,
21    sync::mpsc::{self, Receiver},
22    task::JoinHandle,
23};
24
25use super::{
26    binary_heap::{BinaryHeap, PeekMut},
27    compare::Compare,
28};
29
30pub struct EagerStream<T> {
31    rx: Receiver<T>,
32    task: JoinHandle<()>,
33    runtime: Arc<Runtime>,
34}
35
36impl<T> EagerStream<T> {
37    pub fn from_stream_with_runtime<S>(stream: S, runtime: Arc<Runtime>) -> Self
38    where
39        S: Stream<Item = T> + Send + 'static,
40        T: Send + 'static,
41    {
42        let (tx, rx) = mpsc::channel(1);
43
44        let task = runtime.spawn(async move {
45            futures::pin_mut!(stream);
46            while let Some(item) = stream.next().await {
47                if tx.send(item).await.is_err() {
48                    break;
49                }
50            }
51        });
52
53        Self { rx, task, runtime }
54    }
55}
56
57impl<T> Iterator for EagerStream<T> {
58    type Item = T;
59
60    fn next(&mut self) -> Option<Self::Item> {
61        self.runtime.block_on(self.rx.recv())
62    }
63}
64
65impl<T> Drop for EagerStream<T> {
66    fn drop(&mut self) {
67        self.rx.close();
68        self.task.abort();
69    }
70}
71
72// TODO: Investigate implementing Iterator for ElementBatchIter
73// to reduce next element duplication. May be difficult to make it peekable.
74pub struct ElementBatchIter<I, T>
75where
76    I: Iterator<Item = IntoIter<T>>,
77{
78    pub item: T,
79    batch: I::Item,
80    iter: I,
81}
82
83impl<I, T> ElementBatchIter<I, T>
84where
85    I: Iterator<Item = IntoIter<T>>,
86{
87    fn new_from_iter(mut iter: I) -> Option<Self> {
88        loop {
89            let Some(mut batch) = iter.next() else {
90                break None;
91            };
92
93            if let Some(item) = batch.next() {
94                break Some(Self { item, batch, iter });
95            }
96        }
97    }
98}
99
100pub struct KMerge<I, T, C>
101where
102    I: Iterator<Item = IntoIter<T>>,
103{
104    heap: BinaryHeap<ElementBatchIter<I, T>, C>,
105}
106
107impl<I, T, C> KMerge<I, T, C>
108where
109    I: Iterator<Item = IntoIter<T>>,
110    C: Compare<ElementBatchIter<I, T>>,
111{
112    /// Creates a new [`KMerge`] instance.
113    pub fn new(cmp: C) -> Self {
114        Self {
115            heap: BinaryHeap::from_vec_cmp(Vec::new(), cmp),
116        }
117    }
118
119    pub fn push_iter(&mut self, s: I) {
120        if let Some(heap_elem) = ElementBatchIter::new_from_iter(s) {
121            self.heap.push(heap_elem);
122        }
123    }
124
125    pub fn clear(&mut self) {
126        self.heap.clear();
127    }
128}
129
130impl<I, T, C> Iterator for KMerge<I, T, C>
131where
132    I: Iterator<Item = IntoIter<T>>,
133    C: Compare<ElementBatchIter<I, T>>,
134{
135    type Item = T;
136
137    fn next(&mut self) -> Option<Self::Item> {
138        match self.heap.peek_mut() {
139            Some(mut heap_elem) => {
140                // Get next element from batch
141                match heap_elem.batch.next() {
142                    // Swap current heap element with new element
143                    // return the old element
144                    Some(mut item) => {
145                        std::mem::swap(&mut item, &mut heap_elem.item);
146                        Some(item)
147                    }
148                    // Otherwise get the next batch and the element from it
149                    // Unless the underlying iterator is exhausted
150                    None => loop {
151                        let Some(mut batch) = heap_elem.iter.next() else {
152                            let ElementBatchIter {
153                                item,
154                                batch: _,
155                                iter: _,
156                            } = PeekMut::pop(heap_elem);
157                            break Some(item);
158                        };
159
160                        if let Some(mut item) = batch.next() {
161                            heap_elem.batch = batch;
162                            std::mem::swap(&mut item, &mut heap_elem.item);
163                            break Some(item);
164                        }
165                    },
166                }
167            }
168            None => None,
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use proptest::prelude::*;
176    use rstest::rstest;
177
178    use super::*;
179
180    struct OrdComparator;
181    impl<S> Compare<ElementBatchIter<S, i32>> for OrdComparator
182    where
183        S: Iterator<Item = IntoIter<i32>>,
184    {
185        fn compare(
186            &self,
187            l: &ElementBatchIter<S, i32>,
188            r: &ElementBatchIter<S, i32>,
189        ) -> std::cmp::Ordering {
190            // Max heap ordering must be reversed
191            l.item.cmp(&r.item).reverse()
192        }
193    }
194
195    impl<S> Compare<ElementBatchIter<S, u64>> for OrdComparator
196    where
197        S: Iterator<Item = IntoIter<u64>>,
198    {
199        fn compare(
200            &self,
201            l: &ElementBatchIter<S, u64>,
202            r: &ElementBatchIter<S, u64>,
203        ) -> std::cmp::Ordering {
204            // Max heap ordering must be reversed
205            l.item.cmp(&r.item).reverse()
206        }
207    }
208
209    #[rstest]
210    fn test1() {
211        let iter_a = vec![vec![1, 2, 3].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
212        let iter_b = vec![vec![4, 5, 6].into_iter()].into_iter();
213        let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
214        kmerge.push_iter(iter_a);
215        kmerge.push_iter(iter_b);
216
217        let values: Vec<i32> = kmerge.collect();
218        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
219    }
220
221    #[rstest]
222    fn test2() {
223        let iter_a = vec![vec![1, 2, 6].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
224        let iter_b = vec![vec![3, 4, 5, 6].into_iter()].into_iter();
225        let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
226        kmerge.push_iter(iter_a);
227        kmerge.push_iter(iter_b);
228
229        let values: Vec<i32> = kmerge.collect();
230        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 6, 7, 8, 9]);
231    }
232
233    #[rstest]
234    fn test3() {
235        let iter_a = vec![vec![1, 4, 7].into_iter(), vec![24, 35, 56].into_iter()].into_iter();
236        let iter_b = vec![vec![2, 4, 8].into_iter()].into_iter();
237        let iter_c = vec![vec![3, 5, 9].into_iter(), vec![12, 12, 90].into_iter()].into_iter();
238        let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
239        kmerge.push_iter(iter_a);
240        kmerge.push_iter(iter_b);
241        kmerge.push_iter(iter_c);
242
243        let values: Vec<i32> = kmerge.collect();
244        assert_eq!(
245            values,
246            vec![1, 2, 3, 4, 4, 5, 7, 8, 9, 12, 12, 24, 35, 56, 90]
247        );
248    }
249
250    #[rstest]
251    fn test5() {
252        let iter_a = vec![
253            vec![1, 3, 5].into_iter(),
254            vec![].into_iter(),
255            vec![7, 9, 11].into_iter(),
256        ]
257        .into_iter();
258        let iter_b = vec![vec![2, 4, 6].into_iter()].into_iter();
259        let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
260        kmerge.push_iter(iter_a);
261        kmerge.push_iter(iter_b);
262
263        let values: Vec<i32> = kmerge.collect();
264        assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 9, 11]);
265    }
266
267    #[derive(Debug, Clone)]
268    struct SortedNestedVec(Vec<Vec<u64>>);
269
270    /// Strategy to generate nested vectors where each inner vector is sorted.
271    fn sorted_nested_vec_strategy() -> impl Strategy<Value = SortedNestedVec> {
272        // Generate a vector of u64 values, then split into sorted chunks
273        prop::collection::vec(any::<u64>(), 0..=100).prop_flat_map(|mut flat_vec| {
274            flat_vec.sort_unstable();
275
276            // Generate chunk sizes that will split the sorted vector
277            let total_len = flat_vec.len();
278            if total_len == 0 {
279                return Just(SortedNestedVec(vec![vec![]])).boxed();
280            }
281
282            // Generate random chunk boundaries
283            prop::collection::vec(0..=total_len, 0..=10)
284                .prop_map(move |mut boundaries| {
285                    boundaries.push(0);
286                    boundaries.push(total_len);
287                    boundaries.sort_unstable();
288                    boundaries.dedup();
289
290                    let mut nested_vec = Vec::new();
291                    for [start, end] in boundaries.array_windows() {
292                        nested_vec.push(flat_vec[*start..*end].to_vec());
293                    }
294
295                    SortedNestedVec(nested_vec)
296                })
297                .boxed()
298        })
299    }
300
301    proptest! {
302        /// Property: K-way merge should produce the same result as sorting all data together
303        #[rstest]
304        fn prop_kmerge_equivalent_to_sort(
305            all_data in prop::collection::vec(sorted_nested_vec_strategy(), 0..=10)
306        ) {
307            let mut kmerge: KMerge<_, u64, _> = KMerge::new(OrdComparator);
308
309            let copy_data = all_data.clone();
310            for stream in copy_data {
311                let input = stream.0.into_iter().map(std::iter::IntoIterator::into_iter);
312                kmerge.push_iter(input);
313            }
314            let merged_data: Vec<u64> = kmerge.collect();
315
316            let mut sorted_data: Vec<u64> = all_data
317                .into_iter()
318                .flat_map(|stream| stream.0.into_iter().flatten())
319                .collect();
320            sorted_data.sort_unstable();
321
322            prop_assert_eq!(merged_data.len(), sorted_data.len(), "Lengths should be equal");
323            prop_assert_eq!(merged_data, sorted_data, "Merged data should equal sorted data");
324        }
325
326        /// Property: K-way merge should preserve sortedness when inputs are sorted
327        #[rstest]
328        fn prop_kmerge_preserves_sort_order(
329            all_data in prop::collection::vec(sorted_nested_vec_strategy(), 1..=5)
330        ) {
331            let mut kmerge: KMerge<_, u64, _> = KMerge::new(OrdComparator);
332
333            for stream in all_data {
334                let input = stream.0.into_iter().map(std::iter::IntoIterator::into_iter);
335                kmerge.push_iter(input);
336            }
337            let merged_data: Vec<u64> = kmerge.collect();
338
339            // Check that the merged data is sorted
340            for [a, b] in merged_data.array_windows() {
341                prop_assert!(a <= b, "Merged data should be sorted");
342            }
343        }
344
345        /// Property: Empty iterators should not affect the merge result
346        #[rstest]
347        fn prop_kmerge_handles_empty_iterators(
348            data in sorted_nested_vec_strategy(),
349            empty_count in 0usize..=5
350        ) {
351            let mut kmerge_with_empty: KMerge<_, u64, _> = KMerge::new(OrdComparator);
352            let mut kmerge_without_empty: KMerge<_, u64, _> = KMerge::new(OrdComparator);
353
354            // Add the actual data to both merges
355            let input_with_empty = data.0.clone().into_iter().map(std::iter::IntoIterator::into_iter);
356            let input_without_empty = data.0.into_iter().map(std::iter::IntoIterator::into_iter);
357
358            kmerge_with_empty.push_iter(input_with_empty);
359            kmerge_without_empty.push_iter(input_without_empty);
360
361            // Add empty iterators to the first merge
362            for _ in 0..empty_count {
363                let empty_vec: Vec<Vec<u64>> = vec![];
364                let empty_input = empty_vec.into_iter().map(std::iter::IntoIterator::into_iter);
365                kmerge_with_empty.push_iter(empty_input);
366            }
367
368            let result_with_empty: Vec<u64> = kmerge_with_empty.collect();
369            let result_without_empty: Vec<u64> = kmerge_without_empty.collect();
370
371            prop_assert_eq!(result_with_empty, result_without_empty, "Empty iterators should not affect result");
372        }
373    }
374}