1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
use std::pin::Pin;
use std::task::{Context, Poll};

use futures_util::{ready, Stream};
use pin_project::pin_project;
use vortex_dtype::match_each_integer_ptype;
use vortex_error::{vortex_bail, VortexResult};
use vortex_scalar::Scalar;

use crate::compute::{search_sorted_usize, slice, sub_scalar, take, SearchSortedSide};
use crate::stats::{ArrayStatistics, Stat};
use crate::stream::ArrayStream;
use crate::variants::PrimitiveArrayTrait;
use crate::{ArrayDType, ArrayData, IntoArrayData, IntoArrayVariant};

#[pin_project]
pub struct TakeRows<R: ArrayStream> {
    #[pin]
    reader: R,
    indices: ArrayData,
    row_offset: usize,
}

impl<R: ArrayStream> TakeRows<R> {
    pub fn try_new(reader: R, indices: ArrayData) -> VortexResult<Self> {
        if !indices.is_empty() {
            if !indices.statistics().compute_is_sorted().unwrap_or(false) {
                vortex_bail!("Indices must be sorted to take from IPC stream")
            }

            if indices
                .statistics()
                .compute_null_count()
                .map(|nc| nc > 0)
                .unwrap_or(true)
            {
                vortex_bail!("Indices must not contain nulls")
            }

            if !indices.dtype().is_int() {
                vortex_bail!("Indices must be integers")
            }

            if indices.dtype().is_signed_int()
                && indices
                    .statistics()
                    .compute_as_cast::<i64>(Stat::Min)
                    .map(|min| min < 0)
                    .unwrap_or(true)
            {
                vortex_bail!("Indices must be positive")
            }
        }

        Ok(Self {
            reader,
            indices,
            row_offset: 0,
        })
    }
}

impl<R: ArrayStream> Stream for TakeRows<R> {
    type Item = VortexResult<ArrayData>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        let mut this = self.project();

        if this.indices.is_empty() {
            return Poll::Ready(None);
        }

        while let Some(batch) = ready!(this.reader.as_mut().poll_next(cx)?) {
            let curr_offset = *this.row_offset;
            let left =
                search_sorted_usize(this.indices, curr_offset, SearchSortedSide::Left)?.to_index();
            let right = search_sorted_usize(
                this.indices,
                curr_offset + batch.len(),
                SearchSortedSide::Left,
            )?
            .to_index();

            *this.row_offset += batch.len();

            if left == right {
                continue;
            }

            // TODO(ngates): this is probably too heavy to run on the event loop. We should spawn
            //  onto a worker pool.
            let indices_for_batch = slice(this.indices, left, right)?.into_primitive()?;
            let shifted_arr = match_each_integer_ptype!(indices_for_batch.ptype(), |$T| {
                sub_scalar(&indices_for_batch.into_array(), Scalar::from(curr_offset as $T))?
            });
            return Poll::Ready(take(&batch, &shifted_arr).map(Some).transpose());
        }

        Poll::Ready(None)
    }
}