diff --git a/encodings/zigzag/src/compute.rs b/encodings/zigzag/src/compute.rs index b68b54252..8f26cae34 100644 --- a/encodings/zigzag/src/compute.rs +++ b/encodings/zigzag/src/compute.rs @@ -1,4 +1,7 @@ -use vortex_array::compute::{scalar_at, slice, ComputeVTable, ScalarAtFn, SliceFn}; +use vortex_array::compute::{ + filter, scalar_at, slice, take, ComputeVTable, FilterFn, FilterMask, ScalarAtFn, SliceFn, + TakeFn, +}; use vortex_array::variants::PrimitiveArrayTrait; use vortex_array::{ArrayDType, ArrayData, IntoArrayData}; use vortex_dtype::match_each_unsigned_integer_ptype; @@ -9,6 +12,10 @@ use zigzag::{ZigZag as ExternalZigZag, ZigZag}; use crate::{ZigZagArray, ZigZagEncoding}; impl ComputeVTable for ZigZagEncoding { + fn filter_fn(&self) -> Option<&dyn FilterFn> { + Some(self) + } + fn scalar_at_fn(&self) -> Option<&dyn ScalarAtFn> { Some(self) } @@ -16,6 +23,10 @@ impl ComputeVTable for ZigZagEncoding { fn slice_fn(&self) -> Option<&dyn SliceFn> { Some(self) } + + fn take_fn(&self) -> Option<&dyn TakeFn> { + Some(self) + } } impl ScalarAtFn for ZigZagEncoding { @@ -41,6 +52,26 @@ impl ScalarAtFn for ZigZagEncoding { } } +impl SliceFn for ZigZagEncoding { + fn slice(&self, array: &ZigZagArray, start: usize, stop: usize) -> VortexResult { + Ok(ZigZagArray::try_new(slice(array.encoded(), start, stop)?)?.into_array()) + } +} + +impl TakeFn for ZigZagEncoding { + fn take(&self, array: &ZigZagArray, indices: &ArrayData) -> VortexResult { + let encoded = take(array.encoded(), indices)?; + Ok(ZigZagArray::try_new(encoded)?.into_array()) + } +} + +impl FilterFn for ZigZagEncoding { + fn filter(&self, array: &ZigZagArray, mask: FilterMask) -> VortexResult { + let encoded = filter(&array.encoded(), mask)?; + Ok(ZigZagArray::try_new(encoded)?.into_array()) + } +} + trait ZigZagEncoded { type Int: ZigZag; } @@ -61,18 +92,14 @@ impl ZigZagEncoded for u64 { type Int = i64; } -impl SliceFn for ZigZagEncoding { - fn slice(&self, array: &ZigZagArray, start: usize, stop: usize) -> VortexResult { - Ok(ZigZagArray::try_new(slice(array.encoded(), start, stop)?)?.into_array()) - } -} - #[cfg(test)] mod tests { - use vortex_array::array::PrimitiveArray; - use vortex_array::compute::{scalar_at, search_sorted, SearchResult, SearchSortedSide}; + use vortex_array::array::{BooleanBuffer, PrimitiveArray}; + use vortex_array::compute::{ + filter, scalar_at, search_sorted, take, SearchResult, SearchSortedSide, + }; use vortex_array::validity::Validity; - use vortex_array::IntoArrayData; + use vortex_array::{IntoArrayData, IntoArrayVariant}; use vortex_dtype::Nullability; use vortex_scalar::Scalar; @@ -100,4 +127,33 @@ mod tests { Scalar::primitive(-160, Nullability::Nullable) ); } + + #[test] + fn take_zigzag() { + let zigzag = + ZigZagArray::encode(&PrimitiveArray::from(vec![-189, -160, 1]).into_array()).unwrap(); + let indices = PrimitiveArray::from(vec![0, 2]).into_array(); + let actual = take(zigzag, indices).unwrap().into_primitive().unwrap(); + let expected = ZigZagArray::encode(&PrimitiveArray::from(vec![-189, 1]).into_array()) + .unwrap() + .into_primitive() + .unwrap(); + assert_eq!(actual.into_buffer(), expected.into_buffer()); + } + + #[test] + fn filter_zigzag() { + let zigzag = + ZigZagArray::encode(&PrimitiveArray::from(vec![-189, -160, 1]).into_array()).unwrap(); + let filter_mask = BooleanBuffer::from(vec![true, false, true]).into(); + let actual = filter(&zigzag.into_array(), filter_mask) + .unwrap() + .into_primitive() + .unwrap(); + let expected = ZigZagArray::encode(&PrimitiveArray::from(vec![-189, 1]).into_array()) + .unwrap() + .into_primitive() + .unwrap(); + assert_eq!(actual.into_buffer(), expected.into_buffer()); + } }