use std::fmt::{Display, Formatter};
use std::hash::Hash;
use std::sync::Arc;
use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, Buffer, MutableBuffer};
use enum_iterator::{cardinality, Sequence};
use itertools::Itertools;
use log::debug;
use num_enum::{IntoPrimitive, TryFromPrimitive};
pub use statsset::*;
use vortex_dtype::Nullability::NonNullable;
use vortex_dtype::{DType, NativePType, PType};
use vortex_error::{vortex_err, vortex_panic, VortexError, VortexExpect, VortexResult};
use vortex_scalar::Scalar;
use crate::encoding::Encoding;
use crate::ArrayData;
pub mod flatbuffers;
mod statsset;
pub const PRUNING_STATS: &[Stat] = &[Stat::Min, Stat::Max, Stat::TrueCount, Stat::NullCount];
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Sequence, IntoPrimitive, TryFromPrimitive)]
#[repr(u8)]
pub enum Stat {
BitWidthFreq,
TrailingZeroFreq,
IsConstant,
IsSorted,
IsStrictSorted,
Max,
Min,
RunCount,
TrueCount,
NullCount,
UncompressedSizeInBytes,
}
impl Stat {
pub fn is_commutative(&self) -> bool {
matches!(
self,
Stat::BitWidthFreq
| Stat::TrailingZeroFreq
| Stat::IsConstant
| Stat::Max
| Stat::Min
| Stat::TrueCount
| Stat::NullCount
| Stat::UncompressedSizeInBytes
)
}
pub fn has_same_dtype_as_array(&self) -> bool {
matches!(self, Stat::Min | Stat::Max)
}
pub fn dtype(&self, data_type: &DType) -> DType {
match self {
Stat::BitWidthFreq => DType::List(
Arc::new(DType::Primitive(PType::U64, NonNullable)),
NonNullable,
),
Stat::TrailingZeroFreq => DType::List(
Arc::new(DType::Primitive(PType::U64, NonNullable)),
NonNullable,
),
Stat::IsConstant => DType::Bool(NonNullable),
Stat::IsSorted => DType::Bool(NonNullable),
Stat::IsStrictSorted => DType::Bool(NonNullable),
Stat::Max => data_type.clone(),
Stat::Min => data_type.clone(),
Stat::RunCount => DType::Primitive(PType::U64, NonNullable),
Stat::TrueCount => DType::Primitive(PType::U64, NonNullable),
Stat::NullCount => DType::Primitive(PType::U64, NonNullable),
Stat::UncompressedSizeInBytes => DType::Primitive(PType::U64, NonNullable),
}
}
pub fn name(&self) -> &str {
match self {
Self::BitWidthFreq => "bit_width_frequency",
Self::TrailingZeroFreq => "trailing_zero_frequency",
Self::IsConstant => "is_constant",
Self::IsSorted => "is_sorted",
Self::IsStrictSorted => "is_strict_sorted",
Self::Max => "max",
Self::Min => "min",
Self::RunCount => "run_count",
Self::TrueCount => "true_count",
Self::NullCount => "null_count",
Self::UncompressedSizeInBytes => "uncompressed_size_in_bytes",
}
}
}
pub fn as_stat_bitset_bytes(stats: &[Stat]) -> Vec<u8> {
let stat_count = cardinality::<Stat>();
let mut stat_bitset = BooleanBufferBuilder::new_from_buffer(
MutableBuffer::from_len_zeroed(stat_count.div_ceil(8)),
stat_count,
);
for stat in stats {
stat_bitset.set_bit(u8::from(*stat) as usize, true);
}
stat_bitset
.finish()
.into_inner()
.into_vec()
.unwrap_or_else(|b| b.to_vec())
}
pub fn stats_from_bitset_bytes(bytes: &[u8]) -> Vec<Stat> {
BooleanBuffer::new(Buffer::from(bytes), 0, bytes.len() * 8)
.set_indices()
.filter_map(|i| {
let Ok(stat) = u8::try_from(i) else {
debug!("invalid stat encountered: {i}");
return None;
};
Stat::try_from(stat).ok()
})
.collect::<Vec<_>>()
}
impl Display for Stat {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
pub trait Statistics {
fn get(&self, stat: Stat) -> Option<Scalar>;
fn to_set(&self) -> StatsSet;
fn set(&self, stat: Stat, value: Scalar);
fn clear(&self, stat: Stat);
fn compute(&self, stat: Stat) -> Option<Scalar>;
fn compute_all(&self, stats: &[Stat]) -> VortexResult<StatsSet> {
let mut stats_set = StatsSet::default();
for stat in stats {
if let Some(s) = self.compute(*stat) {
stats_set.set(*stat, s)
}
}
Ok(stats_set)
}
fn retain_only(&self, stats: &[Stat]);
}
pub trait ArrayStatistics {
fn statistics(&self) -> &dyn Statistics;
fn inherit_statistics(&self, parent: &dyn Statistics);
}
pub trait StatisticsVTable<Array: ?Sized> {
fn compute_statistics(&self, _array: &Array, _stat: Stat) -> VortexResult<StatsSet> {
Ok(StatsSet::default())
}
}
impl<E: Encoding + 'static> StatisticsVTable<ArrayData> for E
where
E: StatisticsVTable<E::Array>,
for<'a> &'a E::Array: TryFrom<&'a ArrayData, Error = VortexError>,
{
fn compute_statistics(&self, array: &ArrayData, stat: Stat) -> VortexResult<StatsSet> {
let array_ref = <&E::Array>::try_from(array)?;
let encoding = array
.encoding()
.as_any()
.downcast_ref::<E>()
.ok_or_else(|| vortex_err!("Mismatched encoding"))?;
StatisticsVTable::compute_statistics(encoding, array_ref, stat)
}
}
impl dyn Statistics + '_ {
pub fn get_as<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(
&self,
stat: Stat,
) -> Option<U> {
self.get(stat)
.map(|s| U::try_from(&s))
.transpose()
.unwrap_or_else(|err| {
vortex_panic!(
err,
"Failed to cast stat {} to {}",
stat,
std::any::type_name::<U>()
)
})
}
pub fn get_as_cast<U: NativePType + for<'a> TryFrom<&'a Scalar, Error = VortexError>>(
&self,
stat: Stat,
) -> Option<U> {
self.get(stat)
.filter(|s| s.is_valid())
.map(|s| s.cast(&DType::Primitive(U::PTYPE, NonNullable)))
.transpose()
.and_then(|maybe| maybe.as_ref().map(U::try_from).transpose())
.unwrap_or_else(|err| {
vortex_panic!(err, "Failed to cast stat {} to {}", stat, U::PTYPE)
})
}
pub fn compute_as<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(
&self,
stat: Stat,
) -> Option<U> {
self.compute(stat)
.map(|s| U::try_from(&s))
.transpose()
.unwrap_or_else(|err| {
vortex_panic!(
err,
"Failed to compute stat {} as {}",
stat,
std::any::type_name::<U>()
)
})
}
pub fn compute_as_cast<U: NativePType + for<'a> TryFrom<&'a Scalar, Error = VortexError>>(
&self,
stat: Stat,
) -> Option<U> {
self.compute(stat)
.filter(|s| s.is_valid())
.map(|s| s.cast(&DType::Primitive(U::PTYPE, NonNullable)))
.transpose()
.and_then(|maybe| maybe.as_ref().map(U::try_from).transpose())
.unwrap_or_else(|err| {
vortex_panic!(err, "Failed to compute stat {} as cast {}", stat, U::PTYPE)
})
}
pub fn compute_min<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(&self) -> Option<U> {
self.compute_as(Stat::Min)
}
pub fn compute_max<U: for<'a> TryFrom<&'a Scalar, Error = VortexError>>(&self) -> Option<U> {
self.compute_as(Stat::Max)
}
pub fn compute_is_strict_sorted(&self) -> Option<bool> {
self.compute_as(Stat::IsStrictSorted)
}
pub fn compute_is_sorted(&self) -> Option<bool> {
self.compute_as(Stat::IsSorted)
}
pub fn compute_is_constant(&self) -> Option<bool> {
self.compute_as(Stat::IsConstant)
}
pub fn compute_true_count(&self) -> Option<usize> {
self.compute_as(Stat::TrueCount)
}
pub fn compute_null_count(&self) -> Option<usize> {
self.compute_as(Stat::NullCount)
}
pub fn compute_run_count(&self) -> Option<usize> {
self.compute_as(Stat::RunCount)
}
pub fn compute_bit_width_freq(&self) -> Option<Vec<usize>> {
self.compute_as::<Vec<usize>>(Stat::BitWidthFreq)
}
pub fn compute_trailing_zero_freq(&self) -> Option<Vec<usize>> {
self.compute_as::<Vec<usize>>(Stat::TrailingZeroFreq)
}
pub fn compute_uncompressed_size_in_bytes(&self) -> Option<usize> {
self.compute_as(Stat::UncompressedSizeInBytes)
}
}
pub fn trailing_zeros(array: &ArrayData) -> u8 {
let tz_freq = array
.statistics()
.compute_trailing_zero_freq()
.unwrap_or_else(|| vec![0]);
tz_freq
.iter()
.enumerate()
.find_or_first(|(_, &v)| v > 0)
.map(|(i, _)| i)
.unwrap_or(0)
.try_into()
.vortex_expect("tz_freq must fit in u8")
}
#[cfg(test)]
mod test {
use enum_iterator::all;
use crate::array::PrimitiveArray;
use crate::stats::{ArrayStatistics, Stat};
#[test]
fn min_of_nulls_is_not_panic() {
let min = PrimitiveArray::from_nullable_vec::<i32>(vec![None, None, None, None])
.statistics()
.compute_as_cast::<i64>(Stat::Min);
assert_eq!(min, None);
}
#[test]
fn commutativity() {
assert!(Stat::BitWidthFreq.is_commutative());
assert!(Stat::TrailingZeroFreq.is_commutative());
assert!(Stat::IsConstant.is_commutative());
assert!(Stat::Min.is_commutative());
assert!(Stat::Max.is_commutative());
assert!(Stat::TrueCount.is_commutative());
assert!(Stat::NullCount.is_commutative());
assert!(!Stat::IsStrictSorted.is_commutative());
assert!(!Stat::IsSorted.is_commutative());
assert!(!Stat::RunCount.is_commutative());
}
#[test]
fn has_same_dtype_as_array() {
assert!(Stat::Min.has_same_dtype_as_array());
assert!(Stat::Max.has_same_dtype_as_array());
for stat in all::<Stat>().filter(|s| !matches!(s, Stat::Min | Stat::Max)) {
assert!(!stat.has_same_dtype_as_array());
}
}
}