diff --git a/vortex-array/src/array/sparse/compute/mod.rs b/vortex-array/src/array/sparse/compute/mod.rs index 5582c6e9af..ee86e25810 100644 --- a/vortex-array/src/array/sparse/compute/mod.rs +++ b/vortex-array/src/array/sparse/compute/mod.rs @@ -1,12 +1,15 @@ -use vortex_error::{VortexResult, VortexUnwrap as _}; +use vortex_dtype::match_each_integer_ptype; +use vortex_error::{VortexExpect, VortexResult, VortexUnwrap as _}; use vortex_scalar::Scalar; use crate::array::sparse::SparseArray; +use crate::array::PrimitiveArray; use crate::compute::unary::{scalar_at, scalar_at_unchecked, ScalarAtFn}; use crate::compute::{ - search_sorted, ArrayCompute, SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn, + search_sorted, take, ArrayCompute, FilterFn, SearchResult, SearchSortedFn, SearchSortedSide, + SliceFn, TakeFn, }; -use crate::ArrayDType; +use crate::{Array, ArrayDType, IntoArray, IntoArrayVariant}; mod slice; mod take; @@ -27,6 +30,10 @@ impl ArrayCompute for SparseArray { fn take(&self) -> Option<&dyn TakeFn> { Some(self) } + + fn filter(&self) -> Option<&dyn FilterFn> { + Some(self) + } } impl ScalarAtFn for SparseArray { @@ -68,21 +75,61 @@ impl SearchSortedFn for SparseArray { } } +impl FilterFn for SparseArray { + fn filter(&self, predicate: &Array) -> VortexResult { + let buffer = predicate.clone().into_bool()?.boolean_buffer(); + let mut coordinate_indices: Vec = Vec::new(); + let mut value_indices = Vec::new(); + let mut last_inserted_index = 0; + + let flat_indices = self + .indices() + .into_primitive() + .vortex_expect("Failed to convert SparseArray indices to primitive array"); + match_each_integer_ptype!(flat_indices.ptype(), |$P| { + let indices = flat_indices + .maybe_null_slice::<$P>() + .iter() + .map(|v| (*v as usize) - self.indices_offset()); + for (value_idx, coordinate) in indices.enumerate() { + if buffer.value(coordinate) { + // We count the number of truthy values between this coordinate and the previous truthy one + let adjusted_coordinate = buffer.slice(last_inserted_index, coordinate - last_inserted_index).count_set_bits() as u64; + coordinate_indices.push(adjusted_coordinate + coordinate_indices.last().copied().unwrap_or_default()); + last_inserted_index = coordinate; + value_indices.push(value_idx as u64); + } + } + }); + + Ok(SparseArray::try_new( + PrimitiveArray::from(coordinate_indices).into_array(), + take(&self.values(), PrimitiveArray::from(value_indices).array())?, + buffer.count_set_bits(), + self.fill_value().clone(), + )? + .into_array()) + } +} + #[cfg(test)] mod test { + use rstest::{fixture, rstest}; use vortex_dtype::{DType, Nullability, PType}; use vortex_scalar::Scalar; use crate::array::primitive::PrimitiveArray; use crate::array::sparse::SparseArray; - use crate::compute::{search_sorted, slice, SearchResult, SearchSortedSide}; + use crate::array::BoolArray; + use crate::compute::{filter, search_sorted, slice, SearchResult, SearchSortedSide}; use crate::validity::Validity; - use crate::{Array, IntoArray}; + use crate::{Array, IntoArray, IntoArrayVariant}; + #[fixture] fn array() -> Array { SparseArray::try_new( PrimitiveArray::from(vec![2u64, 9, 15]).into_array(), - PrimitiveArray::from_vec(vec![33, 44, 55], Validity::AllValid).into_array(), + PrimitiveArray::from_vec(vec![33_i32, 44, 55], Validity::AllValid).into_array(), 20, Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), ) @@ -90,33 +137,33 @@ mod test { .into_array() } - #[test] - pub fn search_larger_than() { - let res = search_sorted(&array(), 66, SearchSortedSide::Left).unwrap(); + #[rstest] + fn search_larger_than(array: Array) { + let res = search_sorted(&array, 66, SearchSortedSide::Left).unwrap(); assert_eq!(res, SearchResult::NotFound(16)); } - #[test] - pub fn search_less_than() { - let res = search_sorted(&array(), 22, SearchSortedSide::Left).unwrap(); + #[rstest] + fn search_less_than(array: Array) { + let res = search_sorted(&array, 22, SearchSortedSide::Left).unwrap(); assert_eq!(res, SearchResult::NotFound(2)); } - #[test] - pub fn search_found() { - let res = search_sorted(&array(), 44, SearchSortedSide::Left).unwrap(); + #[rstest] + fn search_found(array: Array) { + let res = search_sorted(&array, 44, SearchSortedSide::Left).unwrap(); assert_eq!(res, SearchResult::Found(9)); } - #[test] - pub fn search_not_found_right() { - let res = search_sorted(&array(), 56, SearchSortedSide::Right).unwrap(); + #[rstest] + fn search_not_found_right(array: Array) { + let res = search_sorted(&array, 56, SearchSortedSide::Right).unwrap(); assert_eq!(res, SearchResult::NotFound(16)); } - #[test] - pub fn search_sliced() { - let array = slice(&array(), 7, 20).unwrap(); + #[rstest] + fn search_sliced(array: Array) { + let array = slice(&array, 7, 20).unwrap(); assert_eq!( search_sorted(&array, 22, SearchSortedSide::Left).unwrap(), SearchResult::NotFound(2) @@ -124,7 +171,7 @@ mod test { } #[test] - pub fn search_right() { + fn search_right() { let array = SparseArray::try_new( PrimitiveArray::from(vec![0u64]).into_array(), PrimitiveArray::from_vec(vec![0u8], Validity::AllValid).into_array(), @@ -143,4 +190,43 @@ mod test { SearchResult::NotFound(1) ); } + + #[rstest] + fn test_filter(array: Array) { + let mut predicate = vec![false, false, true]; + predicate.extend_from_slice(&[false; 17]); + let predicate = BoolArray::from_vec(predicate, Validity::NonNullable).into_array(); + + let filtered_array = filter(&array, &predicate).unwrap(); + let filtered_array = SparseArray::try_from(filtered_array).unwrap(); + + assert_eq!(filtered_array.len(), 1); + assert_eq!(filtered_array.values().len(), 1); + assert_eq!(filtered_array.indices().len(), 1); + } + + #[test] + fn true_fill_value() { + let predicate = BoolArray::from_vec( + vec![false, true, false, true, false, true, true], + Validity::NonNullable, + ) + .into_array(); + let array = SparseArray::try_new( + PrimitiveArray::from(vec![0_u64, 3, 6]).into_array(), + PrimitiveArray::from_vec(vec![33_i32, 44, 55], Validity::AllValid).into_array(), + 7, + Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)), + ) + .unwrap() + .into_array(); + + let filtered_array = filter(&array, &predicate).unwrap(); + let filtered_array = SparseArray::try_from(filtered_array).unwrap(); + + assert_eq!(filtered_array.len(), 4); + let primitive = filtered_array.indices().into_primitive().unwrap(); + + assert_eq!(primitive.maybe_null_slice::(), &[1, 3]); + } }