use itertools::Itertools;
use vortex_dtype::PType;
use vortex_error::VortexResult;
use vortex_scalar::Scalar;
use crate::array::chunked::ChunkedArray;
use crate::array::ChunkedEncoding;
use crate::compute::{
scalar_at, search_sorted_usize, slice, sub_scalar, take, try_cast, SearchSortedSide, TakeFn,
};
use crate::stats::ArrayStatistics;
use crate::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData};
impl TakeFn<ChunkedArray> for ChunkedEncoding {
fn take(&self, array: &ChunkedArray, indices: &ArrayData) -> VortexResult<ArrayData> {
if indices
.statistics()
.compute_is_strict_sorted()
.unwrap_or(false)
{
if array.len() == indices.len() {
return Ok(array.to_array());
}
return take_strict_sorted(array, indices);
}
let indices = try_cast(indices, PType::U64.into())?.into_primitive()?;
let mut chunks = Vec::new();
let mut indices_in_chunk = Vec::new();
let mut prev_chunk_idx = array
.find_chunk_idx(indices.maybe_null_slice::<u64>()[0].try_into()?)
.0;
for idx in indices.maybe_null_slice::<u64>() {
let idx = usize::try_from(*idx)?;
let (chunk_idx, idx_in_chunk) = array.find_chunk_idx(idx);
if chunk_idx != prev_chunk_idx {
let indices_in_chunk_array = indices_in_chunk.clone().into_array();
chunks.push(take(
&array.chunk(prev_chunk_idx)?,
&indices_in_chunk_array,
)?);
indices_in_chunk = Vec::new();
}
indices_in_chunk.push(idx_in_chunk as u64);
prev_chunk_idx = chunk_idx;
}
if !indices_in_chunk.is_empty() {
let indices_in_chunk_array = indices_in_chunk.into_array();
chunks.push(take(
&array.chunk(prev_chunk_idx)?,
&indices_in_chunk_array,
)?);
}
Ok(ChunkedArray::try_new(chunks, array.dtype().clone())?.into_array())
}
}
fn take_strict_sorted(chunked: &ChunkedArray, indices: &ArrayData) -> VortexResult<ArrayData> {
let mut indices_by_chunk = vec![None; chunked.nchunks()];
let mut pos = 0;
while pos < indices.len() {
let idx = usize::try_from(&scalar_at(indices, pos)?)?;
let (chunk_idx, _idx_in_chunk) = chunked.find_chunk_idx(idx);
let chunk_begin = usize::try_from(&scalar_at(chunked.chunk_offsets(), chunk_idx)?)?;
let chunk_end = usize::try_from(&scalar_at(chunked.chunk_offsets(), chunk_idx + 1)?)?;
let chunk_end_pos =
search_sorted_usize(indices, chunk_end, SearchSortedSide::Left)?.to_index();
let chunk_indices = slice(indices, pos, chunk_end_pos)?;
let chunk_indices = if chunk_begin
< PType::try_from(chunk_indices.dtype())?
.max_value_as_u64()
.try_into()?
{
sub_scalar(
&chunk_indices,
Scalar::from(chunk_begin).cast(chunk_indices.dtype())?,
)?
} else {
let u64_chunk_indices = try_cast(&chunk_indices, PType::U64.into())?;
sub_scalar(&u64_chunk_indices, chunk_begin.into())?
};
indices_by_chunk[chunk_idx] = Some(chunk_indices);
pos = chunk_end_pos;
}
let chunks = indices_by_chunk
.into_iter()
.enumerate()
.filter_map(|(chunk_idx, indices)| indices.map(|i| (chunk_idx, i)))
.map(|(chunk_idx, chunk_indices)| take(&chunked.chunk(chunk_idx)?, &chunk_indices))
.try_collect()?;
Ok(ChunkedArray::try_new(chunks, chunked.dtype().clone())?.into_array())
}
#[cfg(test)]
mod test {
use crate::array::chunked::ChunkedArray;
use crate::compute::take;
use crate::{ArrayDType, ArrayLen, IntoArrayData, IntoArrayVariant};
#[test]
fn test_take() {
let a = vec![1i32, 2, 3].into_array();
let arr = ChunkedArray::try_new(vec![a.clone(), a.clone(), a.clone()], a.dtype().clone())
.unwrap();
assert_eq!(arr.nchunks(), 3);
assert_eq!(arr.len(), 9);
let indices = vec![0u64, 0, 6, 4].into_array();
let result = &ChunkedArray::try_from(take(arr.as_ref(), &indices).unwrap())
.unwrap()
.into_array()
.into_primitive()
.unwrap();
assert_eq!(result.maybe_null_slice::<i32>(), &[1, 1, 1, 2]);
}
}