1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};

use crate::encoding::Encoding;
use crate::stats::{ArrayStatistics, Stat};
use crate::{ArrayDType, ArrayData, IntoArrayData, IntoCanonical};

pub trait TakeFn<Array> {
    /// Create a new array by taking the values from the `array` at the
    /// given `indices`.
    ///
    /// # Panics
    ///
    /// Using `indices` that are invalid for the given `array` will cause a panic.
    fn take(&self, array: &Array, indices: &ArrayData) -> VortexResult<ArrayData>;

    /// Create a new array by taking the values from the `array` at the
    /// given `indices`.
    ///
    /// # Safety
    ///
    /// This take variant will not perform bounds checking on indices, so it is the caller's
    /// responsibility to ensure that the `indices` are all valid for the provided `array`.
    /// Failure to do so could result in out of bounds memory access or UB.
    unsafe fn take_unchecked(&self, array: &Array, indices: &ArrayData) -> VortexResult<ArrayData> {
        self.take(array, indices)
    }
}

impl<E: Encoding> TakeFn<ArrayData> for E
where
    E: TakeFn<E::Array>,
    for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
    fn take(&self, array: &ArrayData, indices: &ArrayData) -> VortexResult<ArrayData> {
        let array_ref = <&E::Array>::try_from(array)?;
        let encoding = array
            .encoding()
            .as_any()
            .downcast_ref::<E>()
            .ok_or_else(|| vortex_err!("Mismatched encoding"))?;
        TakeFn::take(encoding, array_ref, indices)
    }
}

pub fn take(
    array: impl AsRef<ArrayData>,
    indices: impl AsRef<ArrayData>,
) -> VortexResult<ArrayData> {
    // TODO(ngates): if indices are sorted and unique (strict-sorted), then we should delegate to
    //  the filter function since they're typically optimised for this case.
    // TODO(ngates): if indices min is quite high, we could slice self and offset the indices
    //  such that canonicalize does less work.

    let array = array.as_ref();
    let indices = indices.as_ref();

    if !indices.dtype().is_int() || indices.dtype().is_nullable() {
        vortex_bail!(
            "Take indices must be a non-nullable integer type, got {}",
            indices.dtype()
        );
    }

    // If the indices are all within bounds, we can skip bounds checking.
    let checked_indices = indices
        .statistics()
        .get_as::<usize>(Stat::Max)
        .is_some_and(|max| max < array.len());

    let taken = take_impl(array, indices, checked_indices)?;

    debug_assert_eq!(
        taken.len(),
        indices.len(),
        "Take length mismatch {}",
        array.encoding().id()
    );
    debug_assert_eq!(
        array.dtype(),
        taken.dtype(),
        "Take dtype mismatch {}",
        array.encoding().id()
    );

    Ok(taken)
}

fn take_impl(
    array: &ArrayData,
    indices: &ArrayData,
    checked_indices: bool,
) -> VortexResult<ArrayData> {
    // If TakeFn defined for the encoding, delegate to TakeFn.
    // If we know from stats that indices are all valid, we can avoid all bounds checks.
    if let Some(take_fn) = array.encoding().take_fn() {
        let result = if checked_indices {
            // SAFETY: indices are all inbounds per stats.
            // TODO(aduffy): this means stats must be trusted, can still trigger UB if stats are bad.
            unsafe { take_fn.take_unchecked(array, indices) }
        } else {
            take_fn.take(array, indices)
        }?;
        if array.dtype() != result.dtype() {
            vortex_bail!(
                "TakeFn {} changed array dtype from {} to {}",
                array.encoding().id(),
                array.dtype(),
                result.dtype()
            );
        }
        return Ok(result);
    }

    // Otherwise, flatten and try again.
    log::debug!("No take implementation found for {}", array.encoding().id());
    let canonical = array.clone().into_canonical()?.into_array();
    let canonical_take_fn = canonical
        .encoding()
        .take_fn()
        .ok_or_else(|| vortex_err!(NotImplemented: "take", canonical.encoding().id()))?;

    if checked_indices {
        // SAFETY: indices are known to be in-bound from stats
        unsafe { canonical_take_fn.take_unchecked(&canonical, indices) }
    } else {
        canonical_take_fn.take(&canonical, indices)
    }
}