Skip to content

Commit

Permalink
Remove primitive compare impl (#1337)
Browse files Browse the repository at this point in the history
And just delegate to Arrow compute instead.

I'm curious if our implementation did anything different? I don't think
so...
  • Loading branch information
gatesn authored Nov 16, 2024
1 parent d3a28f4 commit a1d5349
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 179 deletions.
194 changes: 17 additions & 177 deletions vortex-array/src/array/primitive/compute/compare.rs
Original file line number Diff line number Diff line change
@@ -1,194 +1,34 @@
use arrow_buffer::bit_util::ceil;
use arrow_buffer::{BooleanBuffer, MutableBuffer};
use vortex_dtype::{match_each_native_ptype, NativePType};
use vortex_error::{vortex_err, VortexExpect, VortexResult};
use vortex_scalar::PrimitiveScalar;
use vortex_error::VortexResult;

use crate::array::primitive::PrimitiveArray;
use crate::array::{BoolArray, ConstantArray};
use crate::compute::{MaybeCompareFn, Operator};
use crate::variants::PrimitiveArrayTrait;
use crate::{ArrayDType, ArrayData, IntoArrayData};
use crate::array::ConstantArray;
use crate::compute::{arrow_compare, MaybeCompareFn, Operator};
use crate::stats::{ArrayStatistics, Stat};
use crate::ArrayData;

impl MaybeCompareFn for PrimitiveArray {
fn maybe_compare(
&self,
other: &ArrayData,
operator: Operator,
) -> Option<VortexResult<ArrayData>> {
if let Ok(const_array) = ConstantArray::try_from(other) {
return Some(primitive_const_compare(self, const_array, operator));
// If the RHS is constant, then delegate to Arrow since.
// TODO(ngates): remove these dual checks once we make stats not a hashmap
// https://github.com/spiraldb/vortex/issues/1309
if ConstantArray::try_from(other).is_ok()
|| other
.statistics()
.get_as::<bool>(Stat::IsConstant)
.unwrap_or(false)
{
return Some(arrow_compare(self.as_ref(), other, operator));
}

// If the RHS is primitive, then delegate to Arrow.
if let Ok(primitive) = PrimitiveArray::try_from(other) {
let match_mask = match_each_native_ptype!(self.ptype(), |$T| {
apply_predicate(self.maybe_null_slice::<$T>(), primitive.maybe_null_slice::<$T>(), operator.to_fn::<$T>())
});

let validity = self
.validity()
.and(primitive.validity())
.map(|v| v.into_nullable());

return Some(
validity
.and_then(|v| BoolArray::try_new(match_mask, v))
.map(|a| a.into_array()),
);
return Some(arrow_compare(self.as_ref(), primitive.as_ref(), operator));
}

None
}
}

fn primitive_const_compare(
this: &PrimitiveArray,
other: ConstantArray,
operator: Operator,
) -> VortexResult<ArrayData> {
let primitive_scalar = PrimitiveScalar::try_new(other.dtype(), other.scalar_value())
.vortex_expect("Expected a primitive scalar");

let buffer = match_each_native_ptype!(this.ptype(), |$T| {
let typed_value = primitive_scalar.typed_value::<$T>()
.ok_or_else(|| vortex_err!("Type mismatch between array and constant"))?;
primitive_value_compare::<$T>(this, typed_value, operator)
});

Ok(BoolArray::try_new(buffer, this.validity().into_nullable())?.into_array())
}

fn primitive_value_compare<T: NativePType>(
this: &PrimitiveArray,
value: T,
op: Operator,
) -> BooleanBuffer {
let op_fn = op.to_fn::<T>();
let slice = this.maybe_null_slice::<T>();

BooleanBuffer::collect_bool(this.len(), |idx| {
op_fn(unsafe { *slice.get_unchecked(idx) }, value)
})
}

fn apply_predicate<T: NativePType, F: Fn(T, T) -> bool>(
lhs: &[T],
rhs: &[T],
f: F,
) -> BooleanBuffer {
const BLOCK_SIZE: usize = u64::BITS as usize;

let len = lhs.len();
let reminder = len % BLOCK_SIZE;
let block_count = len / BLOCK_SIZE;

let mut buffer = MutableBuffer::new(ceil(len, BLOCK_SIZE) * 8);

for block in 0..block_count {
let mut packed_block = 0_u64;
for bit_idx in 0..BLOCK_SIZE {
let idx = bit_idx + block * BLOCK_SIZE;
let r = f(unsafe { *lhs.get_unchecked(idx) }, unsafe {
*rhs.get_unchecked(idx)
});
packed_block |= (r as u64) << bit_idx;
}

unsafe {
buffer.push_unchecked(packed_block);
}
}

if reminder != 0 {
let mut packed_block = 0_u64;
for bit_idx in 0..reminder {
let idx = bit_idx + block_count * BLOCK_SIZE;
let r = f(lhs[idx], rhs[idx]);
packed_block |= (r as u64) << bit_idx;
}

unsafe {
buffer.push_unchecked(packed_block);
}
}

BooleanBuffer::new(buffer.into(), 0, len)
}

#[cfg(test)]
#[allow(clippy::panic_in_result_fn)]
mod test {
use itertools::Itertools;

use super::*;
use crate::compute::compare;
use crate::IntoArrayVariant;

fn to_int_indices(indices_bits: BoolArray) -> Vec<u64> {
let filtered = indices_bits
.boolean_buffer()
.iter()
.enumerate()
.filter_map(|(idx, v)| {
let valid_and_true = indices_bits.validity().is_valid(idx) & v;
valid_and_true.then_some(idx as u64)
})
.collect_vec();
filtered
}

#[test]
fn test_basic_comparisons() -> VortexResult<()> {
let arr = PrimitiveArray::from_nullable_vec(vec![
Some(1i32),
Some(2),
Some(3),
Some(4),
None,
Some(5),
Some(6),
Some(7),
Some(8),
None,
Some(9),
None,
])
.into_array();

let matches = compare(&arr, &arr, Operator::Eq)?.into_bool()?;
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);

let matches = compare(&arr, &arr, Operator::NotEq)?.into_bool()?;
let empty: [u64; 0] = [];
assert_eq!(to_int_indices(matches), empty);

let other = PrimitiveArray::from_nullable_vec(vec![
Some(1i32),
Some(2),
Some(3),
Some(4),
None,
Some(6),
Some(7),
Some(8),
Some(9),
None,
Some(10),
None,
])
.into_array();

let matches = compare(&arr, &other, Operator::Lte)?.into_bool()?;
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);

let matches = compare(&arr, &other, Operator::Lt)?.into_bool()?;
assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]);

let matches = compare(&other, &arr, Operator::Gte)?.into_bool()?;
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);

let matches = compare(&other, &arr, Operator::Gt)?.into_bool()?;
assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]);
Ok(())
}
}
13 changes: 11 additions & 2 deletions vortex-array/src/compute/compare.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,17 @@ pub fn compare(
}

// Fallback to arrow on canonical types
let lhs = Datum::try_from(left.clone())?;
let rhs = Datum::try_from(right.clone())?;
arrow_compare(left, right, operator)
}

/// Implementation of `CompareFn` using the Arrow crate.
pub(crate) fn arrow_compare(
lhs: &ArrayData,
rhs: &ArrayData,
operator: Operator,
) -> VortexResult<ArrayData> {
let lhs = Datum::try_from(lhs.clone())?;
let rhs = Datum::try_from(rhs.clone())?;

let array = match operator {
Operator::Eq => cmp::eq(&lhs, &rhs)?,
Expand Down
1 change: 1 addition & 0 deletions vortex-array/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
//! from Arrow.
pub use boolean::{and, and_kleene, or, or_kleene, AndFn, OrFn};
pub(crate) use compare::arrow_compare;
pub use compare::{compare, scalar_cmp, CompareFn, MaybeCompareFn, Operator};
pub use filter::{filter, FilterFn};
pub use search_sorted::*;
Expand Down

0 comments on commit a1d5349

Please sign in to comment.