From 861eaa31369cd0c24b3cfc9bcc7f057adcf7b17c Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Thu, 11 Apr 2024 12:03:08 +0100 Subject: [PATCH 1/6] Add binary search function for faster sparse arrays searches --- .../array/primitive/compute/search_sorted.rs | 5 +- vortex-array/src/array/sparse/compute/mod.rs | 55 ++++--------------- vortex-array/src/array/sparse/mod.rs | 20 +------ vortex-array/src/compute/binary_search.rs | 34 ++++++++++++ vortex-array/src/compute/mod.rs | 7 +++ vortex-array/src/compute/search_sorted.rs | 16 ++++-- 6 files changed, 69 insertions(+), 68 deletions(-) create mode 100644 vortex-array/src/compute/binary_search.rs diff --git a/vortex-array/src/array/primitive/compute/search_sorted.rs b/vortex-array/src/array/primitive/compute/search_sorted.rs index 0d6ef271a5..7ae0ee6e44 100644 --- a/vortex-array/src/array/primitive/compute/search_sorted.rs +++ b/vortex-array/src/array/primitive/compute/search_sorted.rs @@ -9,7 +9,10 @@ use crate::scalar::Scalar; impl SearchSortedFn for &dyn PrimitiveTrait { fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult { let pvalue: T = value.try_into()?; - Ok(self.typed_data().search_sorted(&pvalue, side)) + Ok(self + .typed_data() + .search_sorted(&pvalue, side) + .unwrap_or_else(|o| o)) } } diff --git a/vortex-array/src/array/sparse/compute/mod.rs b/vortex-array/src/array/sparse/compute/mod.rs index aff8f77150..62f84dc2b9 100644 --- a/vortex-array/src/array/sparse/compute/mod.rs +++ b/vortex-array/src/array/sparse/compute/mod.rs @@ -11,7 +11,6 @@ use crate::array::{Array, ArrayRef}; use crate::compute::as_contiguous::{as_contiguous, AsContiguousFn}; use crate::compute::flatten::{flatten_primitive, FlattenFn, FlattenedArray}; use crate::compute::scalar_at::{scalar_at, ScalarAtFn}; -use crate::compute::search_sorted::{search_sorted, SearchSortedSide}; use crate::compute::slice::SliceFn; use crate::compute::take::{take, TakeFn}; use crate::compute::ArrayCompute; @@ -182,51 +181,21 @@ fn take_search_sorted( array: &SparseArray, indices: &PrimitiveArray, ) -> VortexResult<(PrimitiveArray, PrimitiveArray)> { - // adjust the input indices (to take) by the internal index offset of the array - let adjusted_indices = match_each_integer_ptype!(indices.ptype(), |$P| { - indices.typed_data::<$P>() - .iter() - .map(|i| *i as usize + array.indices_offset()) - .collect::>() - }); - - // TODO(robert): Use binary search instead of search_sorted + take and index validation to avoid extra work - // search_sorted for the adjusted indices (need to validate that they are an exact match still) - let physical_indices = adjusted_indices - .iter() - .map(|i| search_sorted(array.indices(), *i, SearchSortedSide::Left).map(|s| s as u64)) - .collect::>>()?; - - // filter out indices that are out of bounds, which will cause the take to fail - let (adjusted_indices, physical_indices): (Vec, Vec) = adjusted_indices - .iter() - .zip_eq(physical_indices) - .filter(|(_, phys_idx)| *phys_idx < array.indices().len() as u64) - .unzip(); - - let physical_indices = PrimitiveArray::from(physical_indices); - let taken_indices = flatten_primitive(&take(array.indices(), &physical_indices)?)?; - let exact_matches: Vec = match_each_integer_ptype!(taken_indices.ptype(), |$P| { - taken_indices + let resolved = match_each_integer_ptype!(indices.ptype(), |$P| { + indices .typed_data::<$P>() .iter() - .zip_eq(adjusted_indices) - .map(|(taken_idx, adj_idx)| *taken_idx as usize == adj_idx) - .collect() + .enumerate() + .map(|(pos, i)| { + array + .find_index(*i as usize) + .map(|r| r.map(|ii| (pos as u64, ii as u64))) + }) + .filter_map_ok(|r| r) + .collect::>>()? }); - let (positions, patch_indices): (Vec, Vec) = physical_indices - .typed_data::() - .iter() - .enumerate() - .filter_map(|(i, phy_idx)| { - // search_sorted != binary search, so we need to filter out indices that weren't found - if exact_matches[i] { - Some((i as u64, *phy_idx)) - } else { - None - } - }) - .unzip(); + + let (positions, patch_indices): (Vec, Vec) = resolved.into_iter().unzip(); Ok(( PrimitiveArray::from(positions), PrimitiveArray::from(patch_indices), diff --git a/vortex-array/src/array/sparse/mod.rs b/vortex-array/src/array/sparse/mod.rs index f1d5622160..cec414b84c 100644 --- a/vortex-array/src/array/sparse/mod.rs +++ b/vortex-array/src/array/sparse/mod.rs @@ -7,9 +7,8 @@ use vortex_schema::DType; use crate::array::constant::ConstantArray; use crate::array::{Array, ArrayRef}; use crate::compress::EncodingCompression; +use crate::compute::binary_search::binary_search; use crate::compute::flatten::flatten_primitive; -use crate::compute::scalar_at::scalar_at; -use crate::compute::search_sorted::{search_sorted, SearchSortedSide}; use crate::compute::ArrayCompute; use crate::encoding::{Encoding, EncodingId, EncodingRef, ENCODINGS}; use crate::formatter::{ArrayDisplay, ArrayFormatter}; @@ -92,22 +91,7 @@ impl SparseArray { /// Returns the position of a given index in the indices array if it exists. pub fn find_index(&self, index: usize) -> VortexResult> { - let true_index = self.indices_offset + index; - - // TODO(ngates): replace this with a binary search that tells us if we get an exact match. - let idx = search_sorted(self.indices(), true_index, SearchSortedSide::Left)?; - if idx >= self.indices().len() { - return Ok(None); - } - - // If the value at this index is equal to the true index, then it exists in the - // indices array. - let patch_index: usize = scalar_at(self.indices(), idx)?.try_into()?; - if true_index == patch_index { - Ok(Some(idx)) - } else { - Ok(None) - } + binary_search(self.indices(), self.indices_offset + index).map(|r| r.ok()) } /// Return indices as a vector of usize with the indices_offset applied. diff --git a/vortex-array/src/compute/binary_search.rs b/vortex-array/src/compute/binary_search.rs new file mode 100644 index 0000000000..d42ccf5b94 --- /dev/null +++ b/vortex-array/src/compute/binary_search.rs @@ -0,0 +1,34 @@ +use vortex_error::{vortex_err, VortexResult}; + +use crate::array::{Array, WithArrayCompute}; +use crate::compute::search_sorted::{SearchSorted, SearchSortedSide}; +use crate::scalar::Scalar; + +pub trait BinarySearchFn { + fn binary_search(&self, value: &Scalar) -> VortexResult>; +} + +pub fn binary_search>( + array: &dyn Array, + target: T, +) -> VortexResult> { + let scalar = target.into().cast(array.dtype())?; + array.with_compute(|c| { + if let Some(binary_search) = c.binary_search() { + return binary_search.binary_search(&scalar); + } + + if c.scalar_at().is_some() { + return Ok(SearchSorted::search_sorted( + &array, + &scalar, + SearchSortedSide::Exact, + )); + } + + Err(vortex_err!( + NotImplemented: "binary_search", + array.encoding().id().name() + )) + }) +} diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 7b88bc239c..31ad6cfd7b 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -9,9 +9,12 @@ use search_sorted::SearchSortedFn; use slice::SliceFn; use take::TakeFn; +use crate::compute::binary_search::BinarySearchFn; + pub mod add; pub mod as_arrow; pub mod as_contiguous; +pub mod binary_search; pub mod cast; pub mod fill; pub mod flatten; @@ -31,6 +34,10 @@ pub trait ArrayCompute { None } + fn binary_search(&self) -> Option<&dyn BinarySearchFn> { + None + } + fn cast(&self) -> Option<&dyn CastFn> { None } diff --git a/vortex-array/src/compute/search_sorted.rs b/vortex-array/src/compute/search_sorted.rs index 90f21b8023..42a8e9f889 100644 --- a/vortex-array/src/compute/search_sorted.rs +++ b/vortex-array/src/compute/search_sorted.rs @@ -11,6 +11,7 @@ use crate::scalar::Scalar; pub enum SearchSortedSide { Left, Right, + Exact, } pub trait SearchSortedFn { @@ -29,7 +30,7 @@ pub fn search_sorted>( } if c.scalar_at().is_some() { - return Ok(SearchSorted::search_sorted(&array, &scalar, side)); + return Ok(SearchSorted::search_sorted(&array, &scalar, side).unwrap_or_else(|o| o)); } Err(vortex_err!( @@ -65,7 +66,7 @@ pub trait Len { } pub trait SearchSorted { - fn search_sorted(&self, value: &T, side: SearchSortedSide) -> usize + fn search_sorted(&self, value: &T, side: SearchSortedSide) -> Result where Self: IndexOrd, { @@ -84,15 +85,18 @@ pub trait SearchSorted { Greater } }), + SearchSortedSide::Exact => { + self.search_sorted_by(|idx| self.index_cmp(idx, value).unwrap_or(Greater)) + } } } - fn search_sorted_by Ordering>(&self, f: F) -> usize; + fn search_sorted_by Ordering>(&self, f: F) -> Result; } impl + Len + ?Sized, T> SearchSorted for S { // Code adapted from Rust standard library slice::binary_search_by - fn search_sorted_by Ordering>(&self, mut f: F) -> usize { + fn search_sorted_by Ordering>(&self, mut f: F) -> Result { // INVARIANTS: // - 0 <= left <= left + size = right <= self.len() // - f returns Less for everything in self[..left] @@ -107,13 +111,13 @@ impl + Len + ?Sized, T> SearchSorted for S { left = if cmp == Less { mid + 1 } else { left }; right = if cmp == Greater { mid } else { right }; if cmp == Equal { - return mid; + return Ok(mid); } size = right - left; } - left + Err(left) } } From ad1bddbff5dd6889bccea467bfe74d52f30a17ce Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Thu, 11 Apr 2024 14:41:03 +0100 Subject: [PATCH 2/6] less --- .../array/primitive/compute/search_sorted.rs | 27 ++++++++++----- vortex-array/src/array/sparse/mod.rs | 9 +++-- vortex-array/src/compute/binary_search.rs | 34 ------------------- vortex-array/src/compute/mod.rs | 7 ---- vortex-array/src/compute/search_sorted.rs | 10 ++++-- vortex-ree/src/ree.rs | 11 +++--- 6 files changed, 36 insertions(+), 62 deletions(-) delete mode 100644 vortex-array/src/compute/binary_search.rs diff --git a/vortex-array/src/array/primitive/compute/search_sorted.rs b/vortex-array/src/array/primitive/compute/search_sorted.rs index 7ae0ee6e44..42697d872b 100644 --- a/vortex-array/src/array/primitive/compute/search_sorted.rs +++ b/vortex-array/src/array/primitive/compute/search_sorted.rs @@ -7,12 +7,13 @@ use crate::ptype::NativePType; use crate::scalar::Scalar; impl SearchSortedFn for &dyn PrimitiveTrait { - fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult { + fn search_sorted( + &self, + value: &Scalar, + side: SearchSortedSide, + ) -> VortexResult> { let pvalue: T = value.try_into()?; - Ok(self - .typed_data() - .search_sorted(&pvalue, side) - .unwrap_or_else(|o| o)) + Ok(self.typed_data().search_sorted(&pvalue, side)) } } @@ -27,19 +28,27 @@ mod test { let values = vec![1u16, 2, 3].into_array(); assert_eq!( - search_sorted(&values, 0, SearchSortedSide::Left).unwrap(), + search_sorted(&values, 0, SearchSortedSide::Left) + .unwrap() + .unwrap_or_else(|o| o), 0 ); assert_eq!( - search_sorted(&values, 1, SearchSortedSide::Left).unwrap(), + search_sorted(&values, 1, SearchSortedSide::Left) + .unwrap() + .unwrap_or_else(|o| o), 0 ); assert_eq!( - search_sorted(&values, 1, SearchSortedSide::Right).unwrap(), + search_sorted(&values, 1, SearchSortedSide::Right) + .unwrap() + .unwrap_or_else(|o| o), 1 ); assert_eq!( - search_sorted(&values, 4, SearchSortedSide::Left).unwrap(), + search_sorted(&values, 4, SearchSortedSide::Left) + .unwrap() + .unwrap_or_else(|o| o), 3 ); } diff --git a/vortex-array/src/array/sparse/mod.rs b/vortex-array/src/array/sparse/mod.rs index cec414b84c..372ee98439 100644 --- a/vortex-array/src/array/sparse/mod.rs +++ b/vortex-array/src/array/sparse/mod.rs @@ -7,8 +7,8 @@ use vortex_schema::DType; use crate::array::constant::ConstantArray; use crate::array::{Array, ArrayRef}; use crate::compress::EncodingCompression; -use crate::compute::binary_search::binary_search; use crate::compute::flatten::flatten_primitive; +use crate::compute::search_sorted::{search_sorted, SearchSortedSide}; use crate::compute::ArrayCompute; use crate::encoding::{Encoding, EncodingId, EncodingRef, ENCODINGS}; use crate::formatter::{ArrayDisplay, ArrayFormatter}; @@ -91,7 +91,12 @@ impl SparseArray { /// Returns the position of a given index in the indices array if it exists. pub fn find_index(&self, index: usize) -> VortexResult> { - binary_search(self.indices(), self.indices_offset + index).map(|r| r.ok()) + search_sorted( + self.indices(), + self.indices_offset + index, + SearchSortedSide::Exact, + ) + .map(|r| r.ok()) } /// Return indices as a vector of usize with the indices_offset applied. diff --git a/vortex-array/src/compute/binary_search.rs b/vortex-array/src/compute/binary_search.rs deleted file mode 100644 index d42ccf5b94..0000000000 --- a/vortex-array/src/compute/binary_search.rs +++ /dev/null @@ -1,34 +0,0 @@ -use vortex_error::{vortex_err, VortexResult}; - -use crate::array::{Array, WithArrayCompute}; -use crate::compute::search_sorted::{SearchSorted, SearchSortedSide}; -use crate::scalar::Scalar; - -pub trait BinarySearchFn { - fn binary_search(&self, value: &Scalar) -> VortexResult>; -} - -pub fn binary_search>( - array: &dyn Array, - target: T, -) -> VortexResult> { - let scalar = target.into().cast(array.dtype())?; - array.with_compute(|c| { - if let Some(binary_search) = c.binary_search() { - return binary_search.binary_search(&scalar); - } - - if c.scalar_at().is_some() { - return Ok(SearchSorted::search_sorted( - &array, - &scalar, - SearchSortedSide::Exact, - )); - } - - Err(vortex_err!( - NotImplemented: "binary_search", - array.encoding().id().name() - )) - }) -} diff --git a/vortex-array/src/compute/mod.rs b/vortex-array/src/compute/mod.rs index 31ad6cfd7b..7b88bc239c 100644 --- a/vortex-array/src/compute/mod.rs +++ b/vortex-array/src/compute/mod.rs @@ -9,12 +9,9 @@ use search_sorted::SearchSortedFn; use slice::SliceFn; use take::TakeFn; -use crate::compute::binary_search::BinarySearchFn; - pub mod add; pub mod as_arrow; pub mod as_contiguous; -pub mod binary_search; pub mod cast; pub mod fill; pub mod flatten; @@ -34,10 +31,6 @@ pub trait ArrayCompute { None } - fn binary_search(&self) -> Option<&dyn BinarySearchFn> { - None - } - fn cast(&self) -> Option<&dyn CastFn> { None } diff --git a/vortex-array/src/compute/search_sorted.rs b/vortex-array/src/compute/search_sorted.rs index 42a8e9f889..5ecf3c73fd 100644 --- a/vortex-array/src/compute/search_sorted.rs +++ b/vortex-array/src/compute/search_sorted.rs @@ -15,14 +15,18 @@ pub enum SearchSortedSide { } pub trait SearchSortedFn { - fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult; + fn search_sorted( + &self, + value: &Scalar, + side: SearchSortedSide, + ) -> VortexResult>; } pub fn search_sorted>( array: &dyn Array, target: T, side: SearchSortedSide, -) -> VortexResult { +) -> VortexResult> { let scalar = target.into().cast(array.dtype())?; array.with_compute(|c| { if let Some(search_sorted) = c.search_sorted() { @@ -30,7 +34,7 @@ pub fn search_sorted>( } if c.scalar_at().is_some() { - return Ok(SearchSorted::search_sorted(&array, &scalar, side).unwrap_or_else(|o| o)); + return Ok(SearchSorted::search_sorted(&array, &scalar, side)); } Err(vortex_err!( diff --git a/vortex-ree/src/ree.rs b/vortex-ree/src/ree.rs index 3ea323bda8..7ec8ca19bd 100644 --- a/vortex-ree/src/ree.rs +++ b/vortex-ree/src/ree.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, RwLock}; use vortex::array::{Array, ArrayKind, ArrayRef}; use vortex::compress::EncodingCompression; use vortex::compute::scalar_at::scalar_at; -use vortex::compute::search_sorted::SearchSortedSide; +use vortex::compute::search_sorted::{search_sorted, SearchSortedSide}; use vortex::compute::ArrayCompute; use vortex::encoding::{Encoding, EncodingId, EncodingRef}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; @@ -12,7 +12,7 @@ use vortex::stats::{Stat, Stats, StatsCompute, StatsSet}; use vortex::validity::Validity; use vortex::validity::{OwnedValidity, ValidityView}; use vortex::view::{AsView, ToOwnedView}; -use vortex::{compute, impl_array, ArrayWalker}; +use vortex::{impl_array, ArrayWalker}; use vortex_error::{vortex_bail, vortex_err, VortexResult}; use vortex_schema::DType; @@ -68,11 +68,8 @@ impl REEArray { } pub fn find_physical_index(&self, index: usize) -> VortexResult { - compute::search_sorted::search_sorted( - self.ends(), - index + self.offset, - SearchSortedSide::Right, - ) + search_sorted(self.ends(), index + self.offset, SearchSortedSide::Right) + .map(|r| r.unwrap_or_else(|o| o)) } pub fn encode(array: &dyn Array) -> VortexResult { From 4a9c3f7b102d7b0d4debd986d900c14605e07530 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Fri, 12 Apr 2024 17:59:12 +0100 Subject: [PATCH 3/6] better --- .../array/primitive/compute/search_sorted.rs | 18 +- vortex-array/src/array/sparse/mod.rs | 4 +- vortex-array/src/compute/search_sorted.rs | 207 ++++++++++++++---- vortex-ree/src/ree.rs | 8 +- 4 files changed, 174 insertions(+), 63 deletions(-) diff --git a/vortex-array/src/array/primitive/compute/search_sorted.rs b/vortex-array/src/array/primitive/compute/search_sorted.rs index 42697d872b..f82d9364c7 100644 --- a/vortex-array/src/array/primitive/compute/search_sorted.rs +++ b/vortex-array/src/array/primitive/compute/search_sorted.rs @@ -1,17 +1,13 @@ use vortex_error::VortexResult; use crate::array::primitive::compute::PrimitiveTrait; -use crate::compute::search_sorted::SearchSorted; +use crate::compute::search_sorted::{SearchResult, SearchSorted}; use crate::compute::search_sorted::{SearchSortedFn, SearchSortedSide}; use crate::ptype::NativePType; use crate::scalar::Scalar; impl SearchSortedFn for &dyn PrimitiveTrait { - fn search_sorted( - &self, - value: &Scalar, - side: SearchSortedSide, - ) -> VortexResult> { + fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult { let pvalue: T = value.try_into()?; Ok(self.typed_data().search_sorted(&pvalue, side)) } @@ -30,25 +26,25 @@ mod test { assert_eq!( search_sorted(&values, 0, SearchSortedSide::Left) .unwrap() - .unwrap_or_else(|o| o), + .to_index(), 0 ); assert_eq!( search_sorted(&values, 1, SearchSortedSide::Left) .unwrap() - .unwrap_or_else(|o| o), + .to_index(), 0 ); assert_eq!( search_sorted(&values, 1, SearchSortedSide::Right) .unwrap() - .unwrap_or_else(|o| o), - 1 + .to_index(), + 0 ); assert_eq!( search_sorted(&values, 4, SearchSortedSide::Left) .unwrap() - .unwrap_or_else(|o| o), + .to_index(), 3 ); } diff --git a/vortex-array/src/array/sparse/mod.rs b/vortex-array/src/array/sparse/mod.rs index 372ee98439..59757a6f80 100644 --- a/vortex-array/src/array/sparse/mod.rs +++ b/vortex-array/src/array/sparse/mod.rs @@ -94,9 +94,9 @@ impl SparseArray { search_sorted( self.indices(), self.indices_offset + index, - SearchSortedSide::Exact, + SearchSortedSide::Left, ) - .map(|r| r.ok()) + .map(|r| r.to_option()) } /// Return indices as a vector of usize with the indices_offset applied. diff --git a/vortex-array/src/compute/search_sorted.rs b/vortex-array/src/compute/search_sorted.rs index 5ecf3c73fd..f2b82b76f1 100644 --- a/vortex-array/src/compute/search_sorted.rs +++ b/vortex-array/src/compute/search_sorted.rs @@ -11,22 +11,39 @@ use crate::scalar::Scalar; pub enum SearchSortedSide { Left, Right, - Exact, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq)] +pub enum SearchResult { + Found(usize), + NotFound(usize), +} + +impl SearchResult { + pub fn to_option(self) -> Option { + match self { + SearchResult::Found(i) => Some(i), + SearchResult::NotFound(_) => None, + } + } + + pub fn to_index(self) -> usize { + match self { + SearchResult::Found(i) => i, + SearchResult::NotFound(i) => i, + } + } } pub trait SearchSortedFn { - fn search_sorted( - &self, - value: &Scalar, - side: SearchSortedSide, - ) -> VortexResult>; + fn search_sorted(&self, value: &Scalar, side: SearchSortedSide) -> VortexResult; } pub fn search_sorted>( array: &dyn Array, target: T, side: SearchSortedSide, -) -> VortexResult> { +) -> VortexResult { let scalar = target.into().cast(array.dtype())?; array.with_compute(|c| { if let Some(search_sorted) = c.search_sorted() { @@ -70,59 +87,102 @@ pub trait Len { } pub trait SearchSorted { - fn search_sorted(&self, value: &T, side: SearchSortedSide) -> Result + fn search_sorted(&self, value: &T, side: SearchSortedSide) -> SearchResult where Self: IndexOrd, { match side { - SearchSortedSide::Left => self.search_sorted_by(|idx| { - if self.index_lt(idx, value) { - Less - } else { - Greater - } - }), - SearchSortedSide::Right => self.search_sorted_by(|idx| { - if self.index_le(idx, value) { - Less - } else { - Greater - } - }), - SearchSortedSide::Exact => { - self.search_sorted_by(|idx| self.index_cmp(idx, value).unwrap_or(Greater)) - } + SearchSortedSide::Left => self.search_sorted_by( + |idx| self.index_cmp(idx, value).unwrap_or(Less), + |idx| { + if self.index_lt(idx, value) { + Less + } else { + Greater + } + }, + side, + ), + SearchSortedSide::Right => self.search_sorted_by( + |idx| self.index_cmp(idx, value).unwrap_or(Less), + |idx| { + if self.index_le(idx, value) { + Less + } else { + Greater + } + }, + side, + ), } } - fn search_sorted_by Ordering>(&self, f: F) -> Result; + /// find function is used to find the element if it exists, if element exists side_find will be used to find desired index amongst equal values + fn search_sorted_by Ordering, N: FnMut(usize) -> Ordering>( + &self, + find: F, + side_find: N, + side: SearchSortedSide, + ) -> SearchResult; } impl + Len + ?Sized, T> SearchSorted for S { - // Code adapted from Rust standard library slice::binary_search_by - fn search_sorted_by Ordering>(&self, mut f: F) -> Result { - // INVARIANTS: - // - 0 <= left <= left + size = right <= self.len() - // - f returns Less for everything in self[..left] - // - f returns Greater for everything in self[right..] - let mut size = self.len(); - let mut left = 0; - let mut right = size; - while left < right { - let mid = left + size / 2; - let cmp = f(mid); - - left = if cmp == Less { mid + 1 } else { left }; - right = if cmp == Greater { mid } else { right }; - if cmp == Equal { - return Ok(mid); - } - - size = right - left; + fn search_sorted_by Ordering, N: FnMut(usize) -> Ordering>( + &self, + find: F, + side_find: N, + side: SearchSortedSide, + ) -> SearchResult { + match search_sorted_side_idx(find, 0, self.len()) { + SearchResult::Found(found) => match side { + SearchSortedSide::Left => match search_sorted_side_idx(side_find, 0, found) { + SearchResult::NotFound(i) => SearchResult::Found(i), + _ => unreachable!( + "searching amongst equal values should never return Found result" + ), + }, + // Right side search returns index one past the result we want, subtract here + SearchSortedSide::Right => { + match search_sorted_side_idx(side_find, found, self.len()) { + SearchResult::NotFound(i) => SearchResult::Found(i - 1), + _ => unreachable!( + "searching amongst equal values should never return Found result" + ), + } + } + }, + s => s, + } + } +} + +// Code adapted from Rust standard library slice::binary_search_by +fn search_sorted_side_idx Ordering>( + mut find: F, + from: usize, + to: usize, +) -> SearchResult { + // INVARIANTS: + // - from <= left <= left + size = right <= to + // - f returns Less for everything in self[..left] + // - f returns Greater for everything in self[right..] + let mut size = to - from; + let mut left = from; + let mut right = to; + while left < right { + let mid = left + size / 2; + let cmp = find(mid); + + left = if cmp == Less { mid + 1 } else { left }; + right = if cmp == Greater { mid } else { right }; + if cmp == Equal { + return SearchResult::Found(mid); } - Err(left) + size = right - left; } + + SearchResult::NotFound(left) } impl IndexOrd for &dyn Array { @@ -150,3 +210,56 @@ impl Len for [T] { self.len() } } + +#[cfg(test)] +mod test { + use crate::compute::search_sorted::{SearchResult, SearchSorted, SearchSortedSide}; + + #[test] + fn left_side_equal() { + let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9]; + let res = arr.search_sorted(&2, SearchSortedSide::Left); + assert_eq!(arr[res.to_index()], 2); + assert_eq!(res, SearchResult::Found(2)); + } + + #[test] + fn right_side_equal() { + let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9]; + let res = arr.search_sorted(&2, SearchSortedSide::Right); + assert_eq!(arr[res.to_index()], 2); + assert_eq!(res, SearchResult::Found(5)); + } + + #[test] + fn left_side_equal_beginning() { + let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + let res = arr.search_sorted(&0, SearchSortedSide::Left); + assert_eq!(arr[res.to_index()], 0); + assert_eq!(res, SearchResult::Found(0)); + } + + #[test] + fn right_side_equal_beginning() { + let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; + let res = arr.search_sorted(&0, SearchSortedSide::Right); + assert_eq!(arr[res.to_index()], 0); + assert_eq!(res, SearchResult::Found(3)); + } + + #[test] + fn left_side_equal_end() { + let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9]; + let res = arr.search_sorted(&9, SearchSortedSide::Left); + assert_eq!(arr[res.to_index()], 9); + assert_eq!(res, SearchResult::Found(9)); + } + + #[test] + fn right_side_equal_end() { + let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9]; + let res = arr.search_sorted(&9, SearchSortedSide::Right); + // assert_eq!(arr[res.to_index()], 9); + assert_eq!(res, SearchResult::Found(12)); + } +} diff --git a/vortex-ree/src/ree.rs b/vortex-ree/src/ree.rs index 7ec8ca19bd..46e9d69198 100644 --- a/vortex-ree/src/ree.rs +++ b/vortex-ree/src/ree.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, RwLock}; use vortex::array::{Array, ArrayKind, ArrayRef}; use vortex::compress::EncodingCompression; use vortex::compute::scalar_at::scalar_at; -use vortex::compute::search_sorted::{search_sorted, SearchSortedSide}; +use vortex::compute::search_sorted::{search_sorted, SearchResult, SearchSortedSide}; use vortex::compute::ArrayCompute; use vortex::encoding::{Encoding, EncodingId, EncodingRef}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; @@ -68,8 +68,10 @@ impl REEArray { } pub fn find_physical_index(&self, index: usize) -> VortexResult { - search_sorted(self.ends(), index + self.offset, SearchSortedSide::Right) - .map(|r| r.unwrap_or_else(|o| o)) + search_sorted(self.ends(), index + self.offset, SearchSortedSide::Right).map(|r| match r { + SearchResult::Found(i) => i + 1, + SearchResult::NotFound(i) => i, + }) } pub fn encode(array: &dyn Array) -> VortexResult { From 46b99bedc4f767f9303e479baf638c7bc2d57d91 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Fri, 12 Apr 2024 18:01:31 +0100 Subject: [PATCH 4/6] found --- vortex-array/src/array/sparse/mod.rs | 2 +- vortex-array/src/compute/search_sorted.rs | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/vortex-array/src/array/sparse/mod.rs b/vortex-array/src/array/sparse/mod.rs index 59757a6f80..596b596f49 100644 --- a/vortex-array/src/array/sparse/mod.rs +++ b/vortex-array/src/array/sparse/mod.rs @@ -96,7 +96,7 @@ impl SparseArray { self.indices_offset + index, SearchSortedSide::Left, ) - .map(|r| r.to_option()) + .map(|r| r.to_found()) } /// Return indices as a vector of usize with the indices_offset applied. diff --git a/vortex-array/src/compute/search_sorted.rs b/vortex-array/src/compute/search_sorted.rs index f2b82b76f1..85062916a4 100644 --- a/vortex-array/src/compute/search_sorted.rs +++ b/vortex-array/src/compute/search_sorted.rs @@ -20,7 +20,7 @@ pub enum SearchResult { } impl SearchResult { - pub fn to_option(self) -> Option { + pub fn to_found(self) -> Option { match self { SearchResult::Found(i) => Some(i), SearchResult::NotFound(_) => None, From 31a28e5e1cbec9fa3ecbe701d52f1052bbdcfde9 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Fri, 12 Apr 2024 22:59:37 +0100 Subject: [PATCH 5/6] consistent --- .../array/primitive/compute/search_sorted.rs | 2 +- vortex-array/src/compute/search_sorted.rs | 31 ++++++++----------- vortex-ree/src/ree.rs | 8 ++--- 3 files changed, 17 insertions(+), 24 deletions(-) diff --git a/vortex-array/src/array/primitive/compute/search_sorted.rs b/vortex-array/src/array/primitive/compute/search_sorted.rs index f82d9364c7..b2095e3854 100644 --- a/vortex-array/src/array/primitive/compute/search_sorted.rs +++ b/vortex-array/src/array/primitive/compute/search_sorted.rs @@ -39,7 +39,7 @@ mod test { search_sorted(&values, 1, SearchSortedSide::Right) .unwrap() .to_index(), - 0 + 1 ); assert_eq!( search_sorted(&values, 4, SearchSortedSide::Left) diff --git a/vortex-array/src/compute/search_sorted.rs b/vortex-array/src/compute/search_sorted.rs index 85062916a4..469452aa54 100644 --- a/vortex-array/src/compute/search_sorted.rs +++ b/vortex-array/src/compute/search_sorted.rs @@ -134,23 +134,18 @@ impl + Len + ?Sized, T> SearchSorted for S { side: SearchSortedSide, ) -> SearchResult { match search_sorted_side_idx(find, 0, self.len()) { - SearchResult::Found(found) => match side { - SearchSortedSide::Left => match search_sorted_side_idx(side_find, 0, found) { + SearchResult::Found(found) => { + let idx_search = match side { + SearchSortedSide::Left => search_sorted_side_idx(side_find, 0, found), + SearchSortedSide::Right => search_sorted_side_idx(side_find, found, self.len()), + }; + match idx_search { SearchResult::NotFound(i) => SearchResult::Found(i), _ => unreachable!( "searching amongst equal values should never return Found result" ), - }, - // Right side search returns index one past the result we want, subtract here - SearchSortedSide::Right => { - match search_sorted_side_idx(side_find, found, self.len()) { - SearchResult::NotFound(i) => SearchResult::Found(i - 1), - _ => unreachable!( - "searching amongst equal values should never return Found result" - ), - } } - }, + } s => s, } } @@ -227,8 +222,8 @@ mod test { fn right_side_equal() { let arr = [0, 1, 2, 2, 2, 2, 3, 4, 5, 6, 7, 8, 9]; let res = arr.search_sorted(&2, SearchSortedSide::Right); - assert_eq!(arr[res.to_index()], 2); - assert_eq!(res, SearchResult::Found(5)); + assert_eq!(arr[res.to_index() - 1], 2); + assert_eq!(res, SearchResult::Found(6)); } #[test] @@ -243,8 +238,8 @@ mod test { fn right_side_equal_beginning() { let arr = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9]; let res = arr.search_sorted(&0, SearchSortedSide::Right); - assert_eq!(arr[res.to_index()], 0); - assert_eq!(res, SearchResult::Found(3)); + assert_eq!(arr[res.to_index() - 1], 0); + assert_eq!(res, SearchResult::Found(4)); } #[test] @@ -259,7 +254,7 @@ mod test { fn right_side_equal_end() { let arr = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 9]; let res = arr.search_sorted(&9, SearchSortedSide::Right); - // assert_eq!(arr[res.to_index()], 9); - assert_eq!(res, SearchResult::Found(12)); + assert_eq!(arr[res.to_index() - 1], 9); + assert_eq!(res, SearchResult::Found(13)); } } diff --git a/vortex-ree/src/ree.rs b/vortex-ree/src/ree.rs index 46e9d69198..c87d112d83 100644 --- a/vortex-ree/src/ree.rs +++ b/vortex-ree/src/ree.rs @@ -3,7 +3,7 @@ use std::sync::{Arc, RwLock}; use vortex::array::{Array, ArrayKind, ArrayRef}; use vortex::compress::EncodingCompression; use vortex::compute::scalar_at::scalar_at; -use vortex::compute::search_sorted::{search_sorted, SearchResult, SearchSortedSide}; +use vortex::compute::search_sorted::{search_sorted, SearchSortedSide}; use vortex::compute::ArrayCompute; use vortex::encoding::{Encoding, EncodingId, EncodingRef}; use vortex::formatter::{ArrayDisplay, ArrayFormatter}; @@ -68,10 +68,8 @@ impl REEArray { } pub fn find_physical_index(&self, index: usize) -> VortexResult { - search_sorted(self.ends(), index + self.offset, SearchSortedSide::Right).map(|r| match r { - SearchResult::Found(i) => i + 1, - SearchResult::NotFound(i) => i, - }) + search_sorted(self.ends(), index + self.offset, SearchSortedSide::Right) + .map(|s| s.to_index()) } pub fn encode(array: &dyn Array) -> VortexResult { From 463ce21530127cc62be7f97fe0402c6bd59e94f8 Mon Sep 17 00:00:00 2001 From: Robert Kruszewski Date: Fri, 12 Apr 2024 23:17:35 +0100 Subject: [PATCH 6/6] fix --- vortex-array/src/array/sparse/compute/slice.rs | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vortex-array/src/array/sparse/compute/slice.rs b/vortex-array/src/array/sparse/compute/slice.rs index 1870c4756f..29c2ee43a0 100644 --- a/vortex-array/src/array/sparse/compute/slice.rs +++ b/vortex-array/src/array/sparse/compute/slice.rs @@ -8,8 +8,10 @@ use crate::compute::slice::{slice, SliceFn}; impl SliceFn for SparseArray { fn slice(&self, start: usize, stop: usize) -> VortexResult { // Find the index of the first patch index that is greater than or equal to the offset of this array - let index_start_index = search_sorted(self.indices(), start, SearchSortedSide::Left)?; - let index_end_index = search_sorted(self.indices(), stop, SearchSortedSide::Left)?; + let index_start_index = + search_sorted(self.indices(), start, SearchSortedSide::Left)?.to_index(); + let index_end_index = + search_sorted(self.indices(), stop, SearchSortedSide::Left)?.to_index(); Ok(SparseArray::try_new_with_offset( slice(self.indices(), index_start_index, index_end_index)?,