From d1ec381d580ee36640f93e6dc53fa4a22c2bc79a Mon Sep 17 00:00:00 2001 From: Robert Kruszewski <github@robertk.io> Date: Wed, 11 Sep 2024 18:42:13 +0100 Subject: [PATCH] Better scalar compare using collect_bool (#792) --- .../src/array/primitive/compute/compare.rs | 34 +++++++++++-------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/vortex-array/src/array/primitive/compute/compare.rs b/vortex-array/src/array/primitive/compute/compare.rs index 840f339596..cf6e5fc4f4 100644 --- a/vortex-array/src/array/primitive/compute/compare.rs +++ b/vortex-array/src/array/primitive/compute/compare.rs @@ -1,5 +1,5 @@ use arrow_buffer::bit_util::ceil; -use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, MutableBuffer}; +use arrow_buffer::{BooleanBuffer, MutableBuffer}; use vortex_dtype::{match_each_native_ptype, NativePType}; use vortex_error::{VortexExpect, VortexResult}; use vortex_scalar::PrimitiveScalar; @@ -7,7 +7,6 @@ use vortex_scalar::PrimitiveScalar; use crate::array::primitive::PrimitiveArray; use crate::array::{BoolArray, ConstantArray}; use crate::compute::{MaybeCompareFn, Operator}; -use crate::validity::ArrayValidity; use crate::{Array, IntoArray}; impl MaybeCompareFn for PrimitiveArray { @@ -42,23 +41,28 @@ fn primitive_const_compare( other: ConstantArray, operator: Operator, ) -> VortexResult<Array> { - let mut builder = BooleanBufferBuilder::new(this.len()); let primitive_scalar = PrimitiveScalar::try_from(other.scalar()).vortex_expect("Expected a primitive scalar"); - match_each_native_ptype!(this.ptype(), |$T| { - let op_fn = operator.to_fn::<$T>(); + let buffer = match_each_native_ptype!(this.ptype(), |$T| { let typed_value = primitive_scalar.typed_value::<$T>().unwrap(); - for v in this.maybe_null_slice::<$T>() { - builder.append(op_fn(*v, typed_value)); - } + primitive_value_compare::<$T>(this, typed_value, operator) }); - let validity = this - .validity() - .and(other.logical_validity().into_validity())? - .into_nullable(); - Ok(BoolArray::try_new(builder.finish(), validity)?.into_array()) + Ok(BoolArray::try_new(buffer, this.validity().into_nullable())?.into_array()) +} + +fn primitive_value_compare<T: NativePType>( + this: &PrimitiveArray, + value: T, + op: Operator, +) -> BooleanBuffer { + let op_fn = op.to_fn::<T>(); + let slice = this.maybe_null_slice::<T>(); + + BooleanBuffer::collect_bool(this.len(), |idx| { + op_fn(unsafe { *slice.get_unchecked(idx) }, value) + }) } fn apply_predicate<T: NativePType, F: Fn(T, T) -> bool>( @@ -78,7 +82,9 @@ fn apply_predicate<T: NativePType, F: Fn(T, T) -> bool>( let mut packed_block = 0_u64; for bit_idx in 0..BLOCK_SIZE { let idx = bit_idx + block * BLOCK_SIZE; - let r = f(lhs[idx], rhs[idx]); + let r = f(unsafe { *lhs.get_unchecked(idx) }, unsafe { + *rhs.get_unchecked(idx) + }); packed_block |= (r as u64) << bit_idx; }