1use std::{sync::Arc, vec::IntoIter};
17
18use futures::{Stream, StreamExt};
19use tokio::{
20 runtime::Runtime,
21 sync::mpsc::{self, Receiver},
22 task::JoinHandle,
23};
24
25use super::{
26 binary_heap::{BinaryHeap, PeekMut},
27 compare::Compare,
28};
29
30pub struct EagerStream<T> {
31 rx: Receiver<T>,
32 task: JoinHandle<()>,
33 runtime: Arc<Runtime>,
34}
35
36impl<T> EagerStream<T> {
37 pub fn from_stream_with_runtime<S>(stream: S, runtime: Arc<Runtime>) -> Self
38 where
39 S: Stream<Item = T> + Send + 'static,
40 T: Send + 'static,
41 {
42 let (tx, rx) = mpsc::channel(1);
43
44 let task = runtime.spawn(async move {
45 futures::pin_mut!(stream);
46 while let Some(item) = stream.next().await {
47 if tx.send(item).await.is_err() {
48 break;
49 }
50 }
51 });
52
53 Self { rx, task, runtime }
54 }
55}
56
57impl<T> Iterator for EagerStream<T> {
58 type Item = T;
59
60 fn next(&mut self) -> Option<Self::Item> {
61 self.runtime.block_on(self.rx.recv())
62 }
63}
64
65impl<T> Drop for EagerStream<T> {
66 fn drop(&mut self) {
67 self.rx.close();
68 self.task.abort();
69 }
70}
71
72pub struct ElementBatchIter<I, T>
75where
76 I: Iterator<Item = IntoIter<T>>,
77{
78 pub item: T,
79 batch: I::Item,
80 iter: I,
81}
82
83impl<I, T> ElementBatchIter<I, T>
84where
85 I: Iterator<Item = IntoIter<T>>,
86{
87 fn new_from_iter(mut iter: I) -> Option<Self> {
88 loop {
89 let Some(mut batch) = iter.next() else {
90 break None;
91 };
92
93 if let Some(item) = batch.next() {
94 break Some(Self { item, batch, iter });
95 }
96 }
97 }
98}
99
100pub struct KMerge<I, T, C>
101where
102 I: Iterator<Item = IntoIter<T>>,
103{
104 heap: BinaryHeap<ElementBatchIter<I, T>, C>,
105}
106
107impl<I, T, C> KMerge<I, T, C>
108where
109 I: Iterator<Item = IntoIter<T>>,
110 C: Compare<ElementBatchIter<I, T>>,
111{
112 pub fn new(cmp: C) -> Self {
114 Self {
115 heap: BinaryHeap::from_vec_cmp(Vec::new(), cmp),
116 }
117 }
118
119 pub fn push_iter(&mut self, s: I) {
120 if let Some(heap_elem) = ElementBatchIter::new_from_iter(s) {
121 self.heap.push(heap_elem);
122 }
123 }
124
125 pub fn clear(&mut self) {
126 self.heap.clear();
127 }
128}
129
130impl<I, T, C> Iterator for KMerge<I, T, C>
131where
132 I: Iterator<Item = IntoIter<T>>,
133 C: Compare<ElementBatchIter<I, T>>,
134{
135 type Item = T;
136
137 fn next(&mut self) -> Option<Self::Item> {
138 match self.heap.peek_mut() {
139 Some(mut heap_elem) => {
140 match heap_elem.batch.next() {
142 Some(mut item) => {
145 std::mem::swap(&mut item, &mut heap_elem.item);
146 Some(item)
147 }
148 None => loop {
151 let Some(mut batch) = heap_elem.iter.next() else {
152 let ElementBatchIter {
153 item,
154 batch: _,
155 iter: _,
156 } = PeekMut::pop(heap_elem);
157 break Some(item);
158 };
159
160 if let Some(mut item) = batch.next() {
161 heap_elem.batch = batch;
162 std::mem::swap(&mut item, &mut heap_elem.item);
163 break Some(item);
164 }
165 },
166 }
167 }
168 None => None,
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use proptest::prelude::*;
176 use rstest::rstest;
177
178 use super::*;
179
180 struct OrdComparator;
181 impl<S> Compare<ElementBatchIter<S, i32>> for OrdComparator
182 where
183 S: Iterator<Item = IntoIter<i32>>,
184 {
185 fn compare(
186 &self,
187 l: &ElementBatchIter<S, i32>,
188 r: &ElementBatchIter<S, i32>,
189 ) -> std::cmp::Ordering {
190 l.item.cmp(&r.item).reverse()
192 }
193 }
194
195 impl<S> Compare<ElementBatchIter<S, u64>> for OrdComparator
196 where
197 S: Iterator<Item = IntoIter<u64>>,
198 {
199 fn compare(
200 &self,
201 l: &ElementBatchIter<S, u64>,
202 r: &ElementBatchIter<S, u64>,
203 ) -> std::cmp::Ordering {
204 l.item.cmp(&r.item).reverse()
206 }
207 }
208
209 #[rstest]
210 fn test1() {
211 let iter_a = vec![vec![1, 2, 3].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
212 let iter_b = vec![vec![4, 5, 6].into_iter()].into_iter();
213 let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
214 kmerge.push_iter(iter_a);
215 kmerge.push_iter(iter_b);
216
217 let values: Vec<i32> = kmerge.collect();
218 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 8, 9]);
219 }
220
221 #[rstest]
222 fn test2() {
223 let iter_a = vec![vec![1, 2, 6].into_iter(), vec![7, 8, 9].into_iter()].into_iter();
224 let iter_b = vec![vec![3, 4, 5, 6].into_iter()].into_iter();
225 let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
226 kmerge.push_iter(iter_a);
227 kmerge.push_iter(iter_b);
228
229 let values: Vec<i32> = kmerge.collect();
230 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 6, 7, 8, 9]);
231 }
232
233 #[rstest]
234 fn test3() {
235 let iter_a = vec![vec![1, 4, 7].into_iter(), vec![24, 35, 56].into_iter()].into_iter();
236 let iter_b = vec![vec![2, 4, 8].into_iter()].into_iter();
237 let iter_c = vec![vec![3, 5, 9].into_iter(), vec![12, 12, 90].into_iter()].into_iter();
238 let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
239 kmerge.push_iter(iter_a);
240 kmerge.push_iter(iter_b);
241 kmerge.push_iter(iter_c);
242
243 let values: Vec<i32> = kmerge.collect();
244 assert_eq!(
245 values,
246 vec![1, 2, 3, 4, 4, 5, 7, 8, 9, 12, 12, 24, 35, 56, 90]
247 );
248 }
249
250 #[rstest]
251 fn test5() {
252 let iter_a = vec![
253 vec![1, 3, 5].into_iter(),
254 vec![].into_iter(),
255 vec![7, 9, 11].into_iter(),
256 ]
257 .into_iter();
258 let iter_b = vec![vec![2, 4, 6].into_iter()].into_iter();
259 let mut kmerge: KMerge<_, i32, _> = KMerge::new(OrdComparator);
260 kmerge.push_iter(iter_a);
261 kmerge.push_iter(iter_b);
262
263 let values: Vec<i32> = kmerge.collect();
264 assert_eq!(values, vec![1, 2, 3, 4, 5, 6, 7, 9, 11]);
265 }
266
267 #[derive(Debug, Clone)]
268 struct SortedNestedVec(Vec<Vec<u64>>);
269
270 fn sorted_nested_vec_strategy() -> impl Strategy<Value = SortedNestedVec> {
272 prop::collection::vec(any::<u64>(), 0..=100).prop_flat_map(|mut flat_vec| {
274 flat_vec.sort_unstable();
275
276 let total_len = flat_vec.len();
278 if total_len == 0 {
279 return Just(SortedNestedVec(vec![vec![]])).boxed();
280 }
281
282 prop::collection::vec(0..=total_len, 0..=10)
284 .prop_map(move |mut boundaries| {
285 boundaries.push(0);
286 boundaries.push(total_len);
287 boundaries.sort_unstable();
288 boundaries.dedup();
289
290 let mut nested_vec = Vec::new();
291 for [start, end] in boundaries.array_windows() {
292 nested_vec.push(flat_vec[*start..*end].to_vec());
293 }
294
295 SortedNestedVec(nested_vec)
296 })
297 .boxed()
298 })
299 }
300
301 proptest! {
302 #[rstest]
304 fn prop_kmerge_equivalent_to_sort(
305 all_data in prop::collection::vec(sorted_nested_vec_strategy(), 0..=10)
306 ) {
307 let mut kmerge: KMerge<_, u64, _> = KMerge::new(OrdComparator);
308
309 let copy_data = all_data.clone();
310 for stream in copy_data {
311 let input = stream.0.into_iter().map(std::iter::IntoIterator::into_iter);
312 kmerge.push_iter(input);
313 }
314 let merged_data: Vec<u64> = kmerge.collect();
315
316 let mut sorted_data: Vec<u64> = all_data
317 .into_iter()
318 .flat_map(|stream| stream.0.into_iter().flatten())
319 .collect();
320 sorted_data.sort_unstable();
321
322 prop_assert_eq!(merged_data.len(), sorted_data.len(), "Lengths should be equal");
323 prop_assert_eq!(merged_data, sorted_data, "Merged data should equal sorted data");
324 }
325
326 #[rstest]
328 fn prop_kmerge_preserves_sort_order(
329 all_data in prop::collection::vec(sorted_nested_vec_strategy(), 1..=5)
330 ) {
331 let mut kmerge: KMerge<_, u64, _> = KMerge::new(OrdComparator);
332
333 for stream in all_data {
334 let input = stream.0.into_iter().map(std::iter::IntoIterator::into_iter);
335 kmerge.push_iter(input);
336 }
337 let merged_data: Vec<u64> = kmerge.collect();
338
339 for [a, b] in merged_data.array_windows() {
341 prop_assert!(a <= b, "Merged data should be sorted");
342 }
343 }
344
345 #[rstest]
347 fn prop_kmerge_handles_empty_iterators(
348 data in sorted_nested_vec_strategy(),
349 empty_count in 0usize..=5
350 ) {
351 let mut kmerge_with_empty: KMerge<_, u64, _> = KMerge::new(OrdComparator);
352 let mut kmerge_without_empty: KMerge<_, u64, _> = KMerge::new(OrdComparator);
353
354 let input_with_empty = data.0.clone().into_iter().map(std::iter::IntoIterator::into_iter);
356 let input_without_empty = data.0.into_iter().map(std::iter::IntoIterator::into_iter);
357
358 kmerge_with_empty.push_iter(input_with_empty);
359 kmerge_without_empty.push_iter(input_without_empty);
360
361 for _ in 0..empty_count {
363 let empty_vec: Vec<Vec<u64>> = vec![];
364 let empty_input = empty_vec.into_iter().map(std::iter::IntoIterator::into_iter);
365 kmerge_with_empty.push_iter(empty_input);
366 }
367
368 let result_with_empty: Vec<u64> = kmerge_with_empty.collect();
369 let result_without_empty: Vec<u64> = kmerge_without_empty.collect();
370
371 prop_assert_eq!(result_with_empty, result_without_empty, "Empty iterators should not affect result");
372 }
373 }
374}