Skip to content

Commit

Permalink
implement FilterFn for SparseArray (#799)
Browse files Browse the repository at this point in the history
Introduces `FilterFn` implementation for `SparseArray`. @robert3005 and
I ran into some
[flake](https://github.com/spiraldb/vortex/actions/runs/10833498279/job/30060436243?pr=797)
in ALP/slicing that changed the underlying encoding due to
canonicalization, so this change should take care of that.
  • Loading branch information
AdamGS authored Sep 12, 2024
1 parent 83dffc5 commit 5240edd
Showing 1 changed file with 108 additions and 22 deletions.
130 changes: 108 additions & 22 deletions vortex-array/src/array/sparse/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
use vortex_error::{VortexResult, VortexUnwrap as _};
use vortex_dtype::match_each_integer_ptype;
use vortex_error::{VortexExpect, VortexResult, VortexUnwrap as _};
use vortex_scalar::Scalar;

use crate::array::sparse::SparseArray;
use crate::array::PrimitiveArray;
use crate::compute::unary::{scalar_at, scalar_at_unchecked, ScalarAtFn};
use crate::compute::{
search_sorted, ArrayCompute, SearchResult, SearchSortedFn, SearchSortedSide, SliceFn, TakeFn,
search_sorted, take, ArrayCompute, FilterFn, SearchResult, SearchSortedFn, SearchSortedSide,
SliceFn, TakeFn,
};
use crate::ArrayDType;
use crate::{Array, ArrayDType, IntoArray, IntoArrayVariant};

mod slice;
mod take;
Expand All @@ -27,6 +30,10 @@ impl ArrayCompute for SparseArray {
fn take(&self) -> Option<&dyn TakeFn> {
Some(self)
}

fn filter(&self) -> Option<&dyn FilterFn> {
Some(self)
}
}

impl ScalarAtFn for SparseArray {
Expand Down Expand Up @@ -68,63 +75,103 @@ impl SearchSortedFn for SparseArray {
}
}

impl FilterFn for SparseArray {
fn filter(&self, predicate: &Array) -> VortexResult<Array> {
let buffer = predicate.clone().into_bool()?.boolean_buffer();
let mut coordinate_indices: Vec<u64> = Vec::new();
let mut value_indices = Vec::new();
let mut last_inserted_index = 0;

let flat_indices = self
.indices()
.into_primitive()
.vortex_expect("Failed to convert SparseArray indices to primitive array");
match_each_integer_ptype!(flat_indices.ptype(), |$P| {
let indices = flat_indices
.maybe_null_slice::<$P>()
.iter()
.map(|v| (*v as usize) - self.indices_offset());
for (value_idx, coordinate) in indices.enumerate() {
if buffer.value(coordinate) {
// We count the number of truthy values between this coordinate and the previous truthy one
let adjusted_coordinate = buffer.slice(last_inserted_index, coordinate - last_inserted_index).count_set_bits() as u64;
coordinate_indices.push(adjusted_coordinate + coordinate_indices.last().copied().unwrap_or_default());
last_inserted_index = coordinate;
value_indices.push(value_idx as u64);
}
}
});

Ok(SparseArray::try_new(
PrimitiveArray::from(coordinate_indices).into_array(),
take(&self.values(), PrimitiveArray::from(value_indices).array())?,
buffer.count_set_bits(),
self.fill_value().clone(),
)?
.into_array())
}
}

#[cfg(test)]
mod test {
use rstest::{fixture, rstest};
use vortex_dtype::{DType, Nullability, PType};
use vortex_scalar::Scalar;

use crate::array::primitive::PrimitiveArray;
use crate::array::sparse::SparseArray;
use crate::compute::{search_sorted, slice, SearchResult, SearchSortedSide};
use crate::array::BoolArray;
use crate::compute::{filter, search_sorted, slice, SearchResult, SearchSortedSide};
use crate::validity::Validity;
use crate::{Array, IntoArray};
use crate::{Array, IntoArray, IntoArrayVariant};

#[fixture]
fn array() -> Array {
SparseArray::try_new(
PrimitiveArray::from(vec![2u64, 9, 15]).into_array(),
PrimitiveArray::from_vec(vec![33, 44, 55], Validity::AllValid).into_array(),
PrimitiveArray::from_vec(vec![33_i32, 44, 55], Validity::AllValid).into_array(),
20,
Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
)
.unwrap()
.into_array()
}

#[test]
pub fn search_larger_than() {
let res = search_sorted(&array(), 66, SearchSortedSide::Left).unwrap();
#[rstest]
fn search_larger_than(array: Array) {
let res = search_sorted(&array, 66, SearchSortedSide::Left).unwrap();
assert_eq!(res, SearchResult::NotFound(16));
}

#[test]
pub fn search_less_than() {
let res = search_sorted(&array(), 22, SearchSortedSide::Left).unwrap();
#[rstest]
fn search_less_than(array: Array) {
let res = search_sorted(&array, 22, SearchSortedSide::Left).unwrap();
assert_eq!(res, SearchResult::NotFound(2));
}

#[test]
pub fn search_found() {
let res = search_sorted(&array(), 44, SearchSortedSide::Left).unwrap();
#[rstest]
fn search_found(array: Array) {
let res = search_sorted(&array, 44, SearchSortedSide::Left).unwrap();
assert_eq!(res, SearchResult::Found(9));
}

#[test]
pub fn search_not_found_right() {
let res = search_sorted(&array(), 56, SearchSortedSide::Right).unwrap();
#[rstest]
fn search_not_found_right(array: Array) {
let res = search_sorted(&array, 56, SearchSortedSide::Right).unwrap();
assert_eq!(res, SearchResult::NotFound(16));
}

#[test]
pub fn search_sliced() {
let array = slice(&array(), 7, 20).unwrap();
#[rstest]
fn search_sliced(array: Array) {
let array = slice(&array, 7, 20).unwrap();
assert_eq!(
search_sorted(&array, 22, SearchSortedSide::Left).unwrap(),
SearchResult::NotFound(2)
);
}

#[test]
pub fn search_right() {
fn search_right() {
let array = SparseArray::try_new(
PrimitiveArray::from(vec![0u64]).into_array(),
PrimitiveArray::from_vec(vec![0u8], Validity::AllValid).into_array(),
Expand All @@ -143,4 +190,43 @@ mod test {
SearchResult::NotFound(1)
);
}

#[rstest]
fn test_filter(array: Array) {
let mut predicate = vec![false, false, true];
predicate.extend_from_slice(&[false; 17]);
let predicate = BoolArray::from_vec(predicate, Validity::NonNullable).into_array();

let filtered_array = filter(&array, &predicate).unwrap();
let filtered_array = SparseArray::try_from(filtered_array).unwrap();

assert_eq!(filtered_array.len(), 1);
assert_eq!(filtered_array.values().len(), 1);
assert_eq!(filtered_array.indices().len(), 1);
}

#[test]
fn true_fill_value() {
let predicate = BoolArray::from_vec(
vec![false, true, false, true, false, true, true],
Validity::NonNullable,
)
.into_array();
let array = SparseArray::try_new(
PrimitiveArray::from(vec![0_u64, 3, 6]).into_array(),
PrimitiveArray::from_vec(vec![33_i32, 44, 55], Validity::AllValid).into_array(),
7,
Scalar::null(DType::Primitive(PType::I32, Nullability::Nullable)),
)
.unwrap()
.into_array();

let filtered_array = filter(&array, &predicate).unwrap();
let filtered_array = SparseArray::try_from(filtered_array).unwrap();

assert_eq!(filtered_array.len(), 4);
let primitive = filtered_array.indices().into_primitive().unwrap();

assert_eq!(primitive.maybe_null_slice::<u64>(), &[1, 3]);
}
}

0 comments on commit 5240edd

Please sign in to comment.