use std::fmt::{Debug, Display};
use std::sync::Arc;
use fsst::{Decompressor, Symbol};
use serde::{Deserialize, Serialize};
use vortex_array::array::{VarBinArray, VarBinEncoding};
use vortex_array::encoding::{ids, Encoding};
use vortex_array::stats::{StatisticsVTable, StatsSet};
use vortex_array::validity::{ArrayValidity, LogicalValidity, Validity, ValidityVTable};
use vortex_array::variants::{BinaryArrayTrait, Utf8ArrayTrait, VariantsVTable};
use vortex_array::visitor::{ArrayVisitor, VisitorVTable};
use vortex_array::{impl_encoding, ArrayDType, ArrayData, ArrayLen, ArrayTrait, IntoCanonical};
use vortex_dtype::{DType, Nullability, PType};
use vortex_error::{vortex_bail, VortexExpect, VortexResult};
impl_encoding!("vortex.fsst", ids::FSST, FSST);
static SYMBOLS_DTYPE: DType = DType::Primitive(PType::U64, Nullability::NonNullable);
static SYMBOL_LENS_DTYPE: DType = DType::Primitive(PType::U8, Nullability::NonNullable);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FSSTMetadata {
symbols_len: usize,
codes_nullability: Nullability,
uncompressed_lengths_ptype: PType,
}
impl Display for FSSTMetadata {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
Debug::fmt(self, f)
}
}
impl FSSTArray {
pub fn try_new(
dtype: DType,
symbols: ArrayData,
symbol_lengths: ArrayData,
codes: ArrayData,
uncompressed_lengths: ArrayData,
) -> VortexResult<Self> {
if symbols.dtype() != &SYMBOLS_DTYPE {
vortex_bail!(InvalidArgument: "symbols array must be of type u64")
}
if symbol_lengths.dtype() != &SYMBOL_LENS_DTYPE {
vortex_bail!(InvalidArgument: "symbol_lengths array must be of type u8")
}
if symbols.len() > 255 {
vortex_bail!(InvalidArgument: "symbols array must have length <= 255");
}
if symbols.len() != symbol_lengths.len() {
vortex_bail!(InvalidArgument: "symbols and symbol_lengths arrays must have same length");
}
if uncompressed_lengths.len() != codes.len() {
vortex_bail!(InvalidArgument: "uncompressed_lengths must be same len as codes");
}
if !uncompressed_lengths.dtype().is_int() || uncompressed_lengths.dtype().is_nullable() {
vortex_bail!(InvalidArgument: "uncompressed_lengths must have integer type and cannot be nullable, found {}", uncompressed_lengths.dtype());
}
if codes.encoding().id() != VarBinEncoding::ID {
vortex_bail!(
InvalidArgument: "codes must have varbin encoding, was {}",
codes.encoding().id()
);
}
if !matches!(codes.dtype(), DType::Binary(_)) {
vortex_bail!(InvalidArgument: "codes array must be DType::Binary type");
}
let symbols_len = symbols.len();
let len = codes.len();
let uncompressed_lengths_ptype = PType::try_from(uncompressed_lengths.dtype())?;
let codes_nullability = codes.dtype().nullability();
let children = Arc::new([symbols, symbol_lengths, codes, uncompressed_lengths]);
Self::try_from_parts(
dtype,
len,
FSSTMetadata {
symbols_len,
codes_nullability,
uncompressed_lengths_ptype,
},
children,
StatsSet::default(),
)
}
pub fn symbols(&self) -> ArrayData {
self.as_ref()
.child(0, &SYMBOLS_DTYPE, self.metadata().symbols_len)
.vortex_expect("FSSTArray symbols child")
}
pub fn symbol_lengths(&self) -> ArrayData {
self.as_ref()
.child(1, &SYMBOL_LENS_DTYPE, self.metadata().symbols_len)
.vortex_expect("FSSTArray symbol_lengths child")
}
pub fn codes(&self) -> ArrayData {
self.as_ref()
.child(2, &self.codes_dtype(), self.len())
.vortex_expect("FSSTArray codes child")
}
#[inline]
pub fn codes_dtype(&self) -> DType {
DType::Binary(self.metadata().codes_nullability)
}
pub fn uncompressed_lengths(&self) -> ArrayData {
self.as_ref()
.child(3, &self.uncompressed_lengths_dtype(), self.len())
.vortex_expect("FSST uncompressed_lengths child")
}
#[inline]
pub fn uncompressed_lengths_dtype(&self) -> DType {
DType::Primitive(
self.metadata().uncompressed_lengths_ptype,
Nullability::NonNullable,
)
}
pub fn validity(&self) -> Validity {
VarBinArray::try_from(self.codes())
.vortex_expect("FSSTArray must have a codes child array")
.validity()
}
pub(crate) fn with_decompressor<F, R>(&self, apply: F) -> VortexResult<R>
where
F: FnOnce(Decompressor) -> VortexResult<R>,
{
let symbols_array = self
.symbols()
.into_canonical()
.map_err(|err| err.with_context("Failed to canonicalize symbols array"))?
.into_primitive()
.map_err(|err| err.with_context("Symbols must be a Primitive Array"))?;
let symbols = symbols_array.maybe_null_slice::<u64>();
let symbol_lengths_array = self
.symbol_lengths()
.into_canonical()
.map_err(|err| err.with_context("Failed to canonicalize symbol_lengths array"))?
.into_primitive()
.map_err(|err| err.with_context("Symbol lengths must be a Primitive Array"))?;
let symbol_lengths = symbol_lengths_array.maybe_null_slice::<u8>();
let symbols = unsafe { std::mem::transmute::<&[u64], &[Symbol]>(symbols) };
let decompressor = Decompressor::new(symbols, symbol_lengths);
apply(decompressor)
}
}
impl VisitorVTable<FSSTArray> for FSSTEncoding {
fn accept(&self, array: &FSSTArray, visitor: &mut dyn ArrayVisitor) -> VortexResult<()> {
visitor.visit_child("symbols", &array.symbols())?;
visitor.visit_child("symbol_lengths", &array.symbol_lengths())?;
visitor.visit_child("codes", &array.codes())?;
visitor.visit_child("uncompressed_lengths", &array.uncompressed_lengths())
}
}
impl StatisticsVTable<FSSTArray> for FSSTEncoding {}
impl ValidityVTable<FSSTArray> for FSSTEncoding {
fn is_valid(&self, array: &FSSTArray, index: usize) -> bool {
array.codes().is_valid(index)
}
fn logical_validity(&self, array: &FSSTArray) -> LogicalValidity {
array.codes().logical_validity()
}
}
impl VariantsVTable<FSSTArray> for FSSTEncoding {
fn as_utf8_array<'a>(&self, array: &'a FSSTArray) -> Option<&'a dyn Utf8ArrayTrait> {
Some(array)
}
fn as_binary_array<'a>(&self, array: &'a FSSTArray) -> Option<&'a dyn BinaryArrayTrait> {
Some(array)
}
}
impl Utf8ArrayTrait for FSSTArray {}
impl BinaryArrayTrait for FSSTArray {}
impl ArrayTrait for FSSTArray {}