diff --git a/vortex-array/src/array/primitive/compute/compare.rs b/vortex-array/src/array/primitive/compute/compare.rs index f2b0066e90..635a3d92ef 100644 --- a/vortex-array/src/array/primitive/compute/compare.rs +++ b/vortex-array/src/array/primitive/compute/compare.rs @@ -1,14 +1,10 @@ -use arrow_buffer::bit_util::ceil; -use arrow_buffer::{BooleanBuffer, MutableBuffer}; -use vortex_dtype::{match_each_native_ptype, NativePType}; -use vortex_error::{vortex_err, VortexExpect, VortexResult}; -use vortex_scalar::PrimitiveScalar; +use vortex_error::VortexResult; use crate::array::primitive::PrimitiveArray; -use crate::array::{BoolArray, ConstantArray}; -use crate::compute::{MaybeCompareFn, Operator}; -use crate::variants::PrimitiveArrayTrait; -use crate::{ArrayDType, ArrayData, IntoArrayData}; +use crate::array::ConstantArray; +use crate::compute::{arrow_compare, MaybeCompareFn, Operator}; +use crate::stats::{ArrayStatistics, Stat}; +use crate::ArrayData; impl MaybeCompareFn for PrimitiveArray { fn maybe_compare( @@ -16,179 +12,23 @@ impl MaybeCompareFn for PrimitiveArray { other: &ArrayData, operator: Operator, ) -> Option> { - if let Ok(const_array) = ConstantArray::try_from(other) { - return Some(primitive_const_compare(self, const_array, operator)); + // If the RHS is constant, then delegate to Arrow since. + // TODO(ngates): remove these dual checks once we make stats not a hashmap + // https://github.com/spiraldb/vortex/issues/1309 + if ConstantArray::try_from(other).is_ok() + || other + .statistics() + .get_as::(Stat::IsConstant) + .unwrap_or(false) + { + return Some(arrow_compare(self.as_ref(), other, operator)); } + // If the RHS is primitive, then delegate to Arrow. if let Ok(primitive) = PrimitiveArray::try_from(other) { - let match_mask = match_each_native_ptype!(self.ptype(), |$T| { - apply_predicate(self.maybe_null_slice::<$T>(), primitive.maybe_null_slice::<$T>(), operator.to_fn::<$T>()) - }); - - let validity = self - .validity() - .and(primitive.validity()) - .map(|v| v.into_nullable()); - - return Some( - validity - .and_then(|v| BoolArray::try_new(match_mask, v)) - .map(|a| a.into_array()), - ); + return Some(arrow_compare(self.as_ref(), primitive.as_ref(), operator)); } None } } - -fn primitive_const_compare( - this: &PrimitiveArray, - other: ConstantArray, - operator: Operator, -) -> VortexResult { - let primitive_scalar = PrimitiveScalar::try_new(other.dtype(), other.scalar_value()) - .vortex_expect("Expected a primitive scalar"); - - let buffer = match_each_native_ptype!(this.ptype(), |$T| { - let typed_value = primitive_scalar.typed_value::<$T>() - .ok_or_else(|| vortex_err!("Type mismatch between array and constant"))?; - primitive_value_compare::<$T>(this, typed_value, operator) - }); - - Ok(BoolArray::try_new(buffer, this.validity().into_nullable())?.into_array()) -} - -fn primitive_value_compare( - this: &PrimitiveArray, - value: T, - op: Operator, -) -> BooleanBuffer { - let op_fn = op.to_fn::(); - let slice = this.maybe_null_slice::(); - - BooleanBuffer::collect_bool(this.len(), |idx| { - op_fn(unsafe { *slice.get_unchecked(idx) }, value) - }) -} - -fn apply_predicate bool>( - lhs: &[T], - rhs: &[T], - f: F, -) -> BooleanBuffer { - const BLOCK_SIZE: usize = u64::BITS as usize; - - let len = lhs.len(); - let reminder = len % BLOCK_SIZE; - let block_count = len / BLOCK_SIZE; - - let mut buffer = MutableBuffer::new(ceil(len, BLOCK_SIZE) * 8); - - for block in 0..block_count { - let mut packed_block = 0_u64; - for bit_idx in 0..BLOCK_SIZE { - let idx = bit_idx + block * BLOCK_SIZE; - let r = f(unsafe { *lhs.get_unchecked(idx) }, unsafe { - *rhs.get_unchecked(idx) - }); - packed_block |= (r as u64) << bit_idx; - } - - unsafe { - buffer.push_unchecked(packed_block); - } - } - - if reminder != 0 { - let mut packed_block = 0_u64; - for bit_idx in 0..reminder { - let idx = bit_idx + block_count * BLOCK_SIZE; - let r = f(lhs[idx], rhs[idx]); - packed_block |= (r as u64) << bit_idx; - } - - unsafe { - buffer.push_unchecked(packed_block); - } - } - - BooleanBuffer::new(buffer.into(), 0, len) -} - -#[cfg(test)] -#[allow(clippy::panic_in_result_fn)] -mod test { - use itertools::Itertools; - - use super::*; - use crate::compute::compare; - use crate::IntoArrayVariant; - - fn to_int_indices(indices_bits: BoolArray) -> Vec { - let filtered = indices_bits - .boolean_buffer() - .iter() - .enumerate() - .filter_map(|(idx, v)| { - let valid_and_true = indices_bits.validity().is_valid(idx) & v; - valid_and_true.then_some(idx as u64) - }) - .collect_vec(); - filtered - } - - #[test] - fn test_basic_comparisons() -> VortexResult<()> { - let arr = PrimitiveArray::from_nullable_vec(vec![ - Some(1i32), - Some(2), - Some(3), - Some(4), - None, - Some(5), - Some(6), - Some(7), - Some(8), - None, - Some(9), - None, - ]) - .into_array(); - - let matches = compare(&arr, &arr, Operator::Eq)?.into_bool()?; - assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]); - - let matches = compare(&arr, &arr, Operator::NotEq)?.into_bool()?; - let empty: [u64; 0] = []; - assert_eq!(to_int_indices(matches), empty); - - let other = PrimitiveArray::from_nullable_vec(vec![ - Some(1i32), - Some(2), - Some(3), - Some(4), - None, - Some(6), - Some(7), - Some(8), - Some(9), - None, - Some(10), - None, - ]) - .into_array(); - - let matches = compare(&arr, &other, Operator::Lte)?.into_bool()?; - assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]); - - let matches = compare(&arr, &other, Operator::Lt)?.into_bool()?; - assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]); - - let matches = compare(&other, &arr, Operator::Gte)?.into_bool()?; - assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]); - - let matches = compare(&other, &arr, Operator::Gt)?.into_bool()?; - assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]); - Ok(()) - } -} diff --git a/vortex-array/src/compute/compare.rs b/vortex-array/src/compute/compare.rs index 5877267493..126a271ec2 100644 --- a/vortex-array/src/compute/compare.rs +++ b/vortex-array/src/compute/compare.rs @@ -127,8 +127,17 @@ pub fn compare( } // Fallback to arrow on canonical types - let lhs = Datum::try_from(left.clone())?; - let rhs = Datum::try_from(right.clone())?; + arrow_compare(left, right, operator) +} + +/// Implementation of `CompareFn` using the Arrow crate. +pub(crate) fn arrow_compare( + lhs: &ArrayData, + rhs: &ArrayData, + operator: Operator, +) -> VortexResult { + let lhs = Datum::try_from(lhs.clone())?; + let rhs = Datum::try_from(rhs.clone())?; let array = match operator { Operator::Eq => cmp::eq(&lhs, &rhs)?, diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 07e7714bf6..9ced04dccd 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -8,6 +8,7 @@ //! from Arrow. pub use boolean::{and, and_kleene, or, or_kleene, AndFn, OrFn}; +pub(crate) use compare::arrow_compare; pub use compare::{compare, scalar_cmp, CompareFn, MaybeCompareFn, Operator}; pub use filter::{filter, FilterFn}; pub use search_sorted::*;