From d1ec381d580ee36640f93e6dc53fa4a22c2bc79a Mon Sep 17 00:00:00 2001
From: Robert Kruszewski <github@robertk.io>
Date: Wed, 11 Sep 2024 18:42:13 +0100
Subject: [PATCH] Better scalar compare using collect_bool (#792)

---
 .../src/array/primitive/compute/compare.rs    | 34 +++++++++++--------
 1 file changed, 20 insertions(+), 14 deletions(-)

diff --git a/vortex-array/src/array/primitive/compute/compare.rs b/vortex-array/src/array/primitive/compute/compare.rs
index 840f339596..cf6e5fc4f4 100644
--- a/vortex-array/src/array/primitive/compute/compare.rs
+++ b/vortex-array/src/array/primitive/compute/compare.rs
@@ -1,5 +1,5 @@
 use arrow_buffer::bit_util::ceil;
-use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, MutableBuffer};
+use arrow_buffer::{BooleanBuffer, MutableBuffer};
 use vortex_dtype::{match_each_native_ptype, NativePType};
 use vortex_error::{VortexExpect, VortexResult};
 use vortex_scalar::PrimitiveScalar;
@@ -7,7 +7,6 @@ use vortex_scalar::PrimitiveScalar;
 use crate::array::primitive::PrimitiveArray;
 use crate::array::{BoolArray, ConstantArray};
 use crate::compute::{MaybeCompareFn, Operator};
-use crate::validity::ArrayValidity;
 use crate::{Array, IntoArray};
 
 impl MaybeCompareFn for PrimitiveArray {
@@ -42,23 +41,28 @@ fn primitive_const_compare(
     other: ConstantArray,
     operator: Operator,
 ) -> VortexResult<Array> {
-    let mut builder = BooleanBufferBuilder::new(this.len());
     let primitive_scalar =
         PrimitiveScalar::try_from(other.scalar()).vortex_expect("Expected a primitive scalar");
 
-    match_each_native_ptype!(this.ptype(), |$T| {
-        let op_fn = operator.to_fn::<$T>();
+    let buffer = match_each_native_ptype!(this.ptype(), |$T| {
         let typed_value = primitive_scalar.typed_value::<$T>().unwrap();
-        for v in this.maybe_null_slice::<$T>() {
-            builder.append(op_fn(*v, typed_value));
-        }
+        primitive_value_compare::<$T>(this, typed_value, operator)
     });
 
-    let validity = this
-        .validity()
-        .and(other.logical_validity().into_validity())?
-        .into_nullable();
-    Ok(BoolArray::try_new(builder.finish(), validity)?.into_array())
+    Ok(BoolArray::try_new(buffer, this.validity().into_nullable())?.into_array())
+}
+
+fn primitive_value_compare<T: NativePType>(
+    this: &PrimitiveArray,
+    value: T,
+    op: Operator,
+) -> BooleanBuffer {
+    let op_fn = op.to_fn::<T>();
+    let slice = this.maybe_null_slice::<T>();
+
+    BooleanBuffer::collect_bool(this.len(), |idx| {
+        op_fn(unsafe { *slice.get_unchecked(idx) }, value)
+    })
 }
 
 fn apply_predicate<T: NativePType, F: Fn(T, T) -> bool>(
@@ -78,7 +82,9 @@ fn apply_predicate<T: NativePType, F: Fn(T, T) -> bool>(
         let mut packed_block = 0_u64;
         for bit_idx in 0..BLOCK_SIZE {
             let idx = bit_idx + block * BLOCK_SIZE;
-            let r = f(lhs[idx], rhs[idx]);
+            let r = f(unsafe { *lhs.get_unchecked(idx) }, unsafe {
+                *rhs.get_unchecked(idx)
+            });
             packed_block |= (r as u64) << bit_idx;
         }