diff --git a/vortex-array/src/array/primitive/compute/search_sorted.rs b/vortex-array/src/array/primitive/compute/search_sorted.rs index 42697d872b..f82d9364c7 100644 --- a/vortex-array/src/array/primitive/compute/search_sorted.rs +++ b/vortex-array/src/array/primitive/compute/search_sorted.rs @@ -1,17 +1,13 @@ use vortex_error::VortexResult; use crate::array::primitive::compute::PrimitiveTrait; -use crate::compute::search_sorted::SearchSorted; +use crate::compute::search_sorted::{SearchResult, SearchSorted}; use crate::compute::search_sorted::{SearchSortedFn, SearchSortedSide}; use crate::ptype::NativePType; use crate::scalar::Scalar; impl SearchSortedFn for &dyn PrimitiveTrait { - fn search_sorted( - &self, - value: &Scalar, - side: SearchSortedSide, - ) -> VortexResult> { + fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult { let pvalue: T = value.try_into()?; Ok(self.typed_data().search_sorted(&pvalue, side)) } @@ -30,25 +26,25 @@ mod test { assert_eq!( search_sorted(&values, 0, SearchSortedSide::Left) .unwrap() - .unwrap_or_else(|o| o), + .to_index(), 0 ); assert_eq!( search_sorted(&values, 1, SearchSortedSide::Left) .unwrap() - .unwrap_or_else(|o| o), + .to_index(), 0 ); assert_eq!( search_sorted(&values, 1, SearchSortedSide::Right) .unwrap() - .unwrap_or_else(|o| o), - 1 + .to_index(), + 0 ); assert_eq!( search_sorted(&values, 4, SearchSortedSide::Left) .unwrap() - .unwrap_or_else(|o| o), + .to_index(), 3 ); } diff --git a/vortex-array/src/array/sparse/mod.rs b/vortex-array/src/array/sparse/mod.rs index 6e3cd58281..b54ad2261e 100644 --- a/vortex-array/src/array/sparse/mod.rs +++ b/vortex-array/src/array/sparse/mod.rs @@ -94,9 +94,9 @@ impl SparseArray { search_sorted( self.indices(), self.indices_offset + index, - SearchSortedSide::Exact, + SearchSortedSide::Left, ) - .map(|r| r.ok()) + .map(|r| r.to_option()) } /// Return indices as a vector of usize with the indices_offset applied. @@ -140,9 +140,9 @@ impl Array for SparseArray { // Find the index of the first patch index that is greater than or equal to the offset of this array let index_start_index = - search_sorted(self.indices(), start, SearchSortedSide::Left)?.unwrap_or_else(|o| o); + search_sorted(self.indices(), start, SearchSortedSide::Left)?.to_index(); let index_end_index = - search_sorted(self.indices(), stop, SearchSortedSide::Left)?.unwrap_or_else(|o| o); + search_sorted(self.indices(), stop, SearchSortedSide::Left)?.to_index(); Ok(SparseArray { indices_offset: self.indices_offset + start, diff --git a/vortex-array/src/compute/search_sorted.rs b/vortex-array/src/compute/search_sorted.rs index 5ecf3c73fd..f2b82b76f1 100644 --- a/vortex-array/src/compute/search_sorted.rs +++ b/vortex-array/src/compute/search_sorted.rs @@ -11,22 +11,39 @@ use crate::scalar::Scalar; pub enum SearchSortedSide { Left, Right, - Exact, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum SearchResult { + Found(usize), + NotFound(usize), +} + +impl SearchResult { + pub fn to_option(self) -> Option { + match self { + SearchResult::Found(i) => Some(i), + SearchResult::NotFound(_) => None, + } + } + + pub fn to_index(self) -> usize { + match self { + SearchResult::Found(i) => i, + SearchResult::NotFound(i) => i, + } + } } pub trait SearchSortedFn { - fn search_sorted( - &self, - value: &Scalar, - side: SearchSortedSide, - ) -> VortexResult>; + fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult; } pub fn search_sorted>( array: &dyn Array, target: T, side: SearchSortedSide, -) -> VortexResult> { +) -> VortexResult { let scalar = target.into().cast(array.dtype())?; array.with_compute(|c| { if let Some(search_sorted) = c.search_sorted() { @@ -70,59 +87,102 @@ pub trait Len { } pub trait SearchSorted { - fn search_sorted(&self, value: &T, side: SearchSortedSide) -> Result + fn search_sorted(&self, value: &T, side: SearchSortedSide) -> SearchResult where Self: IndexOrd, { match side { - SearchSortedSide::Left => self.search_sorted_by(|idx| { - if self.index_lt(idx, value) { - Less - } else { - Greater - } - }), - SearchSortedSide::Right => self.search_sorted_by(|idx| { - if self.index_le(idx, value) { - Less - } else { - Greater - } - }), - SearchSortedSide::Exact => { - self.search_sorted_by(|idx| self.index_cmp(idx, value).unwrap_or(Greater)) - } + SearchSortedSide::Left => self.search_sorted_by( + |idx| self.index_cmp(idx, value).unwrap_or(Less), + |idx| { + if self.index_lt(idx, value) { + Less + } else { + Greater + } + }, + side, + ), + SearchSortedSide::Right => self.search_sorted_by( + |idx| self.index_cmp(idx, value).unwrap_or(Less), + |idx| { + if self.index_le(idx, value) { + Less + } else { + Greater + } + }, + side, + ), } } - fn search_sorted_by Ordering>(&self, f: F) -> Result; + /// find function is used to find the element if it exists, if element exists side_find will be used to find desired index amongst equal values + fn search_sorted_by Ordering, N: FnMut(usize) -> Ordering>( + &self, + find: F, + side_find: N, + side: SearchSortedSide, + ) -> SearchResult; } impl + Len + ?Sized, T> SearchSorted for S { - // Code adapted from Rust standard library slice::binary_search_by - fn search_sorted_by Ordering>(&self, mut f: F) -> Result { - // INVARIANTS: - // - 0 <= left <= left + size = right <= self.len() - // - f returns Less for everything in self[..left] - // - f returns Greater for everything in self[right..] - let mut size = self.len(); - let mut left = 0; - let mut right = size; - while left < right { - let mid = left + size / 2; - let cmp = f(mid); - - left = if cmp == Less { mid + 1 } else { left }; - right = if cmp == Greater { mid } else { right }; - if cmp == Equal { - return Ok(mid); - } - - size = right - left; + fn search_sorted_by Ordering, N: FnMut(usize) -> Ordering>( + &self, + find: F, + side_find: N, + side: SearchSortedSide, + ) -> SearchResult { + match search_sorted_side_idx(find, 0, self.len()) { + SearchResult::Found(found) => match side { + SearchSortedSide::Left => match search_sorted_side_idx(side_find, 0, found) { + SearchResult::NotFound(i) => SearchResult::Found(i), + _ => unreachable!( + "searching amongst equal values should never return Found result" + ), + }, + // Right side search returns index one past the result we want, subtract here + SearchSortedSide::Right => { + match search_sorted_side_idx(side_find, found, self.len()) { + SearchResult::NotFound(i) => SearchResult::Found(i - 1), + _ => unreachable!( + "searching amongst equal values should never return Found result" + ), + } + } + }, + s => s, + } + } +} + +// Code adapted from Rust standard library slice::binary_search_by +fn search_sorted_side_idx Ordering>( + mut find: F, + from: usize, + to: usize, +) -> SearchResult { + // INVARIANTS: + // - from <= left <= left + size = right <= to + // - f returns Less for everything in self[..left] + // - f returns Greater for everything in self[right..] + let mut size = to - from; + let mut left = from; + let mut right = to; + while left < right { + let mid = left + size / 2; + let cmp = find(mid); + + left = if cmp == Less { mid + 1 } else { left }; + right = if cmp == Greater { mid } else { right }; + if cmp == Equal { + return SearchResult::Found(mid); } - Err(left) + size = right - left; } + + SearchResult::NotFound(left) } impl IndexOrd for &dyn Array { @@ -150,3 +210,56 @@ impl Len for [T] { self.len() } } + +#[cfg(test)] +mod test { + use crate::compute::search_sorted::{SearchResult, SearchSorted, SearchSortedSide}; + + #[test] + fn left_side_equal() { + let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9]; + let res = arr.search_sorted(&2, SearchSortedSide::Left); + assert_eq!(arr[res.to_index()], 2); + assert_eq!(res, SearchResult::Found(2)); + } + + #[test] + fn right_side_equal() { + let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9]; + let res = arr.search_sorted(&2, SearchSortedSide::Right); + assert_eq!(arr[res.to_index()], 2); + assert_eq!(res, SearchResult::Found(5)); + } + + #[test] + fn left_side_equal_beginning() { + let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + let res = arr.search_sorted(&0, SearchSortedSide::Left); + assert_eq!(arr[res.to_index()], 0); + assert_eq!(res, SearchResult::Found(0)); + } + + #[test] + fn right_side_equal_beginning() { + let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + let res = arr.search_sorted(&0, SearchSortedSide::Right); + assert_eq!(arr[res.to_index()], 0); + assert_eq!(res, SearchResult::Found(3)); + } + + #[test] + fn left_side_equal_end() { + let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9]; + let res = arr.search_sorted(&9, SearchSortedSide::Left); + assert_eq!(arr[res.to_index()], 9); + assert_eq!(res, SearchResult::Found(9)); + } + + #[test] + fn right_side_equal_end() { + let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9]; + let res = arr.search_sorted(&9, SearchSortedSide::Right); + // assert_eq!(arr[res.to_index()], 9); + assert_eq!(res, SearchResult::Found(12)); + } +} diff --git a/vortex-ree/src/ree.rs b/vortex-ree/src/ree.rs index 5902bc3399..a74fd39590 100644 --- a/vortex-ree/src/ree.rs +++ b/vortex-ree/src/ree.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, RwLock}; use vortex::array::{check_slice_bounds, Array, ArrayKind, ArrayRef}; use vortex::compress::EncodingCompression; use vortex::compute::scalar_at::scalar_at; -use vortex::compute::search_sorted::{search_sorted, SearchSortedSide}; +use vortex::compute::search_sorted::{search_sorted, SearchResult, SearchSortedSide}; use vortex::compute::ArrayCompute; use vortex::encoding::{Encoding, EncodingId, EncodingRef}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; @@ -58,8 +58,10 @@ impl REEArray { } pub fn find_physical_index(&self, index: usize) -> VortexResult { - search_sorted(self.ends(), index + self.offset, SearchSortedSide::Right) - .map(|r| r.unwrap_or_else(|o| o)) + search_sorted(self.ends(), index + self.offset, SearchSortedSide::Right).map(|r| match r { + SearchResult::Found(i) => i + 1, + SearchResult::NotFound(i) => i, + }) } pub fn encode(array: &dyn Array) -> VortexResult {