1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
use arrow_buffer::BooleanBuffer;
use itertools::Itertools;
use num_traits::AsPrimitive;
use vortex_dtype::match_each_integer_ptype;
use vortex_error::VortexResult;

use crate::array::{BoolArray, BoolEncoding};
use crate::compute::TakeFn;
use crate::variants::PrimitiveArrayTrait;
use crate::{ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant};

impl TakeFn<BoolArray> for BoolEncoding {
    fn take(&self, array: &BoolArray, indices: &ArrayData) -> VortexResult<ArrayData> {
        let validity = array.validity();
        let indices = indices.clone().into_primitive()?;

        // For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth
        // the overhead to convert to a Vec<bool>.
        let buffer = if array.len() <= 4096 {
            let bools = array.boolean_buffer().into_iter().collect_vec();
            match_each_integer_ptype!(indices.ptype(), |$I| {
                take_byte_bool(bools, indices.maybe_null_slice::<$I>())
            })
        } else {
            match_each_integer_ptype!(indices.ptype(), |$I| {
                take_bool(&array.boolean_buffer(), indices.maybe_null_slice::<$I>())
            })
        };

        Ok(BoolArray::try_new(buffer, validity.take(indices.as_ref())?)?.into_array())
    }

    unsafe fn take_unchecked(
        &self,
        array: &BoolArray,
        indices: &ArrayData,
    ) -> VortexResult<ArrayData> {
        let validity = array.validity();
        let indices = indices.clone().into_primitive()?;

        // For boolean arrays that roughly fit into a single page (at least, on Linux), it's worth
        // the overhead to convert to a Vec<bool>.
        let buffer = if array.len() <= 4096 {
            let bools = array.boolean_buffer().into_iter().collect_vec();
            match_each_integer_ptype!(indices.ptype(), |$I| {
                take_byte_bool_unchecked(bools, indices.maybe_null_slice::<$I>())
            })
        } else {
            match_each_integer_ptype!(indices.ptype(), |$I| {
                take_bool_unchecked(&array.boolean_buffer(), indices.maybe_null_slice::<$I>())
            })
        };

        // SAFETY: caller enforces indices are valid for array, and array has same len as validity.
        let validity = unsafe { validity.take_unchecked(indices.as_ref())? };
        Ok(BoolArray::try_new(buffer, validity)?.into_array())
    }
}

fn take_byte_bool<I: AsPrimitive<usize>>(bools: Vec<bool>, indices: &[I]) -> BooleanBuffer {
    BooleanBuffer::collect_bool(indices.len(), |idx| {
        bools[unsafe { (*indices.get_unchecked(idx)).as_() }]
    })
}

fn take_byte_bool_unchecked<I: AsPrimitive<usize>>(
    bools: Vec<bool>,
    indices: &[I],
) -> BooleanBuffer {
    BooleanBuffer::collect_bool(indices.len(), |idx| unsafe {
        *bools.get_unchecked((*indices.get_unchecked(idx)).as_())
    })
}

fn take_bool<I: AsPrimitive<usize>>(bools: &BooleanBuffer, indices: &[I]) -> BooleanBuffer {
    BooleanBuffer::collect_bool(indices.len(), |idx| {
        // We can always take from the indices unchecked since collect_bool just iterates len.
        bools.value(unsafe { (*indices.get_unchecked(idx)).as_() })
    })
}

fn take_bool_unchecked<I: AsPrimitive<usize>>(
    bools: &BooleanBuffer,
    indices: &[I],
) -> BooleanBuffer {
    BooleanBuffer::collect_bool(indices.len(), |idx| unsafe {
        // We can always take from the indices unchecked since collect_bool just iterates len.
        bools.value_unchecked((*indices.get_unchecked(idx)).as_())
    })
}

#[cfg(test)]
mod test {
    use crate::array::primitive::PrimitiveArray;
    use crate::array::BoolArray;
    use crate::compute::take;

    #[test]
    fn take_nullable() {
        let reference = BoolArray::from_iter(vec![
            Some(false),
            Some(true),
            Some(false),
            None,
            Some(false),
        ]);

        let b = BoolArray::try_from(take(&reference, PrimitiveArray::from(vec![0, 3, 4])).unwrap())
            .unwrap();
        assert_eq!(
            b.boolean_buffer(),
            BoolArray::from_iter(vec![Some(false), None, Some(false)]).boolean_buffer()
        );
    }
}