Skip to content

Commit

Permalink
Better scalar compare using collect_bool (#792)
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 authored Sep 11, 2024
1 parent 13107a3 commit d1ec381
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions vortex-array/src/array/primitive/compute/compare.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
use arrow_buffer::bit_util::ceil;
use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, MutableBuffer};
use arrow_buffer::{BooleanBuffer, MutableBuffer};
use vortex_dtype::{match_each_native_ptype, NativePType};
use vortex_error::{VortexExpect, VortexResult};
use vortex_scalar::PrimitiveScalar;

use crate::array::primitive::PrimitiveArray;
use crate::array::{BoolArray, ConstantArray};
use crate::compute::{MaybeCompareFn, Operator};
use crate::validity::ArrayValidity;
use crate::{Array, IntoArray};

impl MaybeCompareFn for PrimitiveArray {
Expand Down Expand Up @@ -42,23 +41,28 @@ fn primitive_const_compare(
other: ConstantArray,
operator: Operator,
) -> VortexResult<Array> {
let mut builder = BooleanBufferBuilder::new(this.len());
let primitive_scalar =
PrimitiveScalar::try_from(other.scalar()).vortex_expect("Expected a primitive scalar");

match_each_native_ptype!(this.ptype(), |$T| {
let op_fn = operator.to_fn::<$T>();
let buffer = match_each_native_ptype!(this.ptype(), |$T| {
let typed_value = primitive_scalar.typed_value::<$T>().unwrap();
for v in this.maybe_null_slice::<$T>() {
builder.append(op_fn(*v, typed_value));
}
primitive_value_compare::<$T>(this, typed_value, operator)
});

let validity = this
.validity()
.and(other.logical_validity().into_validity())?
.into_nullable();
Ok(BoolArray::try_new(builder.finish(), validity)?.into_array())
Ok(BoolArray::try_new(buffer, this.validity().into_nullable())?.into_array())
}

fn primitive_value_compare<T: NativePType>(
this: &PrimitiveArray,
value: T,
op: Operator,
) -> BooleanBuffer {
let op_fn = op.to_fn::<T>();
let slice = this.maybe_null_slice::<T>();

BooleanBuffer::collect_bool(this.len(), |idx| {
op_fn(unsafe { *slice.get_unchecked(idx) }, value)
})
}

fn apply_predicate<T: NativePType, F: Fn(T, T) -> bool>(
Expand All @@ -78,7 +82,9 @@ fn apply_predicate<T: NativePType, F: Fn(T, T) -> bool>(
let mut packed_block = 0_u64;
for bit_idx in 0..BLOCK_SIZE {
let idx = bit_idx + block * BLOCK_SIZE;
let r = f(lhs[idx], rhs[idx]);
let r = f(unsafe { *lhs.get_unchecked(idx) }, unsafe {
*rhs.get_unchecked(idx)
});
packed_block |= (r as u64) << bit_idx;
}

Expand Down

0 comments on commit d1ec381

Please sign in to comment.