1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
use std::any::Any;
use std::fmt::Display;

use itertools::Itertools;
use vortex_array::aliases::hash_set::HashSet;
use vortex_array::ArrayData;
use vortex_dtype::field::Field;
use vortex_error::{vortex_err, VortexResult};

use crate::{unbox_any, VortexExpr};

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Select {
    Include(Vec<Field>),
    Exclude(Vec<Field>),
}

impl Select {
    pub fn include(columns: Vec<Field>) -> Self {
        Self::Include(columns)
    }

    pub fn exclude(columns: Vec<Field>) -> Self {
        Self::Exclude(columns)
    }
}

impl Display for Select {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Select::Include(fields) => write!(f, "Include({})", fields.iter().format(",")),
            Select::Exclude(fields) => write!(f, "Exclude({})", fields.iter().format(",")),
        }
    }
}

impl VortexExpr for Select {
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn evaluate(&self, batch: &ArrayData) -> VortexResult<ArrayData> {
        let st = batch
            .as_struct_array()
            .ok_or_else(|| vortex_err!("Not a struct array"))?;
        match self {
            Select::Include(f) => st.project(f),
            Select::Exclude(e) => {
                let normalized_exclusion = e
                    .iter()
                    .map(|ef| match ef {
                        Field::Name(n) => Ok(&**n),
                        Field::Index(i) => st
                            .names()
                            .get(*i)
                            .map(|s| &**s)
                            .ok_or_else(|| vortex_err!("Column doesn't exist")),
                    })
                    .collect::<VortexResult<HashSet<_>>>()?;
                let included_names = st
                    .names()
                    .iter()
                    .filter(|f| !normalized_exclusion.contains(&&***f))
                    .map(|f| Field::from(&**f))
                    .collect::<Vec<_>>();
                st.project(&included_names)
            }
        }
    }

    fn collect_references<'a>(&'a self, references: &mut HashSet<&'a Field>) {
        match self {
            Select::Include(f) => references.extend(f.iter()),
            // It's weird that we treat the references of exclusions and inclusions the same, we need to have a wrapper around Field in the return
            Select::Exclude(e) => references.extend(e.iter()),
        }
    }
}

impl PartialEq<dyn Any> for Select {
    fn eq(&self, other: &dyn Any) -> bool {
        unbox_any(other)
            .downcast_ref::<Self>()
            .map(|x| self == x)
            .unwrap_or(false)
    }
}

#[cfg(test)]
mod tests {
    use vortex_array::array::{PrimitiveArray, StructArray};
    use vortex_array::IntoArrayData;
    use vortex_dtype::field::Field;

    use crate::{Select, VortexExpr};

    fn test_array() -> StructArray {
        StructArray::from_fields(&[
            ("a", PrimitiveArray::from(vec![0, 1, 2]).into_array()),
            ("b", PrimitiveArray::from(vec![4, 5, 6]).into_array()),
        ])
        .unwrap()
    }

    #[test]
    pub fn include_columns() {
        let st = test_array();
        let select = Select::include(vec![Field::from("a")]);
        let selected = select.evaluate(st.as_ref()).unwrap();
        let selected_names = selected.as_struct_array().unwrap().names().clone();
        assert_eq!(selected_names.as_ref(), &["a".into()]);
    }

    #[test]
    pub fn exclude_columns() {
        let st = test_array();
        let select = Select::exclude(vec![Field::from("a")]);
        let selected = select.evaluate(st.as_ref()).unwrap();
        let selected_names = selected.as_struct_array().unwrap().names().clone();
        assert_eq!(selected_names.as_ref(), &["b".into()]);
    }
}