Skip to content

Commit

Permalink
Fix List Sorting, Revert Removal of Rank Kernels (#4747)
Browse files Browse the repository at this point in the history
* Revert "Remove rank kernels"

This reverts commit c06786f.

* Fix child_rank
  • Loading branch information
tustvold authored Sep 5, 2023
1 parent b66c57c commit db5314c
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 10 deletions.
1 change: 1 addition & 0 deletions arrow-ord/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,5 @@ pub mod cmp;
pub mod comparison;
pub mod ord;
pub mod partition;
pub mod rank;
pub mod sort;
195 changes: 195 additions & 0 deletions arrow-ord/src/rank.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::{downcast_primitive_array, Array, ArrowNativeTypeOp, GenericByteArray};
use arrow_buffer::NullBuffer;
use arrow_schema::{ArrowError, DataType, SortOptions};
use std::cmp::Ordering;

/// Assigns a rank to each value in `array` based on its position in the sorted order
///
/// Where values are equal, they will be assigned the highest of their ranks,
/// leaving gaps in the overall rank assignment
///
/// ```
/// # use arrow_array::StringArray;
/// # use arrow_ord::rank::rank;
/// let array = StringArray::from(vec![Some("foo"), None, Some("foo"), None, Some("bar")]);
/// let ranks = rank(&array, None).unwrap();
/// assert_eq!(ranks, &[5, 2, 5, 2, 3]);
/// ```
pub fn rank(
array: &dyn Array,
options: Option<SortOptions>,
) -> Result<Vec<u32>, ArrowError> {
let options = options.unwrap_or_default();
let ranks = downcast_primitive_array! {
array => primitive_rank(array.values(), array.nulls(), options),
DataType::Utf8 => bytes_rank(array.as_bytes::<Utf8Type>(), options),
DataType::LargeUtf8 => bytes_rank(array.as_bytes::<LargeUtf8Type>(), options),
DataType::Binary => bytes_rank(array.as_bytes::<BinaryType>(), options),
DataType::LargeBinary => bytes_rank(array.as_bytes::<LargeBinaryType>(), options),
d => return Err(ArrowError::ComputeError(format!("{d:?} not supported in rank")))
};
Ok(ranks)
}

#[inline(never)]
fn primitive_rank<T: ArrowNativeTypeOp>(
values: &[T],
nulls: Option<&NullBuffer>,
options: SortOptions,
) -> Vec<u32> {
let len: u32 = values.len().try_into().unwrap();
let to_sort = match nulls.filter(|n| n.null_count() > 0) {
Some(n) => n
.valid_indices()
.map(|idx| (values[idx], idx as u32))
.collect(),
None => values.iter().copied().zip(0..len).collect(),
};
rank_impl(values.len(), to_sort, options, T::compare, T::is_eq)
}

#[inline(never)]
fn bytes_rank<T: ByteArrayType>(
array: &GenericByteArray<T>,
options: SortOptions,
) -> Vec<u32> {
let to_sort: Vec<(&[u8], u32)> = match array.nulls().filter(|n| n.null_count() > 0) {
Some(n) => n
.valid_indices()
.map(|idx| (array.value(idx).as_ref(), idx as u32))
.collect(),
None => (0..array.len())
.map(|idx| (array.value(idx).as_ref(), idx as u32))
.collect(),
};
rank_impl(array.len(), to_sort, options, Ord::cmp, PartialEq::eq)
}

fn rank_impl<T, C, E>(
len: usize,
mut valid: Vec<(T, u32)>,
options: SortOptions,
compare: C,
eq: E,
) -> Vec<u32>
where
T: Copy,
C: Fn(T, T) -> Ordering,
E: Fn(T, T) -> bool,
{
// We can use an unstable sort as we combine equal values later
valid.sort_unstable_by(|a, b| compare(a.0, b.0));
if options.descending {
valid.reverse();
}

let (mut valid_rank, null_rank) = match options.nulls_first {
true => (len as u32, (len - valid.len()) as u32),
false => (valid.len() as u32, len as u32),
};

let mut out: Vec<_> = vec![null_rank; len];
if let Some(v) = valid.last() {
out[v.1 as usize] = valid_rank;
}

let mut count = 1; // Number of values in rank
for w in valid.windows(2).rev() {
match eq(w[0].0, w[1].0) {
true => {
count += 1;
out[w[0].1 as usize] = valid_rank;
}
false => {
valid_rank -= count;
count = 1;
out[w[0].1 as usize] = valid_rank
}
}
}

out
}

#[cfg(test)]
mod tests {
use super::*;
use arrow_array::*;

#[test]
fn test_primitive() {
let descending = SortOptions {
descending: true,
nulls_first: true,
};

let nulls_last = SortOptions {
descending: false,
nulls_first: false,
};

let nulls_last_descending = SortOptions {
descending: true,
nulls_first: false,
};

let a = Int32Array::from(vec![Some(1), Some(1), None, Some(3), Some(3), Some(4)]);
let res = rank(&a, None).unwrap();
assert_eq!(res, &[3, 3, 1, 5, 5, 6]);

let res = rank(&a, Some(descending)).unwrap();
assert_eq!(res, &[6, 6, 1, 4, 4, 2]);

let res = rank(&a, Some(nulls_last)).unwrap();
assert_eq!(res, &[2, 2, 6, 4, 4, 5]);

let res = rank(&a, Some(nulls_last_descending)).unwrap();
assert_eq!(res, &[5, 5, 6, 3, 3, 1]);

// Test with non-zero null values
let nulls = NullBuffer::from(vec![true, true, false, true, false, false]);
let a = Int32Array::new(vec![1, 4, 3, 4, 5, 5].into(), Some(nulls));
let res = rank(&a, None).unwrap();
assert_eq!(res, &[4, 6, 3, 6, 3, 3]);
}

#[test]
fn test_bytes() {
let v = vec!["foo", "fo", "bar", "bar"];
let values = StringArray::from(v.clone());
let res = rank(&values, None).unwrap();
assert_eq!(res, &[4, 3, 2, 2]);

let values = LargeStringArray::from(v.clone());
let res = rank(&values, None).unwrap();
assert_eq!(res, &[4, 3, 2, 2]);

let v: Vec<&[u8]> = vec![&[1, 2], &[0], &[1, 2, 3], &[1, 2]];
let values = LargeBinaryArray::from(v.clone());
let res = rank(&values, None).unwrap();
assert_eq!(res, &[3, 1, 4, 3]);

let values = BinaryArray::from(v);
let res = rank(&values, None).unwrap();
assert_eq!(res, &[3, 1, 4, 3]);
}
}
41 changes: 32 additions & 9 deletions arrow-ord/src/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ use arrow_select::take::take;
use std::cmp::Ordering;
use std::sync::Arc;

use crate::rank::rank;
pub use arrow_schema::SortOptions;

/// Sort the `ArrayRef` using `SortOptions`.
Expand Down Expand Up @@ -400,14 +401,7 @@ fn child_rank(values: &dyn Array, options: SortOptions) -> Result<Vec<u32>, Arro
descending: false,
nulls_first: options.nulls_first != options.descending,
});

let sorted_value_indices = sort_to_indices(values, value_options, None)?;
let sorted_indices = sorted_value_indices.values();
let mut out: Vec<_> = vec![0_u32; sorted_indices.len()];
for (ix, val) in sorted_indices.iter().enumerate() {
out[*val as usize] = ix as u32;
}
Ok(out)
rank(values, value_options)
}

// Sort run array and return sorted run array.
Expand Down Expand Up @@ -800,7 +794,9 @@ impl LexicographicalComparator {
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::builder::PrimitiveRunBuilder;
use arrow_array::builder::{
FixedSizeListBuilder, Int64Builder, ListBuilder, PrimitiveRunBuilder,
};
use arrow_buffer::i256;
use half::f16;
use rand::rngs::StdRng;
Expand Down Expand Up @@ -3991,4 +3987,31 @@ mod tests {
// NULL.cmp(4)
assert_eq!(comparator.compare(2, 3), Ordering::Less);
}

#[test]
fn sort_list_equal() {
let a = {
let mut builder = FixedSizeListBuilder::new(Int64Builder::new(), 2);
for value in [[1, 5], [0, 3], [1, 3]] {
builder.values().append_slice(&value);
builder.append(true);
}
builder.finish()
};

let sort_indices = sort_to_indices(&a, None, None).unwrap();
assert_eq!(sort_indices.values(), &[1, 2, 0]);

let a = {
let mut builder = ListBuilder::new(Int64Builder::new());
for value in [[1, 5], [0, 3], [1, 3]] {
builder.values().append_slice(&value);
builder.append(true);
}
builder.finish()
};

let sort_indices = sort_to_indices(&a, None, None).unwrap();
assert_eq!(sort_indices.values(), &[1, 2, 0]);
}
}
21 changes: 21 additions & 0 deletions arrow/benches/sort_kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use arrow::compute::{lexsort, sort, sort_to_indices, SortColumn};
use arrow::datatypes::{Int16Type, Int32Type};
use arrow::util::bench_util::*;
use arrow::{array::*, datatypes::Float32Type};
use arrow_ord::rank::rank;

fn create_f32_array(size: usize, with_nulls: bool) -> ArrayRef {
let null_density = if with_nulls { 0.5 } else { 0.0 };
Expand Down Expand Up @@ -213,6 +214,26 @@ fn add_benchmark(c: &mut Criterion) {
c.bench_function("lexsort (f32, f32) nulls 2^12 limit 2^12", |b| {
b.iter(|| bench_lexsort(&arr_a, &arr_b, Some(2usize.pow(12))))
});

let arr = create_f32_array(2usize.pow(12), false);
c.bench_function("rank f32 2^12", |b| {
b.iter(|| black_box(rank(&arr, None).unwrap()))
});

let arr = create_f32_array(2usize.pow(12), true);
c.bench_function("rank f32 nulls 2^12", |b| {
b.iter(|| black_box(rank(&arr, None).unwrap()))
});

let arr = create_string_array_with_len::<i32>(2usize.pow(12), 0.0, 10);
c.bench_function("rank string[10] 2^12", |b| {
b.iter(|| black_box(rank(&arr, None).unwrap()))
});

let arr = create_string_array_with_len::<i32>(2usize.pow(12), 0.5, 10);
c.bench_function("rank string[10] nulls 2^12", |b| {
b.iter(|| black_box(rank(&arr, None).unwrap()))
});
}

criterion_group!(benches, add_benchmark);
Expand Down
2 changes: 1 addition & 1 deletion arrow/src/compute/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ pub use arrow_arith::{
};
pub use arrow_cast::cast;
pub use arrow_cast::parse as cast_utils;
pub use arrow_ord::{cmp, partition, sort};
pub use arrow_ord::{cmp, partition, rank, sort};
pub use arrow_select::{concat, filter, interleave, nullif, take, window, zip};
pub use arrow_string::{concat_elements, length, regexp, substring};

Expand Down
1 change: 1 addition & 0 deletions arrow/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub use self::kernels::filter::*;
pub use self::kernels::interleave::*;
pub use self::kernels::nullif::*;
pub use self::kernels::partition::*;
pub use self::kernels::rank::*;
pub use self::kernels::regexp::*;
pub use self::kernels::sort::*;
pub use self::kernels::take::*;
Expand Down

0 comments on commit db5314c

Please sign in to comment.