use core::fmt::Formatter;
use std::fmt::Display;
use rand::rngs::StdRng;
use rand::SeedableRng as _;
use vortex_array::aliases::hash_set::HashSet;
use vortex_array::array::{ChunkedArray, ConstantEncoding};
use vortex_array::compress::{
check_dtype_unchanged, check_statistics_unchanged, check_validity_unchanged,
CompressionStrategy,
};
use vortex_array::compute::slice;
use vortex_array::encoding::{Encoding, EncodingRef};
use vortex_array::patches::Patches;
use vortex_array::validity::Validity;
use vortex_array::{ArrayDType, ArrayData, IntoCanonical};
use vortex_error::{VortexExpect as _, VortexResult};
use super::compressors::chunked::DEFAULT_CHUNKED_COMPRESSOR;
use super::compressors::struct_::StructCompressor;
use super::{CompressConfig, Objective, DEFAULT_COMPRESSORS};
use crate::compressors::constant::ConstantCompressor;
use crate::compressors::{CompressedArray, CompressionTree, CompressorRef, EncodingCompressor};
use crate::downscale::downscale_integer_array;
use crate::sampling::stratified_slices;
#[derive(Debug, Clone)]
pub struct SamplingCompressor<'a> {
compressors: HashSet<CompressorRef<'a>>,
options: CompressConfig,
path: Vec<String>,
depth: u8,
disabled_compressors: HashSet<CompressorRef<'a>>,
}
impl Display for SamplingCompressor<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "[{}|{}]", self.depth, self.path.join("."))
}
}
impl CompressionStrategy for SamplingCompressor<'_> {
#[allow(clippy::same_name_method)]
fn compress(&self, array: &ArrayData) -> VortexResult<ArrayData> {
Self::compress(self, array, None).map(CompressedArray::into_array)
}
fn used_encodings(&self) -> HashSet<EncodingRef> {
self.compressors
.iter()
.flat_map(|c| c.used_encodings())
.collect()
}
}
impl Default for SamplingCompressor<'_> {
fn default() -> Self {
Self::new(HashSet::from_iter(DEFAULT_COMPRESSORS))
}
}
impl<'a> SamplingCompressor<'a> {
pub fn new(compressors: HashSet<CompressorRef<'a>>) -> Self {
Self::new_with_options(compressors, Default::default())
}
pub fn new_with_options(
compressors: HashSet<CompressorRef<'a>>,
options: CompressConfig,
) -> Self {
Self {
compressors,
options,
path: Vec::new(),
depth: 0,
disabled_compressors: HashSet::new(),
}
}
pub fn named(&self, name: &str) -> Self {
let mut cloned = self.clone();
cloned.path.push(name.into());
cloned
}
pub fn auxiliary(&self, name: &str) -> Self {
let mut cloned = self.clone();
cloned.path.push(name.into());
cloned.disabled_compressors = HashSet::new();
cloned
}
pub fn for_compressor(&self, compression: &dyn EncodingCompressor) -> Self {
let mut cloned = self.clone();
cloned.depth += compression.cost();
cloned
}
#[inline]
pub fn options(&self) -> &CompressConfig {
&self.options
}
pub fn excluding(&self, compressor: CompressorRef<'a>) -> Self {
let mut cloned = self.clone();
cloned.disabled_compressors.insert(compressor);
cloned
}
pub fn including(&self, compressor: CompressorRef<'a>) -> Self {
let mut cloned = self.clone();
cloned.compressors.insert(compressor);
cloned
}
pub fn including_only(&self, compressors: &[CompressorRef<'a>]) -> Self {
let mut cloned = self.clone();
cloned.compressors.clear();
cloned.compressors.extend(compressors);
cloned
}
pub fn is_enabled(&self, compressor: CompressorRef<'a>) -> bool {
self.compressors.contains(compressor) && !self.disabled_compressors.contains(compressor)
}
#[allow(clippy::same_name_method)]
pub fn compress(
&self,
arr: &ArrayData,
like: Option<&CompressionTree<'a>>,
) -> VortexResult<CompressedArray<'a>> {
if arr.is_empty() {
return Ok(CompressedArray::uncompressed(arr.clone()));
}
if let Some(l) = like {
if let Some(compressed) = l.compress(arr, self) {
let compressed = compressed?;
check_validity_unchanged(arr, compressed.as_ref());
check_dtype_unchanged(arr, compressed.as_ref());
check_statistics_unchanged(arr, compressed.as_ref());
return Ok(compressed);
} else {
log::debug!("{} cannot compress {} like {}", self, arr, l);
}
}
let compressed = self.compress_array(arr)?;
check_validity_unchanged(arr, compressed.as_ref());
check_dtype_unchanged(arr, compressed.as_ref());
check_statistics_unchanged(arr, compressed.as_ref());
Ok(compressed)
}
pub fn compress_validity(&self, validity: Validity) -> VortexResult<Validity> {
match validity {
Validity::Array(a) => Ok(Validity::Array(self.compress(&a, None)?.into_array())),
a => Ok(a),
}
}
pub fn compress_patches(&self, patches: Patches) -> VortexResult<Patches> {
Ok(Patches::new(
patches.array_len(),
self.compress(&downscale_integer_array(patches.indices().clone())?, None)?
.into_array(),
self.compress(patches.values(), None)?.into_array(),
))
}
pub(crate) fn compress_array(&self, array: &ArrayData) -> VortexResult<CompressedArray<'a>> {
let mut rng = StdRng::seed_from_u64(self.options.rng_seed);
if array.is_encoding(ConstantEncoding::ID) {
return Ok(CompressedArray::uncompressed(array.clone()));
}
if let Some(cc) = DEFAULT_CHUNKED_COMPRESSOR.can_compress(array) {
return cc.compress(array, None, self.clone());
}
if let Some(cc) = StructCompressor.can_compress(array) {
return cc.compress(array, None, self.clone());
}
if self.is_enabled(&ConstantCompressor) && ConstantCompressor.can_compress(array).is_some()
{
return ConstantCompressor.compress(array, None, self.clone());
}
let (mut candidates, too_deep) = self
.compressors
.iter()
.filter(|&encoding| !self.disabled_compressors.contains(encoding))
.filter(|&encoding| encoding.can_compress(array).is_some())
.partition::<Vec<&dyn EncodingCompressor>, _>(|&encoding| {
self.depth + encoding.cost() <= self.options.max_cost
});
if !too_deep.is_empty() {
log::debug!(
"{} skipping encodings due to depth/cost: {}",
self,
too_deep
.iter()
.map(|x| x.id())
.collect::<Vec<_>>()
.join(", ")
);
}
log::debug!("{} candidates for {}: {:?}", self, array, candidates);
if candidates.is_empty() {
log::debug!(
"{} no compressors for array with dtype: {} and encoding: {}",
self,
array.dtype(),
array.encoding().id(),
);
return Ok(CompressedArray::uncompressed(array.clone()));
}
if candidates.len() > 1 {
candidates.retain(|&compression| compression.id() != array.encoding().id().as_ref());
}
if array.len() <= (self.options.sample_size as usize * self.options.sample_count as usize) {
return find_best_compression(candidates, array, self);
}
let sample = ChunkedArray::try_new(
stratified_slices(
array.len(),
self.options.sample_size,
self.options.sample_count,
&mut rng,
)
.into_iter()
.map(|(start, stop)| slice(array, start, stop))
.collect::<VortexResult<Vec<ArrayData>>>()?,
array.dtype().clone(),
)?
.into_canonical()?
.into();
let best = find_best_compression(candidates, &sample, self)?
.into_path()
.map(|best_compressor| {
log::debug!(
"{} Compressing array {} with {}",
self,
array,
best_compressor
);
best_compressor.compress_unchecked(array, self)
})
.transpose()?;
Ok(best.unwrap_or_else(|| CompressedArray::uncompressed(array.clone())))
}
}
pub(crate) fn find_best_compression<'a>(
candidates: Vec<&'a dyn EncodingCompressor>,
sample: &ArrayData,
ctx: &SamplingCompressor<'a>,
) -> VortexResult<CompressedArray<'a>> {
let mut best = None;
let mut best_objective = ctx.options().objective.starting_value();
let mut best_objective_ratio = 1.0;
let mut best_compression_ratio = 1.0;
let mut best_compression_ratio_sample = None;
for compression in candidates {
log::debug!(
"{} trying candidate {} for {}",
ctx,
compression.id(),
sample
);
if compression.can_compress(sample).is_none() {
continue;
}
let compressed_sample =
compression.compress(sample, None, ctx.for_compressor(compression))?;
let ratio = (compressed_sample.nbytes() as f64) / (sample.nbytes() as f64);
let objective = Objective::evaluate(&compressed_sample, sample.nbytes(), ctx.options());
if ratio < best_compression_ratio {
best_compression_ratio = ratio;
if ratio < best_objective_ratio && objective >= best_objective {
best_compression_ratio_sample = Some(compressed_sample.clone());
}
}
if objective < best_objective && ratio < 1.0 {
best_objective = objective;
best_objective_ratio = ratio;
best = Some(compressed_sample);
}
log::debug!(
"{} with {}: ratio ({}), objective fn value ({}); best so far: ratio ({}), objective fn value ({})",
ctx,
compression.id(),
ratio,
objective,
best_compression_ratio,
best_objective
);
}
let best = best.unwrap_or_else(|| CompressedArray::uncompressed(sample.clone()));
if best_compression_ratio < best_objective_ratio && best_compression_ratio_sample.is_some() {
let best_ratio_sample =
best_compression_ratio_sample.vortex_expect("already checked that this Option is Some");
log::debug!(
"{} best objective fn value ({}) has ratio {} from {}",
ctx,
best_objective,
best_compression_ratio,
best.array().tree_display()
);
log::debug!(
"{} best ratio ({}) has objective fn value {} from {}",
ctx,
best_compression_ratio,
best_objective,
best_ratio_sample.array().tree_display()
);
}
log::debug!(
"{} best compression ({} bytes, {} objective fn value, {} compression ratio",
ctx,
best.nbytes(),
best_objective,
best_compression_ratio,
);
Ok(best)
}
#[cfg(test)]
mod tests {
use itertools::Itertools;
use vortex_alp::ALPRDEncoding;
use vortex_array::array::PrimitiveArray;
use vortex_array::encoding::Encoding;
use vortex_array::IntoArrayData;
use crate::SamplingCompressor;
#[test]
fn test_default() {
let values = (0..4096)
.map(|x| (x as f64) / 1234567890.0f64)
.collect_vec();
let array = PrimitiveArray::from(values).into_array();
let compressed = SamplingCompressor::default()
.compress(&array, None)
.unwrap()
.into_array();
assert_eq!(compressed.encoding().id(), ALPRDEncoding::ID);
}
}