Skip to content

Commit

Permalink
Array comparison compute function (#336)
Browse files Browse the repository at this point in the history
Necessary for implementing right-hand-side field references in array
predicates
  • Loading branch information
jdcasale authored May 21, 2024
1 parent 8b6606a commit 48ec35c
Show file tree
Hide file tree
Showing 8 changed files with 326 additions and 1 deletion.
6 changes: 5 additions & 1 deletion vortex-array/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,4 +64,8 @@ harness = false

[[bench]]
name = "filter_indices"
harness = false
harness = false

[[bench]]
name = "compare"
harness = false
66 changes: 66 additions & 0 deletions vortex-array/benches/compare.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
use criterion::{black_box, criterion_group, criterion_main, Criterion};
use itertools::Itertools;
use rand::distributions::Uniform;
use rand::{thread_rng, Rng};
use vortex::array::bool::BoolArray;
use vortex::IntoArray;
use vortex_error::VortexError;
use vortex_expr::operators::Operator;

fn filter_bool_indices(c: &mut Criterion) {
let mut group = c.benchmark_group("compare");

let mut rng = thread_rng();
let range = Uniform::new(0u8, 1);
let arr = BoolArray::from(
(0..10_000_000)
.map(|_| rng.sample(range) == 0)
.collect_vec(),
)
.into_array();
let arr2 = BoolArray::from(
(0..10_000_000)
.map(|_| rng.sample(range) == 0)
.collect_vec(),
)
.into_array();

group.bench_function("compare_bool", |b| {
b.iter(|| {
let indices =
vortex::compute::compare::compare(&arr, &arr2, Operator::GreaterThanOrEqualTo)
.unwrap();
black_box(indices);
Ok::<(), VortexError>(())
});
});
}

fn filter_indices(c: &mut Criterion) {
let mut group = c.benchmark_group("compare");

let mut rng = thread_rng();
let range = Uniform::new(0i64, 100_000_000);
let arr = (0..10_000_000)
.map(|_| rng.sample(range))
.collect_vec()
.into_array();

let arr2 = (0..10_000_000)
.map(|_| rng.sample(range))
.collect_vec()
.into_array();

group.bench_function("compare_int", |b| {
b.iter(|| {
let indices =
vortex::compute::compare::compare(&arr, &arr2, Operator::GreaterThanOrEqualTo)
.unwrap();
black_box(indices);
Ok::<(), VortexError>(())
});
});
}

criterion_group!(benches, filter_indices, filter_bool_indices);
criterion_main!(benches);
87 changes: 87 additions & 0 deletions vortex-array/src/array/bool/compute/compare.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
use std::ops::{BitAnd, BitOr, BitXor, Not};

use vortex_error::VortexResult;
use vortex_expr::operators::Operator;

use crate::array::bool::BoolArray;
use crate::compute::compare::CompareFn;
use crate::{Array, ArrayTrait, IntoArray};

impl CompareFn for BoolArray {
fn compare(&self, other: &Array, op: Operator) -> VortexResult<Array> {
let flattened = other.clone().flatten_bool()?;
let lhs = self.boolean_buffer();
let rhs = flattened.boolean_buffer();
let result_buf = match op {
Operator::EqualTo => lhs.bitxor(&rhs).not(),
Operator::NotEqualTo => lhs.bitxor(&rhs),

Operator::GreaterThan => lhs.bitand(&rhs.not()),
Operator::GreaterThanOrEqualTo => lhs.bitor(&rhs.not()),
Operator::LessThan => lhs.not().bitand(&rhs),
Operator::LessThanOrEqualTo => lhs.not().bitor(&rhs),
};
Ok(BoolArray::from(
self.validity()
.to_logical(self.len())
.to_null_buffer()?
.map(|nulls| result_buf.bitand(&nulls.into_inner()))
.unwrap_or(result_buf),
)
.into_array())
}
}

#[cfg(test)]
mod test {
use itertools::Itertools;

use super::*;
use crate::compute::compare::compare;
use crate::validity::Validity;

fn to_int_indices(indices_bits: BoolArray) -> Vec<u64> {
let filtered = indices_bits
.boolean_buffer()
.iter()
.enumerate()
.flat_map(|(idx, v)| if v { Some(idx as u64) } else { None })
.collect_vec();
filtered
}

#[test]
fn test_basic_comparisons() -> VortexResult<()> {
let arr = BoolArray::from_vec(
vec![true, true, false, true, false],
Validity::Array(BoolArray::from(vec![false, true, true, true, true]).into_array()),
)
.into_array();

let matches = compare(&arr, &arr, Operator::EqualTo)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [1u64, 2, 3, 4]);

let matches = compare(&arr, &arr, Operator::NotEqualTo)?.flatten_bool()?;
let empty: [u64; 0] = [];
assert_eq!(to_int_indices(matches), empty);

let other = BoolArray::from_vec(
vec![false, false, false, true, true],
Validity::Array(BoolArray::from(vec![false, true, true, true, true]).into_array()),
)
.into_array();

let matches = compare(&arr, &other, Operator::LessThanOrEqualTo)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [2u64, 3, 4]);

let matches = compare(&arr, &other, Operator::LessThan)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [4u64]);

let matches = compare(&other, &arr, Operator::GreaterThanOrEqualTo)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [2u64, 3, 4]);

let matches = compare(&other, &arr, Operator::GreaterThan)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [4u64]);
Ok(())
}
}
6 changes: 6 additions & 0 deletions vortex-array/src/array/bool/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use crate::array::bool::BoolArray;
use crate::compute::as_arrow::AsArrowArray;
use crate::compute::as_contiguous::AsContiguousFn;
use crate::compute::compare::CompareFn;
use crate::compute::fill::FillForwardFn;
use crate::compute::scalar_at::ScalarAtFn;
use crate::compute::slice::SliceFn;
Expand All @@ -9,6 +10,7 @@ use crate::compute::ArrayCompute;

mod as_arrow;
mod as_contiguous;
mod compare;
mod fill;
mod flatten;
mod scalar_at;
Expand All @@ -24,6 +26,10 @@ impl ArrayCompute for BoolArray {
Some(self)
}

fn compare(&self) -> Option<&dyn CompareFn> {
Some(self)
}

fn fill_forward(&self) -> Option<&dyn FillForwardFn> {
Some(self)
}
Expand Down
117 changes: 117 additions & 0 deletions vortex-array/src/array/primitive/compute/compare.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use std::ops::BitAnd;

use arrow_buffer::BooleanBuffer;
use vortex_dtype::{match_each_native_ptype, NativePType};
use vortex_error::VortexResult;
use vortex_expr::operators::Operator;

use crate::array::bool::BoolArray;
use crate::array::primitive::PrimitiveArray;
use crate::compute::compare::CompareFn;
use crate::{Array, ArrayTrait, IntoArray};

impl CompareFn for PrimitiveArray {
fn compare(&self, other: &Array, predicate: Operator) -> VortexResult<Array> {
let flattened = other.clone().flatten_primitive()?;

let matching_idxs = match_each_native_ptype!(self.ptype(), |$T| {
let predicate_fn = &predicate.to_predicate::<$T>();
apply_predicate(self.typed_data::<$T>(), flattened.typed_data::<$T>(), predicate_fn)
});

let present = self
.validity()
.to_logical(self.len())
.to_present_null_buffer()?
.into_inner();
let present_other = flattened
.validity()
.to_logical(self.len())
.to_present_null_buffer()?
.into_inner();

Ok(BoolArray::from(matching_idxs.bitand(&present).bitand(&present_other)).into_array())
}
}

fn apply_predicate<T: NativePType, F: Fn(&T, &T) -> bool>(
lhs: &[T],
rhs: &[T],
f: F,
) -> BooleanBuffer {
let matches = lhs.iter().zip(rhs.iter()).map(|(lhs, rhs)| f(lhs, rhs));
BooleanBuffer::from_iter(matches)
}

#[cfg(test)]
mod test {
use itertools::Itertools;

use super::*;
use crate::compute::compare::compare;

fn to_int_indices(indices_bits: BoolArray) -> Vec<u64> {
let filtered = indices_bits
.boolean_buffer()
.iter()
.enumerate()
.flat_map(|(idx, v)| if v { Some(idx as u64) } else { None })
.collect_vec();
filtered
}

#[test]
fn test_basic_comparisons() -> VortexResult<()> {
let arr = PrimitiveArray::from_nullable_vec(vec![
Some(1i32),
Some(2),
Some(3),
Some(4),
None,
Some(5),
Some(6),
Some(7),
Some(8),
None,
Some(9),
None,
])
.into_array();

let matches = compare(&arr, &arr, Operator::EqualTo)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);

let matches = compare(&arr, &arr, Operator::NotEqualTo)?.flatten_bool()?;
let empty: [u64; 0] = [];
assert_eq!(to_int_indices(matches), empty);

let other = PrimitiveArray::from_nullable_vec(vec![
Some(1i32),
Some(2),
Some(3),
Some(4),
None,
Some(6),
Some(7),
Some(8),
Some(9),
None,
Some(10),
None,
])
.into_array();

let matches = compare(&arr, &other, Operator::LessThanOrEqualTo)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);

let matches = compare(&arr, &other, Operator::LessThan)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]);

let matches = compare(&other, &arr, Operator::GreaterThanOrEqualTo)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]);

let matches = compare(&other, &arr, Operator::GreaterThan)?.flatten_bool()?;
assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]);
Ok(())
}
}
10 changes: 10 additions & 0 deletions vortex-array/src/array/primitive/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use crate::array::primitive::PrimitiveArray;
use crate::compute::as_arrow::AsArrowArray;
use crate::compute::as_contiguous::AsContiguousFn;
use crate::compute::cast::CastFn;
use crate::compute::compare::CompareFn;
use crate::compute::fill::FillForwardFn;
use crate::compute::filter_indices::FilterIndicesFn;
use crate::compute::scalar_at::ScalarAtFn;
use crate::compute::scalar_subtract::SubtractScalarFn;
use crate::compute::search_sorted::SearchSortedFn;
Expand All @@ -13,6 +15,7 @@ use crate::compute::ArrayCompute;
mod as_arrow;
mod as_contiguous;
mod cast;
mod compare;
mod fill;
mod filter_indices;
mod scalar_at;
Expand All @@ -34,9 +37,16 @@ impl ArrayCompute for PrimitiveArray {
Some(self)
}

fn compare(&self) -> Option<&dyn CompareFn> {
Some(self)
}

fn fill_forward(&self) -> Option<&dyn FillForwardFn> {
Some(self)
}
fn filter_indices(&self) -> Option<&dyn FilterIndicesFn> {
Some(self)
}

fn scalar_at(&self) -> Option<&dyn ScalarAtFn> {
Some(self)
Expand Down
29 changes: 29 additions & 0 deletions vortex-array/src/compute/compare.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
use vortex_dtype::DType;
use vortex_error::{vortex_err, VortexResult};
use vortex_expr::operators::Operator;

use crate::{Array, ArrayDType};

pub trait CompareFn {
fn compare(&self, array: &Array, predicate: Operator) -> VortexResult<Array>;
}

pub fn compare(array: &Array, other: &Array, predicate: Operator) -> VortexResult<Array> {
if let Some(matching_indices) =
array.with_dyn(|c| c.compare().map(|t| t.compare(other, predicate)))
{
return matching_indices;
}
// if compare is not implemented for the given array type, but the array has a numeric
// DType, we can flatten the array and apply filter to the flattened primitive array
match array.dtype() {
DType::Primitive(..) => {
let flat = array.clone().flatten_primitive()?;
flat.compare(other, predicate)
}
_ => Err(vortex_err!(
NotImplemented: "compare",
array.encoding().id()
)),
}
}
Loading

0 comments on commit 48ec35c

Please sign in to comment.