diff --git a/vortex-array/Cargo.toml b/vortex-array/Cargo.toml index c0cc5497e2..4c7238ab09 100644 --- a/vortex-array/Cargo.toml +++ b/vortex-array/Cargo.toml @@ -64,4 +64,8 @@ harness = false [[bench]] name = "filter_indices" -harness = false \ No newline at end of file +harness = false + +[[bench]] +name = "compare" +harness = false diff --git a/vortex-array/benches/compare.rs b/vortex-array/benches/compare.rs new file mode 100644 index 0000000000..2a00411d9e --- /dev/null +++ b/vortex-array/benches/compare.rs @@ -0,0 +1,66 @@ +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use itertools::Itertools; +use rand::distributions::Uniform; +use rand::{thread_rng, Rng}; +use vortex::array::bool::BoolArray; +use vortex::IntoArray; +use vortex_error::VortexError; +use vortex_expr::operators::Operator; + +fn filter_bool_indices(c: &mut Criterion) { + let mut group = c.benchmark_group("compare"); + + let mut rng = thread_rng(); + let range = Uniform::new(0u8, 1); + let arr = BoolArray::from( + (0..10_000_000) + .map(|_| rng.sample(range) == 0) + .collect_vec(), + ) + .into_array(); + let arr2 = BoolArray::from( + (0..10_000_000) + .map(|_| rng.sample(range) == 0) + .collect_vec(), + ) + .into_array(); + + group.bench_function("compare_bool", |b| { + b.iter(|| { + let indices = + vortex::compute::compare::compare(&arr, &arr2, Operator::GreaterThanOrEqualTo) + .unwrap(); + black_box(indices); + Ok::<(), VortexError>(()) + }); + }); +} + +fn filter_indices(c: &mut Criterion) { + let mut group = c.benchmark_group("compare"); + + let mut rng = thread_rng(); + let range = Uniform::new(0i64, 100_000_000); + let arr = (0..10_000_000) + .map(|_| rng.sample(range)) + .collect_vec() + .into_array(); + + let arr2 = (0..10_000_000) + .map(|_| rng.sample(range)) + .collect_vec() + .into_array(); + + group.bench_function("compare_int", |b| { + b.iter(|| { + let indices = + vortex::compute::compare::compare(&arr, &arr2, Operator::GreaterThanOrEqualTo) + .unwrap(); + black_box(indices); + Ok::<(), VortexError>(()) + }); + }); +} + +criterion_group!(benches, filter_indices, filter_bool_indices); +criterion_main!(benches); diff --git a/vortex-array/src/array/bool/compute/compare.rs b/vortex-array/src/array/bool/compute/compare.rs new file mode 100644 index 0000000000..cc9969b1d4 --- /dev/null +++ b/vortex-array/src/array/bool/compute/compare.rs @@ -0,0 +1,87 @@ +use std::ops::{BitAnd, BitOr, BitXor, Not}; + +use vortex_error::VortexResult; +use vortex_expr::operators::Operator; + +use crate::array::bool::BoolArray; +use crate::compute::compare::CompareFn; +use crate::{Array, ArrayTrait, IntoArray}; + +impl CompareFn for BoolArray { + fn compare(&self, other: &Array, op: Operator) -> VortexResult { + let flattened = other.clone().flatten_bool()?; + let lhs = self.boolean_buffer(); + let rhs = flattened.boolean_buffer(); + let result_buf = match op { + Operator::EqualTo => lhs.bitxor(&rhs).not(), + Operator::NotEqualTo => lhs.bitxor(&rhs), + + Operator::GreaterThan => lhs.bitand(&rhs.not()), + Operator::GreaterThanOrEqualTo => lhs.bitor(&rhs.not()), + Operator::LessThan => lhs.not().bitand(&rhs), + Operator::LessThanOrEqualTo => lhs.not().bitor(&rhs), + }; + Ok(BoolArray::from( + self.validity() + .to_logical(self.len()) + .to_null_buffer()? + .map(|nulls| result_buf.bitand(&nulls.into_inner())) + .unwrap_or(result_buf), + ) + .into_array()) + } +} + +#[cfg(test)] +mod test { + use itertools::Itertools; + + use super::*; + use crate::compute::compare::compare; + use crate::validity::Validity; + + fn to_int_indices(indices_bits: BoolArray) -> Vec { + let filtered = indices_bits + .boolean_buffer() + .iter() + .enumerate() + .flat_map(|(idx, v)| if v { Some(idx as u64) } else { None }) + .collect_vec(); + filtered + } + + #[test] + fn test_basic_comparisons() -> VortexResult<()> { + let arr = BoolArray::from_vec( + vec![true, true, false, true, false], + Validity::Array(BoolArray::from(vec![false, true, true, true, true]).into_array()), + ) + .into_array(); + + let matches = compare(&arr, &arr, Operator::EqualTo)?.flatten_bool()?; + assert_eq!(to_int_indices(matches), [1u64, 2, 3, 4]); + + let matches = compare(&arr, &arr, Operator::NotEqualTo)?.flatten_bool()?; + let empty: [u64; 0] = []; + assert_eq!(to_int_indices(matches), empty); + + let other = BoolArray::from_vec( + vec![false, false, false, true, true], + Validity::Array(BoolArray::from(vec![false, true, true, true, true]).into_array()), + ) + .into_array(); + + let matches = compare(&arr, &other, Operator::LessThanOrEqualTo)?.flatten_bool()?; + assert_eq!(to_int_indices(matches), [2u64, 3, 4]); + + let matches = compare(&arr, &other, Operator::LessThan)?.flatten_bool()?; + assert_eq!(to_int_indices(matches), [4u64]); + + let matches = compare(&other, &arr, Operator::GreaterThanOrEqualTo)?.flatten_bool()?; + assert_eq!(to_int_indices(matches), [2u64, 3, 4]); + + let matches = compare(&other, &arr, Operator::GreaterThan)?.flatten_bool()?; + assert_eq!(to_int_indices(matches), [4u64]); + Ok(()) + } +} diff --git a/vortex-array/src/array/bool/compute/mod.rs b/vortex-array/src/array/bool/compute/mod.rs index 35dd3e4a15..b8832b8113 100644 --- a/vortex-array/src/array/bool/compute/mod.rs +++ b/vortex-array/src/array/bool/compute/mod.rs @@ -1,6 +1,7 @@ use crate::array::bool::BoolArray; use crate::compute::as_arrow::AsArrowArray; use crate::compute::as_contiguous::AsContiguousFn; +use crate::compute::compare::CompareFn; use crate::compute::fill::FillForwardFn; use crate::compute::scalar_at::ScalarAtFn; use crate::compute::slice::SliceFn; @@ -9,6 +10,7 @@ use crate::compute::ArrayCompute; mod as_arrow; mod as_contiguous; +mod compare; mod fill; mod flatten; mod scalar_at; @@ -24,6 +26,10 @@ impl ArrayCompute for BoolArray { Some(self) } + fn compare(&self) -> Option<&dyn CompareFn> { + Some(self) + } + fn fill_forward(&self) -> Option<&dyn FillForwardFn> { Some(self) } diff --git a/vortex-array/src/array/primitive/compute/compare.rs b/vortex-array/src/array/primitive/compute/compare.rs new file mode 100644 index 0000000000..4d4ce18238 --- /dev/null +++ b/vortex-array/src/array/primitive/compute/compare.rs @@ -0,0 +1,117 @@ +use std::ops::BitAnd; + +use arrow_buffer::BooleanBuffer; +use vortex_dtype::{match_each_native_ptype, NativePType}; +use vortex_error::VortexResult; +use vortex_expr::operators::Operator; + +use crate::array::bool::BoolArray; +use crate::array::primitive::PrimitiveArray; +use crate::compute::compare::CompareFn; +use crate::{Array, ArrayTrait, IntoArray}; + +impl CompareFn for PrimitiveArray { + fn compare(&self, other: &Array, predicate: Operator) -> VortexResult { + let flattened = other.clone().flatten_primitive()?; + + let matching_idxs = match_each_native_ptype!(self.ptype(), |$T| { + let predicate_fn = &predicate.to_predicate::<$T>(); + apply_predicate(self.typed_data::<$T>(), flattened.typed_data::<$T>(), predicate_fn) + }); + + let present = self + .validity() + .to_logical(self.len()) + .to_present_null_buffer()? + .into_inner(); + let present_other = flattened + .validity() + .to_logical(self.len()) + .to_present_null_buffer()? + .into_inner(); + + Ok(BoolArray::from(matching_idxs.bitand(&present).bitand(&present_other)).into_array()) + } +} + +fn apply_predicate bool>( + lhs: &[T], + rhs: &[T], + f: F, +) -> BooleanBuffer { + let matches = lhs.iter().zip(rhs.iter()).map(|(lhs, rhs)| f(lhs, rhs)); + BooleanBuffer::from_iter(matches) +} + +#[cfg(test)] +mod test { + use itertools::Itertools; + + use super::*; + use crate::compute::compare::compare; + + fn to_int_indices(indices_bits: BoolArray) -> Vec { + let filtered = indices_bits + .boolean_buffer() + .iter() + .enumerate() + .flat_map(|(idx, v)| if v { Some(idx as u64) } else { None }) + .collect_vec(); + filtered + } + + #[test] + fn test_basic_comparisons() -> VortexResult<()> { + let arr = PrimitiveArray::from_nullable_vec(vec![ + Some(1i32), + Some(2), + Some(3), + Some(4), + None, + Some(5), + Some(6), + Some(7), + Some(8), + None, + Some(9), + None, + ]) + .into_array(); + + let matches = compare(&arr, &arr, Operator::EqualTo)?.flatten_bool()?; + assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]); + + let matches = compare(&arr, &arr, Operator::NotEqualTo)?.flatten_bool()?; + let empty: [u64; 0] = []; + assert_eq!(to_int_indices(matches), empty); + + let other = PrimitiveArray::from_nullable_vec(vec![ + Some(1i32), + Some(2), + Some(3), + Some(4), + None, + Some(6), + Some(7), + Some(8), + Some(9), + None, + Some(10), + None, + ]) + .into_array(); + + let matches = compare(&arr, &other, Operator::LessThanOrEqualTo)?.flatten_bool()?; + assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]); + + let matches = compare(&arr, &other, Operator::LessThan)?.flatten_bool()?; + assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]); + + let matches = compare(&other, &arr, Operator::GreaterThanOrEqualTo)?.flatten_bool()?; + assert_eq!(to_int_indices(matches), [0u64, 1, 2, 3, 5, 6, 7, 8, 10]); + + let matches = compare(&other, &arr, Operator::GreaterThan)?.flatten_bool()?; + assert_eq!(to_int_indices(matches), [5u64, 6, 7, 8, 10]); + Ok(()) + } +} diff --git a/vortex-array/src/array/primitive/compute/mod.rs b/vortex-array/src/array/primitive/compute/mod.rs index 8aa141a4de..8b87128a11 100644 --- a/vortex-array/src/array/primitive/compute/mod.rs +++ b/vortex-array/src/array/primitive/compute/mod.rs @@ -2,7 +2,9 @@ use crate::array::primitive::PrimitiveArray; use crate::compute::as_arrow::AsArrowArray; use crate::compute::as_contiguous::AsContiguousFn; use crate::compute::cast::CastFn; +use crate::compute::compare::CompareFn; use crate::compute::fill::FillForwardFn; +use crate::compute::filter_indices::FilterIndicesFn; use crate::compute::scalar_at::ScalarAtFn; use crate::compute::scalar_subtract::SubtractScalarFn; use crate::compute::search_sorted::SearchSortedFn; @@ -13,6 +15,7 @@ use crate::compute::ArrayCompute; mod as_arrow; mod as_contiguous; mod cast; +mod compare; mod fill; mod filter_indices; mod scalar_at; @@ -34,9 +37,16 @@ impl ArrayCompute for PrimitiveArray { Some(self) } + fn compare(&self) -> Option<&dyn CompareFn> { + Some(self) + } + fn fill_forward(&self) -> Option<&dyn FillForwardFn> { Some(self) } + fn filter_indices(&self) -> Option<&dyn FilterIndicesFn> { + Some(self) + } fn scalar_at(&self) -> Option<&dyn ScalarAtFn> { Some(self) diff --git a/vortex-array/src/compute/compare.rs b/vortex-array/src/compute/compare.rs new file mode 100644 index 0000000000..c1b00a056c --- /dev/null +++ b/vortex-array/src/compute/compare.rs @@ -0,0 +1,29 @@ +use vortex_dtype::DType; +use vortex_error::{vortex_err, VortexResult}; +use vortex_expr::operators::Operator; + +use crate::{Array, ArrayDType}; + +pub trait CompareFn { + fn compare(&self, array: &Array, predicate: Operator) -> VortexResult; +} + +pub fn compare(array: &Array, other: &Array, predicate: Operator) -> VortexResult { + if let Some(matching_indices) = + array.with_dyn(|c| c.compare().map(|t| t.compare(other, predicate))) + { + return matching_indices; + } + // if compare is not implemented for the given array type, but the array has a numeric + // DType, we can flatten the array and apply filter to the flattened primitive array + match array.dtype() { + DType::Primitive(..) => { + let flat = array.clone().flatten_primitive()?; + flat.compare(other, predicate) + } + _ => Err(vortex_err!( + NotImplemented: "compare", + array.encoding().id() + )), + } +} diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 8853d27d53..8ca8fd1815 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -1,6 +1,7 @@ use as_arrow::AsArrowArray; use as_contiguous::AsContiguousFn; use cast::CastFn; +use compare::CompareFn; use fill::FillForwardFn; use patch::PatchFn; use scalar_at::ScalarAtFn; @@ -14,6 +15,7 @@ use crate::compute::scalar_subtract::SubtractScalarFn; pub mod as_arrow; pub mod as_contiguous; pub mod cast; +pub mod compare; pub mod fill; pub mod filter_indices; pub mod patch; @@ -36,6 +38,10 @@ pub trait ArrayCompute { None } + fn compare(&self) -> Option<&dyn CompareFn> { + None + } + fn fill_forward(&self) -> Option<&dyn FillForwardFn> { None }