Skip to content

Commit

Permalink
better
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 committed Apr 12, 2024
1 parent cfcb608 commit 34ec013
Show file tree
Hide file tree
Showing 4 changed files with 176 additions and 65 deletions.
18 changes: 7 additions & 11 deletions vortex-array/src/array/primitive/compute/search_sorted.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
use vortex_error::VortexResult;

use crate::array::primitive::compute::PrimitiveTrait;
use crate::compute::search_sorted::SearchSorted;
use crate::compute::search_sorted::{SearchResult, SearchSorted};
use crate::compute::search_sorted::{SearchSortedFn, SearchSortedSide};
use crate::ptype::NativePType;
use crate::scalar::Scalar;

impl<T: NativePType> SearchSortedFn for &dyn PrimitiveTrait<T> {
fn search_sorted(
&self,
value: &Scalar,
side: SearchSortedSide,
) -> VortexResult<Result<usize, usize>> {
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<SearchResult> {
let pvalue: T = value.try_into()?;
Ok(self.typed_data().search_sorted(&pvalue, side))
}
Expand All @@ -30,25 +26,25 @@ mod test {
assert_eq!(
search_sorted(&values, 0, SearchSortedSide::Left)
.unwrap()
.unwrap_or_else(|o| o),
.to_index(),
0
);
assert_eq!(
search_sorted(&values, 1, SearchSortedSide::Left)
.unwrap()
.unwrap_or_else(|o| o),
.to_index(),
0
);
assert_eq!(
search_sorted(&values, 1, SearchSortedSide::Right)
.unwrap()
.unwrap_or_else(|o| o),
1
.to_index(),
0
);
assert_eq!(
search_sorted(&values, 4, SearchSortedSide::Left)
.unwrap()
.unwrap_or_else(|o| o),
.to_index(),
3
);
}
Expand Down
8 changes: 4 additions & 4 deletions vortex-array/src/array/sparse/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ impl SparseArray {
search_sorted(
self.indices(),
self.indices_offset + index,
SearchSortedSide::Exact,
SearchSortedSide::Left,
)
.map(|r| r.ok())
.map(|r| r.to_option())
}

/// Return indices as a vector of usize with the indices_offset applied.
Expand Down Expand Up @@ -140,9 +140,9 @@ impl Array for SparseArray {

// Find the index of the first patch index that is greater than or equal to the offset of this array
let index_start_index =
search_sorted(self.indices(), start, SearchSortedSide::Left)?.unwrap_or_else(|o| o);
search_sorted(self.indices(), start, SearchSortedSide::Left)?.to_index();
let index_end_index =
search_sorted(self.indices(), stop, SearchSortedSide::Left)?.unwrap_or_else(|o| o);
search_sorted(self.indices(), stop, SearchSortedSide::Left)?.to_index();

Ok(SparseArray {
indices_offset: self.indices_offset + start,
Expand Down
207 changes: 160 additions & 47 deletions vortex-array/src/compute/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,39 @@ use crate::scalar::Scalar;
pub enum SearchSortedSide {
Left,
Right,
Exact,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum SearchResult {
Found(usize),
NotFound(usize),
}

impl SearchResult {
pub fn to_option(self) -> Option<usize> {
match self {
SearchResult::Found(i) => Some(i),
SearchResult::NotFound(_) => None,
}
}

pub fn to_index(self) -> usize {
match self {
SearchResult::Found(i) => i,
SearchResult::NotFound(i) => i,
}
}
}

pub trait SearchSortedFn {
fn search_sorted(
&self,
value: &Scalar,
side: SearchSortedSide,
) -> VortexResult<Result<usize, usize>>;
fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult<SearchResult>;
}

pub fn search_sorted<T: Into<Scalar>>(
array: &dyn Array,
target: T,
side: SearchSortedSide,
) -> VortexResult<Result<usize, usize>> {
) -> VortexResult<SearchResult> {
let scalar = target.into().cast(array.dtype())?;
array.with_compute(|c| {
if let Some(search_sorted) = c.search_sorted() {
Expand Down Expand Up @@ -70,59 +87,102 @@ pub trait Len {
}

pub trait SearchSorted<T> {
fn search_sorted(&self, value: &T, side: SearchSortedSide) -> Result<usize, usize>
fn search_sorted(&self, value: &T, side: SearchSortedSide) -> SearchResult
where
Self: IndexOrd<T>,
{
match side {
SearchSortedSide::Left => self.search_sorted_by(|idx| {
if self.index_lt(idx, value) {
Less
} else {
Greater
}
}),
SearchSortedSide::Right => self.search_sorted_by(|idx| {
if self.index_le(idx, value) {
Less
} else {
Greater
}
}),
SearchSortedSide::Exact => {
self.search_sorted_by(|idx| self.index_cmp(idx, value).unwrap_or(Greater))
}
SearchSortedSide::Left => self.search_sorted_by(
|idx| self.index_cmp(idx, value).unwrap_or(Less),
|idx| {
if self.index_lt(idx, value) {
Less
} else {
Greater
}
},
side,
),
SearchSortedSide::Right => self.search_sorted_by(
|idx| self.index_cmp(idx, value).unwrap_or(Less),
|idx| {
if self.index_le(idx, value) {
Less
} else {
Greater
}
},
side,
),
}
}

fn search_sorted_by<F: FnMut(usize) -> Ordering>(&self, f: F) -> Result<usize, usize>;
/// find function is used to find the element if it exists, if element exists side_find will be used to find desired index amongst equal values
fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
&self,
find: F,
side_find: N,
side: SearchSortedSide,
) -> SearchResult;
}

impl<S: IndexOrd<T> + Len + ?Sized, T> SearchSorted<T> for S {
// Code adapted from Rust standard library slice::binary_search_by
fn search_sorted_by<F: FnMut(usize) -> Ordering>(&self, mut f: F) -> Result<usize, usize> {
// INVARIANTS:
// - 0 <= left <= left + size = right <= self.len()
// - f returns Less for everything in self[..left]
// - f returns Greater for everything in self[right..]
let mut size = self.len();
let mut left = 0;
let mut right = size;
while left < right {
let mid = left + size / 2;
let cmp = f(mid);

left = if cmp == Less { mid + 1 } else { left };
right = if cmp == Greater { mid } else { right };
if cmp == Equal {
return Ok(mid);
}

size = right - left;
fn search_sorted_by<F: FnMut(usize) -> Ordering, N: FnMut(usize) -> Ordering>(
&self,
find: F,
side_find: N,
side: SearchSortedSide,
) -> SearchResult {
match search_sorted_side_idx(find, 0, self.len()) {
SearchResult::Found(found) => match side {
SearchSortedSide::Left => match search_sorted_side_idx(side_find, 0, found) {
SearchResult::NotFound(i) => SearchResult::Found(i),
_ => unreachable!(
"searching amongst equal values should never return Found result"
),
},
// Right side search returns index one past the result we want, subtract here
SearchSortedSide::Right => {
match search_sorted_side_idx(side_find, found, self.len()) {
SearchResult::NotFound(i) => SearchResult::Found(i - 1),
_ => unreachable!(
"searching amongst equal values should never return Found result"
),
}
}
},
s => s,
}
}
}

// Code adapted from Rust standard library slice::binary_search_by
fn search_sorted_side_idx<F: FnMut(usize) -> Ordering>(
mut find: F,
from: usize,
to: usize,
) -> SearchResult {
// INVARIANTS:
// - from <= left <= left + size = right <= to
// - f returns Less for everything in self[..left]
// - f returns Greater for everything in self[right..]
let mut size = to - from;
let mut left = from;
let mut right = to;
while left < right {
let mid = left + size / 2;
let cmp = find(mid);

left = if cmp == Less { mid + 1 } else { left };
right = if cmp == Greater { mid } else { right };
if cmp == Equal {
return SearchResult::Found(mid);
}

Err(left)
size = right - left;
}

SearchResult::NotFound(left)
}

impl IndexOrd<Scalar> for &dyn Array {
Expand Down Expand Up @@ -150,3 +210,56 @@ impl<T> Len for [T] {
self.len()
}
}

#[cfg(test)]
mod test {
use crate::compute::search_sorted::{SearchResult, SearchSorted, SearchSortedSide};

#[test]
fn left_side_equal() {
let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
let res = arr.search_sorted(&2, SearchSortedSide::Left);
assert_eq!(arr[res.to_index()], 2);
assert_eq!(res, SearchResult::Found(2));
}

#[test]
fn right_side_equal() {
let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9];
let res = arr.search_sorted(&2, SearchSortedSide::Right);
assert_eq!(arr[res.to_index()], 2);
assert_eq!(res, SearchResult::Found(5));
}

#[test]
fn left_side_equal_beginning() {
let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let res = arr.search_sorted(&0, SearchSortedSide::Left);
assert_eq!(arr[res.to_index()], 0);
assert_eq!(res, SearchResult::Found(0));
}

#[test]
fn right_side_equal_beginning() {
let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9];
let res = arr.search_sorted(&0, SearchSortedSide::Right);
assert_eq!(arr[res.to_index()], 0);
assert_eq!(res, SearchResult::Found(3));
}

#[test]
fn left_side_equal_end() {
let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
let res = arr.search_sorted(&9, SearchSortedSide::Left);
assert_eq!(arr[res.to_index()], 9);
assert_eq!(res, SearchResult::Found(9));
}

#[test]
fn right_side_equal_end() {
let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9];
let res = arr.search_sorted(&9, SearchSortedSide::Right);
// assert_eq!(arr[res.to_index()], 9);
assert_eq!(res, SearchResult::Found(12));
}
}
8 changes: 5 additions & 3 deletions vortex-ree/src/ree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::sync::{Arc, RwLock};
use vortex::array::{check_slice_bounds, Array, ArrayKind, ArrayRef};
use vortex::compress::EncodingCompression;
use vortex::compute::scalar_at::scalar_at;
use vortex::compute::search_sorted::{search_sorted, SearchSortedSide};
use vortex::compute::search_sorted::{search_sorted, SearchResult, SearchSortedSide};
use vortex::compute::ArrayCompute;
use vortex::encoding::{Encoding, EncodingId, EncodingRef};
use vortex::formatter::{ArrayDisplay, ArrayFormatter};
Expand Down Expand Up @@ -58,8 +58,10 @@ impl REEArray {
}

pub fn find_physical_index(&self, index: usize) -> VortexResult<usize> {
search_sorted(self.ends(), index + self.offset, SearchSortedSide::Right)
.map(|r| r.unwrap_or_else(|o| o))
search_sorted(self.ends(), index + self.offset, SearchSortedSide::Right).map(|r| match r {
SearchResult::Found(i) => i + 1,
SearchResult::NotFound(i) => i,
})
}

pub fn encode(array: &dyn Array) -> VortexResult<ArrayRef> {
Expand Down

0 comments on commit 34ec013

Please sign in to comment.