1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
#![allow(clippy::cast_possible_truncation)]

pub use array::*;
use vortex_array::patches::Patches;
use vortex_array::validity::Validity;
use vortex_array::variants::PrimitiveArrayTrait;

mod array;
mod compute;
mod variants;

use std::ops::{Shl, Shr};

use itertools::Itertools;
use num_traits::{Float, One, PrimInt};
use vortex_array::aliases::hash_map::HashMap;
use vortex_array::array::PrimitiveArray;
use vortex_array::{ArrayDType, IntoArrayData};
use vortex_dtype::{DType, NativePType};
use vortex_error::{vortex_bail, VortexExpect, VortexResult, VortexUnwrap};
use vortex_fastlanes::bitpack_encode_unchecked;

use crate::match_each_alp_float_ptype;

macro_rules! bit_width {
    ($value:expr) => {
        if $value == 0 {
            1
        } else {
            $value.ilog2().wrapping_add(1) as usize
        }
    };
}

/// Max number of bits to cut from the MSB section of each float.
const CUT_LIMIT: usize = 16;

const MAX_DICT_SIZE: u8 = 8;

mod private {
    pub trait Sealed {}

    impl Sealed for f32 {}
    impl Sealed for f64 {}
}

/// Main trait for ALP-RD encodable floating point numbers.
///
/// Like the paper, we limit this to the IEEE7 754 single-precision (`f32`) and double-precision
/// (`f64`) floating point types.
pub trait ALPRDFloat: private::Sealed + Float + Copy + NativePType {
    /// The unsigned integer type with the same bit-width as the floating-point type.
    type UINT: NativePType + PrimInt + One + Copy;

    /// Number of bits the value occupies in registers.
    const BITS: usize = size_of::<Self>() * 8;

    /// Bit-wise transmute from the unsigned integer type to the floating-point type.
    fn from_bits(bits: Self::UINT) -> Self;

    /// Bit-wise transmute into the unsigned integer type.
    fn to_bits(value: Self) -> Self::UINT;

    /// Truncating conversion from the unsigned integer type to `u16`.
    fn to_u16(bits: Self::UINT) -> u16;

    /// Type-widening conversion from `u16` to the unsigned integer type.
    fn from_u16(value: u16) -> Self::UINT;
}

impl ALPRDFloat for f64 {
    type UINT = u64;

    fn from_bits(bits: Self::UINT) -> Self {
        f64::from_bits(bits)
    }

    fn to_bits(value: Self) -> Self::UINT {
        value.to_bits()
    }

    fn to_u16(bits: Self::UINT) -> u16 {
        bits as u16
    }

    fn from_u16(value: u16) -> Self::UINT {
        value as u64
    }
}

impl ALPRDFloat for f32 {
    type UINT = u32;

    fn from_bits(bits: Self::UINT) -> Self {
        f32::from_bits(bits)
    }

    fn to_bits(value: Self) -> Self::UINT {
        value.to_bits()
    }

    fn to_u16(bits: Self::UINT) -> u16 {
        bits as u16
    }

    fn from_u16(value: u16) -> Self::UINT {
        value as u32
    }
}

/// Encoder for ALP-RD ("real doubles") values.
///
/// The encoder calculates its parameters from a single sample of floating-point values,
/// and then can be applied to many vectors.
///
/// ALP-RD uses the algorithm outlined in Section 3.4 of the paper. The crux of it is that the front
/// (most significant) bits of many double vectors tend to be  the same, i.e. most doubles in a
/// vector often use the same exponent and front bits. Compression proceeds by finding the best
/// prefix of up to 16 bits that can be collapsed into a dictionary of
/// up to 8 elements. Each double can then be broken into the front/left `L` bits, which neatly
/// bit-packs down to 1-3 bits per element (depending on the actual dictionary size).
/// The remaining `R` bits naturally bit-pack.
///
/// In the ideal case, this scheme allows us to store a sequence of doubles in 49 bits-per-value.
///
/// Our implementation draws on the MIT-licensed [C++ implementation] provided by the original authors.
///
/// [C++ implementation]: https://github.com/cwida/ALP/blob/main/include/alp/rd.hpp
pub struct RDEncoder {
    right_bit_width: u8,
    codes: Vec<u16>,
}

impl RDEncoder {
    /// Build a new encoder from a sample of doubles.
    pub fn new<T>(sample: &[T]) -> Self
    where
        T: ALPRDFloat + NativePType,
        T::UINT: NativePType,
    {
        let dictionary = find_best_dictionary::<T>(sample);

        let mut codes = vec![0; dictionary.dictionary.len()];
        dictionary.dictionary.into_iter().for_each(|(bits, code)| {
            // write the reverse mapping into the codes vector.
            codes[code as usize] = bits
        });

        Self {
            right_bit_width: dictionary.right_bit_width,
            codes,
        }
    }

    /// Encode a set of floating point values with ALP-RD.
    ///
    /// Each value will be split into a left and right component, which are compressed individually.
    pub fn encode(&self, array: &PrimitiveArray) -> ALPRDArray {
        match_each_alp_float_ptype!(array.ptype(), |$P| {
            self.encode_generic::<$P>(array)
        })
    }

    fn encode_generic<T>(&self, array: &PrimitiveArray) -> ALPRDArray
    where
        T: ALPRDFloat + NativePType,
        T::UINT: NativePType,
    {
        assert!(
            !self.codes.is_empty(),
            "codes lookup table must be populated before RD encoding"
        );

        let doubles = array.maybe_null_slice::<T>();

        let mut left_parts: Vec<u16> = Vec::with_capacity(doubles.len());
        let mut right_parts: Vec<T::UINT> = Vec::with_capacity(doubles.len());
        let mut exceptions_pos: Vec<u64> = Vec::with_capacity(doubles.len() / 4);
        let mut exceptions: Vec<u16> = Vec::with_capacity(doubles.len() / 4);

        // mask for right-parts
        let right_mask = T::UINT::one().shl(self.right_bit_width as _) - T::UINT::one();
        let max_code = self.codes.len() - 1;
        let left_bit_width = bit_width!(max_code);

        for v in doubles.iter().copied() {
            right_parts.push(T::to_bits(v) & right_mask);
            left_parts.push(<T as ALPRDFloat>::to_u16(
                T::to_bits(v).shr(self.right_bit_width as _),
            ));
        }

        // dict-encode the left-parts, keeping track of exceptions
        for (idx, left) in left_parts.iter_mut().enumerate() {
            // TODO: revisit if we need to change the branch order for perf.
            if let Some(code) = self.codes.iter().position(|v| *v == *left) {
                *left = code as u16;
            } else {
                exceptions.push(*left);
                exceptions_pos.push(idx as _);

                *left = 0u16;
            }
        }

        // Bit-pack down the encoded left-parts array that have been dictionary encoded.
        let primitive_left = PrimitiveArray::from_vec(left_parts, array.validity());
        // SAFETY: by construction, all values in left_parts can be packed to left_bit_width.
        let packed_left = unsafe {
            bitpack_encode_unchecked(primitive_left, left_bit_width as _)
                .vortex_unwrap()
                .into_array()
        };

        let primitive_right = PrimitiveArray::from(right_parts);
        // SAFETY: by construction, all values in right_parts are right_bit_width + leading zeros.
        let packed_right = unsafe {
            bitpack_encode_unchecked(primitive_right, self.right_bit_width as _)
                .vortex_unwrap()
                .into_array()
        };

        // Bit-pack the dict-encoded left-parts
        // Bit-pack the right-parts
        // Patches for exceptions.
        let exceptions = (!exceptions_pos.is_empty()).then(|| {
            let max_exc_pos = exceptions_pos.last().copied().unwrap_or_default();
            let bw = bit_width!(max_exc_pos) as u8;

            let exc_pos_array = PrimitiveArray::from(exceptions_pos);
            // SAFETY: We calculate bw such that it is wide enough to hold the largest position index.
            let packed_pos = unsafe {
                bitpack_encode_unchecked(exc_pos_array, bw)
                    .vortex_unwrap()
                    .into_array()
            };

            let exc_array =
                PrimitiveArray::from_vec(exceptions, Validity::NonNullable).into_array();
            Patches::new(doubles.len(), packed_pos, exc_array)
        });

        ALPRDArray::try_new(
            DType::Primitive(T::PTYPE, packed_left.dtype().nullability()),
            packed_left,
            &self.codes,
            packed_right,
            self.right_bit_width,
            exceptions,
        )
        .vortex_expect("ALPRDArray construction in encode")
    }
}

/// Decode a vector of ALP-RD encoded values back into their original floating point format.
///
/// # Panics
///
/// The function panics if the provided `left_parts` and `right_parts` differ in length.
pub fn alp_rd_decode<T: ALPRDFloat>(
    left_parts: &[u16],
    left_parts_dict: &[u16],
    right_bit_width: u8,
    right_parts: &[T::UINT],
    left_parts_patches: Option<Patches>,
) -> VortexResult<Vec<T>> {
    if left_parts.len() != right_parts.len() {
        vortex_bail!("alp_rd_decode: left_parts.len != right_parts.len");
    }

    let mut dict = Vec::with_capacity(left_parts_dict.len());
    dict.extend_from_slice(left_parts_dict);

    let mut left_parts_decoded: Vec<u16> = Vec::with_capacity(left_parts.len());

    // Decode with bit-packing and dict unpacking.
    for code in left_parts {
        left_parts_decoded.push(dict[*code as usize]);
    }

    let patched_left_parts = match left_parts_patches {
        Some(patches) => PrimitiveArray::from(left_parts_decoded)
            .patch(patches)?
            .into_maybe_null_slice::<u16>(),
        None => left_parts_decoded,
    };

    // recombine the left-and-right parts, adjusting by the right_bit_width.
    Ok(patched_left_parts
        .into_iter()
        .zip(right_parts.iter().copied())
        .map(|(left, right)| {
            let left = <T as ALPRDFloat>::from_u16(left);
            T::from_bits((left << (right_bit_width as usize)) | right)
        })
        .collect())
}

/// Find the best "cut point" for a set of floating point values such that we can
/// cast them all to the relevant value instead.
fn find_best_dictionary<T: ALPRDFloat>(samples: &[T]) -> ALPRDDictionary {
    let mut best_est_size = f64::MAX;
    let mut best_dict = ALPRDDictionary::default();

    for p in 1..=16 {
        let candidate_right_bw = (T::BITS - p) as u8;
        let (dictionary, exception_count) =
            build_left_parts_dictionary::<T>(samples, candidate_right_bw, MAX_DICT_SIZE);
        let estimated_size = estimate_compression_size(
            dictionary.right_bit_width,
            dictionary.left_bit_width,
            exception_count,
            samples.len(),
        );
        if estimated_size < best_est_size {
            best_est_size = estimated_size;
            best_dict = dictionary;
        }
    }

    best_dict
}

/// Build dictionary of the leftmost bits.
fn build_left_parts_dictionary<T: ALPRDFloat>(
    samples: &[T],
    right_bw: u8,
    max_dict_size: u8,
) -> (ALPRDDictionary, usize) {
    assert!(
        right_bw >= (T::BITS - CUT_LIMIT) as _,
        "left-parts must be <= 16 bits"
    );

    // Count the number of occurrences of each left bit pattern
    let counts = samples
        .iter()
        .copied()
        .map(|v| <T as ALPRDFloat>::to_u16(T::to_bits(v).shr(right_bw as _)))
        .counts();

    // Sorted counts: sort by negative count so that heavy hitters sort first.
    let mut sorted_bit_counts: Vec<(u16, usize)> = counts.into_iter().collect_vec();
    sorted_bit_counts.sort_by_key(|(_, count)| count.wrapping_neg());

    // Assign the most-frequently occurring left-bits as dictionary codes, up to `dict_size`...
    let mut dictionary = HashMap::with_capacity(max_dict_size as _);
    let mut code = 0u16;
    while code < (max_dict_size as _) && (code as usize) < sorted_bit_counts.len() {
        let (bits, _) = sorted_bit_counts[code as usize];
        dictionary.insert(bits, code);
        code += 1;
    }

    // ...and the rest are exceptions.
    let exception_count: usize = sorted_bit_counts
        .iter()
        .skip(code as _)
        .map(|(_, count)| *count)
        .sum();

    // Left bit-width is determined based on the actual dictionary size.
    let max_code = dictionary.len() - 1;
    let left_bw = bit_width!(max_code) as u8;

    (
        ALPRDDictionary {
            dictionary,
            right_bit_width: right_bw,
            left_bit_width: left_bw,
        },
        exception_count,
    )
}

/// Estimate the bits-per-value when using these compression settings.
fn estimate_compression_size(
    right_bw: u8,
    left_bw: u8,
    exception_count: usize,
    sample_n: usize,
) -> f64 {
    const EXC_POSITION_SIZE: usize = 16; // two bytes for exception position.
    const EXC_SIZE: usize = 16; // two bytes for each exception (up to 16 front bits).

    let exceptions_size = exception_count * (EXC_POSITION_SIZE + EXC_SIZE);
    (right_bw as f64) + (left_bw as f64) + ((exceptions_size as f64) / (sample_n as f64))
}

/// The ALP-RD dictionary, encoding the "left parts" and their dictionary encoding.
#[derive(Debug, Default)]
struct ALPRDDictionary {
    /// Items in the dictionary are bit patterns, along with their 16-bit encoding.
    dictionary: HashMap<u16, u16>,
    /// The (compressed) left bit width. This is after bit-packing the dictionary codes.
    left_bit_width: u8,
    /// The right bit width. This is the bit-packed width of each of the "real double" values.
    right_bit_width: u8,
}