Skip to content

Commit

Permalink
Define consistent float ordering (#808)
Browse files Browse the repository at this point in the history
We use total_cmp instead of partial_cmp to compare floating point
numbers and use bit patern to determine float equality
  • Loading branch information
robert3005 authored Sep 16, 2024
1 parent 343ed1c commit 02b752c
Show file tree
Hide file tree
Showing 11 changed files with 164 additions and 109 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions fuzz/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ cargo-fuzz = true
[dependencies]
arrow-buffer = { workspace = true }
libfuzzer-sys = { workspace = true }
num-traits = { workspace = true }
vortex-array = { workspace = true, features = ["arbitrary"] }
vortex-buffer = { workspace = true }
vortex-dtype = { workspace = true, features = ["arbitrary"] }
Expand Down
38 changes: 18 additions & 20 deletions fuzz/fuzz_targets/array_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use vortex::array::{
};
use vortex::compute::unary::scalar_at;
use vortex::compute::{search_sorted, slice, take, SearchResult, SearchSortedSide};
use vortex::encoding::{EncodingId, EncodingRef};
use vortex::encoding::EncodingRef;
use vortex::{Array, IntoCanonical};
use vortex_fuzz::{sort_canonical_array, Action, FuzzArrayAction};
use vortex_sampling_compressor::SamplingCompressor;
Expand Down Expand Up @@ -81,7 +81,8 @@ fn assert_search_sorted(
assert_eq!(
search_result,
expected,
"Expected to find {s} at {expected} in {} from {side} but instead found it at {search_result} in step {step}",
"Expected to find {s}({}) at {expected} in {} from {side} but instead found it at {search_result} in step {step}",
s.dtype(),
array.encoding().id()
);
}
Expand All @@ -92,19 +93,18 @@ fn assert_array_eq(lhs: &Array, rhs: &Array, step: usize) {
let l = scalar_at(lhs, idx).unwrap();
let r = scalar_at(rhs, idx).unwrap();

fuzzing_scalar_cmp(l, r, lhs.encoding().id(), rhs.encoding().id(), idx, step);
assert_eq!(l.is_valid(), r.is_valid());
assert!(
equal_scalar_values(l.value(), r.value()),
"{l} != {r} at index {idx}, lhs is {} rhs is {} in step {step}",
lhs.encoding().id(),
rhs.encoding().id()
);
}
}

fn fuzzing_scalar_cmp(
l: Scalar,
r: Scalar,
lhs_encoding: EncodingId,
rhs_encoding: EncodingId,
idx: usize,
step: usize,
) {
let equal_values = match (l.value(), r.value()) {
fn equal_scalar_values(l: &ScalarValue, r: &ScalarValue) -> bool {
match (l, r) {
(ScalarValue::Primitive(l), ScalarValue::Primitive(r))
if l.ptype().is_float() && r.ptype().is_float() =>
{
Expand All @@ -115,12 +115,10 @@ fn fuzzing_scalar_cmp(
_ => unreachable!(),
}
}
_ => l.value() == r.value(),
};

assert!(
equal_values,
"{l} != {r} at index {idx}, lhs is {lhs_encoding} rhs is {rhs_encoding} in step {step}",
);
assert_eq!(l.is_valid(), r.is_valid());
(ScalarValue::List(lc), ScalarValue::List(rc)) => lc
.iter()
.zip(rc.iter())
.all(|(l, r)| equal_scalar_values(l, r)),
_ => l == r,
}
}
49 changes: 37 additions & 12 deletions fuzz/src/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use vortex::compute::{IndexOrd, Len, SearchResult, SearchSorted, SearchSortedSid
use vortex::validity::ArrayValidity;
use vortex::{Array, ArrayDType, IntoArray, IntoArrayVariant};
use vortex_buffer::{Buffer, BufferString};
use vortex_dtype::{match_each_native_ptype, DType};
use vortex_dtype::{match_each_native_ptype, DType, NativePType};
use vortex_scalar::Scalar;

struct SearchNullableSlice<T>(Vec<Option<T>>);
Expand All @@ -32,6 +32,29 @@ impl<T> Len for SearchNullableSlice<T> {
}
}

struct SearchPrimitiveSlice<T>(Vec<Option<T>>);

impl<T: NativePType> IndexOrd<Option<T>> for SearchPrimitiveSlice<T> {
fn index_cmp(&self, idx: usize, elem: &Option<T>) -> Option<Ordering> {
match elem {
None => unreachable!("Can't search for None"),
Some(v) => {
// SAFETY: Used in search_sorted_by same as the standard library. The search_sorted ensures idx is in bounds
match unsafe { self.0.get_unchecked(idx) } {
None => Some(Ordering::Greater),
Some(i) => Some(i.compare(*v)),
}
}
}
}
}

impl<T> Len for SearchPrimitiveSlice<T> {
fn len(&self) -> usize {
self.0.len()
}
}

pub fn search_sorted_canonical_array(
array: &Array,
scalar: &Scalar,
Expand All @@ -55,24 +78,26 @@ pub fn search_sorted_canonical_array(
let to_find = scalar.try_into().unwrap();
SearchNullableSlice(opt_values).search_sorted(&Some(to_find), side)
}
DType::Primitive(p, _) => match_each_native_ptype!(p, |$P| {
DType::Primitive(p, _) => {
let primitive_array = array.clone().into_primitive().unwrap();
let validity = primitive_array
.logical_validity()
.into_array()
.into_bool()
.unwrap()
.boolean_buffer();
let opt_values = primitive_array
.maybe_null_slice::<$P>()
.iter()
.copied()
.zip(validity.iter())
.map(|(b, v)| v.then_some(b))
.collect::<Vec<_>>();
let to_find: $P = scalar.try_into().unwrap();
SearchNullableSlice(opt_values).search_sorted(&Some(to_find), side)
}),
match_each_native_ptype!(p, |$P| {
let opt_values = primitive_array
.maybe_null_slice::<$P>()
.iter()
.copied()
.zip(validity.iter())
.map(|(b, v)| v.then_some(b))
.collect::<Vec<_>>();
let to_find: $P = scalar.try_into().unwrap();
SearchPrimitiveSlice(opt_values).search_sorted(&Some(to_find), side)
})
}
DType::Utf8(_) | DType::Binary(_) => {
let utf8 = array.clone().into_varbin().unwrap();
let opt_values = utf8
Expand Down
76 changes: 29 additions & 47 deletions fuzz/src/sort.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use vortex::array::{BoolArray, PrimitiveArray, VarBinArray};
use vortex::compute::unary::scalar_at;
use vortex::validity::ArrayValidity;
use vortex::{Array, ArrayDType, IntoArray, IntoArrayVariant};
use vortex_dtype::{match_each_float_ptype, match_each_integer_ptype, DType};
use vortex_dtype::{match_each_native_ptype, DType, NativePType};

use crate::take::take_canonical_array;

Expand All @@ -32,52 +32,25 @@ pub fn sort_canonical_array(array: &Array) -> Array {
}
DType::Primitive(p, _) => {
let primitive_array = array.clone().into_primitive().unwrap();
if p.is_int() {
match_each_integer_ptype!(p, |$P| {
let mut opt_values = primitive_array
.maybe_null_slice::<$P>()
.iter()
.copied()
.zip(
primitive_array
.logical_validity()
.into_array()
.into_bool()
.unwrap()
.boolean_buffer()
.iter(),
)
.map(|(p, v)| v.then_some(p))
.collect::<Vec<_>>();
sort_opt_slice(&mut opt_values);
PrimitiveArray::from_nullable_vec(opt_values).into_array()
})
} else {
match_each_float_ptype!(p, |$F| {
let mut opt_values = primitive_array
.maybe_null_slice::<$F>()
.iter()
.copied()
.zip(
primitive_array
.logical_validity()
.into_array()
.into_bool()
.unwrap()
.boolean_buffer()
.iter(),
)
.map(|(p, v)| v.then_some(p))
.collect::<Vec<_>>();
opt_values.sort_by(|a, b| match (a, b) {
(Some(v), Some(w)) => v.to_bits().cmp(&w.to_bits()),
(None, None) => Ordering::Equal,
(None, Some(_)) => Ordering::Greater,
(Some(_), None) => Ordering::Less,
});
PrimitiveArray::from_nullable_vec(opt_values).into_array()
})
}
match_each_native_ptype!(p, |$P| {
let mut opt_values = primitive_array
.maybe_null_slice::<$P>()
.iter()
.copied()
.zip(
primitive_array
.logical_validity()
.into_array()
.into_bool()
.unwrap()
.boolean_buffer()
.iter(),
)
.map(|(p, v)| v.then_some(p))
.collect::<Vec<_>>();
sort_primitive_slice(&mut opt_values);
PrimitiveArray::from_nullable_vec(opt_values).into_array()
})
}
DType::Utf8(_) | DType::Binary(_) => {
let utf8 = array.clone().into_varbin().unwrap();
Expand All @@ -101,6 +74,15 @@ pub fn sort_canonical_array(array: &Array) -> Array {
}
}

fn sort_primitive_slice<T: NativePType>(s: &mut [Option<T>]) {
s.sort_by(|a, b| match (a, b) {
(Some(v), Some(w)) => v.compare(*w),
(None, None) => Ordering::Equal,
(None, Some(_)) => Ordering::Greater,
(Some(_), None) => Ordering::Less,
});
}

/// Reverse sorting of Option<T> such that None is last (Greatest)
fn sort_opt_slice<T: Ord>(s: &mut [Option<T>]) {
s.sort_by(|a, b| match (a, b) {
Expand Down
31 changes: 28 additions & 3 deletions vortex-array/src/array/primitive/compute/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ impl SearchSortedFn for PrimitiveArray {
match self.validity() {
Validity::NonNullable | Validity::AllValid => {
let pvalue: $T = value.try_into()?;
Ok(self.maybe_null_slice::<$T>().search_sorted(&pvalue, side))
Ok(SearchSortedPrimitive::new(self).search_sorted(&pvalue, side))
}
Validity::AllInvalid => Ok(SearchResult::NotFound(0)),
Validity::Array(_) => {
Expand All @@ -27,15 +27,40 @@ impl SearchSortedFn for PrimitiveArray {
}
}

struct SearchSortedNullsLast<'a, T> {
struct SearchSortedPrimitive<'a, T> {
values: &'a [T],
}

impl<'a, T: NativePType> SearchSortedPrimitive<'a, T> {
pub fn new(array: &'a PrimitiveArray) -> Self {
Self {
values: array.maybe_null_slice(),
}
}
}

impl<'a, T: NativePType> IndexOrd<T> for SearchSortedPrimitive<'a, T> {
fn index_cmp(&self, idx: usize, elem: &T) -> Option<Ordering> {
// SAFETY: Used in search_sorted_by same as the standard library. The search_sorted ensures idx is in bounds
Some(unsafe { self.values.get_unchecked(idx) }.compare(*elem))
}
}

impl<'a, T> Len for SearchSortedPrimitive<'a, T> {
fn len(&self) -> usize {
self.values.len()
}
}

struct SearchSortedNullsLast<'a, T> {
values: SearchSortedPrimitive<'a, T>,
validity: Validity,
}

impl<'a, T: NativePType> SearchSortedNullsLast<'a, T> {
pub fn new(array: &'a PrimitiveArray) -> Self {
Self {
values: array.maybe_null_slice(),
values: SearchSortedPrimitive::new(array),
validity: array.validity(),
}
}
Expand Down
8 changes: 4 additions & 4 deletions vortex-array/src/array/primitive/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,15 +197,15 @@ impl<T: NativePType> Accessor<T> for PrimitiveArray {
ArrayValidity::is_valid(self, index)
}

fn array_validity(&self) -> Validity {
self.validity()
}

#[inline]
fn value_unchecked(&self, index: usize) -> T {
self.maybe_null_slice::<T>()[index]
}

fn array_validity(&self) -> Validity {
self.validity()
}

#[inline]
fn decode_batch(&self, start_idx: usize) -> Vec<T> {
let batch_size = <Self as Accessor<T>>::batch_size(self, start_idx);
Expand Down
12 changes: 6 additions & 6 deletions vortex-array/src/array/primitive/stats.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::cmp::Ordering;
use std::collections::HashMap;
use std::mem::size_of;

Expand Down Expand Up @@ -189,26 +190,25 @@ impl<T: PStatsType> StatsAccumulator<T> {
self.nan_count += 1;
}

if self.prev == next {
if next.is_eq(self.prev) {
self.is_strict_sorted = false;
} else {
if next < self.prev {
if matches!(next.compare(self.prev), Ordering::Less) {
self.is_sorted = false;
}
self.run_count += 1;
}
if next < self.min {
if matches!(next.compare(self.min), Ordering::Less) {
self.min = next;
} else if next > self.max {
} else if matches!(next.compare(self.max), Ordering::Greater) {
self.max = next;
}
self.prev = next;
}

pub fn finish(self) -> StatsSet {
let is_constant = (self.min == self.max && self.null_count == 0 && self.nan_count == 0)
|| self.null_count == self.len
|| self.nan_count == self.len;
|| self.null_count == self.len;

StatsSet::from(HashMap::from([
(Stat::Min, self.min.into()),
Expand Down
Loading

0 comments on commit 02b752c

Please sign in to comment.