From db5314c5c2680861683b7dcb8f69cc27aa7ac8ed Mon Sep 17 00:00:00 2001 From: Raphael Taylor-Davies <1781103+tustvold@users.noreply.github.com> Date: Tue, 5 Sep 2023 16:01:09 +0100 Subject: [PATCH] Fix List Sorting, Revert Removal of Rank Kernels (#4747) * Revert "Remove rank kernels" This reverts commit c06786faaf750de7c899dd7750111c2d684e307b. * Fix child_rank --- arrow-ord/src/lib.rs | 1 + arrow-ord/src/rank.rs | 195 +++++++++++++++++++++++++++++++++++ arrow-ord/src/sort.rs | 41 ++++++-- arrow/benches/sort_kernel.rs | 21 ++++ arrow/src/compute/kernels.rs | 2 +- arrow/src/compute/mod.rs | 1 + 6 files changed, 251 insertions(+), 10 deletions(-) create mode 100644 arrow-ord/src/rank.rs diff --git a/arrow-ord/src/lib.rs b/arrow-ord/src/lib.rs index 19ad8229417f..8fe4ecbc05aa 100644 --- a/arrow-ord/src/lib.rs +++ b/arrow-ord/src/lib.rs @@ -48,4 +48,5 @@ pub mod cmp; pub mod comparison; pub mod ord; pub mod partition; +pub mod rank; pub mod sort; diff --git a/arrow-ord/src/rank.rs b/arrow-ord/src/rank.rs new file mode 100644 index 000000000000..1e79156a71a3 --- /dev/null +++ b/arrow-ord/src/rank.rs @@ -0,0 +1,195 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow_array::cast::AsArray; +use arrow_array::types::*; +use arrow_array::{downcast_primitive_array, Array, ArrowNativeTypeOp, GenericByteArray}; +use arrow_buffer::NullBuffer; +use arrow_schema::{ArrowError, DataType, SortOptions}; +use std::cmp::Ordering; + +/// Assigns a rank to each value in `array` based on its position in the sorted order +/// +/// Where values are equal, they will be assigned the highest of their ranks, +/// leaving gaps in the overall rank assignment +/// +/// ``` +/// # use arrow_array::StringArray; +/// # use arrow_ord::rank::rank; +/// let array = StringArray::from(vec![Some("foo"), None, Some("foo"), None, Some("bar")]); +/// let ranks = rank(&array, None).unwrap(); +/// assert_eq!(ranks, &[5, 2, 5, 2, 3]); +/// ``` +pub fn rank( + array: &dyn Array, + options: Option, +) -> Result, ArrowError> { + let options = options.unwrap_or_default(); + let ranks = downcast_primitive_array! { + array => primitive_rank(array.values(), array.nulls(), options), + DataType::Utf8 => bytes_rank(array.as_bytes::(), options), + DataType::LargeUtf8 => bytes_rank(array.as_bytes::(), options), + DataType::Binary => bytes_rank(array.as_bytes::(), options), + DataType::LargeBinary => bytes_rank(array.as_bytes::(), options), + d => return Err(ArrowError::ComputeError(format!("{d:?} not supported in rank"))) + }; + Ok(ranks) +} + +#[inline(never)] +fn primitive_rank( + values: &[T], + nulls: Option<&NullBuffer>, + options: SortOptions, +) -> Vec { + let len: u32 = values.len().try_into().unwrap(); + let to_sort = match nulls.filter(|n| n.null_count() > 0) { + Some(n) => n + .valid_indices() + .map(|idx| (values[idx], idx as u32)) + .collect(), + None => values.iter().copied().zip(0..len).collect(), + }; + rank_impl(values.len(), to_sort, options, T::compare, T::is_eq) +} + +#[inline(never)] +fn bytes_rank( + array: &GenericByteArray, + options: SortOptions, +) -> Vec { + let to_sort: Vec<(&[u8], u32)> = match array.nulls().filter(|n| n.null_count() > 0) { + Some(n) => n + .valid_indices() + .map(|idx| (array.value(idx).as_ref(), idx as u32)) + .collect(), + None => (0..array.len()) + .map(|idx| (array.value(idx).as_ref(), idx as u32)) + .collect(), + }; + rank_impl(array.len(), to_sort, options, Ord::cmp, PartialEq::eq) +} + +fn rank_impl( + len: usize, + mut valid: Vec<(T, u32)>, + options: SortOptions, + compare: C, + eq: E, +) -> Vec +where + T: Copy, + C: Fn(T, T) -> Ordering, + E: Fn(T, T) -> bool, +{ + // We can use an unstable sort as we combine equal values later + valid.sort_unstable_by(|a, b| compare(a.0, b.0)); + if options.descending { + valid.reverse(); + } + + let (mut valid_rank, null_rank) = match options.nulls_first { + true => (len as u32, (len - valid.len()) as u32), + false => (valid.len() as u32, len as u32), + }; + + let mut out: Vec<_> = vec![null_rank; len]; + if let Some(v) = valid.last() { + out[v.1 as usize] = valid_rank; + } + + let mut count = 1; // Number of values in rank + for w in valid.windows(2).rev() { + match eq(w[0].0, w[1].0) { + true => { + count += 1; + out[w[0].1 as usize] = valid_rank; + } + false => { + valid_rank -= count; + count = 1; + out[w[0].1 as usize] = valid_rank + } + } + } + + out +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow_array::*; + + #[test] + fn test_primitive() { + let descending = SortOptions { + descending: true, + nulls_first: true, + }; + + let nulls_last = SortOptions { + descending: false, + nulls_first: false, + }; + + let nulls_last_descending = SortOptions { + descending: true, + nulls_first: false, + }; + + let a = Int32Array::from(vec![Some(1), Some(1), None, Some(3), Some(3), Some(4)]); + let res = rank(&a, None).unwrap(); + assert_eq!(res, &[3, 3, 1, 5, 5, 6]); + + let res = rank(&a, Some(descending)).unwrap(); + assert_eq!(res, &[6, 6, 1, 4, 4, 2]); + + let res = rank(&a, Some(nulls_last)).unwrap(); + assert_eq!(res, &[2, 2, 6, 4, 4, 5]); + + let res = rank(&a, Some(nulls_last_descending)).unwrap(); + assert_eq!(res, &[5, 5, 6, 3, 3, 1]); + + // Test with non-zero null values + let nulls = NullBuffer::from(vec![true, true, false, true, false, false]); + let a = Int32Array::new(vec![1, 4, 3, 4, 5, 5].into(), Some(nulls)); + let res = rank(&a, None).unwrap(); + assert_eq!(res, &[4, 6, 3, 6, 3, 3]); + } + + #[test] + fn test_bytes() { + let v = vec!["foo", "fo", "bar", "bar"]; + let values = StringArray::from(v.clone()); + let res = rank(&values, None).unwrap(); + assert_eq!(res, &[4, 3, 2, 2]); + + let values = LargeStringArray::from(v.clone()); + let res = rank(&values, None).unwrap(); + assert_eq!(res, &[4, 3, 2, 2]); + + let v: Vec<&[u8]> = vec![&[1, 2], &[0], &[1, 2, 3], &[1, 2]]; + let values = LargeBinaryArray::from(v.clone()); + let res = rank(&values, None).unwrap(); + assert_eq!(res, &[3, 1, 4, 3]); + + let values = BinaryArray::from(v); + let res = rank(&values, None).unwrap(); + assert_eq!(res, &[3, 1, 4, 3]); + } +} diff --git a/arrow-ord/src/sort.rs b/arrow-ord/src/sort.rs index 6c8c3b8facef..a477d6c261b3 100644 --- a/arrow-ord/src/sort.rs +++ b/arrow-ord/src/sort.rs @@ -30,6 +30,7 @@ use arrow_select::take::take; use std::cmp::Ordering; use std::sync::Arc; +use crate::rank::rank; pub use arrow_schema::SortOptions; /// Sort the `ArrayRef` using `SortOptions`. @@ -400,14 +401,7 @@ fn child_rank(values: &dyn Array, options: SortOptions) -> Result, Arro descending: false, nulls_first: options.nulls_first != options.descending, }); - - let sorted_value_indices = sort_to_indices(values, value_options, None)?; - let sorted_indices = sorted_value_indices.values(); - let mut out: Vec<_> = vec![0_u32; sorted_indices.len()]; - for (ix, val) in sorted_indices.iter().enumerate() { - out[*val as usize] = ix as u32; - } - Ok(out) + rank(values, value_options) } // Sort run array and return sorted run array. @@ -800,7 +794,9 @@ impl LexicographicalComparator { #[cfg(test)] mod tests { use super::*; - use arrow_array::builder::PrimitiveRunBuilder; + use arrow_array::builder::{ + FixedSizeListBuilder, Int64Builder, ListBuilder, PrimitiveRunBuilder, + }; use arrow_buffer::i256; use half::f16; use rand::rngs::StdRng; @@ -3991,4 +3987,31 @@ mod tests { // NULL.cmp(4) assert_eq!(comparator.compare(2, 3), Ordering::Less); } + + #[test] + fn sort_list_equal() { + let a = { + let mut builder = FixedSizeListBuilder::new(Int64Builder::new(), 2); + for value in [[1, 5], [0, 3], [1, 3]] { + builder.values().append_slice(&value); + builder.append(true); + } + builder.finish() + }; + + let sort_indices = sort_to_indices(&a, None, None).unwrap(); + assert_eq!(sort_indices.values(), &[1, 2, 0]); + + let a = { + let mut builder = ListBuilder::new(Int64Builder::new()); + for value in [[1, 5], [0, 3], [1, 3]] { + builder.values().append_slice(&value); + builder.append(true); + } + builder.finish() + }; + + let sort_indices = sort_to_indices(&a, None, None).unwrap(); + assert_eq!(sort_indices.values(), &[1, 2, 0]); + } } diff --git a/arrow/benches/sort_kernel.rs b/arrow/benches/sort_kernel.rs index dd55076647a5..63e10e0528ba 100644 --- a/arrow/benches/sort_kernel.rs +++ b/arrow/benches/sort_kernel.rs @@ -27,6 +27,7 @@ use arrow::compute::{lexsort, sort, sort_to_indices, SortColumn}; use arrow::datatypes::{Int16Type, Int32Type}; use arrow::util::bench_util::*; use arrow::{array::*, datatypes::Float32Type}; +use arrow_ord::rank::rank; fn create_f32_array(size: usize, with_nulls: bool) -> ArrayRef { let null_density = if with_nulls { 0.5 } else { 0.0 }; @@ -213,6 +214,26 @@ fn add_benchmark(c: &mut Criterion) { c.bench_function("lexsort (f32, f32) nulls 2^12 limit 2^12", |b| { b.iter(|| bench_lexsort(&arr_a, &arr_b, Some(2usize.pow(12)))) }); + + let arr = create_f32_array(2usize.pow(12), false); + c.bench_function("rank f32 2^12", |b| { + b.iter(|| black_box(rank(&arr, None).unwrap())) + }); + + let arr = create_f32_array(2usize.pow(12), true); + c.bench_function("rank f32 nulls 2^12", |b| { + b.iter(|| black_box(rank(&arr, None).unwrap())) + }); + + let arr = create_string_array_with_len::(2usize.pow(12), 0.0, 10); + c.bench_function("rank string[10] 2^12", |b| { + b.iter(|| black_box(rank(&arr, None).unwrap())) + }); + + let arr = create_string_array_with_len::(2usize.pow(12), 0.5, 10); + c.bench_function("rank string[10] nulls 2^12", |b| { + b.iter(|| black_box(rank(&arr, None).unwrap())) + }); } criterion_group!(benches, add_benchmark); diff --git a/arrow/src/compute/kernels.rs b/arrow/src/compute/kernels.rs index dba41625020b..35ad80e009cc 100644 --- a/arrow/src/compute/kernels.rs +++ b/arrow/src/compute/kernels.rs @@ -22,7 +22,7 @@ pub use arrow_arith::{ }; pub use arrow_cast::cast; pub use arrow_cast::parse as cast_utils; -pub use arrow_ord::{cmp, partition, sort}; +pub use arrow_ord::{cmp, partition, rank, sort}; pub use arrow_select::{concat, filter, interleave, nullif, take, window, zip}; pub use arrow_string::{concat_elements, length, regexp, substring}; diff --git a/arrow/src/compute/mod.rs b/arrow/src/compute/mod.rs index 7cfe787b08cf..47a9d149aadb 100644 --- a/arrow/src/compute/mod.rs +++ b/arrow/src/compute/mod.rs @@ -30,6 +30,7 @@ pub use self::kernels::filter::*; pub use self::kernels::interleave::*; pub use self::kernels::nullif::*; pub use self::kernels::partition::*; +pub use self::kernels::rank::*; pub use self::kernels::regexp::*; pub use self::kernels::sort::*; pub use self::kernels::take::*;