Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SearchSorted can return whether search resulted in exact match #226

Merged
merged 6 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 14 additions & 6 deletions vortex-array/src/array/primitive/compute/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
use vortex_error::VortexResult;

use crate::array::primitive::compute::PrimitiveTrait;
use crate::compute::search_sorted::SearchSorted;
use crate::compute::search_sorted::{SearchResult, SearchSorted};
use crate::compute::search_sorted::{SearchSortedFn, SearchSortedSide};
use crate::ptype::NativePType;
use crate::scalar::Scalar;

impl<T: NativePType> SearchSortedFn for &dyn PrimitiveTrait<T> {
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<usize> {
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<SearchResult> {
let pvalue: T = value.try_into()?;
Ok(self.typed_data().search_sorted(&pvalue, side))
}
Expand All @@ -24,19 +24,27 @@ mod test {
let values = vec![1u16, 2, 3].into_array();

assert_eq!(
search_sorted(&values, 0, SearchSortedSide::Left).unwrap(),
search_sorted(&values, 0, SearchSortedSide::Left)
.unwrap()
.to_index(),
0
);
assert_eq!(
search_sorted(&values, 1, SearchSortedSide::Left).unwrap(),
search_sorted(&values, 1, SearchSortedSide::Left)
.unwrap()
.to_index(),
0
);
assert_eq!(
search_sorted(&values, 1, SearchSortedSide::Right).unwrap(),
search_sorted(&values, 1, SearchSortedSide::Right)
.unwrap()
.to_index(),
1
);
assert_eq!(
search_sorted(&values, 4, SearchSortedSide::Left).unwrap(),
search_sorted(&values, 4, SearchSortedSide::Left)
.unwrap()
.to_index(),
3
);
}
Expand Down
55 changes: 12 additions & 43 deletions vortex-array/src/array/sparse/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ use crate::array::{Array, ArrayRef};
use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn};
use crate::compute::flatten::{flatten_primitive, FlattenFn, FlattenedArray};
use crate::compute::scalar_at::{scalar_at, ScalarAtFn};
use crate::compute::search_sorted::{search_sorted, SearchSortedSide};
use crate::compute::slice::SliceFn;
use crate::compute::take::{take, TakeFn};
use crate::compute::ArrayCompute;
Expand Down Expand Up @@ -182,51 +181,21 @@ fn take_search_sorted(
array: &SparseArray,
indices: &PrimitiveArray,
) -> VortexResult<(PrimitiveArray, PrimitiveArray)> {
// adjust the input indices (to take) by the internal index offset of the array
let adjusted_indices = match_each_integer_ptype!(indices.ptype(), |$P| {
indices.typed_data::<$P>()
.iter()
.map(|i| *i as usize + array.indices_offset())
.collect::<Vec<_>>()
});

// TODO(robert): Use binary search instead of search_sorted + take and index validation to avoid extra work
// search_sorted for the adjusted indices (need to validate that they are an exact match still)
let physical_indices = adjusted_indices
.iter()
.map(|i| search_sorted(array.indices(), *i, SearchSortedSide::Left).map(|s| s as u64))
.collect::<VortexResult<Vec<_>>>()?;

// filter out indices that are out of bounds, which will cause the take to fail
let (adjusted_indices, physical_indices): (Vec<usize>, Vec<u64>) = adjusted_indices
.iter()
.zip_eq(physical_indices)
.filter(|(_, phys_idx)| *phys_idx < array.indices().len() as u64)
.unzip();

let physical_indices = PrimitiveArray::from(physical_indices);
let taken_indices = flatten_primitive(&take(array.indices(), &physical_indices)?)?;
let exact_matches: Vec<bool> = match_each_integer_ptype!(taken_indices.ptype(), |$P| {
taken_indices
let resolved = match_each_integer_ptype!(indices.ptype(), |$P| {
indices
.typed_data::<$P>()
.iter()
.zip_eq(adjusted_indices)
.map(|(taken_idx, adj_idx)| *taken_idx as usize == adj_idx)
.collect()
.enumerate()
.map(|(pos, i)| {
array
.find_index(*i as usize)
.map(|r| r.map(|ii| (pos as u64, ii as u64)))
})
.filter_map_ok(|r| r)
.collect::<VortexResult<Vec<_>>>()?
});
let (positions, patch_indices): (Vec<u64>, Vec<u64>) = physical_indices
.typed_data::<u64>()
.iter()
.enumerate()
.filter_map(|(i, phy_idx)| {
// search_sorted != binary search, so we need to filter out indices that weren't found
if exact_matches[i] {
Some((i as u64, *phy_idx))
} else {
None
}
})
.unzip();

let (positions, patch_indices): (Vec<u64>, Vec<u64>) = resolved.into_iter().unzip();
Ok((
PrimitiveArray::from(positions),
PrimitiveArray::from(patch_indices),
Expand Down
6 changes: 4 additions & 2 deletions vortex-array/src/array/sparse/compute/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ use crate::compute::slice::{slice, SliceFn};
impl SliceFn for SparseArray {
fn slice(&self, start: usize, stop: usize) -> VortexResult<ArrayRef> {
// Find the index of the first patch index that is greater than or equal to the offset of this array
let index_start_index = search_sorted(self.indices(), start, SearchSortedSide::Left)?;
let index_end_index = search_sorted(self.indices(), stop, SearchSortedSide::Left)?;
let index_start_index =
search_sorted(self.indices(), start, SearchSortedSide::Left)?.to_index();
let index_end_index =
search_sorted(self.indices(), stop, SearchSortedSide::Left)?.to_index();

Ok(SparseArray::try_new_with_offset(
slice(self.indices(), index_start_index, index_end_index)?,
Expand Down
23 changes: 6 additions & 17 deletions vortex-array/src/array/sparse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use crate::array::constant::ConstantArray;
use crate::array::{Array, ArrayRef};
use crate::compress::EncodingCompression;
use crate::compute::flatten::flatten_primitive;
use crate::compute::scalar_at::scalar_at;
use crate::compute::search_sorted::{search_sorted, SearchSortedSide};
use crate::compute::ArrayCompute;
use crate::encoding::{Encoding, EncodingId, EncodingRef, ENCODINGS};
Expand Down Expand Up @@ -92,22 +91,12 @@ impl SparseArray {

/// Returns the position of a given index in the indices array if it exists.
pub fn find_index(&self, index: usize) -> VortexResult<Option<usize>> {
let true_index = self.indices_offset + index;

// TODO(ngates): replace this with a binary search that tells us if we get an exact match.
let idx = search_sorted(self.indices(), true_index, SearchSortedSide::Left)?;
if idx >= self.indices().len() {
return Ok(None);
}

// If the value at this index is equal to the true index, then it exists in the
// indices array.
let patch_index: usize = scalar_at(self.indices(), idx)?.try_into()?;
if true_index == patch_index {
Ok(Some(idx))
} else {
Ok(None)
}
search_sorted(
self.indices(),
self.indices_offset + index,
SearchSortedSide::Left,
)
.map(|r| r.to_found())
}

/// Return indices as a vector of usize with the indices_offset applied.
Expand Down
190 changes: 153 additions & 37 deletions vortex-array/src/compute/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,37 @@ pub enum SearchSortedSide {
Right,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum SearchResult {
Found(usize),
NotFound(usize),
}

impl SearchResult {
pub fn to_found(self) -> Option<usize> {
match self {
SearchResult::Found(i) => Some(i),
SearchResult::NotFound(_) => None,
}
}

pub fn to_index(self) -> usize {
match self {
SearchResult::Found(i) => i,
SearchResult::NotFound(i) => i,
}
}
}

pub trait SearchSortedFn {
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<usize>;
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<SearchResult>;
}

pub fn search_sorted<T: Into<Scalar>>(
array: &dyn Array,
target: T,
side: SearchSortedSide,
) -> VortexResult<usize> {
) -> VortexResult<SearchResult> {
let scalar = target.into().cast(array.dtype())?;
array.with_compute(|c| {
if let Some(search_sorted) = c.search_sorted() {
Expand Down Expand Up @@ -65,56 +87,97 @@ pub trait Len {
}

pub trait SearchSorted<T> {
fn search_sorted(&self, value: &T, side: SearchSortedSide) -> usize
fn search_sorted(&self, value: &T, side: SearchSortedSide) -> SearchResult
where
Self: IndexOrd<T>,
{
match side {
SearchSortedSide::Left => self.search_sorted_by(|idx| {
if self.index_lt(idx, value) {
Less
} else {
Greater
}
}),
SearchSortedSide::Right => self.search_sorted_by(|idx| {
if self.index_le(idx, value) {
Less
} else {
Greater
}
}),
SearchSortedSide::Left => self.search_sorted_by(
|idx| self.index_cmp(idx, value).unwrap_or(Less),
|idx| {
if self.index_lt(idx, value) {
Less
} else {
Greater
}
},
side,
),
SearchSortedSide::Right => self.search_sorted_by(
|idx| self.index_cmp(idx, value).unwrap_or(Less),
|idx| {
if self.index_le(idx, value) {
Less
} else {
Greater
}
},
side,
),
}
}

fn search_sorted_by<F: FnMut(usize) -> Ordering>(&self, f: F) -> usize;
/// find function is used to find the element if it exists, if element exists side_find will be used to find desired index amongst equal values
fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
&self,
find: F,
side_find: N,
side: SearchSortedSide,
) -> SearchResult;
}

impl<S: IndexOrd<T> + Len + ?Sized, T> SearchSorted<T> for S {
// Code adapted from Rust standard library slice::binary_search_by
fn search_sorted_by<F: FnMut(usize) -> Ordering>(&self, mut f: F) -> usize {
// INVARIANTS:
// - 0 <= left <= left + size = right <= self.len()
// - f returns Less for everything in self[..left]
// - f returns Greater for everything in self[right..]
let mut size = self.len();
let mut left = 0;
let mut right = size;
while left < right {
let mid = left + size / 2;
let cmp = f(mid);

left = if cmp == Less { mid + 1 } else { left };
right = if cmp == Greater { mid } else { right };
if cmp == Equal {
return mid;
fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
&self,
find: F,
side_find: N,
side: SearchSortedSide,
) -> SearchResult {
match search_sorted_side_idx(find, 0, self.len()) {
SearchResult::Found(found) => {
let idx_search = match side {
SearchSortedSide::Left => search_sorted_side_idx(side_find, 0, found),
SearchSortedSide::Right => search_sorted_side_idx(side_find, found, self.len()),
};
match idx_search {
SearchResult::NotFound(i) => SearchResult::Found(i),
_ => unreachable!(
"searching amongst equal values should never return Found result"
),
}
}
s => s,
}
}
}

// Code adapted from Rust standard library slice::binary_search_by
fn search_sorted_side_idx<F: FnMut(usize) -> Ordering>(
mut find: F,
from: usize,
to: usize,
) -> SearchResult {
// INVARIANTS:
// - from <= left <= left + size = right <= to
// - f returns Less for everything in self[..left]
// - f returns Greater for everything in self[right..]
let mut size = to - from;
let mut left = from;
let mut right = to;
while left < right {
let mid = left + size / 2;
let cmp = find(mid);

size = right - left;
left = if cmp == Less { mid + 1 } else { left };
right = if cmp == Greater { mid } else { right };
if cmp == Equal {
return SearchResult::Found(mid);
}

left
size = right - left;
}

SearchResult::NotFound(left)
}

impl IndexOrd<Scalar> for &dyn Array {
Expand Down Expand Up @@ -142,3 +205,56 @@ impl<T> Len for [T] {
self.len()
}
}

#[cfg(test)]
mod test {
use crate::compute::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};

#[test]
fn left_side_equal() {
let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
let res = arr.search_sorted(&2, SearchSortedSide::Left);
assert_eq!(arr[res.to_index()], 2);
assert_eq!(res, SearchResult::Found(2));
}

#[test]
fn right_side_equal() {
let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
let res = arr.search_sorted(&2, SearchSortedSide::Right);
assert_eq!(arr[res.to_index() - 1], 2);
assert_eq!(res, SearchResult::Found(6));
}

#[test]
fn left_side_equal_beginning() {
let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let res = arr.search_sorted(&0, SearchSortedSide::Left);
assert_eq!(arr[res.to_index()], 0);
assert_eq!(res, SearchResult::Found(0));
}

#[test]
fn right_side_equal_beginning() {
let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let res = arr.search_sorted(&0, SearchSortedSide::Right);
assert_eq!(arr[res.to_index() - 1], 0);
assert_eq!(res, SearchResult::Found(4));
}

#[test]
fn left_side_equal_end() {
let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
let res = arr.search_sorted(&9, SearchSortedSide::Left);
assert_eq!(arr[res.to_index()], 9);
assert_eq!(res, SearchResult::Found(9));
}

#[test]
fn right_side_equal_end() {
let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
let res = arr.search_sorted(&9, SearchSortedSide::Right);
assert_eq!(arr[res.to_index() - 1], 9);
assert_eq!(res, SearchResult::Found(13));
}
}
Loading