diff --git a/fuzz/src/filter.rs b/fuzz/src/filter.rs index 4450ab963..fe7feb6d6 100644 --- a/fuzz/src/filter.rs +++ b/fuzz/src/filter.rs @@ -7,15 +7,27 @@ use vortex_dtype::{match_each_native_ptype, DType}; use vortex_error::VortexExpect; pub fn filter_canonical_array(array: &ArrayData, filter: &[bool]) -> ArrayData { + let validity = if array.dtype().is_nullable() { + let validity_buff = array + .logical_validity() + .into_array() + .into_bool() + .unwrap() + .boolean_buffer(); + Validity::from_iter( + filter + .iter() + .zip(validity_buff.iter()) + .filter(|(f, _)| **f) + .map(|(_, v)| v), + ) + } else { + Validity::NonNullable + }; + match array.dtype() { DType::Bool(_) => { let bool_array = array.clone().into_bool().unwrap(); - let vec_validity = bool_array - .logical_validity() - .into_array() - .into_bool() - .unwrap() - .boolean_buffer(); BoolArray::try_new( BooleanBuffer::from_iter( filter @@ -24,25 +36,13 @@ pub fn filter_canonical_array(array: &ArrayData, filter: &[bool]) -> ArrayData { .filter(|(f, _)| **f) .map(|(_, v)| v), ), - Validity::from_iter( - filter - .iter() - .zip(vec_validity.iter()) - .filter(|(f, _)| **f) - .map(|(_, v)| v), - ), + validity, ) .vortex_expect("Validity length cannot mismatch") .into_array() } DType::Primitive(p, _) => match_each_native_ptype!(p, |$P| { let primitive_array = array.clone().into_primitive().unwrap(); - let vec_validity = primitive_array - .logical_validity() - .into_array() - .into_bool() - .unwrap() - .boolean_buffer(); PrimitiveArray::from_vec( filter .iter() @@ -50,13 +50,7 @@ pub fn filter_canonical_array(array: &ArrayData, filter: &[bool]) -> ArrayData { .filter(|(f, _)| **f) .map(|(_, v)| v) .collect::>(), - Validity::from_iter( - filter - .iter() - .zip(vec_validity.iter()) - .filter(|(f, _)| **f) - .map(|(_, v)| v) - ), + validity, ) .into_array() }), @@ -78,24 +72,12 @@ pub fn filter_canonical_array(array: &ArrayData, filter: &[bool]) -> ArrayData { .children() .map(|c| filter_canonical_array(&c, filter)) .collect::>(); - let vec_validity = struct_array - .logical_validity() - .into_array() - .into_bool() - .unwrap() - .boolean_buffer(); StructArray::try_new( struct_array.names().clone(), filtered_children, filter.iter().filter(|b| **b).map(|b| *b as usize).sum(), - Validity::from_iter( - filter - .iter() - .zip(vec_validity.iter()) - .filter(|(f, _)| **f) - .map(|(_, v)| v), - ), + validity, ) .unwrap() .into_array() diff --git a/vortex-array/src/validity.rs b/vortex-array/src/validity.rs index 6fea9246c..f26e96ec8 100644 --- a/vortex-array/src/validity.rs +++ b/vortex-array/src/validity.rs @@ -405,7 +405,7 @@ impl FromIterator for Validity { impl FromIterator for Validity { fn from_iter>(iter: T) -> Self { - Self::Array(BoolArray::from_iter(iter).into_array()) + Validity::from(BooleanBuffer::from_iter(iter)) } }