From cd6ccd84af7f1b37598cd09cf3e04d5217798626 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Mon, 16 Sep 2024 13:52:06 +0100 Subject: [PATCH] Fuzz filter function (#822) --- fuzz/fuzz_targets/array_ops.rs | 6 +- fuzz/src/filter.rs | 105 +++++++++++++++++++++++++++++++++ fuzz/src/lib.rs | 19 +++++- 3 files changed, 126 insertions(+), 4 deletions(-) create mode 100644 fuzz/src/filter.rs diff --git a/fuzz/fuzz_targets/array_ops.rs b/fuzz/fuzz_targets/array_ops.rs index 1e5f2f5784..b6040e4b7b 100644 --- a/fuzz/fuzz_targets/array_ops.rs +++ b/fuzz/fuzz_targets/array_ops.rs @@ -7,7 +7,7 @@ use vortex::array::{ BoolEncoding, PrimitiveEncoding, StructEncoding, VarBinEncoding, VarBinViewEncoding, }; use vortex::compute::unary::scalar_at; -use vortex::compute::{search_sorted, slice, take, SearchResult, SearchSortedSide}; +use vortex::compute::{filter, search_sorted, slice, take, SearchResult, SearchSortedSide}; use vortex::encoding::EncodingRef; use vortex::{Array, IntoCanonical}; use vortex_fuzz::{sort_canonical_array, Action, FuzzArrayAction}; @@ -56,6 +56,10 @@ fuzz_target!(|fuzz_action: FuzzArrayAction| -> Corpus { } assert_search_sorted(sorted, s, side, expected.search(), i) } + Action::Filter(mask) => { + current_array = filter(¤t_array, &mask).unwrap(); + assert_array_eq(&expected.array(), ¤t_array, i); + } } } Corpus::Keep diff --git a/fuzz/src/filter.rs b/fuzz/src/filter.rs new file mode 100644 index 0000000000..8ff0f1634b --- /dev/null +++ b/fuzz/src/filter.rs @@ -0,0 +1,105 @@ +use vortex::accessor::ArrayAccessor; +use vortex::array::{BoolArray, PrimitiveArray, StructArray, VarBinArray}; +use vortex::validity::{ArrayValidity, Validity}; +use vortex::variants::StructArrayTrait; +use vortex::{Array, ArrayDType, IntoArray, IntoArrayVariant}; +use vortex_dtype::{match_each_native_ptype, DType}; + +pub fn filter_canonical_array(array: &Array, filter: &[bool]) -> Array { + match array.dtype() { + DType::Bool(_) => { + let bool_array = array.clone().into_bool().unwrap(); + let vec_validity = bool_array + .logical_validity() + .into_array() + .into_bool() + .unwrap() + .boolean_buffer(); + BoolArray::from_vec( + filter + .iter() + .zip(bool_array.boolean_buffer().iter()) + .filter(|(f, _)| **f) + .map(|(_, v)| v) + .collect::>(), + Validity::from( + filter + .iter() + .zip(vec_validity.iter()) + .filter(|(f, _)| **f) + .map(|(_, v)| v) + .collect::>(), + ), + ) + .into_array() + } + DType::Primitive(p, _) => match_each_native_ptype!(p, |$P| { + let primitive_array = array.clone().into_primitive().unwrap(); + let vec_validity = primitive_array + .logical_validity() + .into_array() + .into_bool() + .unwrap() + .boolean_buffer(); + PrimitiveArray::from_vec( + filter + .iter() + .zip(primitive_array.maybe_null_slice::<$P>().iter().copied()) + .filter(|(f, _)| **f) + .map(|(_, v)| v) + .collect::>(), + Validity::from( + filter + .iter() + .zip(vec_validity.iter()) + .filter(|(f, _)| **f) + .map(|(_, v)| v) + .collect::>(), + ), + ) + .into_array() + }), + DType::Utf8(_) | DType::Binary(_) => { + let utf8 = array.clone().into_varbin().unwrap(); + let values = utf8 + .with_iterator(|iter| { + iter.zip(filter.iter()) + .filter(|(_, f)| **f) + .map(|(v, _)| v.map(|u| u.to_vec())) + .collect::>() + }) + .unwrap(); + VarBinArray::from_iter(values, array.dtype().clone()).into_array() + } + DType::Struct(..) => { + let struct_array = array.clone().into_struct().unwrap(); + let filtered_children = struct_array + .children() + .map(|c| filter_canonical_array(&c, filter)) + .collect::>(); + let vec_validity = struct_array + .logical_validity() + .into_array() + .into_bool() + .unwrap() + .boolean_buffer(); + + StructArray::try_new( + struct_array.names().clone(), + filtered_children, + filter.iter().filter(|b| **b).map(|b| *b as usize).sum(), + Validity::from( + filter + .iter() + .zip(vec_validity.iter()) + .filter(|(f, _)| **f) + .map(|(_, v)| v) + .collect::>(), + ), + ) + .unwrap() + .into_array() + } + _ => unreachable!("Not a canonical array"), + } +} diff --git a/fuzz/src/lib.rs b/fuzz/src/lib.rs index f3429d48ad..ac456148d1 100644 --- a/fuzz/src/lib.rs +++ b/fuzz/src/lib.rs @@ -1,3 +1,4 @@ +mod filter; mod search_sorted; mod slice; mod sort; @@ -10,14 +11,15 @@ use std::ops::Range; use libfuzzer_sys::arbitrary::Error::EmptyChoose; use libfuzzer_sys::arbitrary::{Arbitrary, Result, Unstructured}; pub use sort::sort_canonical_array; -use vortex::array::PrimitiveArray; +use vortex::array::{BoolArray, PrimitiveArray}; use vortex::compute::unary::scalar_at; use vortex::compute::{SearchResult, SearchSortedSide}; -use vortex::{Array, ArrayDType}; +use vortex::{Array, ArrayDType, IntoArray}; use vortex_sampling_compressor::SamplingCompressor; use vortex_scalar::arbitrary::random_scalar; use vortex_scalar::Scalar; +use crate::filter::filter_canonical_array; use crate::search_sorted::search_sorted_canonical_array; use crate::slice::slice_canonical_array; use crate::take::take_canonical_array; @@ -56,6 +58,7 @@ pub enum Action { Slice(Range), Take(Array), SearchSorted(Scalar, SearchSortedSide), + Filter(Array), } impl<'a> Arbitrary<'a> for FuzzArrayAction { @@ -65,7 +68,7 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction { let mut actions = Vec::new(); let action_count = u.int_in_range(1..=4)?; for _ in 0..action_count { - actions.push(match u.int_in_range(0..=3)? { + actions.push(match u.int_in_range(0..=4)? { 0 => { if actions .last() @@ -132,6 +135,16 @@ impl<'a> Arbitrary<'a> for FuzzArrayAction { )), ) } + 4 => { + let mask = (0..current_array.len()) + .map(|_| bool::arbitrary(u)) + .collect::>>()?; + let filtered = filter_canonical_array(¤t_array, &mask); + ( + Action::Filter(BoolArray::from(mask).into_array()), + ExpectedValue::Array(filtered), + ) + } _ => unreachable!(), }) }