diff --git a/.github/workflows/arrow.yml b/.github/workflows/arrow.yml index 279e276a7912..8203c15afc6c 100644 --- a/.github/workflows/arrow.yml +++ b/.github/workflows/arrow.yml @@ -80,8 +80,8 @@ jobs: run: cargo test -p arrow-json --all-features - name: Test arrow-string with all features run: cargo test -p arrow-string --all-features - - name: Test arrow-ord with all features except SIMD - run: cargo test -p arrow-ord --features dyn_cmp_dict + - name: Test arrow-ord with all features + run: cargo test -p arrow-ord --all-features - name: Test arrow-arith with all features except SIMD run: cargo test -p arrow-arith - name: Test arrow-row with all features @@ -145,8 +145,6 @@ jobs: rust-version: nightly - name: Test arrow-array with SIMD run: cargo test -p arrow-array --features simd - - name: Test arrow-ord with SIMD - run: cargo test -p arrow-ord --features simd - name: Test arrow-arith with SIMD run: cargo test -p arrow-arith --features simd - name: Test arrow with SIMD @@ -206,8 +204,8 @@ jobs: run: cargo clippy -p arrow-json --all-targets --all-features -- -D warnings - name: Clippy arrow-string with all features run: cargo clippy -p arrow-string --all-targets --all-features -- -D warnings - - name: Clippy arrow-ord with all features except SIMD - run: cargo clippy -p arrow-ord --all-targets --features dyn_cmp_dict -- -D warnings + - name: Clippy arrow-ord with all features + run: cargo clippy -p arrow-ord --all-targets --all-features -- -D warnings - name: Clippy arrow-arith with all features except SIMD run: cargo clippy -p arrow-arith --all-targets -- -D warnings - name: Clippy arrow-row with all features diff --git a/.github/workflows/miri.sh b/.github/workflows/miri.sh index faf9f028d281..ec8712660c74 100755 --- a/.github/workflows/miri.sh +++ b/.github/workflows/miri.sh @@ -15,4 +15,4 @@ cargo miri test -p arrow-data --features ffi cargo miri test -p arrow-schema --features ffi cargo miri test -p arrow-array cargo miri test -p arrow-arith --features simd -cargo miri test -p arrow-ord --features simd +cargo miri test -p arrow-ord diff --git a/arrow-flight/src/sql/metadata/db_schemas.rs b/arrow-flight/src/sql/metadata/db_schemas.rs index 7b10e1c14299..20780a116032 100644 --- a/arrow-flight/src/sql/metadata/db_schemas.rs +++ b/arrow-flight/src/sql/metadata/db_schemas.rs @@ -22,8 +22,8 @@ use std::sync::Arc; use arrow_arith::boolean::and; -use arrow_array::{builder::StringBuilder, ArrayRef, RecordBatch}; -use arrow_ord::comparison::eq_utf8_scalar; +use arrow_array::{builder::StringBuilder, ArrayRef, RecordBatch, Scalar, StringArray}; +use arrow_ord::cmp::eq; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use arrow_select::{filter::filter_record_batch, take::take}; use arrow_string::like::like_utf8_scalar; @@ -129,7 +129,8 @@ impl GetDbSchemasBuilder { } if let Some(catalog_filter_name) = catalog_filter { - filters.push(eq_utf8_scalar(&catalog_name, &catalog_filter_name)?); + let scalar = StringArray::from_iter_values([catalog_filter_name]); + filters.push(eq(&catalog_name, &Scalar::new(&scalar))?); } // `AND` any filters together diff --git a/arrow-flight/src/sql/metadata/sql_info.rs b/arrow-flight/src/sql/metadata/sql_info.rs index b37ac85308f4..88c97227814d 100644 --- a/arrow-flight/src/sql/metadata/sql_info.rs +++ b/arrow-flight/src/sql/metadata/sql_info.rs @@ -33,10 +33,9 @@ use arrow_array::builder::{ ArrayBuilder, BooleanBuilder, Int32Builder, Int64Builder, Int8Builder, ListBuilder, MapBuilder, StringBuilder, UInt32Builder, }; -use arrow_array::cast::downcast_array; -use arrow_array::RecordBatch; +use arrow_array::{RecordBatch, Scalar}; use arrow_data::ArrayData; -use arrow_ord::comparison::eq_scalar; +use arrow_ord::cmp::eq; use arrow_schema::{DataType, Field, Fields, Schema, SchemaRef, UnionFields, UnionMode}; use arrow_select::filter::filter_record_batch; use once_cell::sync::Lazy; @@ -425,13 +424,16 @@ impl SqlInfoData { &self, info: impl IntoIterator, ) -> Result { - let arr: UInt32Array = downcast_array(self.batch.column(0).as_ref()); + let arr = self.batch.column(0); let type_filter = info .into_iter() - .map(|tt| eq_scalar(&arr, tt)) + .map(|tt| { + let s = UInt32Array::from(vec![tt]); + eq(arr, &Scalar::new(&s)) + }) .collect::, _>>()? .into_iter() - // We know the arrays are of same length as they are produced fromn the same root array + // We know the arrays are of same length as they are produced from the same root array .reduce(|filter, arr| or(&filter, &arr).unwrap()); if let Some(filter) = type_filter { Ok(filter_record_batch(&self.batch, &filter)?) diff --git a/arrow-flight/src/sql/metadata/tables.rs b/arrow-flight/src/sql/metadata/tables.rs index 67193969d46d..de55f0624f2f 100644 --- a/arrow-flight/src/sql/metadata/tables.rs +++ b/arrow-flight/src/sql/metadata/tables.rs @@ -23,8 +23,8 @@ use std::sync::Arc; use arrow_arith::boolean::{and, or}; use arrow_array::builder::{BinaryBuilder, StringBuilder}; -use arrow_array::{ArrayRef, RecordBatch}; -use arrow_ord::comparison::eq_utf8_scalar; +use arrow_array::{ArrayRef, RecordBatch, Scalar, StringArray}; +use arrow_ord::cmp::eq; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use arrow_select::{filter::filter_record_batch, take::take}; use arrow_string::like::like_utf8_scalar; @@ -184,12 +184,16 @@ impl GetTablesBuilder { let mut filters = vec![]; if let Some(catalog_filter_name) = catalog_filter { - filters.push(eq_utf8_scalar(&catalog_name, &catalog_filter_name)?); + let scalar = StringArray::from_iter_values([catalog_filter_name]); + filters.push(eq(&catalog_name, &Scalar::new(&scalar))?); } let tt_filter = table_types_filter .into_iter() - .map(|tt| eq_utf8_scalar(&table_type, &tt)) + .map(|tt| { + let scalar = StringArray::from_iter_values([tt]); + eq(&table_type, &Scalar::new(&scalar)) + }) .collect::, _>>()? .into_iter() // We know the arrays are of same length as they are produced fromn the same root array diff --git a/arrow-flight/src/sql/metadata/xdbc_info.rs b/arrow-flight/src/sql/metadata/xdbc_info.rs index b70a3ce3cb3e..8212c847a4fa 100644 --- a/arrow-flight/src/sql/metadata/xdbc_info.rs +++ b/arrow-flight/src/sql/metadata/xdbc_info.rs @@ -27,9 +27,8 @@ use std::sync::Arc; use arrow_array::builder::{BooleanBuilder, Int32Builder, ListBuilder, StringBuilder}; -use arrow_array::cast::downcast_array; -use arrow_array::{ArrayRef, Int32Array, ListArray, RecordBatch}; -use arrow_ord::comparison::eq_scalar; +use arrow_array::{ArrayRef, Int32Array, ListArray, RecordBatch, Scalar}; +use arrow_ord::cmp::eq; use arrow_schema::{DataType, Field, Schema, SchemaRef}; use arrow_select::filter::filter_record_batch; use arrow_select::take::take; @@ -81,8 +80,8 @@ impl XdbcTypeInfoData { /// from [`CommandGetXdbcTypeInfo`] pub fn record_batch(&self, data_type: impl Into>) -> Result { if let Some(dt) = data_type.into() { - let arr: Int32Array = downcast_array(self.batch.column(1).as_ref()); - let filter = eq_scalar(&arr, dt)?; + let scalar = Int32Array::from(vec![dt]); + let filter = eq(self.batch.column(1), &Scalar::new(&scalar))?; Ok(filter_record_batch(&self.batch, &filter)?) } else { Ok(self.batch.clone()) diff --git a/arrow-ord/Cargo.toml b/arrow-ord/Cargo.toml index fb061b9b5499..c9c30074fe6e 100644 --- a/arrow-ord/Cargo.toml +++ b/arrow-ord/Cargo.toml @@ -44,10 +44,3 @@ half = { version = "2.1", default-features = false, features = ["num-traits"] } [dev-dependencies] rand = { version = "0.8", default-features = false, features = ["std", "std_rng"] } - -[package.metadata.docs.rs] -features = ["dyn_cmp_dict"] - -[features] -dyn_cmp_dict = [] -simd = ["arrow-array/simd"] diff --git a/arrow-ord/src/cmp.rs b/arrow-ord/src/cmp.rs new file mode 100644 index 000000000000..aad61fa8f062 --- /dev/null +++ b/arrow-ord/src/cmp.rs @@ -0,0 +1,489 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Comparison kernels for `Array`s. +//! +//! These kernels can leverage SIMD if available on your system. Currently no runtime +//! detection is provided, you should enable the specific SIMD intrinsics using +//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation +//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. +//! + +use arrow_array::cast::AsArray; +use arrow_array::types::ByteArrayType; +use arrow_array::{ + downcast_primitive_array, AnyDictionaryArray, Array, ArrowNativeTypeOp, BooleanArray, + Datum, FixedSizeBinaryArray, GenericByteArray, +}; +use arrow_buffer::bit_util::ceil; +use arrow_buffer::{BooleanBuffer, MutableBuffer, NullBuffer}; +use arrow_schema::ArrowError; +use arrow_select::take::take; + +#[derive(Debug, Copy, Clone)] +enum Op { + Equal, + NotEqual, + Less, + LessEqual, + Greater, + GreaterEqual, +} + +impl std::fmt::Display for Op { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Op::Equal => write!(f, "=="), + Op::NotEqual => write!(f, "!="), + Op::Less => write!(f, "<"), + Op::LessEqual => write!(f, "<="), + Op::Greater => write!(f, ">"), + Op::GreaterEqual => write!(f, ">="), + } + } +} + +/// Perform `left == right` operation on two [`Datum`] +/// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Note that totalOrder treats positive and negative zeros as different. If it is necessary +/// to treat them as equal, please normalize zeros before calling this kernel. +/// +/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`] +pub fn eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + compare_op(Op::Equal, lhs, rhs) +} + +/// Perform `left != right` operation on two [`Datum`] +/// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Note that totalOrder treats positive and negative zeros as different. If it is necessary +/// to treat them as equal, please normalize zeros before calling this kernel. +/// +/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`] +pub fn neq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + compare_op(Op::NotEqual, lhs, rhs) +} + +/// Perform `left < right` operation on two [`Datum`] +/// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Note that totalOrder treats positive and negative zeros as different. If it is necessary +/// to treat them as equal, please normalize zeros before calling this kernel. +/// +/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`] +pub fn lt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + compare_op(Op::Less, lhs, rhs) +} + +/// Perform `left <= right` operation on two [`Datum`] +/// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Note that totalOrder treats positive and negative zeros as different. If it is necessary +/// to treat them as equal, please normalize zeros before calling this kernel. +/// +/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`] +pub fn lt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + compare_op(Op::LessEqual, lhs, rhs) +} + +/// Perform `left > right` operation on two [`Datum`] +/// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Note that totalOrder treats positive and negative zeros as different. If it is necessary +/// to treat them as equal, please normalize zeros before calling this kernel. +/// +/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`] +pub fn gt(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + compare_op(Op::Greater, lhs, rhs) +} + +/// Perform `left >= right` operation on two [`Datum`] +/// +/// For floating values like f32 and f64, this comparison produces an ordering in accordance to +/// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. +/// Note that totalOrder treats positive and negative zeros as different. If it is necessary +/// to treat them as equal, please normalize zeros before calling this kernel. +/// +/// Please refer to [`f32::total_cmp`] and [`f64::total_cmp`] +pub fn gt_eq(lhs: &dyn Datum, rhs: &dyn Datum) -> Result { + compare_op(Op::GreaterEqual, lhs, rhs) +} + +/// Perform `op` on the provided `Datum` +fn compare_op( + op: Op, + lhs: &dyn Datum, + rhs: &dyn Datum, +) -> Result { + use arrow_schema::DataType::*; + let (l, l_s) = lhs.get(); + let (r, r_s) = rhs.get(); + + let l_len = l.len(); + let r_len = r.len(); + let l_nulls = l.logical_nulls(); + let r_nulls = r.logical_nulls(); + + let (len, nulls) = match (l_s, r_s) { + (true, true) | (false, false) => { + if l_len != r_len { + return Err(ArrowError::InvalidArgumentError(format!( + "Cannot compare arrays of different lengths, got {l_len} vs {r_len}" + ))); + } + (l_len, NullBuffer::union(l_nulls.as_ref(), r_nulls.as_ref())) + } + (true, false) => match l_nulls.map(|x| x.null_count() != 0).unwrap_or_default() { + true => (r_len, Some(NullBuffer::new_null(r_len))), + false => (r_len, r_nulls), // Left is scalar and not null + }, + (false, true) => match r_nulls.map(|x| x.null_count() != 0).unwrap_or_default() { + true => (l_len, Some(NullBuffer::new_null(l_len))), + false => (l_len, l_nulls), // Right is scalar and not null + }, + }; + + let l_v = l.as_any_dictionary_opt(); + let l = l_v.map(|x| x.values().as_ref()).unwrap_or(l); + + let r_v = r.as_any_dictionary_opt(); + let r = r_v.map(|x| x.values().as_ref()).unwrap_or(r); + + let values = downcast_primitive_array! { + (l, r) => apply(op, l.values().as_ref(), l_s, l_v, r.values().as_ref(), r_s, r_v), + (Boolean, Boolean) => apply(op, l.as_boolean(), l_s, l_v, r.as_boolean(), r_s, r_v), + (Utf8, Utf8) => apply(op, l.as_string::(), l_s, l_v, r.as_string::(), r_s, r_v), + (LargeUtf8, LargeUtf8) => apply(op, l.as_string::(), l_s, l_v, r.as_string::(), r_s, r_v), + (Binary, Binary) => apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), + (LargeBinary, LargeBinary) => apply(op, l.as_binary::(), l_s, l_v, r.as_binary::(), r_s, r_v), + (FixedSizeBinary(_), FixedSizeBinary(_)) => apply(op, l.as_fixed_size_binary(), l_s, l_v, r.as_fixed_size_binary(), r_s, r_v), + (l_t, r_t) => return Err(ArrowError::InvalidArgumentError(format!("Invalid comparison operation: {l_t} {op} {r_t}"))), + }.unwrap_or_else(|| { + let count = nulls.as_ref().map(|x| x.null_count()).unwrap_or_default(); + assert_eq!(count, len); // Sanity check + BooleanBuffer::new_unset(len) + }); + + assert_eq!(values.len(), len); // Sanity check + Ok(BooleanArray::new(values, nulls)) +} + +/// Perform a potentially vectored `op` on the provided `ArrayOrd` +fn apply( + op: Op, + l: T, + l_s: bool, + l_v: Option<&dyn AnyDictionaryArray>, + r: T, + r_s: bool, + r_v: Option<&dyn AnyDictionaryArray>, +) -> Option { + if l.len() == 0 || r.len() == 0 { + return None; // Handle empty dictionaries + } + + if !l_s && !r_s && (l_v.is_some() || r_v.is_some()) { + // Not scalar and at least one side has a dictionary, need to perform vectored comparison + let l_v = l_v + .map(|x| x.normalized_keys()) + .unwrap_or_else(|| (0..l.len()).collect()); + + let r_v = r_v + .map(|x| x.normalized_keys()) + .unwrap_or_else(|| (0..r.len()).collect()); + + assert_eq!(l_v.len(), r_v.len()); // Sanity check + + Some(match op { + Op::Equal => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_eq), + Op::NotEqual => apply_op_vectored(l, &l_v, r, &r_v, true, T::is_eq), + Op::Less => apply_op_vectored(l, &l_v, r, &r_v, false, T::is_lt), + Op::LessEqual => apply_op_vectored(r, &r_v, l, &l_v, true, T::is_lt), + Op::Greater => apply_op_vectored(r, &r_v, l, &l_v, false, T::is_lt), + Op::GreaterEqual => apply_op_vectored(l, &l_v, r, &r_v, true, T::is_lt), + }) + } else { + let l_s = l_s.then(|| l_v.map(|x| x.normalized_keys()[0]).unwrap_or_default()); + let r_s = r_s.then(|| r_v.map(|x| x.normalized_keys()[0]).unwrap_or_default()); + + let buffer = match op { + Op::Equal => apply_op(l, l_s, r, r_s, false, T::is_eq), + Op::NotEqual => apply_op(l, l_s, r, r_s, true, T::is_eq), + Op::Less => apply_op(l, l_s, r, r_s, false, T::is_lt), + Op::LessEqual => apply_op(r, r_s, l, l_s, true, T::is_lt), + Op::Greater => apply_op(r, r_s, l, l_s, false, T::is_lt), + Op::GreaterEqual => apply_op(l, l_s, r, r_s, true, T::is_lt), + }; + + // If a side had a dictionary, and was not scalar, we need to materialize this + Some(match (l_v, r_v) { + (Some(l_v), _) if l_s.is_none() => take_bits(l_v, buffer), + (_, Some(r_v)) if r_s.is_none() => take_bits(r_v, buffer), + _ => buffer, + }) + } +} + +/// Perform a take operation on `buffer` with the given dictionary +fn take_bits(v: &dyn AnyDictionaryArray, buffer: BooleanBuffer) -> BooleanBuffer { + let array = take(&BooleanArray::new(buffer, None), v.keys(), None).unwrap(); + array.as_boolean().values().clone() +} + +/// Invokes `f` with values `0..len` collecting the boolean results into a new `BooleanBuffer` +/// +/// This is similar to [`MutableBuffer::collect_bool`] but with +/// the option to efficiently negate the result +fn collect_bool(len: usize, neg: bool, f: impl Fn(usize) -> bool) -> BooleanBuffer { + let mut buffer = MutableBuffer::new(ceil(len, 64) * 8); + + let chunks = len / 64; + let remainder = len % 64; + for chunk in 0..chunks { + let mut packed = 0; + for bit_idx in 0..64 { + let i = bit_idx + chunk * 64; + packed |= (f(i) as u64) << bit_idx; + } + if neg { + packed = !packed + } + + // SAFETY: Already allocated sufficient capacity + unsafe { buffer.push_unchecked(packed) } + } + + if remainder != 0 { + let mut packed = 0; + for bit_idx in 0..remainder { + let i = bit_idx + chunks * 64; + packed |= (f(i) as u64) << bit_idx; + } + if neg { + packed = !packed + } + + // SAFETY: Already allocated sufficient capacity + unsafe { buffer.push_unchecked(packed) } + } + BooleanBuffer::new(buffer.into(), 0, len) +} + +/// Applies `op` to possibly scalar `ArrayOrd` +/// +/// If l is scalar `l_s` will be `Some(idx)` where `idx` is the index of the scalar value in `l` +/// If r is scalar `r_s` will be `Some(idx)` where `idx` is the index of the scalar value in `r` +fn apply_op( + l: T, + l_s: Option, + r: T, + r_s: Option, + neg: bool, + op: impl Fn(T::Item, T::Item) -> bool, +) -> BooleanBuffer { + match (l_s, r_s) { + (None, None) => { + assert_eq!(l.len(), r.len()); + collect_bool(l.len(), neg, |idx| unsafe { + op(l.value_unchecked(idx), r.value_unchecked(idx)) + }) + } + (Some(l_s), Some(r_s)) => { + let a = l.value(l_s); + let b = r.value(r_s); + std::iter::once(op(a, b)).collect() + } + (Some(l_s), None) => { + let v = l.value(l_s); + collect_bool(r.len(), neg, |idx| op(v, unsafe { r.value_unchecked(idx) })) + } + (None, Some(r_s)) => { + let v = r.value(r_s); + collect_bool(l.len(), neg, |idx| op(unsafe { l.value_unchecked(idx) }, v)) + } + } +} + +/// Applies `op` to possibly scalar `ArrayOrd` with the given indices +fn apply_op_vectored( + l: T, + l_v: &[usize], + r: T, + r_v: &[usize], + neg: bool, + op: impl Fn(T::Item, T::Item) -> bool, +) -> BooleanBuffer { + assert_eq!(l_v.len(), r_v.len()); + collect_bool(l_v.len(), neg, |idx| unsafe { + let l_idx = *l_v.get_unchecked(idx); + let r_idx = *r_v.get_unchecked(idx); + op(l.value_unchecked(l_idx), r.value_unchecked(r_idx)) + }) +} + +trait ArrayOrd { + type Item: Copy + Default; + + fn len(&self) -> usize; + + fn value(&self, idx: usize) -> Self::Item { + assert!(idx < self.len()); + unsafe { self.value_unchecked(idx) } + } + + /// # Safety + /// + /// Safe if `idx < self.len()` + unsafe fn value_unchecked(&self, idx: usize) -> Self::Item; + + fn is_eq(l: Self::Item, r: Self::Item) -> bool; + + fn is_lt(l: Self::Item, r: Self::Item) -> bool; +} + +impl<'a> ArrayOrd for &'a BooleanArray { + type Item = bool; + + fn len(&self) -> usize { + Array::len(self) + } + + unsafe fn value_unchecked(&self, idx: usize) -> Self::Item { + BooleanArray::value_unchecked(self, idx) + } + + fn is_eq(l: Self::Item, r: Self::Item) -> bool { + l == r + } + + fn is_lt(l: Self::Item, r: Self::Item) -> bool { + !l & r + } +} + +impl ArrayOrd for &[T] { + type Item = T; + + fn len(&self) -> usize { + (*self).len() + } + + unsafe fn value_unchecked(&self, idx: usize) -> Self::Item { + *self.get_unchecked(idx) + } + + fn is_eq(l: Self::Item, r: Self::Item) -> bool { + l.is_eq(r) + } + + fn is_lt(l: Self::Item, r: Self::Item) -> bool { + l.is_lt(r) + } +} + +impl<'a, T: ByteArrayType> ArrayOrd for &'a GenericByteArray { + type Item = &'a [u8]; + + fn len(&self) -> usize { + Array::len(self) + } + + unsafe fn value_unchecked(&self, idx: usize) -> Self::Item { + GenericByteArray::value_unchecked(self, idx).as_ref() + } + + fn is_eq(l: Self::Item, r: Self::Item) -> bool { + l == r + } + + fn is_lt(l: Self::Item, r: Self::Item) -> bool { + l < r + } +} + +impl<'a> ArrayOrd for &'a FixedSizeBinaryArray { + type Item = &'a [u8]; + + fn len(&self) -> usize { + Array::len(self) + } + + unsafe fn value_unchecked(&self, idx: usize) -> Self::Item { + FixedSizeBinaryArray::value_unchecked(self, idx) + } + + fn is_eq(l: Self::Item, r: Self::Item) -> bool { + l == r + } + + fn is_lt(l: Self::Item, r: Self::Item) -> bool { + l < r + } +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow_array::{DictionaryArray, Int32Array, Scalar}; + + use super::*; + + #[test] + fn test_null_dict() { + let a = DictionaryArray::new( + Int32Array::new_null(10), + Arc::new(Int32Array::new_null(0)), + ); + let r = eq(&a, &a).unwrap(); + assert_eq!(r.null_count(), 10); + + let a = DictionaryArray::new( + Int32Array::from(vec![1, 2, 3, 4, 5, 6]), + Arc::new(Int32Array::new_null(10)), + ); + let r = eq(&a, &a).unwrap(); + assert_eq!(r.null_count(), 6); + + let scalar = DictionaryArray::new( + Int32Array::new_null(1), + Arc::new(Int32Array::new_null(0)), + ); + let r = eq(&a, &Scalar::new(&scalar)).unwrap(); + assert_eq!(r.null_count(), 6); + + let scalar = DictionaryArray::new( + Int32Array::new_null(1), + Arc::new(Int32Array::new_null(0)), + ); + let r = eq(&Scalar::new(&scalar), &Scalar::new(&scalar)).unwrap(); + assert_eq!(r.null_count(), 1); + + let a = DictionaryArray::new( + Int32Array::from(vec![0, 1, 2]), + Arc::new(Int32Array::from(vec![3, 2, 1])), + ); + let r = eq(&a, &Scalar::new(&scalar)).unwrap(); + assert_eq!(r.null_count(), 3); + } +} diff --git a/arrow-ord/src/comparison.rs b/arrow-ord/src/comparison.rs index 21583fac08ff..1a6e564283d7 100644 --- a/arrow-ord/src/comparison.rs +++ b/arrow-ord/src/comparison.rs @@ -23,15 +23,229 @@ //! [here](https://doc.rust-lang.org/stable/core/arch/) for more information. //! +use half::f16; +use std::sync::Arc; + use arrow_array::cast::*; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::i256; -use arrow_buffer::{bit_util, BooleanBuffer, Buffer, MutableBuffer, NullBuffer}; -use arrow_data::ArrayData; +use arrow_buffer::{bit_util, BooleanBuffer, MutableBuffer, NullBuffer}; use arrow_schema::{ArrowError, DataType, IntervalUnit, TimeUnit}; -use arrow_select::take::take; -use half::f16; + +/// Calls $RIGHT.$TY() (e.g. `right.to_i128()`) with a nice error message. +/// Type of expression is `Result<.., ArrowError>` +macro_rules! try_to_type { + ($RIGHT: expr, $TY: ident) => { + try_to_type_result($RIGHT.$TY(), &format!("{:?}", $RIGHT), stringify!($TY)) + }; +} + +// Avoids creating a closure for each combination of `$RIGHT` and `$TY` +fn try_to_type_result( + value: Option, + right: &str, + ty: &str, +) -> Result { + value.ok_or_else(|| { + ArrowError::ComputeError(format!("Could not convert {right} with {ty}",)) + }) +} + +fn make_primitive_scalar( + d: &DataType, + scalar: T, +) -> Result { + match d { + DataType::Int8 => { + let right = try_to_type!(scalar, to_i8)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::Int16 => { + let right = try_to_type!(scalar, to_i16)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::Int32 => { + let right = try_to_type!(scalar, to_i32)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::Int64 => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::UInt8 => { + let right = try_to_type!(scalar, to_u8)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::UInt16 => { + let right = try_to_type!(scalar, to_u16)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::UInt32 => { + let right = try_to_type!(scalar, to_u32)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::UInt64 => { + let right = try_to_type!(scalar, to_u64)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::Float16 => { + let right = try_to_type!(scalar, to_f32)?; + Ok(Arc::new(PrimitiveArray::::from(vec![ + f16::from_f32(right), + ]))) + } + DataType::Float32 => { + let right = try_to_type!(scalar, to_f32)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::Float64 => { + let right = try_to_type!(scalar, to_f64)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::Decimal128(_, _) => { + let right = try_to_type!(scalar, to_i128)?; + Ok(Arc::new(PrimitiveArray::::from(vec![ + right, + ]))) + } + DataType::Decimal256(_, _) => { + let right = try_to_type!(scalar, to_i128)?; + Ok(Arc::new(PrimitiveArray::::from(vec![ + i256::from_i128(right), + ]))) + } + DataType::Date32 => { + let right = try_to_type!(scalar, to_i32)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::Date64 => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from(vec![right]))) + } + DataType::Timestamp(TimeUnit::Nanosecond, _) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Timestamp(TimeUnit::Microsecond, _) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Timestamp(TimeUnit::Millisecond, _) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Timestamp(TimeUnit::Second, _) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from(vec![ + right, + ]))) + } + DataType::Time32(TimeUnit::Second) => { + let right = try_to_type!(scalar, to_i32)?; + Ok(Arc::new(PrimitiveArray::::from(vec![ + right, + ]))) + } + DataType::Time32(TimeUnit::Millisecond) => { + let right = try_to_type!(scalar, to_i32)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Time64(TimeUnit::Microsecond) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Time64(TimeUnit::Nanosecond) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Interval(IntervalUnit::YearMonth) => { + let right = try_to_type!(scalar, to_i32)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Interval(IntervalUnit::DayTime) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from(vec![ + right, + ]))) + } + DataType::Interval(IntervalUnit::MonthDayNano) => { + let right = try_to_type!(scalar, to_i128)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Duration(TimeUnit::Second) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from(vec![ + right, + ]))) + } + DataType::Duration(TimeUnit::Millisecond) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Duration(TimeUnit::Microsecond) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Duration(TimeUnit::Nanosecond) => { + let right = try_to_type!(scalar, to_i64)?; + Ok(Arc::new(PrimitiveArray::::from( + vec![right], + ))) + } + DataType::Dictionary(_, v) => make_primitive_scalar(v.as_ref(), scalar), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Unsupported primitive scalar data type {d:?}", + ))), + } +} + +fn make_binary_scalar(d: &DataType, scalar: &[u8]) -> Result { + match d { + DataType::Binary => Ok(Arc::new(BinaryArray::from_iter_values([scalar]))), + DataType::FixedSizeBinary(_) => Ok(Arc::new( + FixedSizeBinaryArray::try_from_iter([scalar].into_iter())?, + )), + DataType::LargeBinary => { + Ok(Arc::new(LargeBinaryArray::from_iter_values([scalar]))) + } + DataType::Dictionary(_, v) => make_binary_scalar(v.as_ref(), scalar), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Unsupported binary scalar data type {d:?}", + ))), + } +} + +fn make_utf8_scalar(d: &DataType, scalar: &str) -> Result { + match d { + DataType::Utf8 => Ok(Arc::new(StringArray::from_iter_values([scalar]))), + DataType::LargeUtf8 => Ok(Arc::new(LargeStringArray::from_iter_values([scalar]))), + DataType::Dictionary(_, v) => make_utf8_scalar(v.as_ref(), scalar), + _ => Err(ArrowError::InvalidArgumentError(format!( + "Unsupported utf8 scalar data type {d:?}", + ))), + } +} /// Helper function to perform boolean lambda function on values from two array accessors, this /// version does not attempt to use SIMD. @@ -67,6 +281,7 @@ where /// Evaluate `op(left, right)` for [`PrimitiveArray`]s using a specified /// comparison function. +#[deprecated(note = "Use BooleanArray::from_binary")] pub fn no_simd_compare_op( left: &PrimitiveArray, right: &PrimitiveArray, @@ -81,6 +296,7 @@ where /// Evaluate `op(left, right)` for [`PrimitiveArray`] and scalar using /// a specified comparison function. +#[deprecated(note = "Use BooleanArray::from_unary")] pub fn no_simd_compare_op_scalar( left: &PrimitiveArray, right: T::Native, @@ -94,617 +310,345 @@ where } /// Perform `left == right` operation on [`StringArray`] / [`LargeStringArray`]. +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op(left, right, |a, b| a == b) -} - -fn utf8_empty( - left: &GenericStringArray, -) -> Result { - let null_bit_buffer = left.nulls().map(|b| b.inner().sliced()); - - let buffer = unsafe { - MutableBuffer::from_trusted_len_iter_bool(left.value_offsets().windows(2).map( - |offset| { - if EQ { - offset[1].as_usize() == offset[0].as_usize() - } else { - offset[1].as_usize() > offset[0].as_usize() - } - }, - )) - }; - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - left.len(), - None, - null_bit_buffer, - 0, - vec![Buffer::from(buffer)], - vec![], - ) - }; - Ok(BooleanArray::from(data)) + crate::cmp::eq(left, right) } /// Perform `left == right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - if right.is_empty() { - return utf8_empty::<_, true>(left); - } - compare_op_scalar(left, |a| a == right) + let right = GenericStringArray::::from(vec![right]); + crate::cmp::eq(&left, &Scalar::new(&right)) } /// Perform `left == right` operation on [`BooleanArray`] +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_bool( left: &BooleanArray, right: &BooleanArray, ) -> Result { - compare_op(left, right, |a, b| !(a ^ b)) + crate::cmp::eq(&left, &right) } /// Perform `left != right` operation on [`BooleanArray`] +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_bool( left: &BooleanArray, right: &BooleanArray, ) -> Result { - compare_op(left, right, |a, b| (a ^ b)) + crate::cmp::neq(&left, &right) } /// Perform `left < right` operation on [`BooleanArray`] +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_bool( left: &BooleanArray, right: &BooleanArray, ) -> Result { - compare_op(left, right, |a, b| ((!a) & b)) + crate::cmp::lt(&left, &right) } /// Perform `left <= right` operation on [`BooleanArray`] +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_bool( left: &BooleanArray, right: &BooleanArray, ) -> Result { - compare_op(left, right, |a, b| !(a & (!b))) + crate::cmp::lt_eq(&left, &right) } /// Perform `left > right` operation on [`BooleanArray`] +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_bool( left: &BooleanArray, right: &BooleanArray, ) -> Result { - compare_op(left, right, |a, b| (a & (!b))) + crate::cmp::gt(&left, &right) } /// Perform `left >= right` operation on [`BooleanArray`] +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_bool( left: &BooleanArray, right: &BooleanArray, ) -> Result { - compare_op(left, right, |a, b| !((!a) & b)) + crate::cmp::gt_eq(&left, &right) } /// Perform `left == right` operation on [`BooleanArray`] and a scalar +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_bool_scalar( left: &BooleanArray, right: bool, ) -> Result { - let values = match right { - true => left.values().clone(), - false => !left.values(), - }; - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - values.len(), - None, - left.nulls().map(|b| b.inner().sliced()), - values.offset(), - vec![values.into_inner()], - vec![], - ) - }; - - Ok(BooleanArray::from(data)) + let right = BooleanArray::from(vec![right]); + crate::cmp::eq(&left, &Scalar::new(&right)) } /// Perform `left < right` operation on [`BooleanArray`] and a scalar +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_bool_scalar( left: &BooleanArray, right: bool, ) -> Result { - compare_op_scalar(left, |a: bool| !a & right) + let right = BooleanArray::from(vec![right]); + crate::cmp::lt(&left, &Scalar::new(&right)) } /// Perform `left <= right` operation on [`BooleanArray`] and a scalar +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_bool_scalar( left: &BooleanArray, right: bool, ) -> Result { - compare_op_scalar(left, |a| a <= right) + let right = BooleanArray::from(vec![right]); + crate::cmp::lt_eq(&left, &Scalar::new(&right)) } /// Perform `left > right` operation on [`BooleanArray`] and a scalar +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_bool_scalar( left: &BooleanArray, right: bool, ) -> Result { - compare_op_scalar(left, |a: bool| a & !right) + let right = BooleanArray::from(vec![right]); + crate::cmp::gt(&left, &Scalar::new(&right)) } /// Perform `left >= right` operation on [`BooleanArray`] and a scalar +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_bool_scalar( left: &BooleanArray, right: bool, ) -> Result { - compare_op_scalar(left, |a| a >= right) + let right = BooleanArray::from(vec![right]); + crate::cmp::gt_eq(&left, &Scalar::new(&right)) } /// Perform `left != right` operation on [`BooleanArray`] and a scalar +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_bool_scalar( left: &BooleanArray, right: bool, ) -> Result { - eq_bool_scalar(left, !right) + let right = BooleanArray::from(vec![right]); + crate::cmp::neq(&left, &Scalar::new(&right)) } /// Perform `left == right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op(left, right, |a, b| a == b) + crate::cmp::eq(left, right) } /// Perform `left == right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar(left, |a| a == right) + let right = GenericBinaryArray::::from_iter_values([right]); + crate::cmp::eq(left, &Scalar::new(&right)) } /// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op(left, right, |a, b| a != b) + crate::cmp::neq(left, right) } /// Perform `left != right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar(left, |a| a != right) + let right = GenericBinaryArray::::from_iter_values([right]); + crate::cmp::neq(left, &Scalar::new(&right)) } /// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op(left, right, |a, b| a < b) + crate::cmp::lt(left, right) } /// Perform `left < right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar(left, |a| a < right) + let right = GenericBinaryArray::::from_iter_values([right]); + crate::cmp::lt(left, &Scalar::new(&right)) } /// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op(left, right, |a, b| a <= b) + crate::cmp::lt_eq(left, right) } /// Perform `left <= right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar(left, |a| a <= right) + let right = GenericBinaryArray::::from_iter_values([right]); + crate::cmp::lt_eq(left, &Scalar::new(&right)) } /// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op(left, right, |a, b| a > b) + crate::cmp::gt(left, right) } /// Perform `left > right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar(left, |a| a > right) + let right = GenericBinaryArray::::from_iter_values([right]); + crate::cmp::gt(left, &Scalar::new(&right)) } /// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`]. +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_binary( left: &GenericBinaryArray, right: &GenericBinaryArray, ) -> Result { - compare_op(left, right, |a, b| a >= b) + crate::cmp::gt_eq(left, right) } /// Perform `left >= right` operation on [`BinaryArray`] / [`LargeBinaryArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_binary_scalar( left: &GenericBinaryArray, right: &[u8], ) -> Result { - compare_op_scalar(left, |a| a >= right) + let right = GenericBinaryArray::::from_iter_values([right]); + crate::cmp::gt_eq(left, &Scalar::new(&right)) } /// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`]. +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op(left, right, |a, b| a != b) + crate::cmp::neq(left, right) } /// Perform `left != right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - if right.is_empty() { - return utf8_empty::<_, false>(left); - } - compare_op_scalar(left, |a| a != right) + let right = GenericStringArray::::from_iter_values([right]); + crate::cmp::neq(left, &Scalar::new(&right)) } /// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`]. +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op(left, right, |a, b| a < b) + crate::cmp::lt(left, right) } /// Perform `left < right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar(left, |a| a < right) + let right = GenericStringArray::::from_iter_values([right]); + crate::cmp::lt(left, &Scalar::new(&right)) } /// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`]. +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op(left, right, |a, b| a <= b) + crate::cmp::lt_eq(left, right) } /// Perform `left <= right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar(left, |a| a <= right) + let right = GenericStringArray::::from_iter_values([right]); + crate::cmp::lt_eq(left, &Scalar::new(&right)) } /// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`]. +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op(left, right, |a, b| a > b) + crate::cmp::gt(left, right) } /// Perform `left > right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar(left, |a| a > right) + let right = GenericStringArray::::from_iter_values([right]); + crate::cmp::gt(left, &Scalar::new(&right)) } /// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`]. +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_utf8( left: &GenericStringArray, right: &GenericStringArray, ) -> Result { - compare_op(left, right, |a, b| a >= b) + crate::cmp::gt_eq(left, right) } /// Perform `left >= right` operation on [`StringArray`] / [`LargeStringArray`] and a scalar. +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_utf8_scalar( left: &GenericStringArray, right: &str, ) -> Result { - compare_op_scalar(left, |a| a >= right) -} - -// Avoids creating a closure for each combination of `$RIGHT` and `$TY` -fn try_to_type_result( - value: Option, - right: &str, - ty: &str, -) -> Result { - value.ok_or_else(|| { - ArrowError::ComputeError(format!("Could not convert {right} with {ty}",)) - }) -} - -/// Calls $RIGHT.$TY() (e.g. `right.to_i128()`) with a nice error message. -/// Type of expression is `Result<.., ArrowError>` -macro_rules! try_to_type { - ($RIGHT: expr, $TY: ident) => { - try_to_type_result($RIGHT.$TY(), &format!("{:?}", $RIGHT), stringify!($TY)) - }; -} - -macro_rules! dyn_compare_scalar { - // Applies `LEFT OP RIGHT` when `LEFT` is a `PrimitiveArray` - ($LEFT: expr, $RIGHT: expr, $OP: ident) => {{ - match $LEFT.data_type() { - DataType::Int8 => { - let right = try_to_type!($RIGHT, to_i8)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Int16 => { - let right = try_to_type!($RIGHT, to_i16)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Int32 => { - let right = try_to_type!($RIGHT, to_i32)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Int64 => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::UInt8 => { - let right = try_to_type!($RIGHT, to_u8)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::UInt16 => { - let right = try_to_type!($RIGHT, to_u16)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::UInt32 => { - let right = try_to_type!($RIGHT, to_u32)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::UInt64 => { - let right = try_to_type!($RIGHT, to_u64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Float16 => { - let right = try_to_type!($RIGHT, to_f32)?; - let left = as_primitive_array::($LEFT); - $OP::(left, f16::from_f32(right)) - } - DataType::Float32 => { - let right = try_to_type!($RIGHT, to_f32)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Float64 => { - let right = try_to_type!($RIGHT, to_f64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Decimal128(_, _) => { - let right = try_to_type!($RIGHT, to_i128)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Decimal256(_, _) => { - let right = try_to_type!($RIGHT, to_i128)?; - let left = as_primitive_array::($LEFT); - $OP::(left, i256::from_i128(right)) - } - DataType::Date32 => { - let right = try_to_type!($RIGHT, to_i32)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Date64 => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Timestamp(TimeUnit::Nanosecond, _) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Timestamp(TimeUnit::Microsecond, _) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Timestamp(TimeUnit::Millisecond, _) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Timestamp(TimeUnit::Second, _) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Time32(TimeUnit::Second) => { - let right = try_to_type!($RIGHT, to_i32)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Time32(TimeUnit::Millisecond) => { - let right = try_to_type!($RIGHT, to_i32)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Time64(TimeUnit::Microsecond) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Time64(TimeUnit::Nanosecond) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Interval(IntervalUnit::YearMonth) => { - let right = try_to_type!($RIGHT, to_i32)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Interval(IntervalUnit::DayTime) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Interval(IntervalUnit::MonthDayNano) => { - let right = try_to_type!($RIGHT, to_i128)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Duration(TimeUnit::Second) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Duration(TimeUnit::Millisecond) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Duration(TimeUnit::Microsecond) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - DataType::Duration(TimeUnit::Nanosecond) => { - let right = try_to_type!($RIGHT, to_i64)?; - let left = as_primitive_array::($LEFT); - $OP::(left, right) - } - _ => Err(ArrowError::ComputeError(format!( - "Unsupported data type {:?} for comparison {} with {:?}", - $LEFT.data_type(), - stringify!($OP), - $RIGHT - ))), - } - }}; - // Applies `LEFT OP RIGHT` when `LEFT` is a `DictionaryArray` with keys of type `KT` - ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{ - match $KT.as_ref() { - DataType::UInt8 => { - let left = as_dictionary_array::($LEFT); - unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) - } - DataType::UInt16 => { - let left = as_dictionary_array::($LEFT); - unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) - } - DataType::UInt32 => { - let left = as_dictionary_array::($LEFT); - unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) - } - DataType::UInt64 => { - let left = as_dictionary_array::($LEFT); - unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) - } - DataType::Int8 => { - let left = as_dictionary_array::($LEFT); - unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) - } - DataType::Int16 => { - let left = as_dictionary_array::($LEFT); - unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) - } - DataType::Int32 => { - let left = as_dictionary_array::($LEFT); - unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) - } - DataType::Int64 => { - let left = as_dictionary_array::($LEFT); - unpack_dict_comparison(left, $OP(left.values(), $RIGHT)?) - } - _ => Err(ArrowError::ComputeError(format!( - "Unsupported dictionary key type {:?}", - $KT.as_ref() - ))), - } - }}; -} - -macro_rules! dyn_compare_utf8_scalar { - ($LEFT: expr, $RIGHT: expr, $KT: ident, $OP: ident) => {{ - match $KT.as_ref() { - DataType::UInt8 => { - let left = as_dictionary_array::($LEFT); - let values = as_string_array(left.values()); - unpack_dict_comparison(left, $OP(values, $RIGHT)?) - } - DataType::UInt16 => { - let left = as_dictionary_array::($LEFT); - let values = as_string_array(left.values()); - unpack_dict_comparison(left, $OP(values, $RIGHT)?) - } - DataType::UInt32 => { - let left = as_dictionary_array::($LEFT); - let values = as_string_array(left.values()); - unpack_dict_comparison(left, $OP(values, $RIGHT)?) - } - DataType::UInt64 => { - let left = as_dictionary_array::($LEFT); - let values = as_string_array(left.values()); - unpack_dict_comparison(left, $OP(values, $RIGHT)?) - } - DataType::Int8 => { - let left = as_dictionary_array::($LEFT); - let values = as_string_array(left.values()); - unpack_dict_comparison(left, $OP(values, $RIGHT)?) - } - DataType::Int16 => { - let left = as_dictionary_array::($LEFT); - let values = as_string_array(left.values()); - unpack_dict_comparison(left, $OP(values, $RIGHT)?) - } - DataType::Int32 => { - let left = as_dictionary_array::($LEFT); - let values = as_string_array(left.values()); - unpack_dict_comparison(left, $OP(values, $RIGHT)?) - } - DataType::Int64 => { - let left = as_dictionary_array::($LEFT); - let values = as_string_array(left.values()); - unpack_dict_comparison(left, $OP(values, $RIGHT)?) - } - _ => Err(ArrowError::ComputeError(String::from("Unknown key type"))), - } - }}; + let right = GenericStringArray::::from_iter_values([right]); + crate::cmp::gt_eq(left, &Scalar::new(&right)) } /// Perform `left == right` operation on an array and a numeric scalar @@ -716,16 +660,13 @@ macro_rules! dyn_compare_utf8_scalar { /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { - match left.data_type() { - DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, eq_dyn_scalar) - } - _ => dyn_compare_scalar!(left, right, eq_scalar), - } + let right = make_primitive_scalar(left.data_type(), right)?; + crate::cmp::eq(&left, &Scalar::new(&right)) } /// Perform `left < right` operation on an array and a numeric scalar @@ -737,16 +678,13 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { - match left.data_type() { - DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, lt_dyn_scalar) - } - _ => dyn_compare_scalar!(left, right, lt_scalar), - } + let right = make_primitive_scalar(left.data_type(), right)?; + crate::cmp::lt(&left, &Scalar::new(&right)) } /// Perform `left <= right` operation on an array and a numeric scalar @@ -758,16 +696,13 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { - match left.data_type() { - DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, lt_eq_dyn_scalar) - } - _ => dyn_compare_scalar!(left, right, lt_eq_scalar), - } + let right = make_primitive_scalar(left.data_type(), right)?; + crate::cmp::lt_eq(&left, &Scalar::new(&right)) } /// Perform `left > right` operation on an array and a numeric scalar @@ -779,16 +714,13 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { - match left.data_type() { - DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, gt_dyn_scalar) - } - _ => dyn_compare_scalar!(left, right, gt_scalar), - } + let right = make_primitive_scalar(left.data_type(), right)?; + crate::cmp::gt(&left, &Scalar::new(&right)) } /// Perform `left >= right` operation on an array and a numeric scalar @@ -800,16 +732,13 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { - match left.data_type() { - DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, gt_eq_dyn_scalar) - } - _ => dyn_compare_scalar!(left, right, gt_eq_scalar), - } + let right = make_primitive_scalar(left.data_type(), right)?; + crate::cmp::gt_eq(&left, &Scalar::new(&right)) } /// Perform `left != right` operation on an array and a numeric scalar @@ -821,1325 +750,211 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_dyn_scalar(left: &dyn Array, right: T) -> Result where T: num::ToPrimitive + std::fmt::Debug, { - match left.data_type() { - DataType::Dictionary(key_type, _value_type) => { - dyn_compare_scalar!(left, right, key_type, neq_dyn_scalar) - } - _ => dyn_compare_scalar!(left, right, neq_scalar), - } + let right = make_primitive_scalar(left.data_type(), right)?; + crate::cmp::neq(&left, &Scalar::new(&right)) } /// Perform `left == right` operation on an array and a numeric scalar /// value. Supports BinaryArray and LargeBinaryArray +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_dyn_binary_scalar( left: &dyn Array, right: &[u8], ) -> Result { - match left.data_type() { - DataType::Binary => eq_binary_scalar(left.as_binary::(), right), - DataType::FixedSizeBinary(_) => { - let left = left.as_any().downcast_ref::().unwrap(); - compare_op_scalar(left, |a| a == right) - } - DataType::LargeBinary => eq_binary_scalar(left.as_binary::(), right), - _ => Err(ArrowError::ComputeError( - "eq_dyn_binary_scalar only supports Binary / FixedSizeBinary / LargeBinary arrays".to_string(), - )), - } + let right = make_binary_scalar(left.data_type(), right)?; + crate::cmp::eq(&left, &Scalar::new(&right)) } /// Perform `left != right` operation on an array and a numeric scalar /// value. Supports BinaryArray and LargeBinaryArray +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_dyn_binary_scalar( left: &dyn Array, right: &[u8], ) -> Result { - match left.data_type() { - DataType::Binary => neq_binary_scalar(left.as_binary::(), right), - DataType::LargeBinary => neq_binary_scalar(left.as_binary::(), right), - DataType::FixedSizeBinary(_) => { - let left = left.as_any().downcast_ref::().unwrap(); - compare_op_scalar(left, |a| a != right) - } - _ => Err(ArrowError::ComputeError( - "neq_dyn_binary_scalar only supports Binary / FixedSizeBinary / LargeBinary arrays" - .to_string(), - )), - } + let right = make_binary_scalar(left.data_type(), right)?; + crate::cmp::neq(&left, &Scalar::new(&right)) } /// Perform `left < right` operation on an array and a numeric scalar /// value. Supports BinaryArray and LargeBinaryArray +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_dyn_binary_scalar( left: &dyn Array, right: &[u8], ) -> Result { - match left.data_type() { - DataType::Binary => lt_binary_scalar(left.as_binary::(), right), - DataType::LargeBinary => lt_binary_scalar(left.as_binary::(), right), - _ => Err(ArrowError::ComputeError( - "lt_dyn_binary_scalar only supports Binary or LargeBinary arrays".to_string(), - )), - } + let right = make_binary_scalar(left.data_type(), right)?; + crate::cmp::lt(&left, &Scalar::new(&right)) } /// Perform `left <= right` operation on an array and a numeric scalar /// value. Supports BinaryArray and LargeBinaryArray +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_dyn_binary_scalar( left: &dyn Array, right: &[u8], ) -> Result { - match left.data_type() { - DataType::Binary => lt_eq_binary_scalar(left.as_binary::(), right), - DataType::LargeBinary => lt_eq_binary_scalar(left.as_binary::(), right), - _ => Err(ArrowError::ComputeError( - "lt_eq_dyn_binary_scalar only supports Binary or LargeBinary arrays" - .to_string(), - )), - } + let right = make_binary_scalar(left.data_type(), right)?; + crate::cmp::lt_eq(&left, &Scalar::new(&right)) } /// Perform `left > right` operation on an array and a numeric scalar /// value. Supports BinaryArray and LargeBinaryArray +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_dyn_binary_scalar( left: &dyn Array, right: &[u8], ) -> Result { - match left.data_type() { - DataType::Binary => gt_binary_scalar(left.as_binary::(), right), - DataType::LargeBinary => gt_binary_scalar(left.as_binary::(), right), - _ => Err(ArrowError::ComputeError( - "gt_dyn_binary_scalar only supports Binary or LargeBinary arrays".to_string(), - )), - } + let right = make_binary_scalar(left.data_type(), right)?; + crate::cmp::gt(&left, &Scalar::new(&right)) } /// Perform `left >= right` operation on an array and a numeric scalar /// value. Supports BinaryArray and LargeBinaryArray +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_dyn_binary_scalar( - left: &dyn Array, - right: &[u8], -) -> Result { - match left.data_type() { - DataType::Binary => gt_eq_binary_scalar(left.as_binary::(), right), - DataType::LargeBinary => gt_eq_binary_scalar(left.as_binary::(), right), - _ => Err(ArrowError::ComputeError( - "gt_eq_dyn_binary_scalar only supports Binary or LargeBinary arrays" - .to_string(), - )), - } -} - -/// Perform `left == right` operation on an array and a numeric scalar -/// value. Supports StringArrays, and DictionaryArrays that have string values -pub fn eq_dyn_utf8_scalar( - left: &dyn Array, - right: &str, -) -> Result { - let result = match left.data_type() { - DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { - DataType::Utf8 | DataType::LargeUtf8 => { - dyn_compare_utf8_scalar!(left, right, key_type, eq_utf8_scalar) - } - _ => Err(ArrowError::ComputeError( - "eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), - )), - }, - DataType::Utf8 => { - eq_utf8_scalar(left.as_string::(), right) - } - DataType::LargeUtf8 => { - eq_utf8_scalar(left.as_string::(), right) - } - _ => Err(ArrowError::ComputeError( - "eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), - )), - }; - result -} - -/// Perform `left < right` operation on an array and a numeric scalar -/// value. Supports StringArrays, and DictionaryArrays that have string values -pub fn lt_dyn_utf8_scalar( - left: &dyn Array, - right: &str, -) -> Result { - let result = match left.data_type() { - DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { - DataType::Utf8 | DataType::LargeUtf8 => { - dyn_compare_utf8_scalar!(left, right, key_type, lt_utf8_scalar) - } - _ => Err(ArrowError::ComputeError( - "lt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), - )), - }, - DataType::Utf8 => { - lt_utf8_scalar(left.as_string::(), right) - } - DataType::LargeUtf8 => { - lt_utf8_scalar(left.as_string::(), right) - } - _ => Err(ArrowError::ComputeError( - "lt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), - )), - }; - result -} - -/// Perform `left >= right` operation on an array and a numeric scalar -/// value. Supports StringArrays, and DictionaryArrays that have string values -pub fn gt_eq_dyn_utf8_scalar( - left: &dyn Array, - right: &str, -) -> Result { - let result = match left.data_type() { - DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { - DataType::Utf8 | DataType::LargeUtf8 => { - dyn_compare_utf8_scalar!(left, right, key_type, gt_eq_utf8_scalar) - } - _ => Err(ArrowError::ComputeError( - "gt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), - )), - }, - DataType::Utf8 => { - gt_eq_utf8_scalar(left.as_string::(), right) - } - DataType::LargeUtf8 => { - gt_eq_utf8_scalar(left.as_string::(), right) - } - _ => Err(ArrowError::ComputeError( - "gt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), - )), - }; - result -} - -/// Perform `left <= right` operation on an array and a numeric scalar -/// value. Supports StringArrays, and DictionaryArrays that have string values -pub fn lt_eq_dyn_utf8_scalar( - left: &dyn Array, - right: &str, -) -> Result { - let result = match left.data_type() { - DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { - DataType::Utf8 | DataType::LargeUtf8 => { - dyn_compare_utf8_scalar!(left, right, key_type, lt_eq_utf8_scalar) - } - _ => Err(ArrowError::ComputeError( - "lt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), - )), - }, - DataType::Utf8 => { - lt_eq_utf8_scalar(left.as_string::(), right) - } - DataType::LargeUtf8 => { - lt_eq_utf8_scalar(left.as_string::(), right) - } - _ => Err(ArrowError::ComputeError( - "lt_eq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), - )), - }; - result -} - -/// Perform `left > right` operation on an array and a numeric scalar -/// value. Supports StringArrays, and DictionaryArrays that have string values -pub fn gt_dyn_utf8_scalar( - left: &dyn Array, - right: &str, -) -> Result { - let result = match left.data_type() { - DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { - DataType::Utf8 | DataType::LargeUtf8 => { - dyn_compare_utf8_scalar!(left, right, key_type, gt_utf8_scalar) - } - _ => Err(ArrowError::ComputeError( - "gt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), - )), - }, - DataType::Utf8 => { - gt_utf8_scalar(left.as_string::(), right) - } - DataType::LargeUtf8 => { - gt_utf8_scalar(left.as_string::(), right) - } - _ => Err(ArrowError::ComputeError( - "gt_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), - )), - }; - result -} - -/// Perform `left != right` operation on an array and a numeric scalar -/// value. Supports StringArrays, and DictionaryArrays that have string values -pub fn neq_dyn_utf8_scalar( - left: &dyn Array, - right: &str, -) -> Result { - let result = match left.data_type() { - DataType::Dictionary(key_type, value_type) => match value_type.as_ref() { - DataType::Utf8 | DataType::LargeUtf8 => { - dyn_compare_utf8_scalar!(left, right, key_type, neq_utf8_scalar) - } - _ => Err(ArrowError::ComputeError( - "neq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays or DictionaryArray with Utf8 or LargeUtf8 values".to_string(), - )), - }, - DataType::Utf8 => { - neq_utf8_scalar(left.as_string::(), right) - } - DataType::LargeUtf8 => { - neq_utf8_scalar(left.as_string::(), right) - } - _ => Err(ArrowError::ComputeError( - "neq_dyn_utf8_scalar only supports Utf8 or LargeUtf8 arrays".to_string(), - )), - }; - result -} - -/// Perform `left == right` operation on an array and a numeric scalar -/// value. -pub fn eq_dyn_bool_scalar( - left: &dyn Array, - right: bool, -) -> Result { - let result = match left.data_type() { - DataType::Boolean => eq_bool_scalar(left.as_boolean(), right), - _ => Err(ArrowError::ComputeError( - "eq_dyn_bool_scalar only supports BooleanArray".to_string(), - )), - }; - result -} - -/// Perform `left < right` operation on an array and a numeric scalar -/// value. Supports BooleanArrays. -pub fn lt_dyn_bool_scalar( - left: &dyn Array, - right: bool, -) -> Result { - let result = match left.data_type() { - DataType::Boolean => lt_bool_scalar(left.as_boolean(), right), - _ => Err(ArrowError::ComputeError( - "lt_dyn_bool_scalar only supports BooleanArray".to_string(), - )), - }; - result -} - -/// Perform `left > right` operation on an array and a numeric scalar -/// value. Supports BooleanArrays. -pub fn gt_dyn_bool_scalar( - left: &dyn Array, - right: bool, -) -> Result { - let result = match left.data_type() { - DataType::Boolean => gt_bool_scalar(left.as_boolean(), right), - _ => Err(ArrowError::ComputeError( - "gt_dyn_bool_scalar only supports BooleanArray".to_string(), - )), - }; - result -} - -/// Perform `left <= right` operation on an array and a numeric scalar -/// value. Supports BooleanArrays. -pub fn lt_eq_dyn_bool_scalar( - left: &dyn Array, - right: bool, -) -> Result { - let result = match left.data_type() { - DataType::Boolean => lt_eq_bool_scalar(left.as_boolean(), right), - _ => Err(ArrowError::ComputeError( - "lt_eq_dyn_bool_scalar only supports BooleanArray".to_string(), - )), - }; - result -} - -/// Perform `left >= right` operation on an array and a numeric scalar -/// value. Supports BooleanArrays. -pub fn gt_eq_dyn_bool_scalar( - left: &dyn Array, - right: bool, -) -> Result { - let result = match left.data_type() { - DataType::Boolean => gt_eq_bool_scalar(left.as_boolean(), right), - _ => Err(ArrowError::ComputeError( - "gt_eq_dyn_bool_scalar only supports BooleanArray".to_string(), - )), - }; - result -} - -/// Perform `left != right` operation on an array and a numeric scalar -/// value. Supports BooleanArrays. -pub fn neq_dyn_bool_scalar( - left: &dyn Array, - right: bool, -) -> Result { - let result = match left.data_type() { - DataType::Boolean => neq_bool_scalar(left.as_boolean(), right), - _ => Err(ArrowError::ComputeError( - "neq_dyn_bool_scalar only supports BooleanArray".to_string(), - )), - }; - result -} - -/// unpacks the results of comparing left.values (as a boolean) -/// -/// TODO add example -/// -fn unpack_dict_comparison( - dict: &DictionaryArray, - dict_comparison: BooleanArray, -) -> Result -where - K: ArrowDictionaryKeyType, - K::Native: num::ToPrimitive, -{ - let array = take(&dict_comparison, dict.keys(), None)? - .as_boolean() - .clone(); - Ok(array) -} - -/// Helper function to perform boolean lambda function on values from two arrays using -/// SIMD. -#[cfg(feature = "simd")] -fn simd_compare_op( - left: &PrimitiveArray, - right: &PrimitiveArray, - simd_op: SI, - scalar_op: SC, -) -> Result -where - T: ArrowNumericType, - SI: Fn(T::Simd, T::Simd) -> T::SimdMask, - SC: Fn(T::Native, T::Native) -> bool, -{ - use std::borrow::BorrowMut; - - let len = left.len(); - if len != right.len() { - return Err(ArrowError::ComputeError( - "Cannot perform comparison operation on arrays of different length" - .to_string(), - )); - } - - let nulls = NullBuffer::union(left.nulls(), right.nulls()); - - // we process the data in chunks so that each iteration results in one u64 of comparison result bits - const CHUNK_SIZE: usize = 64; - let lanes = T::lanes(); - - // this is currently the case for all our datatypes and allows us to always append full bytes - assert!( - lanes <= CHUNK_SIZE, - "Number of vector lanes must be at most 64" - ); - - let buffer_size = bit_util::ceil(len, 8); - let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - - let mut left_chunks = left.values().chunks_exact(CHUNK_SIZE); - let mut right_chunks = right.values().chunks_exact(CHUNK_SIZE); - - let result_chunks = result.typed_data_mut(); - let result_remainder = left_chunks - .borrow_mut() - .zip(right_chunks.borrow_mut()) - .fold(result_chunks, |result_slice, (left_slice, right_slice)| { - let mut i = 0; - let mut bitmask = 0_u64; - while i < CHUNK_SIZE { - let simd_left = T::load(&left_slice[i..]); - let simd_right = T::load(&right_slice[i..]); - let simd_result = simd_op(simd_left, simd_right); - - let m = T::mask_to_u64(&simd_result); - bitmask |= m << i; - - i += lanes; - } - let bytes = bitmask.to_le_bytes(); - result_slice[0..8].copy_from_slice(&bytes); - - &mut result_slice[8..] - }); - - let left_remainder = left_chunks.remainder(); - let right_remainder = right_chunks.remainder(); - - assert_eq!(left_remainder.len(), right_remainder.len()); - - if !left_remainder.is_empty() { - let remainder_bitmask = left_remainder - .iter() - .zip(right_remainder.iter()) - .enumerate() - .fold(0_u64, |mut mask, (i, (scalar_left, scalar_right))| { - let bit = scalar_op(*scalar_left, *scalar_right) as u64; - mask |= bit << i; - mask - }); - let remainder_mask_as_bytes = - &remainder_bitmask.to_le_bytes()[0..bit_util::ceil(left_remainder.len(), 8)]; - result_remainder.copy_from_slice(remainder_mask_as_bytes); - } - - let values = BooleanBuffer::new(result.into(), 0, len); - Ok(BooleanArray::new(values, nulls)) -} - -/// Helper function to perform boolean lambda function on values from an array and a scalar value using -/// SIMD. -#[cfg(feature = "simd")] -fn simd_compare_op_scalar( - left: &PrimitiveArray, - right: T::Native, - simd_op: SI, - scalar_op: SC, -) -> Result -where - T: ArrowNumericType, - SI: Fn(T::Simd, T::Simd) -> T::SimdMask, - SC: Fn(T::Native, T::Native) -> bool, -{ - use std::borrow::BorrowMut; - - let len = left.len(); - - // we process the data in chunks so that each iteration results in one u64 of comparison result bits - const CHUNK_SIZE: usize = 64; - let lanes = T::lanes(); - - // this is currently the case for all our datatypes and allows us to always append full bytes - assert!( - lanes <= CHUNK_SIZE, - "Number of vector lanes must be at most 64" - ); - - let buffer_size = bit_util::ceil(len, 8); - let mut result = MutableBuffer::new(buffer_size).with_bitset(buffer_size, false); - - let mut left_chunks = left.values().chunks_exact(CHUNK_SIZE); - let simd_right = T::init(right); - - let result_chunks = result.typed_data_mut(); - let result_remainder = - left_chunks - .borrow_mut() - .fold(result_chunks, |result_slice, left_slice| { - let mut i = 0; - let mut bitmask = 0_u64; - while i < CHUNK_SIZE { - let simd_left = T::load(&left_slice[i..]); - let simd_result = simd_op(simd_left, simd_right); - - let m = T::mask_to_u64(&simd_result); - bitmask |= m << i; - - i += lanes; - } - let bytes = bitmask.to_le_bytes(); - result_slice[0..8].copy_from_slice(&bytes); - - &mut result_slice[8..] - }); - - let left_remainder = left_chunks.remainder(); - - if !left_remainder.is_empty() { - let remainder_bitmask = left_remainder.iter().enumerate().fold( - 0_u64, - |mut mask, (i, scalar_left)| { - let bit = scalar_op(*scalar_left, right) as u64; - mask |= bit << i; - mask - }, - ); - let remainder_mask_as_bytes = - &remainder_bitmask.to_le_bytes()[0..bit_util::ceil(left_remainder.len(), 8)]; - result_remainder.copy_from_slice(remainder_mask_as_bytes); - } - - let null_bit_buffer = left.nulls().map(|b| b.inner().sliced()); - - // null count is the same as in the input since the right side of the scalar comparison cannot be null - let null_count = left.null_count(); - - let data = unsafe { - ArrayData::new_unchecked( - DataType::Boolean, - len, - Some(null_count), - null_bit_buffer, - 0, - vec![result.into()], - vec![], - ) - }; - Ok(BooleanArray::from(data)) -} - -fn cmp_primitive_array( - left: &dyn Array, - right: &dyn Array, - op: F, -) -> Result -where - F: Fn(T::Native, T::Native) -> bool, -{ - let left_array = left.as_primitive::(); - let right_array = right.as_primitive::(); - compare_op(left_array, right_array, op) -} - -#[cfg(feature = "dyn_cmp_dict")] -macro_rules! typed_dict_non_dict_cmp { - ($LEFT: expr, $RIGHT: expr, $LEFT_KEY_TYPE: expr, $RIGHT_TYPE: tt, $OP_BOOL: expr, $OP: expr) => {{ - match $LEFT_KEY_TYPE { - DataType::Int8 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::Int16 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::Int32 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::Int64 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::UInt8 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::UInt16 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::UInt32 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::UInt64 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_primitive::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - t => Err(ArrowError::NotYetImplemented(format!( - "Cannot compare dictionary array of key type {}", - t - ))), - } - }}; -} - -#[cfg(feature = "dyn_cmp_dict")] -macro_rules! typed_dict_string_array_cmp { - ($LEFT: expr, $RIGHT: expr, $LEFT_KEY_TYPE: expr, $RIGHT_TYPE: tt, $OP: expr) => {{ - match $LEFT_KEY_TYPE { - DataType::Int8 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::Int16 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::Int32 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::Int64 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::UInt8 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::UInt16 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::UInt32 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - DataType::UInt64 => { - let left = as_dictionary_array::($LEFT); - cmp_dict_string_array::<_, $RIGHT_TYPE, _>(left, $RIGHT, $OP) - } - t => Err(ArrowError::NotYetImplemented(format!( - "Cannot compare dictionary array of key type {}", - t - ))), - } - }}; -} - -#[cfg(feature = "dyn_cmp_dict")] -macro_rules! typed_cmp_dict_non_dict { - ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr, $OP_FLOAT: expr) => {{ - match ($LEFT.data_type(), $RIGHT.data_type()) { - (DataType::Dictionary(left_key_type, left_value_type), right_type) => { - match (left_value_type.as_ref(), right_type) { - (DataType::Boolean, DataType::Boolean) => { - let left = $LEFT; - downcast_dictionary_array!( - left => { - cmp_dict_boolean_array::<_, _>(left, $RIGHT, $OP) - } - _ => Err(ArrowError::NotYetImplemented(format!( - "Cannot compare dictionary array of key type {}", - left_key_type.as_ref() - ))), - ) - } - (DataType::Int8, DataType::Int8) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int8Type, $OP_BOOL, $OP) - } - (DataType::Int16, DataType::Int16) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int16Type, $OP_BOOL, $OP) - } - (DataType::Int32, DataType::Int32) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int32Type, $OP_BOOL, $OP) - } - (DataType::Int64, DataType::Int64) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Int64Type, $OP_BOOL, $OP) - } - (DataType::UInt8, DataType::UInt8) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt8Type, $OP_BOOL, $OP) - } - (DataType::UInt16, DataType::UInt16) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt16Type, $OP_BOOL, $OP) - } - (DataType::UInt32, DataType::UInt32) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt32Type, $OP_BOOL, $OP) - } - (DataType::UInt64, DataType::UInt64) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), UInt64Type, $OP_BOOL, $OP) - } - (DataType::Float16, DataType::Float16) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Float16Type, $OP_BOOL, $OP_FLOAT) - } - (DataType::Float32, DataType::Float32) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Float32Type, $OP_BOOL, $OP_FLOAT) - } - (DataType::Float64, DataType::Float64) => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Float64Type, $OP_BOOL, $OP_FLOAT) - } - (DataType::Decimal128(_, s1), DataType::Decimal128(_, s2)) if s1 == s2 => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Decimal128Type, $OP_BOOL, $OP) - } - (DataType::Decimal256(_, s1), DataType::Decimal256(_, s2)) if s1 == s2 => { - typed_dict_non_dict_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), Decimal256Type, $OP_BOOL, $OP) - } - (DataType::Utf8, DataType::Utf8) => { - typed_dict_string_array_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), i32, $OP) - } - (DataType::LargeUtf8, DataType::LargeUtf8) => { - typed_dict_string_array_cmp!($LEFT, $RIGHT, left_key_type.as_ref(), i64, $OP) - } - (DataType::Binary, DataType::Binary) => { - let left = $LEFT; - downcast_dictionary_array!( - left => { - cmp_dict_binary_array::<_, i32, _>(left, $RIGHT, $OP) - } - _ => Err(ArrowError::NotYetImplemented(format!( - "Cannot compare dictionary array of key type {}", - left_key_type.as_ref() - ))), - ) - } - (DataType::LargeBinary, DataType::LargeBinary) => { - let left = $LEFT; - downcast_dictionary_array!( - left => { - cmp_dict_binary_array::<_, i64, _>(left, $RIGHT, $OP) - } - _ => Err(ArrowError::NotYetImplemented(format!( - "Cannot compare dictionary array of key type {}", - left_key_type.as_ref() - ))), - ) - } - (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( - "Comparing dictionary array of type {} with array of type {} is not yet implemented", - t1, t2 - ))), - (t1, t2) => Err(ArrowError::CastError(format!( - "Cannot compare dictionary array with array of different value types ({} and {})", - t1, t2 - ))), - } - } - _ => unreachable!("Should not reach this branch"), - } - }}; -} - -#[cfg(not(feature = "dyn_cmp_dict"))] -macro_rules! typed_cmp_dict_non_dict { - ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr, $OP_FLOAT: expr) => {{ - Err(ArrowError::CastError(format!( - "Comparing dictionary array of type {} with array of type {} requires \"dyn_cmp_dict\" feature", - $LEFT.data_type(), $RIGHT.data_type() - ))) - }} + left: &dyn Array, + right: &[u8], +) -> Result { + let right = make_binary_scalar(left.data_type(), right)?; + crate::cmp::gt_eq(&left, &Scalar::new(&right)) } -macro_rules! typed_compares { - ($LEFT: expr, $RIGHT: expr, $OP_BOOL: expr, $OP: expr, $OP_FLOAT: expr) => {{ - match ($LEFT.data_type(), $RIGHT.data_type()) { - (DataType::Boolean, DataType::Boolean) => { - compare_op(as_boolean_array($LEFT), as_boolean_array($RIGHT), $OP_BOOL) - } - (DataType::Int8, DataType::Int8) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::Int16, DataType::Int16) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::Int32, DataType::Int32) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::Int64, DataType::Int64) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::UInt8, DataType::UInt8) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::UInt16, DataType::UInt16) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::UInt32, DataType::UInt32) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::UInt64, DataType::UInt64) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::Float16, DataType::Float16) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP_FLOAT) - } - (DataType::Float32, DataType::Float32) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP_FLOAT) - } - (DataType::Float64, DataType::Float64) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP_FLOAT) - } - (DataType::Decimal128(_, s1), DataType::Decimal128(_, s2)) if s1 == s2 => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::Decimal256(_, s1), DataType::Decimal256(_, s2)) if s1 == s2 => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::Utf8, DataType::Utf8) => { - compare_op(as_string_array($LEFT), as_string_array($RIGHT), $OP) - } - (DataType::LargeUtf8, DataType::LargeUtf8) => compare_op( - as_largestring_array($LEFT), - as_largestring_array($RIGHT), - $OP, - ), - (DataType::FixedSizeBinary(_), DataType::FixedSizeBinary(_)) => { - let lhs = $LEFT - .as_any() - .downcast_ref::() - .unwrap(); - let rhs = $RIGHT - .as_any() - .downcast_ref::() - .unwrap(); - - compare_op(lhs, rhs, $OP) - } - (DataType::Binary, DataType::Binary) => compare_op( - as_generic_binary_array::($LEFT), - as_generic_binary_array::($RIGHT), - $OP, - ), - (DataType::LargeBinary, DataType::LargeBinary) => compare_op( - as_generic_binary_array::($LEFT), - as_generic_binary_array::($RIGHT), - $OP, - ), - ( - DataType::Timestamp(TimeUnit::Nanosecond, _), - DataType::Timestamp(TimeUnit::Nanosecond, _), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Timestamp(TimeUnit::Microsecond, _), - DataType::Timestamp(TimeUnit::Microsecond, _), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Timestamp(TimeUnit::Millisecond, _), - DataType::Timestamp(TimeUnit::Millisecond, _), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Timestamp(TimeUnit::Second, _), - DataType::Timestamp(TimeUnit::Second, _), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - (DataType::Date32, DataType::Date32) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::Date64, DataType::Date64) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - (DataType::Time32(TimeUnit::Second), DataType::Time32(TimeUnit::Second)) => { - cmp_primitive_array::($LEFT, $RIGHT, $OP) - } - ( - DataType::Time32(TimeUnit::Millisecond), - DataType::Time32(TimeUnit::Millisecond), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Time64(TimeUnit::Microsecond), - DataType::Time64(TimeUnit::Microsecond), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Time64(TimeUnit::Nanosecond), - DataType::Time64(TimeUnit::Nanosecond), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Interval(IntervalUnit::YearMonth), - DataType::Interval(IntervalUnit::YearMonth), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Interval(IntervalUnit::DayTime), - DataType::Interval(IntervalUnit::DayTime), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Interval(IntervalUnit::MonthDayNano), - DataType::Interval(IntervalUnit::MonthDayNano), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Duration(TimeUnit::Second), - DataType::Duration(TimeUnit::Second), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Duration(TimeUnit::Millisecond), - DataType::Duration(TimeUnit::Millisecond), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Duration(TimeUnit::Microsecond), - DataType::Duration(TimeUnit::Microsecond), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - ( - DataType::Duration(TimeUnit::Nanosecond), - DataType::Duration(TimeUnit::Nanosecond), - ) => cmp_primitive_array::($LEFT, $RIGHT, $OP), - (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( - "Comparing arrays of type {} is not yet implemented", - t1 - ))), - (t1, t2) => Err(ArrowError::CastError(format!( - "Cannot compare two arrays of different types ({} and {})", - t1, t2 - ))), - } - }}; +/// Perform `left == right` operation on an array and a numeric scalar +/// value. Supports StringArrays, and DictionaryArrays that have string values +#[deprecated(note = "Use arrow_ord::cmp::eq")] +pub fn eq_dyn_utf8_scalar( + left: &dyn Array, + right: &str, +) -> Result { + let right = make_utf8_scalar(left.data_type(), right)?; + crate::cmp::eq(&left, &Scalar::new(&right)) } -/// Applies $OP to $LEFT and $RIGHT which are two dictionaries which have (the same) key type $KT -#[cfg(feature = "dyn_cmp_dict")] -macro_rules! typed_dict_cmp { - ($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr, $KT: tt) => {{ - match ($LEFT.value_type(), $RIGHT.value_type()) { - (DataType::Boolean, DataType::Boolean) => { - cmp_dict_bool::<$KT, _>($LEFT, $RIGHT, $OP_BOOL) - } - (DataType::Int8, DataType::Int8) => { - cmp_dict::<$KT, Int8Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::Int16, DataType::Int16) => { - cmp_dict::<$KT, Int16Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::Int32, DataType::Int32) => { - cmp_dict::<$KT, Int32Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::Int64, DataType::Int64) => { - cmp_dict::<$KT, Int64Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::UInt8, DataType::UInt8) => { - cmp_dict::<$KT, UInt8Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::UInt16, DataType::UInt16) => { - cmp_dict::<$KT, UInt16Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::UInt32, DataType::UInt32) => { - cmp_dict::<$KT, UInt32Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::UInt64, DataType::UInt64) => { - cmp_dict::<$KT, UInt64Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::Float16, DataType::Float16) => { - cmp_dict::<$KT, Float16Type, _>($LEFT, $RIGHT, $OP_FLOAT) - } - (DataType::Float32, DataType::Float32) => { - cmp_dict::<$KT, Float32Type, _>($LEFT, $RIGHT, $OP_FLOAT) - } - (DataType::Float64, DataType::Float64) => { - cmp_dict::<$KT, Float64Type, _>($LEFT, $RIGHT, $OP_FLOAT) - } - (DataType::Decimal128(_, s1), DataType::Decimal128(_, s2)) if s1 == s2 => { - cmp_dict::<$KT, Decimal128Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::Decimal256(_, s1), DataType::Decimal256(_, s2)) if s1 == s2 => { - cmp_dict::<$KT, Decimal256Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::Utf8, DataType::Utf8) => { - cmp_dict_utf8::<$KT, i32, _>($LEFT, $RIGHT, $OP) - } - (DataType::LargeUtf8, DataType::LargeUtf8) => { - cmp_dict_utf8::<$KT, i64, _>($LEFT, $RIGHT, $OP) - } - (DataType::Binary, DataType::Binary) => { - cmp_dict_binary::<$KT, i32, _>($LEFT, $RIGHT, $OP) - } - (DataType::LargeBinary, DataType::LargeBinary) => { - cmp_dict_binary::<$KT, i64, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Timestamp(TimeUnit::Nanosecond, _), - DataType::Timestamp(TimeUnit::Nanosecond, _), - ) => { - cmp_dict::<$KT, TimestampNanosecondType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Timestamp(TimeUnit::Microsecond, _), - DataType::Timestamp(TimeUnit::Microsecond, _), - ) => { - cmp_dict::<$KT, TimestampMicrosecondType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Timestamp(TimeUnit::Millisecond, _), - DataType::Timestamp(TimeUnit::Millisecond, _), - ) => { - cmp_dict::<$KT, TimestampMillisecondType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Timestamp(TimeUnit::Second, _), - DataType::Timestamp(TimeUnit::Second, _), - ) => { - cmp_dict::<$KT, TimestampSecondType, _>($LEFT, $RIGHT, $OP) - } - (DataType::Date32, DataType::Date32) => { - cmp_dict::<$KT, Date32Type, _>($LEFT, $RIGHT, $OP) - } - (DataType::Date64, DataType::Date64) => { - cmp_dict::<$KT, Date64Type, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Time32(TimeUnit::Second), - DataType::Time32(TimeUnit::Second), - ) => { - cmp_dict::<$KT, Time32SecondType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Time32(TimeUnit::Millisecond), - DataType::Time32(TimeUnit::Millisecond), - ) => { - cmp_dict::<$KT, Time32MillisecondType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Time64(TimeUnit::Microsecond), - DataType::Time64(TimeUnit::Microsecond), - ) => { - cmp_dict::<$KT, Time64MicrosecondType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Time64(TimeUnit::Nanosecond), - DataType::Time64(TimeUnit::Nanosecond), - ) => { - cmp_dict::<$KT, Time64NanosecondType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Interval(IntervalUnit::YearMonth), - DataType::Interval(IntervalUnit::YearMonth), - ) => { - cmp_dict::<$KT, IntervalYearMonthType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Interval(IntervalUnit::DayTime), - DataType::Interval(IntervalUnit::DayTime), - ) => { - cmp_dict::<$KT, IntervalDayTimeType, _>($LEFT, $RIGHT, $OP) - } - ( - DataType::Interval(IntervalUnit::MonthDayNano), - DataType::Interval(IntervalUnit::MonthDayNano), - ) => { - cmp_dict::<$KT, IntervalMonthDayNanoType, _>($LEFT, $RIGHT, $OP) - } - (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( - "Comparing dictionary arrays of value type {} is not yet implemented", - t1 - ))), - (t1, t2) => Err(ArrowError::CastError(format!( - "Cannot compare two dictionary arrays of different value types ({} and {})", - t1, t2 - ))), - } - }}; +/// Perform `left < right` operation on an array and a numeric scalar +/// value. Supports StringArrays, and DictionaryArrays that have string values +#[deprecated(note = "Use arrow_ord::cmp::lt")] +pub fn lt_dyn_utf8_scalar( + left: &dyn Array, + right: &str, +) -> Result { + let right = make_utf8_scalar(left.data_type(), right)?; + crate::cmp::lt(&left, &Scalar::new(&right)) } -#[cfg(feature = "dyn_cmp_dict")] -macro_rules! typed_dict_compares { - // Applies `LEFT OP RIGHT` when `LEFT` and `RIGHT` both are `DictionaryArray` - ($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr) => {{ - match ($LEFT.data_type(), $RIGHT.data_type()) { - (DataType::Dictionary(left_key_type, _), DataType::Dictionary(right_key_type, _))=> { - match (left_key_type.as_ref(), right_key_type.as_ref()) { - (DataType::Int8, DataType::Int8) => { - let left = as_dictionary_array::($LEFT); - let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, Int8Type) - } - (DataType::Int16, DataType::Int16) => { - let left = as_dictionary_array::($LEFT); - let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, Int16Type) - } - (DataType::Int32, DataType::Int32) => { - let left = as_dictionary_array::($LEFT); - let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, Int32Type) - } - (DataType::Int64, DataType::Int64) => { - let left = as_dictionary_array::($LEFT); - let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, Int64Type) - } - (DataType::UInt8, DataType::UInt8) => { - let left = as_dictionary_array::($LEFT); - let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, UInt8Type) - } - (DataType::UInt16, DataType::UInt16) => { - let left = as_dictionary_array::($LEFT); - let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, UInt16Type) - } - (DataType::UInt32, DataType::UInt32) => { - let left = as_dictionary_array::($LEFT); - let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, UInt32Type) - } - (DataType::UInt64, DataType::UInt64) => { - let left = as_dictionary_array::($LEFT); - let right = as_dictionary_array::($RIGHT); - typed_dict_cmp!(left, right, $OP, $OP_FLOAT, $OP_BOOL, UInt64Type) - } - (t1, t2) if t1 == t2 => Err(ArrowError::NotYetImplemented(format!( - "Comparing dictionary arrays of type {} is not yet implemented", - t1 - ))), - (t1, t2) => Err(ArrowError::CastError(format!( - "Cannot compare two dictionary arrays of different key types ({} and {})", - t1, t2 - ))), - } - } - (t1, t2) => Err(ArrowError::CastError(format!( - "Cannot compare dictionary array with non-dictionary array ({} and {})", - t1, t2 - ))), - } - }}; +/// Perform `left >= right` operation on an array and a numeric scalar +/// value. Supports StringArrays, and DictionaryArrays that have string values +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] +pub fn gt_eq_dyn_utf8_scalar( + left: &dyn Array, + right: &str, +) -> Result { + let right = make_utf8_scalar(left.data_type(), right)?; + crate::cmp::gt_eq(&left, &Scalar::new(&right)) } -#[cfg(not(feature = "dyn_cmp_dict"))] -macro_rules! typed_dict_compares { - ($LEFT: expr, $RIGHT: expr, $OP: expr, $OP_FLOAT: expr, $OP_BOOL: expr) => {{ - Err(ArrowError::CastError(format!( - "Comparing array of type {} with array of type {} requires \"dyn_cmp_dict\" feature", - $LEFT.data_type(), $RIGHT.data_type() - ))) - }} +/// Perform `left <= right` operation on an array and a numeric scalar +/// value. Supports StringArrays, and DictionaryArrays that have string values +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] +pub fn lt_eq_dyn_utf8_scalar( + left: &dyn Array, + right: &str, +) -> Result { + let right = make_utf8_scalar(left.data_type(), right)?; + crate::cmp::lt_eq(&left, &Scalar::new(&right)) } -/// Perform given operation on `DictionaryArray` and `PrimitiveArray`. The value -/// type of `DictionaryArray` is same as `PrimitiveArray`'s type. -#[cfg(feature = "dyn_cmp_dict")] -fn cmp_dict_primitive( - left: &DictionaryArray, - right: &dyn Array, - op: F, -) -> Result -where - K: ArrowDictionaryKeyType, - T: ArrowPrimitiveType + Sync + Send, - F: Fn(T::Native, T::Native) -> bool, -{ - compare_op( - left.downcast_dict::>().unwrap(), - right.as_primitive::(), - op, - ) +/// Perform `left > right` operation on an array and a numeric scalar +/// value. Supports StringArrays, and DictionaryArrays that have string values +#[deprecated(note = "Use arrow_ord::cmp::gt")] +pub fn gt_dyn_utf8_scalar( + left: &dyn Array, + right: &str, +) -> Result { + let right = make_utf8_scalar(left.data_type(), right)?; + crate::cmp::gt(&left, &Scalar::new(&right)) } -/// Perform given operation on `DictionaryArray` and `GenericStringArray`. The value -/// type of `DictionaryArray` is same as `GenericStringArray`'s type. -#[cfg(feature = "dyn_cmp_dict")] -fn cmp_dict_string_array( - left: &DictionaryArray, - right: &dyn Array, - op: F, -) -> Result -where - K: ArrowDictionaryKeyType, - F: Fn(&str, &str) -> bool, -{ - compare_op( - left.downcast_dict::>() - .unwrap(), - right - .as_any() - .downcast_ref::>() - .unwrap(), - op, - ) +/// Perform `left != right` operation on an array and a numeric scalar +/// value. Supports StringArrays, and DictionaryArrays that have string values +#[deprecated(note = "Use arrow_ord::cmp::neq")] +pub fn neq_dyn_utf8_scalar( + left: &dyn Array, + right: &str, +) -> Result { + let right = make_utf8_scalar(left.data_type(), right)?; + crate::cmp::neq(&left, &Scalar::new(&right)) } -/// Perform given operation on `DictionaryArray` and `BooleanArray`. The value -/// type of `DictionaryArray` is same as `BooleanArray`'s type. -#[cfg(feature = "dyn_cmp_dict")] -fn cmp_dict_boolean_array( - left: &DictionaryArray, - right: &dyn Array, - op: F, -) -> Result -where - K: ArrowDictionaryKeyType, - F: Fn(bool, bool) -> bool, -{ - compare_op( - left.downcast_dict::().unwrap(), - right.as_any().downcast_ref::().unwrap(), - op, - ) +/// Perform `left == right` operation on an array and a numeric scalar +/// value. +#[deprecated(note = "Use arrow_ord::cmp::eq")] +pub fn eq_dyn_bool_scalar( + left: &dyn Array, + right: bool, +) -> Result { + let right = BooleanArray::from(vec![right]); + crate::cmp::eq(&left, &Scalar::new(&right)) } -/// Perform given operation on `DictionaryArray` and `GenericBinaryArray`. The value -/// type of `DictionaryArray` is same as `GenericBinaryArray`'s type. -#[cfg(feature = "dyn_cmp_dict")] -fn cmp_dict_binary_array( - left: &DictionaryArray, - right: &dyn Array, - op: F, -) -> Result -where - K: ArrowDictionaryKeyType, - F: Fn(&[u8], &[u8]) -> bool, -{ - compare_op( - left.downcast_dict::>() - .unwrap(), - right - .as_any() - .downcast_ref::>() - .unwrap(), - op, - ) +/// Perform `left < right` operation on an array and a numeric scalar +/// value. Supports BooleanArrays. +#[deprecated(note = "Use arrow_ord::cmp::lt")] +pub fn lt_dyn_bool_scalar( + left: &dyn Array, + right: bool, +) -> Result { + let right = BooleanArray::from(vec![right]); + crate::cmp::lt(&left, &Scalar::new(&right)) } -/// Perform given operation on two `DictionaryArray`s which value type is -/// primitive type. Returns an error if the two arrays have different value -/// type -#[cfg(feature = "dyn_cmp_dict")] -pub fn cmp_dict( - left: &DictionaryArray, - right: &DictionaryArray, - op: F, -) -> Result -where - K: ArrowDictionaryKeyType, - T: ArrowPrimitiveType + Sync + Send, - F: Fn(T::Native, T::Native) -> bool, -{ - compare_op( - left.downcast_dict::>().unwrap(), - right.downcast_dict::>().unwrap(), - op, - ) +/// Perform `left > right` operation on an array and a numeric scalar +/// value. Supports BooleanArrays. +#[deprecated(note = "Use arrow_ord::cmp::gt")] +pub fn gt_dyn_bool_scalar( + left: &dyn Array, + right: bool, +) -> Result { + let right = BooleanArray::from(vec![right]); + crate::cmp::gt(&left, &Scalar::new(&right)) } -/// Perform the given operation on two `DictionaryArray`s which value type is -/// `DataType::Boolean`. -#[cfg(feature = "dyn_cmp_dict")] -pub fn cmp_dict_bool( - left: &DictionaryArray, - right: &DictionaryArray, - op: F, -) -> Result -where - K: ArrowDictionaryKeyType, - F: Fn(bool, bool) -> bool, -{ - compare_op( - left.downcast_dict::().unwrap(), - right.downcast_dict::().unwrap(), - op, - ) +/// Perform `left <= right` operation on an array and a numeric scalar +/// value. Supports BooleanArrays. +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] +pub fn lt_eq_dyn_bool_scalar( + left: &dyn Array, + right: bool, +) -> Result { + let right = BooleanArray::from(vec![right]); + crate::cmp::lt_eq(&left, &Scalar::new(&right)) } -/// Perform the given operation on two `DictionaryArray`s which value type is -/// `DataType::Utf8` or `DataType::LargeUtf8`. -#[cfg(feature = "dyn_cmp_dict")] -pub fn cmp_dict_utf8( - left: &DictionaryArray, - right: &DictionaryArray, - op: F, -) -> Result -where - K: ArrowDictionaryKeyType, - F: Fn(&str, &str) -> bool, -{ - compare_op( - left.downcast_dict::>() - .unwrap(), - right - .downcast_dict::>() - .unwrap(), - op, - ) +/// Perform `left >= right` operation on an array and a numeric scalar +/// value. Supports BooleanArrays. +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] +pub fn gt_eq_dyn_bool_scalar( + left: &dyn Array, + right: bool, +) -> Result { + let right = BooleanArray::from(vec![right]); + crate::cmp::gt_eq(&left, &Scalar::new(&right)) } -/// Perform the given operation on two `DictionaryArray`s which value type is -/// `DataType::Binary` or `DataType::LargeBinary`. -#[cfg(feature = "dyn_cmp_dict")] -pub fn cmp_dict_binary( - left: &DictionaryArray, - right: &DictionaryArray, - op: F, -) -> Result -where - K: ArrowDictionaryKeyType, - F: Fn(&[u8], &[u8]) -> bool, -{ - compare_op( - left.downcast_dict::>() - .unwrap(), - right - .downcast_dict::>() - .unwrap(), - op, - ) +/// Perform `left != right` operation on an array and a numeric scalar +/// value. Supports BooleanArrays. +#[deprecated(note = "Use arrow_ord::cmp::neq")] +pub fn neq_dyn_bool_scalar( + left: &dyn Array, + right: bool, +) -> Result { + let right = BooleanArray::from(vec![right]); + crate::cmp::neq(&left, &Scalar::new(&right)) } /// Perform `left == right` operation on two (dynamic) [`Array`]s. @@ -2162,29 +977,9 @@ where /// let result = eq_dyn(&array1, &array2).unwrap(); /// assert_eq!(BooleanArray::from(vec![Some(true), None, Some(false)]), result); /// ``` +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result { - match left.data_type() { - DataType::Dictionary(_, _) - if matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_dict_compares!(left, right, |a, b| a == b, |a, b| a.is_eq(b), |a, b| a - == b) - } - DataType::Dictionary(_, _) - if !matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_cmp_dict_non_dict!(left, right, |a, b| a == b, |a, b| a == b, |a, b| a - .is_eq(b)) - } - _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a == b, |a, b| a == b, |a, b| b - .is_eq(a)) - } - _ => { - typed_compares!(left, right, |a, b| !(a ^ b), |a, b| a == b, |a, b| a - .is_eq(b)) - } - } + crate::cmp::eq(&left, &right) } /// Perform `left != right` operation on two (dynamic) [`Array`]s. @@ -2209,29 +1004,9 @@ pub fn eq_dyn(left: &dyn Array, right: &dyn Array) -> Result Result { - match left.data_type() { - DataType::Dictionary(_, _) - if matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_dict_compares!(left, right, |a, b| a != b, |a, b| a.is_ne(b), |a, b| a - != b) - } - DataType::Dictionary(_, _) - if !matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_cmp_dict_non_dict!(left, right, |a, b| a != b, |a, b| a != b, |a, b| a - .is_ne(b)) - } - _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a != b, |a, b| a != b, |a, b| b - .is_ne(a)) - } - _ => { - typed_compares!(left, right, |a, b| (a ^ b), |a, b| a != b, |a, b| a - .is_ne(b)) - } - } + crate::cmp::neq(&left, &right) } /// Perform `left < right` operation on two (dynamic) [`Array`]s. @@ -2255,30 +1030,9 @@ pub fn neq_dyn(left: &dyn Array, right: &dyn Array) -> Result Result { - match left.data_type() { - DataType::Dictionary(_, _) - if matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_dict_compares!(left, right, |a, b| a < b, |a, b| a.is_lt(b), |a, b| a - < b) - } - DataType::Dictionary(_, _) - if !matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_cmp_dict_non_dict!(left, right, |a, b| a < b, |a, b| a < b, |a, b| a - .is_lt(b)) - } - _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a > b, |a, b| a > b, |a, b| b - .is_lt(a)) - } - _ => { - typed_compares!(left, right, |a, b| ((!a) & b), |a, b| a < b, |a, b| a - .is_lt(b)) - } - } + crate::cmp::lt(&left, &right) } /// Perform `left <= right` operation on two (dynamic) [`Array`]s. @@ -2302,32 +1056,12 @@ pub fn lt_dyn(left: &dyn Array, right: &dyn Array) -> Result Result { - match left.data_type() { - DataType::Dictionary(_, _) - if matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_dict_compares!(left, right, |a, b| a <= b, |a, b| a.is_le(b), |a, b| a - <= b) - } - DataType::Dictionary(_, _) - if !matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_cmp_dict_non_dict!(left, right, |a, b| a <= b, |a, b| a <= b, |a, b| a - .is_le(b)) - } - _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a >= b, |a, b| a >= b, |a, b| b - .is_le(a)) - } - _ => { - typed_compares!(left, right, |a, b| !(a & (!b)), |a, b| a <= b, |a, b| a - .is_le(b)) - } - } + crate::cmp::lt_eq(&left, &right) } /// Perform `left > right` operation on two (dynamic) [`Array`]s. @@ -2350,30 +1084,9 @@ pub fn lt_eq_dyn( /// let result = gt_dyn(&array1, &array2).unwrap(); /// assert_eq!(BooleanArray::from(vec![Some(true), Some(false), None]), result); /// ``` -#[allow(clippy::bool_comparison)] +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result { - match left.data_type() { - DataType::Dictionary(_, _) - if matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_dict_compares!(left, right, |a, b| a > b, |a, b| a.is_gt(b), |a, b| a - > b) - } - DataType::Dictionary(_, _) - if !matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_cmp_dict_non_dict!(left, right, |a, b| a > b, |a, b| a > b, |a, b| a - .is_gt(b)) - } - _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a < b, |a, b| a < b, |a, b| b - .is_gt(a)) - } - _ => { - typed_compares!(left, right, |a, b| (a & (!b)), |a, b| a > b, |a, b| a - .is_gt(b)) - } - } + crate::cmp::gt(&left, &right) } /// Perform `left >= right` operation on two (dynamic) [`Array`]s. @@ -2396,32 +1109,12 @@ pub fn gt_dyn(left: &dyn Array, right: &dyn Array) -> Result Result { - match left.data_type() { - DataType::Dictionary(_, _) - if matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_dict_compares!(left, right, |a, b| a >= b, |a, b| a.is_ge(b), |a, b| a - >= b) - } - DataType::Dictionary(_, _) - if !matches!(right.data_type(), DataType::Dictionary(_, _)) => - { - typed_cmp_dict_non_dict!(left, right, |a, b| a >= b, |a, b| a >= b, |a, b| a - .is_ge(b)) - } - _ if matches!(right.data_type(), DataType::Dictionary(_, _)) => { - typed_cmp_dict_non_dict!(right, left, |a, b| a <= b, |a, b| a <= b, |a, b| b - .is_ge(a)) - } - _ => { - typed_compares!(left, right, |a, b| !((!a) & b), |a, b| a >= b, |a, b| a - .is_ge(b)) - } - } + crate::cmp::gt_eq(&left, &right) } /// Perform `left == right` operation on two [`PrimitiveArray`]s. @@ -2432,6 +1125,7 @@ pub fn gt_eq_dyn( /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq( left: &PrimitiveArray, right: &PrimitiveArray, @@ -2440,20 +1134,17 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op(left, right, T::eq, |a, b| a == b); - #[cfg(not(feature = "simd"))] - return compare_op(left, right, |a, b| a.is_eq(b)); + crate::cmp::eq(&left, &right) } /// Perform `left == right` operation on a [`PrimitiveArray`] and a scalar value. /// -/// If `simd` feature flag is not enabled: /// For floating values like f32 and f64, this comparison produces an ordering in accordance to /// the totalOrder predicate as defined in the IEEE 754 (2008 revision) floating point standard. /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::eq")] pub fn eq_scalar( left: &PrimitiveArray, right: T::Native, @@ -2462,10 +1153,8 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op_scalar(left, right, T::eq, |a, b| a == b); - #[cfg(not(feature = "simd"))] - return compare_op_scalar(left, |a| a.is_eq(right)); + let right = PrimitiveArray::::new(vec![right].into(), None); + crate::cmp::eq(&left, &Scalar::new(&right)) } /// Applies an unary and infallible comparison function to a primitive array. @@ -2488,6 +1177,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq( left: &PrimitiveArray, right: &PrimitiveArray, @@ -2496,10 +1186,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op(left, right, T::ne, |a, b| a != b); - #[cfg(not(feature = "simd"))] - return compare_op(left, right, |a, b| a.is_ne(b)); + crate::cmp::neq(&left, &right) } /// Perform `left != right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2510,6 +1197,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::neq")] pub fn neq_scalar( left: &PrimitiveArray, right: T::Native, @@ -2518,10 +1206,8 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op_scalar(left, right, T::ne, |a, b| a != b); - #[cfg(not(feature = "simd"))] - return compare_op_scalar(left, |a| a.is_ne(right)); + let right = PrimitiveArray::::new(vec![right].into(), None); + crate::cmp::neq(&left, &Scalar::new(&right)) } /// Perform `left < right` operation on two [`PrimitiveArray`]s. Null values are less than non-null @@ -2533,6 +1219,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt( left: &PrimitiveArray, right: &PrimitiveArray, @@ -2541,10 +1228,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op(left, right, T::lt, |a, b| a < b); - #[cfg(not(feature = "simd"))] - return compare_op(left, right, |a, b| a.is_lt(b)); + crate::cmp::lt(&left, &right) } /// Perform `left < right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2556,6 +1240,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::lt")] pub fn lt_scalar( left: &PrimitiveArray, right: T::Native, @@ -2564,10 +1249,8 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op_scalar(left, right, T::lt, |a, b| a < b); - #[cfg(not(feature = "simd"))] - return compare_op_scalar(left, |a| a.is_lt(right)); + let right = PrimitiveArray::::new(vec![right].into(), None); + crate::cmp::lt(&left, &Scalar::new(&right)) } /// Perform `left <= right` operation on two [`PrimitiveArray`]s. Null values are less than non-null @@ -2579,6 +1262,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq( left: &PrimitiveArray, right: &PrimitiveArray, @@ -2587,10 +1271,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op(left, right, T::le, |a, b| a <= b); - #[cfg(not(feature = "simd"))] - return compare_op(left, right, |a, b| a.is_le(b)); + crate::cmp::lt_eq(&left, &right) } /// Perform `left <= right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2602,6 +1283,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::lt_eq")] pub fn lt_eq_scalar( left: &PrimitiveArray, right: T::Native, @@ -2610,10 +1292,8 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op_scalar(left, right, T::le, |a, b| a <= b); - #[cfg(not(feature = "simd"))] - return compare_op_scalar(left, |a| a.is_le(right)); + let right = PrimitiveArray::::new(vec![right].into(), None); + crate::cmp::lt_eq(&left, &Scalar::new(&right)) } /// Perform `left > right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null @@ -2625,6 +1305,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt( left: &PrimitiveArray, right: &PrimitiveArray, @@ -2633,10 +1314,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op(left, right, T::gt, |a, b| a > b); - #[cfg(not(feature = "simd"))] - return compare_op(left, right, |a, b| a.is_gt(b)); + crate::cmp::gt(&left, &right) } /// Perform `left > right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2648,6 +1326,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::gt")] pub fn gt_scalar( left: &PrimitiveArray, right: T::Native, @@ -2656,10 +1335,8 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op_scalar(left, right, T::gt, |a, b| a > b); - #[cfg(not(feature = "simd"))] - return compare_op_scalar(left, |a| a.is_gt(right)); + let right = PrimitiveArray::::new(vec![right].into(), None); + crate::cmp::gt(&left, &Scalar::new(&right)) } /// Perform `left >= right` operation on two [`PrimitiveArray`]s. Non-null values are greater than null @@ -2671,6 +1348,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq( left: &PrimitiveArray, right: &PrimitiveArray, @@ -2679,10 +1357,7 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op(left, right, T::ge, |a, b| a >= b); - #[cfg(not(feature = "simd"))] - return compare_op(left, right, |a, b| a.is_ge(b)); + crate::cmp::gt_eq(&left, &right) } /// Perform `left >= right` operation on a [`PrimitiveArray`] and a scalar value. @@ -2694,6 +1369,7 @@ where /// Note that totalOrder treats positive and negative zeros are different. If it is necessary /// to treat them as equal, please normalize zeros before calling this kernel. /// Please refer to `f32::total_cmp` and `f64::total_cmp`. +#[deprecated(note = "Use arrow_ord::cmp::gt_eq")] pub fn gt_eq_scalar( left: &PrimitiveArray, right: T::Native, @@ -2702,10 +1378,8 @@ where T: ArrowNumericType, T::Native: ArrowNativeTypeOp, { - #[cfg(feature = "simd")] - return simd_compare_op_scalar(left, right, T::ge, |a, b| a >= b); - #[cfg(not(feature = "simd"))] - return compare_op_scalar(left, |a| a.is_ge(right)); + let right = PrimitiveArray::::new(vec![right].into(), None); + crate::cmp::gt_eq(&left, &Scalar::new(&right)) } /// Checks if a [`GenericListArray`] contains a value in the [`PrimitiveArray`] @@ -2793,14 +1467,18 @@ where // disable wrapping inside literal vectors used for test data and assertions #[rustfmt::skip::macros(vec)] #[cfg(test)] +#[allow(deprecated)] mod tests { - use super::*; + use std::sync::Arc; + use arrow_array::builder::{ ListBuilder, PrimitiveDictionaryBuilder, StringBuilder, StringDictionaryBuilder, }; - use arrow_buffer::i256; + use arrow_buffer::{i256, Buffer}; + use arrow_data::ArrayData; use arrow_schema::Field; - use std::sync::Arc; + + use super::*; /// Evaluate `KERNEL` with two vectors as inputs and assert against the expected output. /// `A_VEC` and `B_VEC` can be of type `Vec` or `Vec>` where `T` is the native @@ -4645,7 +3323,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_i8_array() { // Construct a value array let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); @@ -4667,7 +3344,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_u64_array() { let values = UInt64Array::from_iter_values([10_u64, 11, 12, 13, 14, 15, 16, 17]); let values = Arc::new(values) as ArrayRef; @@ -4688,7 +3364,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_utf8_array() { let test1 = vec!["a", "a", "b", "c"]; let test2 = vec!["a", "b", "b", "c"]; @@ -4716,7 +3391,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_binary_array() { let values: BinaryArray = ["hello", "", "parquet"] .into_iter() @@ -4740,7 +3414,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_interval_array() { let values = IntervalDayTimeArray::from(vec![1, 6, 10, 2, 3, 5]); let values = Arc::new(values) as ArrayRef; @@ -4761,7 +3434,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_date_array() { let values = Date32Array::from(vec![1, 6, 10, 2, 3, 5]); let values = Arc::new(values) as ArrayRef; @@ -4782,7 +3454,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_bool_array() { let values = BooleanArray::from(vec![true, false]); let values = Arc::new(values) as ArrayRef; @@ -4803,7 +3474,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_lt_dyn_gt_dyn_dictionary_i8_array() { // Construct a value array let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); @@ -4834,7 +3504,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_lt_dyn_gt_dyn_dictionary_bool_array() { let values = BooleanArray::from(vec![true, false]); let values = Arc::new(values) as ArrayRef; @@ -4876,7 +3545,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_i8_i8_array() { let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); let keys = Int8Array::from_iter_values([2_i8, 3, 4]); @@ -4911,7 +3579,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_lt_dyn_lt_eq_dyn_gt_dyn_gt_eq_dyn_dictionary_i8_i8_array() { let values = Int8Array::from_iter_values([10_i8, 11, 12, 13, 14, 15, 16, 17]); let keys = Int8Array::from_iter_values([2_i8, 3, 4]); @@ -4984,7 +3651,6 @@ mod tests { ); assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(eq(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -4992,7 +3658,6 @@ mod tests { ); assert_eq!(neq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(neq(&array1, &array2).unwrap(), expected); let array1: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 10.0] @@ -5008,7 +3673,6 @@ mod tests { ); assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(eq(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -5016,7 +3680,6 @@ mod tests { ); assert_eq!(neq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(neq(&array1, &array2).unwrap(), expected); let array1: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 10.0] @@ -5033,7 +3696,6 @@ mod tests { ); assert_eq!(eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(eq(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -5041,7 +3703,6 @@ mod tests { ); assert_eq!(neq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(neq(&array1, &array2).unwrap(), expected); } @@ -5061,7 +3722,6 @@ mod tests { ); assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(lt(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -5069,7 +3729,6 @@ mod tests { ); assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(lt_eq(&array1, &array2).unwrap(), expected); let array1: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 11.0, f32::NAN] @@ -5086,7 +3745,6 @@ mod tests { ); assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(lt(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -5094,7 +3752,6 @@ mod tests { ); assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(lt_eq(&array1, &array2).unwrap(), expected); let array1: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 11.0, f64::NAN] @@ -5111,7 +3768,6 @@ mod tests { ); assert_eq!(lt_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(lt(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -5119,7 +3775,6 @@ mod tests { ); assert_eq!(lt_eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(lt_eq(&array1, &array2).unwrap(), expected); } @@ -5139,7 +3794,6 @@ mod tests { ); assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(gt(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -5147,7 +3801,6 @@ mod tests { ); assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(gt_eq(&array1, &array2).unwrap(), expected); let array1: Float32Array = vec![f32::NAN, 7.0, 8.0, 8.0, 11.0, f32::NAN] @@ -5164,7 +3817,6 @@ mod tests { ); assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(gt(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -5172,7 +3824,6 @@ mod tests { ); assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(gt_eq(&array1, &array2).unwrap(), expected); let array1: Float64Array = vec![f64::NAN, 7.0, 8.0, 8.0, 11.0, f64::NAN] @@ -5189,7 +3840,6 @@ mod tests { ); assert_eq!(gt_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(gt(&array1, &array2).unwrap(), expected); let expected = BooleanArray::from( @@ -5197,7 +3847,6 @@ mod tests { ); assert_eq!(gt_eq_dyn(&array1, &array2).unwrap(), expected); - #[cfg(not(feature = "simd"))] assert_eq!(gt_eq(&array1, &array2).unwrap(), expected); } @@ -5207,21 +3856,12 @@ mod tests { .into_iter() .map(Some) .collect(); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] + let expected = BooleanArray::from( vec![Some(true), Some(false), Some(false), Some(false), Some(false)], ); assert_eq!(eq_dyn_scalar(&array, f32::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(true), Some(true), Some(true), Some(true), Some(true)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(false), Some(true), Some(true), Some(true), Some(true)], ); @@ -5231,21 +3871,12 @@ mod tests { .into_iter() .map(Some) .collect(); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] + let expected = BooleanArray::from( vec![Some(true), Some(false), Some(false), Some(false), Some(false)], ); assert_eq!(eq_dyn_scalar(&array, f32::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(true), Some(true), Some(true), Some(true), Some(true)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(false), Some(true), Some(true), Some(true), Some(true)], ); @@ -5255,21 +3886,12 @@ mod tests { .into_iter() .map(Some) .collect(); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] + let expected = BooleanArray::from( vec![Some(true), Some(false), Some(false), Some(false), Some(false)], ); assert_eq!(eq_dyn_scalar(&array, f64::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(true), Some(true), Some(true), Some(true), Some(true)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(false), Some(true), Some(true), Some(true), Some(true)], ); @@ -5282,21 +3904,12 @@ mod tests { .into_iter() .map(Some) .collect(); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] + let expected = BooleanArray::from( vec![Some(false), Some(true), Some(true), Some(true), Some(true)], ); assert_eq!(lt_dyn_scalar(&array, f16::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(true), Some(true), Some(true), Some(true), Some(true)], ); @@ -5306,21 +3919,12 @@ mod tests { .into_iter() .map(Some) .collect(); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] + let expected = BooleanArray::from( vec![Some(false), Some(true), Some(true), Some(true), Some(true)], ); assert_eq!(lt_dyn_scalar(&array, f32::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(true), Some(true), Some(true), Some(true), Some(true)], ); @@ -5330,21 +3934,12 @@ mod tests { .into_iter() .map(Some) .collect(); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] + let expected = BooleanArray::from( vec![Some(false), Some(true), Some(true), Some(true), Some(true)], ); assert_eq!(lt_dyn_scalar(&array, f64::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(true), Some(true), Some(true), Some(true), Some(true)], ); @@ -5362,11 +3957,6 @@ mod tests { ); assert_eq!(gt_dyn_scalar(&array, f16::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(true), Some(false), Some(false), Some(false), Some(false)], ); @@ -5381,11 +3971,6 @@ mod tests { ); assert_eq!(gt_dyn_scalar(&array, f32::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(true), Some(false), Some(false), Some(false), Some(false)], ); @@ -5400,11 +3985,6 @@ mod tests { ); assert_eq!(gt_dyn_scalar(&array, f64::NAN).unwrap(), expected); - #[cfg(feature = "simd")] - let expected = BooleanArray::from( - vec![Some(false), Some(false), Some(false), Some(false), Some(false)], - ); - #[cfg(not(feature = "simd"))] let expected = BooleanArray::from( vec![Some(true), Some(false), Some(false), Some(false), Some(false)], ); @@ -5412,7 +3992,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_to_utf8_array() { let test1 = vec!["a", "a", "b", "c"]; let test2 = vec!["a", "b", "b", "d"]; @@ -5453,7 +4032,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_lt_dyn_lt_eq_dyn_gt_dyn_gt_eq_dyn_dictionary_to_utf8_array() { let test1 = vec!["abc", "abc", "b", "cde"]; let test2 = vec!["abc", "b", "b", "def"]; @@ -5518,7 +4096,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_to_binary_array() { let values: BinaryArray = ["hello", "", "parquet"] .into_iter() @@ -5559,7 +4136,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_lt_dyn_lt_eq_dyn_gt_dyn_gt_eq_dyn_dictionary_to_binary_array() { let values: BinaryArray = ["hello", "", "parquet"] .into_iter() @@ -5624,7 +4200,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dict_non_dict_float_nan() { let array1: Float16Array = vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(10.0)] .into_iter() @@ -5683,7 +4258,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_lt_dyn_lt_eq_dyn_dict_non_dict_float_nan() { let array1: Float16Array = vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(11.0), f16::NAN] .into_iter() @@ -5741,7 +4315,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_gt_dyn_gt_eq_dyn_dict_non_dict_float_nan() { let array1: Float16Array = vec![f16::NAN, f16::from_f32(7.0), f16::from_f32(8.0), f16::from_f32(8.0), f16::from_f32(11.0), f16::NAN] .into_iter() @@ -5799,7 +4372,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_eq_dyn_neq_dyn_dictionary_to_boolean_array() { let test1 = vec![Some(true), None, Some(false)]; let test2 = vec![Some(true), None, None, Some(true)]; @@ -5836,7 +4408,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_lt_dyn_lt_eq_dyn_gt_dyn_gt_eq_dyn_dictionary_to_boolean_array() { let test1 = vec![Some(true), None, Some(false)]; let test2 = vec![Some(true), None, None, Some(true)]; @@ -5897,7 +4468,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_cmp_dict_decimal128() { let values = Decimal128Array::from_iter_values([0, 1, 2, 3, 4, 5]); let keys = Int8Array::from_iter_values([1_i8, 2, 5, 4, 3, 0]); @@ -5934,7 +4504,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_cmp_dict_non_dict_decimal128() { let array1: Decimal128Array = Decimal128Array::from_iter_values([1, 2, 5, 4, 3, 0]); @@ -5970,7 +4539,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_cmp_dict_decimal256() { let values = Decimal256Array::from_iter_values( [0, 1, 2, 3, 4, 5].into_iter().map(i256::from_i128), @@ -6011,7 +4579,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_cmp_dict_non_dict_decimal256() { let array1: Decimal256Array = Decimal256Array::from_iter_values( [1, 2, 5, 4, 3, 0].into_iter().map(i256::from_i128), @@ -6317,7 +4884,6 @@ mod tests { } #[test] - #[cfg(not(feature = "simd"))] fn test_floating_zeros() { let a = Float32Array::from(vec![0.0_f32, -0.0]); let b = Float32Array::from(vec![-0.0_f32, 0.0]); @@ -6354,7 +4920,6 @@ mod tests { } #[test] - #[cfg(feature = "dyn_cmp_dict")] fn test_dictionary_nested_nulls() { let keys = Int32Array::from(vec![0, 1, 2]); let v1 = Arc::new(Int32Array::from(vec![Some(0), None, Some(2)])); diff --git a/arrow-ord/src/lib.rs b/arrow-ord/src/lib.rs index 62338c022384..19ad8229417f 100644 --- a/arrow-ord/src/lib.rs +++ b/arrow-ord/src/lib.rs @@ -43,6 +43,8 @@ //! ``` //! +pub mod cmp; +#[doc(hidden)] pub mod comparison; pub mod ord; pub mod partition; diff --git a/arrow-ord/src/partition.rs b/arrow-ord/src/partition.rs index 4a0a6730d882..52aa5ee8d0f1 100644 --- a/arrow-ord/src/partition.rs +++ b/arrow-ord/src/partition.rs @@ -23,7 +23,7 @@ use arrow_array::{Array, ArrayRef}; use arrow_buffer::BooleanBuffer; use arrow_schema::ArrowError; -use crate::comparison::neq_dyn; +use crate::cmp::neq; use crate::sort::SortColumn; /// A computed set of partitions, see [`partition`] @@ -158,7 +158,7 @@ fn find_boundaries(v: &dyn Array) -> Result { let v1 = v.slice(0, slice_len); let v2 = v.slice(1, slice_len); - let array_ne = neq_dyn(v1.as_ref(), v2.as_ref())?; + let array_ne = neq(&v1, &v2)?; // Set if values have different non-NULL values let values_ne = match array_ne.nulls().filter(|n| n.null_count() > 0) { Some(n) => n.inner() & array_ne.values(), diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml index bcf6a84311d5..9456dd4b012c 100644 --- a/arrow/Cargo.toml +++ b/arrow/Cargo.toml @@ -71,7 +71,7 @@ ipc_compression = ["ipc", "arrow-ipc/lz4", "arrow-ipc/zstd"] csv = ["arrow-csv"] ipc = ["arrow-ipc"] json = ["arrow-json"] -simd = ["arrow-array/simd", "arrow-ord/simd", "arrow-arith/simd"] +simd = ["arrow-array/simd", "arrow-arith/simd"] prettyprint = ["arrow-cast/prettyprint"] # The test utils feature enables code used in benchmarks and tests but # not the core arrow code itself. Be aware that `rand` must be kept as @@ -87,7 +87,7 @@ force_validate = ["arrow-data/force_validate"] ffi = ["arrow-schema/ffi", "arrow-data/ffi"] # Enable dyn-comparison of dictionary arrays with other arrays # Note: this does not impact comparison against scalars -dyn_cmp_dict = ["arrow-string/dyn_cmp_dict", "arrow-ord/dyn_cmp_dict"] +dyn_cmp_dict = ["arrow-string/dyn_cmp_dict"] chrono-tz = ["arrow-array/chrono-tz"] [dev-dependencies] diff --git a/arrow/benches/comparison_kernels.rs b/arrow/benches/comparison_kernels.rs index 73db3ffed368..b9fb6c8e3300 100644 --- a/arrow/benches/comparison_kernels.rs +++ b/arrow/benches/comparison_kernels.rs @@ -21,61 +21,16 @@ use criterion::Criterion; extern crate arrow; -use arrow::compute::*; -use arrow::datatypes::{ArrowNativeTypeOp, ArrowNumericType, IntervalMonthDayNanoType}; +use arrow::compute::kernels::cmp::*; +use arrow::datatypes::IntervalMonthDayNanoType; use arrow::util::bench_util::*; use arrow::{array::*, datatypes::Float32Type, datatypes::Int32Type}; +use arrow_array::Scalar; +use arrow_string::like::*; +use arrow_string::regexp::regexp_is_match_utf8_scalar; const SIZE: usize = 65536; -fn bench_eq(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) -where - T: ArrowNumericType, - ::Native: ArrowNativeTypeOp, -{ - eq(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); -} - -fn bench_neq(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) -where - T: ArrowNumericType, - ::Native: ArrowNativeTypeOp, -{ - neq(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); -} - -fn bench_lt(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) -where - T: ArrowNumericType, - ::Native: ArrowNativeTypeOp, -{ - lt(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); -} - -fn bench_lt_eq(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) -where - T: ArrowNumericType, - ::Native: ArrowNativeTypeOp, -{ - lt_eq(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); -} - -fn bench_gt(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) -where - T: ArrowNumericType, - ::Native: ArrowNativeTypeOp, -{ - gt(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); -} - -fn bench_gt_eq(arr_a: &PrimitiveArray, arr_b: &PrimitiveArray) -where - T: ArrowNumericType, - ::Native: ArrowNativeTypeOp, -{ - gt_eq(criterion::black_box(arr_a), criterion::black_box(arr_b)).unwrap(); -} - fn bench_like_utf8_scalar(arr_a: &StringArray, value_b: &str) { like_utf8_scalar(criterion::black_box(arr_a), criterion::black_box(value_b)).unwrap(); } @@ -104,27 +59,6 @@ fn bench_regexp_is_match_utf8_scalar(arr_a: &StringArray, value_b: &str) { .unwrap(); } -#[cfg(not(feature = "dyn_cmp_dict"))] -fn dyn_cmp_dict_benchmarks(_c: &mut Criterion) {} - -#[cfg(feature = "dyn_cmp_dict")] -fn dyn_cmp_dict_benchmarks(c: &mut Criterion) { - let strings = create_string_array::(20, 0.); - let dict_arr_a = create_dict_from_values::(SIZE, 0., &strings); - let dict_arr_b = create_dict_from_values::(SIZE, 0., &strings); - - c.bench_function("eq dictionary[10] string[4])", |b| { - b.iter(|| { - cmp_dict_utf8::<_, i32, _>( - criterion::black_box(&dict_arr_a), - criterion::black_box(&dict_arr_b), - |a, b| a == b, - ) - .unwrap() - }) - }); -} - fn add_benchmark(c: &mut Criterion) { let arr_a = create_primitive_array_with_seed::(SIZE, 0.0, 42); let arr_b = create_primitive_array_with_seed::(SIZE, 0.0, 43); @@ -135,105 +69,79 @@ fn add_benchmark(c: &mut Criterion) { create_primitive_array_with_seed::(SIZE, 0.0, 43); let arr_string = create_string_array::(SIZE, 0.0); + let scalar = Float32Array::from(vec![1.0]); - c.bench_function("eq Float32", |b| b.iter(|| bench_eq(&arr_a, &arr_b))); + c.bench_function("eq Float32", |b| b.iter(|| eq(&arr_a, &arr_b))); c.bench_function("eq scalar Float32", |b| { - b.iter(|| { - eq_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap() - }) + b.iter(|| eq(&arr_a, &Scalar::new(&scalar)).unwrap()) }); - c.bench_function("neq Float32", |b| b.iter(|| bench_neq(&arr_a, &arr_b))); + c.bench_function("neq Float32", |b| b.iter(|| neq(&arr_a, &arr_b))); c.bench_function("neq scalar Float32", |b| { - b.iter(|| { - neq_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap() - }) + b.iter(|| neq(&arr_a, &Scalar::new(&scalar)).unwrap()) }); - c.bench_function("lt Float32", |b| b.iter(|| bench_lt(&arr_a, &arr_b))); + c.bench_function("lt Float32", |b| b.iter(|| lt(&arr_a, &arr_b))); c.bench_function("lt scalar Float32", |b| { - b.iter(|| { - lt_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap() - }) + b.iter(|| lt(&arr_a, &Scalar::new(&scalar)).unwrap()) }); - c.bench_function("lt_eq Float32", |b| b.iter(|| bench_lt_eq(&arr_a, &arr_b))); + c.bench_function("lt_eq Float32", |b| b.iter(|| lt_eq(&arr_a, &arr_b))); c.bench_function("lt_eq scalar Float32", |b| { - b.iter(|| { - lt_eq_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap() - }) + b.iter(|| lt_eq(&arr_a, &Scalar::new(&scalar)).unwrap()) }); - c.bench_function("gt Float32", |b| b.iter(|| bench_gt(&arr_a, &arr_b))); + c.bench_function("gt Float32", |b| b.iter(|| gt(&arr_a, &arr_b))); c.bench_function("gt scalar Float32", |b| { - b.iter(|| { - gt_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap() - }) + b.iter(|| gt(&arr_a, &Scalar::new(&scalar)).unwrap()) }); - c.bench_function("gt_eq Float32", |b| b.iter(|| bench_gt_eq(&arr_a, &arr_b))); + c.bench_function("gt_eq Float32", |b| b.iter(|| gt_eq(&arr_a, &arr_b))); c.bench_function("gt_eq scalar Float32", |b| { - b.iter(|| { - gt_eq_scalar(criterion::black_box(&arr_a), criterion::black_box(1.0)).unwrap() - }) + b.iter(|| gt_eq(&arr_a, &Scalar::new(&scalar)).unwrap()) }); let arr_a = create_primitive_array_with_seed::(SIZE, 0.0, 42); let arr_b = create_primitive_array_with_seed::(SIZE, 0.0, 43); + let scalar = Int32Array::from(vec![1]); - c.bench_function("eq Int32", |b| b.iter(|| bench_eq(&arr_a, &arr_b))); + c.bench_function("eq Int32", |b| b.iter(|| eq(&arr_a, &arr_b))); c.bench_function("eq scalar Int32", |b| { - b.iter(|| { - eq_scalar(criterion::black_box(&arr_a), criterion::black_box(1)).unwrap() - }) + b.iter(|| eq(&arr_a, &Scalar::new(&scalar)).unwrap()) }); - c.bench_function("neq Int32", |b| b.iter(|| bench_neq(&arr_a, &arr_b))); + c.bench_function("neq Int32", |b| b.iter(|| neq(&arr_a, &arr_b))); c.bench_function("neq scalar Int32", |b| { - b.iter(|| { - neq_scalar(criterion::black_box(&arr_a), criterion::black_box(1)).unwrap() - }) + b.iter(|| neq(&arr_a, &Scalar::new(&scalar)).unwrap()) }); - c.bench_function("lt Int32", |b| b.iter(|| bench_lt(&arr_a, &arr_b))); + c.bench_function("lt Int32", |b| b.iter(|| lt(&arr_a, &arr_b))); c.bench_function("lt scalar Int32", |b| { - b.iter(|| { - lt_scalar(criterion::black_box(&arr_a), criterion::black_box(1)).unwrap() - }) + b.iter(|| lt(&arr_a, &Scalar::new(&scalar)).unwrap()) }); - c.bench_function("lt_eq Int32", |b| b.iter(|| bench_lt_eq(&arr_a, &arr_b))); + c.bench_function("lt_eq Int32", |b| b.iter(|| lt_eq(&arr_a, &arr_b))); c.bench_function("lt_eq scalar Int32", |b| { - b.iter(|| { - lt_eq_scalar(criterion::black_box(&arr_a), criterion::black_box(1)).unwrap() - }) + b.iter(|| lt_eq(&arr_a, &Scalar::new(&scalar)).unwrap()) }); - c.bench_function("gt Int32", |b| b.iter(|| bench_gt(&arr_a, &arr_b))); + c.bench_function("gt Int32", |b| b.iter(|| gt(&arr_a, &arr_b))); c.bench_function("gt scalar Int32", |b| { - b.iter(|| { - gt_scalar(criterion::black_box(&arr_a), criterion::black_box(1)).unwrap() - }) + b.iter(|| gt(&arr_a, &Scalar::new(&scalar)).unwrap()) }); - c.bench_function("gt_eq Int32", |b| b.iter(|| bench_gt_eq(&arr_a, &arr_b))); + c.bench_function("gt_eq Int32", |b| b.iter(|| gt_eq(&arr_a, &arr_b))); c.bench_function("gt_eq scalar Int32", |b| { - b.iter(|| { - gt_eq_scalar(criterion::black_box(&arr_a), criterion::black_box(1)).unwrap() - }) + b.iter(|| gt_eq(&arr_a, &Scalar::new(&scalar)).unwrap()) }); c.bench_function("eq MonthDayNano", |b| { - b.iter(|| bench_eq(&arr_month_day_nano_a, &arr_month_day_nano_b)) + b.iter(|| eq(&arr_month_day_nano_a, &arr_month_day_nano_b)) }); + let scalar = IntervalMonthDayNanoArray::from(vec![123]); + c.bench_function("eq scalar MonthDayNano", |b| { - b.iter(|| { - eq_scalar( - criterion::black_box(&arr_month_day_nano_a), - criterion::black_box(123), - ) - .unwrap() - }) + b.iter(|| eq(&arr_month_day_nano_b, &Scalar::new(&scalar)).unwrap()) }); c.bench_function("like_utf8 scalar equals", |b| { @@ -326,14 +234,15 @@ fn add_benchmark(c: &mut Criterion) { let strings = create_string_array::(20, 0.); let dict_arr_a = create_dict_from_values::(SIZE, 0., &strings); + let scalar = StringArray::from(vec!["test"]); c.bench_function("eq_dyn_utf8_scalar dictionary[10] string[4])", |b| { - b.iter(|| eq_dyn_utf8_scalar(&dict_arr_a, "test")) + b.iter(|| eq(&dict_arr_a, &Scalar::new(&scalar))) }); c.bench_function( "gt_eq_dyn_utf8_scalar scalar dictionary[10] string[4])", - |b| b.iter(|| gt_eq_dyn_utf8_scalar(&dict_arr_a, "test")), + |b| b.iter(|| gt_eq(&dict_arr_a, &Scalar::new(&scalar))), ); c.bench_function("like_utf8_scalar_dyn dictionary[10] string[4])", |b| { @@ -344,7 +253,13 @@ fn add_benchmark(c: &mut Criterion) { b.iter(|| ilike_utf8_scalar_dyn(&dict_arr_a, "test")) }); - dyn_cmp_dict_benchmarks(c); + let strings = create_string_array::(20, 0.); + let dict_arr_a = create_dict_from_values::(SIZE, 0., &strings); + let dict_arr_b = create_dict_from_values::(SIZE, 0., &strings); + + c.bench_function("eq dictionary[10] string[4])", |b| { + b.iter(|| eq(&dict_arr_a, &dict_arr_b).unwrap()) + }); } criterion_group!(benches, add_benchmark); diff --git a/arrow/benches/equal.rs b/arrow/benches/equal.rs index 2f4e2fada9e9..4e99bf3071c9 100644 --- a/arrow/benches/equal.rs +++ b/arrow/benches/equal.rs @@ -20,7 +20,6 @@ #[macro_use] extern crate criterion; -use arrow::compute::eq_utf8_scalar; use criterion::Criterion; extern crate arrow; @@ -32,10 +31,6 @@ fn bench_equal>(arr_a: &A) { criterion::black_box(arr_a == arr_a); } -fn bench_equal_utf8_scalar(arr_a: &GenericStringArray, right: &str) { - criterion::black_box(eq_utf8_scalar(arr_a, right).unwrap()); -} - fn add_benchmark(c: &mut Criterion) { let arr_a = create_primitive_array::(512, 0.0); c.bench_function("equal_512", |b| b.iter(|| bench_equal(&arr_a))); @@ -49,11 +44,6 @@ fn add_benchmark(c: &mut Criterion) { let arr_a = create_string_array::(512, 0.0); c.bench_function("equal_string_512", |b| b.iter(|| bench_equal(&arr_a))); - let arr_a = create_string_array::(512, 0.0); - c.bench_function("equal_string_scalar_empty_512", |b| { - b.iter(|| bench_equal_utf8_scalar(&arr_a, "")) - }); - let arr_a_nulls = create_string_array::(512, 0.5); c.bench_function("equal_string_nulls_512", |b| { b.iter(|| bench_equal(&arr_a_nulls)) diff --git a/arrow/src/compute/kernels.rs b/arrow/src/compute/kernels.rs index 1a79aef547d3..dba41625020b 100644 --- a/arrow/src/compute/kernels.rs +++ b/arrow/src/compute/kernels.rs @@ -22,7 +22,7 @@ pub use arrow_arith::{ }; pub use arrow_cast::cast; pub use arrow_cast::parse as cast_utils; -pub use arrow_ord::{partition, sort}; +pub use arrow_ord::{cmp, partition, sort}; pub use arrow_select::{concat, filter, interleave, nullif, take, window, zip}; pub use arrow_string::{concat_elements, length, regexp, substring}; diff --git a/parquet/examples/async_read_parquet.rs b/parquet/examples/async_read_parquet.rs index f600cd0d11e3..e59cad8055cb 100644 --- a/parquet/examples/async_read_parquet.rs +++ b/parquet/examples/async_read_parquet.rs @@ -15,7 +15,9 @@ // specific language governing permissions and limitations // under the License. +use arrow::compute::kernels::cmp::eq; use arrow::util::pretty::print_batches; +use arrow_array::{Int32Array, Scalar}; use futures::TryStreamExt; use parquet::arrow::arrow_reader::{ArrowPredicateFn, RowFilter}; use parquet::arrow::{ParquetRecordBatchStreamBuilder, ProjectionMask}; @@ -44,9 +46,10 @@ async fn main() -> Result<()> { // Highlight: set `RowFilter`, it'll push down filter predicates to skip IO and decode. // For more specific usage: please refer to https://github.com/apache/arrow-datafusion/blob/master/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs. + let scalar = Int32Array::from(vec![1]); let filter = ArrowPredicateFn::new( ProjectionMask::roots(file_metadata.schema_descr(), [0]), - |record_batch| arrow::compute::eq_dyn_scalar(record_batch.column(0), 1), + move |record_batch| eq(record_batch.column(0), &Scalar::new(&scalar)), ); let row_filter = RowFilter::new(vec![Box::new(filter)]); builder = builder.with_row_filter(row_filter); diff --git a/parquet/src/arrow/async_reader/mod.rs b/parquet/src/arrow/async_reader/mod.rs index c7e0f64783f1..54793c47fea1 100644 --- a/parquet/src/arrow/async_reader/mod.rs +++ b/parquet/src/arrow/async_reader/mod.rs @@ -776,10 +776,11 @@ mod tests { use crate::file::footer::parse_metadata; use crate::file::page_index::index_reader; use crate::file::properties::WriterProperties; + use arrow::compute::kernels::cmp::eq; use arrow::error::Result as ArrowResult; use arrow_array::cast::AsArray; use arrow_array::types::Int32Type; - use arrow_array::{Array, ArrayRef, Int32Array, StringArray}; + use arrow_array::{Array, ArrayRef, Int32Array, Int8Array, Scalar, StringArray}; use futures::TryStreamExt; use rand::{thread_rng, Rng}; use std::sync::Mutex; @@ -1188,14 +1189,16 @@ mod tests { }; let requests = test.requests.clone(); + let a_scalar = StringArray::from_iter_values(["b"]); let a_filter = ArrowPredicateFn::new( ProjectionMask::leaves(&parquet_schema, vec![0]), - |batch| arrow::compute::eq_dyn_utf8_scalar(batch.column(0), "b"), + move |batch| eq(batch.column(0), &Scalar::new(&a_scalar)), ); + let b_scalar = StringArray::from_iter_values(["4"]); let b_filter = ArrowPredicateFn::new( ProjectionMask::leaves(&parquet_schema, vec![1]), - |batch| arrow::compute::eq_dyn_utf8_scalar(batch.column(0), "4"), + move |batch| eq(batch.column(0), &Scalar::new(&b_scalar)), ); let filter = RowFilter::new(vec![Box::new(a_filter), Box::new(b_filter)]); @@ -1353,12 +1356,13 @@ mod tests { let a_filter = ArrowPredicateFn::new( ProjectionMask::leaves(&parquet_schema, vec![1]), - |batch| arrow::compute::eq_dyn_bool_scalar(batch.column(0), true), + |batch| Ok(batch.column(0).as_boolean().clone()), ); + let b_scalar = Int8Array::from(vec![2]); let b_filter = ArrowPredicateFn::new( ProjectionMask::leaves(&parquet_schema, vec![2]), - |batch| arrow::compute::eq_dyn_scalar(batch.column(0), 2_i32), + move |batch| eq(batch.column(0), &Scalar::new(&b_scalar)), ); let filter = RowFilter::new(vec![Box::new(a_filter), Box::new(b_filter)]);