Skip to content

Commit

Permalink
Fix regression in search_sorted when Patches replaced SparseArray (#1624
Browse files Browse the repository at this point in the history
)
  • Loading branch information
robert3005 authored Dec 9, 2024
1 parent 4edfc74 commit e8cd434
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 20 deletions.
17 changes: 17 additions & 0 deletions encodings/fastlanes/src/bitpacking/compute/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,23 @@ mod test {
assert_eq!(found, SearchResult::Found(0));
}

#[test]
fn test_search_sorted_nulls_not_found() {
let bitpacked = BitPackedArray::encode(
PrimitiveArray::from_nullable_vec(vec![Some(0u8), Some(107u8), None, None]).as_ref(),
0,
)
.unwrap();

let found = search_sorted(
bitpacked.as_ref(),
Scalar::primitive(127u8, Nullability::Nullable),
SearchSortedSide::Left,
)
.unwrap();
assert_eq!(found, SearchResult::NotFound(2));
}

#[test]
fn test_search_sorted_many() {
// Test search_sorted_many with an array that contains several null values.
Expand Down
91 changes: 79 additions & 12 deletions vortex-array/src/patches.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,21 @@ impl Patches {
target: T,
side: SearchSortedSide,
) -> VortexResult<SearchResult> {
Ok(match search_sorted(self.values(), target.into(), side)? {
SearchResult::Found(idx) => SearchResult::Found(if idx == self.indices().len() {
self.array_len()
} else {
usize::try_from(&scalar_at(self.indices(), idx)?)?
}),
SearchResult::NotFound(idx) => SearchResult::NotFound(if idx == self.indices().len() {
self.array_len()
} else {
usize::try_from(&scalar_at(self.indices(), idx)?)?
}),
search_sorted(self.values(), target.into(), side).and_then(|sr| {
let sidx = sr.to_offsets_index(self.indices().len());
let index = usize::try_from(&scalar_at(self.indices(), sidx)?)?;
Ok(match sr {
// If we reached the end of patched values when searching then the result is one after the last patch index
SearchResult::Found(i) => SearchResult::Found(if i == self.indices().len() {
index + 1
} else {
index
}),
// If the result is NotFound we should return index that's one after the nearest not found index for the corresponding value
SearchResult::NotFound(i) => {
SearchResult::NotFound(if i == 0 { index } else { index + 1 })
}
})
})
}

Expand Down Expand Up @@ -260,9 +264,12 @@ impl Patches {

#[cfg(test)]
mod test {
use rstest::{fixture, rstest};

use crate::array::PrimitiveArray;
use crate::compute::FilterMask;
use crate::compute::{FilterMask, SearchResult, SearchSortedSide};
use crate::patches::Patches;
use crate::validity::Validity;
use crate::{IntoArrayData, IntoArrayVariant};

#[test]
Expand All @@ -283,4 +290,64 @@ mod test {
assert_eq!(indices.maybe_null_slice::<u64>(), &[0, 1]);
assert_eq!(values.maybe_null_slice::<i32>(), &[100, 200]);
}

#[fixture]
fn patches() -> Patches {
Patches::new(
20,
PrimitiveArray::from(vec![2u64, 9, 15]).into_array(),
PrimitiveArray::from_vec(vec![33_i32, 44, 55], Validity::AllValid).into_array(),
)
}

#[rstest]
fn search_larger_than(patches: Patches) {
let res = patches.search_sorted(66, SearchSortedSide::Left).unwrap();
assert_eq!(res, SearchResult::NotFound(16));
}

#[rstest]
fn search_less_than(patches: Patches) {
let res = patches.search_sorted(22, SearchSortedSide::Left).unwrap();
assert_eq!(res, SearchResult::NotFound(2));
}

#[rstest]
fn search_found(patches: Patches) {
let res = patches.search_sorted(44, SearchSortedSide::Left).unwrap();
assert_eq!(res, SearchResult::Found(9));
}

#[rstest]
fn search_not_found_right(patches: Patches) {
let res = patches.search_sorted(56, SearchSortedSide::Right).unwrap();
assert_eq!(res, SearchResult::NotFound(16));
}

#[rstest]
fn search_sliced(patches: Patches) {
let sliced = patches.slice(7, 20).unwrap().unwrap();
assert_eq!(
sliced.search_sorted(22, SearchSortedSide::Left).unwrap(),
SearchResult::NotFound(2)
);
}

#[test]
fn search_right() {
let patches = Patches::new(
2,
PrimitiveArray::from(vec![0u64]).into_array(),
PrimitiveArray::from_vec(vec![0u8], Validity::AllValid).into_array(),
);

assert_eq!(
patches.search_sorted(0, SearchSortedSide::Right).unwrap(),
SearchResult::Found(1)
);
assert_eq!(
patches.search_sorted(1, SearchSortedSide::Right).unwrap(),
SearchResult::NotFound(1)
);
}
}
12 changes: 9 additions & 3 deletions vortex-sampling-compressor/src/compressors/alp.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use vortex_alp::{alp_encode_components, match_each_alp_float_ptype, ALPArray, ALPEncoding};
use vortex_alp::{
alp_encode_components, match_each_alp_float_ptype, ALPArray, ALPEncoding, ALPRDEncoding,
};
use vortex_array::aliases::hash_set::HashSet;
use vortex_array::array::PrimitiveArray;
use vortex_array::encoding::{Encoding, EncodingRef};
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::{ArrayData, IntoArrayData, IntoArrayVariant};
use vortex_dtype::PType;
use vortex_error::VortexResult;
use vortex_fastlanes::BitPackedEncoding;

use super::alp_rd::ALPRDCompressor;
use crate::compressors::{CompressedArray, CompressionTree, EncodingCompressor};
Expand Down Expand Up @@ -41,7 +44,6 @@ impl EncodingCompressor for ALPCompressor {
like: Option<CompressionTree<'a>>,
ctx: SamplingCompressor<'a>,
) -> VortexResult<CompressedArray<'a>> {
// TODO(robert): Fill forward nulls?
let parray = array.clone().into_primitive()?;

let (exponents, encoded, patches) = match_each_alp_float_ptype!(
Expand Down Expand Up @@ -72,6 +74,10 @@ impl EncodingCompressor for ALPCompressor {
}

fn used_encodings(&self) -> HashSet<EncodingRef> {
HashSet::from([&ALPEncoding as EncodingRef])
HashSet::from([
&ALPEncoding as EncodingRef,
&ALPRDEncoding,
&BitPackedEncoding,
])
}
}
13 changes: 8 additions & 5 deletions vortex-sampling-compressor/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ use compressors::varbin::VarBinCompressor;
use compressors::{CompressedArray, CompressorRef};
use vortex_alp::{ALPEncoding, ALPRDEncoding};
use vortex_array::array::{
PrimitiveEncoding, SparseEncoding, StructEncoding, VarBinEncoding, VarBinViewEncoding,
ListEncoding, PrimitiveEncoding, SparseEncoding, StructEncoding, VarBinEncoding,
VarBinViewEncoding,
};
use vortex_array::encoding::EncodingRef;
use vortex_array::Context;
Expand All @@ -32,6 +33,7 @@ use vortex_zigzag::ZigZagEncoding;
use crate::compressors::alp::ALPCompressor;
use crate::compressors::date_time_parts::DateTimePartsCompressor;
use crate::compressors::dict::DictCompressor;
use crate::compressors::list::ListCompressor;
use crate::compressors::r#for::FoRCompressor;
use crate::compressors::runend::DEFAULT_RUN_END_COMPRESSOR;
use crate::compressors::runend_bool::RunEndBoolCompressor;
Expand All @@ -48,8 +50,6 @@ mod sampling_compressor;

pub use sampling_compressor::*;

use crate::compressors::list::ListCompressor;

pub const DEFAULT_COMPRESSORS: [CompressorRef; 15] = [
&ALPCompressor as CompressorRef,
&BITPACK_WITH_PATCHES,
Expand All @@ -72,7 +72,7 @@ pub const DEFAULT_COMPRESSORS: [CompressorRef; 15] = [
];

#[cfg(not(target_arch = "wasm32"))]
pub const ALL_COMPRESSORS: [CompressorRef; 17] = [
pub const ALL_COMPRESSORS: [CompressorRef; 18] = [
&ALPCompressor as CompressorRef,
&BITPACK_WITH_PATCHES,
&DEFAULT_CHUNKED_COMPRESSOR,
Expand All @@ -88,12 +88,13 @@ pub const ALL_COMPRESSORS: [CompressorRef; 17] = [
&DEFAULT_RUN_END_COMPRESSOR,
&SparseCompressor,
&StructCompressor,
&ListCompressor,
&VarBinCompressor,
&ZigZagCompressor,
];

#[cfg(target_arch = "wasm32")]
pub const ALL_COMPRESSORS: [CompressorRef; 15] = [
pub const ALL_COMPRESSORS: [CompressorRef; 16] = [
&ALPCompressor as CompressorRef,
&BITPACK_WITH_PATCHES,
&DEFAULT_CHUNKED_COMPRESSOR,
Expand All @@ -110,6 +111,7 @@ pub const ALL_COMPRESSORS: [CompressorRef; 15] = [
&DEFAULT_RUN_END_COMPRESSOR,
&SparseCompressor,
&StructCompressor,
&ListCompressor,
&VarBinCompressor,
&ZigZagCompressor,
];
Expand All @@ -135,6 +137,7 @@ pub static ALL_ENCODINGS_CONTEXT: LazyLock<Arc<Context>> = LazyLock::new(|| {
&RunEndBoolEncoding,
&SparseEncoding,
&StructEncoding,
&ListEncoding,
&VarBinEncoding,
&VarBinViewEncoding,
&ZigZagEncoding,
Expand Down

0 comments on commit e8cd434

Please sign in to comment.