1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
use std::fmt::{Debug, Display};
use std::sync::Arc;

use enum_iterator::all;
use serde::{Deserialize, Serialize};
use vortex_dtype::{DType, ExtDType, ExtID};
use vortex_error::{VortexExpect as _, VortexResult};

use crate::encoding::ids;
use crate::stats::{ArrayStatistics as _, Stat, StatisticsVTable, StatsSet};
use crate::validity::{ArrayValidity, LogicalValidity, ValidityVTable};
use crate::variants::{ExtensionArrayTrait, VariantsVTable};
use crate::visitor::{ArrayVisitor, VisitorVTable};
use crate::{impl_encoding, ArrayDType, ArrayData, ArrayLen, ArrayTrait, Canonical, IntoCanonical};

mod compute;

impl_encoding!("vortex.ext", ids::EXTENSION, Extension);

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExtensionMetadata;

impl Display for ExtensionMetadata {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        Debug::fmt(self, f)
    }
}

impl ExtensionArray {
    pub fn new(ext_dtype: Arc<ExtDType>, storage: ArrayData) -> Self {
        assert_eq!(
            ext_dtype.storage_dtype(),
            storage.dtype(),
            "ExtensionArray: storage_dtype must match storage array DType",
        );

        Self::try_from_parts(
            DType::Extension(ext_dtype),
            storage.len(),
            ExtensionMetadata,
            [storage].into(),
            Default::default(),
        )
        .vortex_expect("Invalid ExtensionArray")
    }

    pub fn storage(&self) -> ArrayData {
        self.as_ref()
            .child(0, self.ext_dtype().storage_dtype(), self.len())
            .vortex_expect("Missing storage array for ExtensionArray")
    }

    #[allow(dead_code)]
    #[inline]
    pub fn id(&self) -> &ExtID {
        self.ext_dtype().id()
    }
}

impl ArrayTrait for ExtensionArray {}

impl VariantsVTable<ExtensionArray> for ExtensionEncoding {
    fn as_extension_array<'a>(
        &self,
        array: &'a ExtensionArray,
    ) -> Option<&'a dyn ExtensionArrayTrait> {
        Some(array)
    }
}

impl ExtensionArrayTrait for ExtensionArray {
    fn storage_data(&self) -> ArrayData {
        self.storage()
    }
}

impl IntoCanonical for ExtensionArray {
    fn into_canonical(self) -> VortexResult<Canonical> {
        Ok(Canonical::Extension(self))
    }
}

impl ValidityVTable<ExtensionArray> for ExtensionEncoding {
    fn is_valid(&self, array: &ExtensionArray, index: usize) -> bool {
        array.storage().is_valid(index)
    }

    fn logical_validity(&self, array: &ExtensionArray) -> LogicalValidity {
        array.storage().logical_validity()
    }
}

impl VisitorVTable<ExtensionArray> for ExtensionEncoding {
    fn accept(&self, array: &ExtensionArray, visitor: &mut dyn ArrayVisitor) -> VortexResult<()> {
        visitor.visit_child("storage", &array.storage())
    }
}

impl StatisticsVTable<ExtensionArray> for ExtensionEncoding {
    fn compute_statistics(&self, array: &ExtensionArray, stat: Stat) -> VortexResult<StatsSet> {
        let mut stats = array.storage().statistics().compute_all(&[stat])?;

        // for e.g., min/max, we want to cast to the extension array's dtype
        // for other stats, we don't need to change anything
        for stat in all::<Stat>().filter(|s| s.has_same_dtype_as_array()) {
            if let Some(value) = stats.get(stat) {
                stats.set(stat, value.cast(array.dtype())?);
            }
        }

        Ok(stats)
    }
}

#[cfg(test)]
mod tests {
    use vortex_dtype::PType;
    use vortex_scalar::Scalar;

    use super::*;
    use crate::array::PrimitiveArray;
    use crate::validity::Validity;
    use crate::IntoArrayData as _;

    #[test]
    fn compute_statistics() {
        let ext_dtype = Arc::new(ExtDType::new(
            ExtID::new("timestamp".into()),
            DType::from(PType::I64).into(),
            None,
        ));
        let array = ExtensionArray::new(
            ext_dtype.clone(),
            PrimitiveArray::from_vec(vec![1i64, 2, 3, 4, 5], Validity::NonNullable).into_array(),
        );

        let stats = array
            .statistics()
            .compute_all(&[Stat::Min, Stat::Max, Stat::NullCount])
            .unwrap();
        let num_stats = stats.clone().into_iter().count();
        assert!(
            num_stats >= 3,
            "Expected at least 3 stats, got {}",
            num_stats
        );

        assert_eq!(
            stats.get(Stat::Min),
            Some(&Scalar::extension(ext_dtype.clone(), Scalar::from(1_i64)))
        );
        assert_eq!(
            stats.get(Stat::Max),
            Some(&Scalar::extension(ext_dtype, Scalar::from(5_i64)))
        );
        assert_eq!(stats.get(Stat::NullCount), Some(&0u64.into()));
    }
}