diff --git a/encodings/fastlanes/src/bitpacking/compute/take.rs b/encodings/fastlanes/src/bitpacking/compute/take.rs index 95ab8cbc0f..51f2b2042f 100644 --- a/encodings/fastlanes/src/bitpacking/compute/take.rs +++ b/encodings/fastlanes/src/bitpacking/compute/take.rs @@ -1,9 +1,8 @@ -use std::cmp::min; - use fastlanes::BitPacking; use itertools::Itertools; +use vortex::aliases::hash_map::HashMap; use vortex::array::{PrimitiveArray, SparseArray}; -use vortex::compute::{slice, take, TakeFn}; +use vortex::compute::{take, TakeFn}; use vortex::{Array, ArrayDType, IntoArray, IntoArrayVariant, IntoCanonical}; use vortex_dtype::{ match_each_integer_ptype, match_each_unsigned_integer_ptype, NativePType, PType, @@ -16,7 +15,6 @@ use crate::{unpack_single_primitive, BitPackedArray}; // all 1024 elements takes ~8.8x as long as unpacking a single element on an M2 Macbook Air. // see https://github.com/spiraldb/vortex/pull/190#issue-2223752833 const UNPACK_CHUNK_THRESHOLD: usize = 8; -const BULK_PATCH_THRESHOLD: usize = 64; impl TakeFn for BitPackedArray { fn take(&self, indices: &Array) -> VortexResult { @@ -57,86 +55,47 @@ fn take_primitive( let packed = array.packed_slice::(); let patches = array.patches().map(SparseArray::try_from).transpose()?; - // Group indices by 1024-element chunk, *without* allocating on the heap - let chunked_indices = &indices - .maybe_null_slice::() - .iter() - .map(|i| { - i.to_usize() - .vortex_expect("index must be expressible as usize") - + offset - }) - .chunk_by(|idx| idx / 1024); + let mut chunk_counts: HashMap = HashMap::new(); + // let batch_count = chunk_counts.len(); + for index in indices.maybe_null_slice::() { + let value = index.to_usize().vortex_expect("x") + offset; + let entry = chunk_counts.entry(value / 1024).or_insert_with(|| 0_usize); + *entry += 1; + } + let mut unpacked_chunks: HashMap = HashMap::new(); + for (chunk, count) in chunk_counts { + if count > UNPACK_CHUNK_THRESHOLD { + let chunk_size = 128 * bit_width / size_of::(); + let packed_chunk = &packed[chunk * chunk_size..][..chunk_size]; + let mut unpacked = [T::zero(); 1024]; + unsafe { + BitPacking::unchecked_unpack(bit_width, packed_chunk, &mut unpacked); + } + unpacked_chunks.insert(chunk, unpacked); + } + } let mut output = Vec::with_capacity(indices.len()); - let mut unpacked = [T::zero(); 1024]; - - let mut batch_count = 0_usize; - for (chunk, offsets) in chunked_indices { - batch_count += 1; - let chunk_size = 128 * bit_width / size_of::(); - let packed_chunk = &packed[chunk * chunk_size..][..chunk_size]; - // array_chunks produced a fixed size array, doesn't heap allocate - let mut have_unpacked = false; - let mut offset_chunk_iter = offsets - // relativize indices to the start of the chunk - .map(|i| i % 1024) - .array_chunks::(); - - // this loop only runs if we have at least UNPACK_CHUNK_THRESHOLD offsets - for offset_chunk in &mut offset_chunk_iter { - if !have_unpacked { - unsafe { - BitPacking::unchecked_unpack(bit_width, packed_chunk, &mut unpacked); - } - have_unpacked = true; + for index in indices.maybe_null_slice::() { + let value = index.to_usize().vortex_expect("x") + offset; + let chunk = value / 1024; + let in_chunk_index = value % 1024; + match unpacked_chunks.get(&chunk) { + Some(unpacked) => { + output.push(unpacked[in_chunk_index]); } - - for index in offset_chunk { - output.push(unpacked[index]); - } - } - - // if we have a remainder (i.e., < UNPACK_CHUNK_THRESHOLD leftover offsets), we need to handle it - if let Some(remainder) = offset_chunk_iter.into_remainder() { - if have_unpacked { - // we already bulk unpacked this chunk, so we can just push the remaining elements - for index in remainder { - output.push(unpacked[index]); - } - } else { - // we had fewer than UNPACK_CHUNK_THRESHOLD offsets in the first place, - // so we need to unpack each one individually - for index in remainder { - output.push(unsafe { - unpack_single_primitive::(packed_chunk, bit_width, index) - }); - } + None => { + let chunk_size = 128 * bit_width / size_of::(); + let packed_chunk = &packed[chunk * chunk_size..][..chunk_size]; + output.push(unsafe { + unpack_single_primitive::(packed_chunk, bit_width, in_chunk_index) + }); } } } if let Some(ref patches) = patches { - patch_for_take_primitive::(patches, indices, offset, batch_count, &mut output)?; - } - - Ok(output) -} - -fn patch_for_take_primitive( - patches: &SparseArray, - indices: &PrimitiveArray, - offset: usize, - batch_count: usize, - output: &mut [T], -) -> VortexResult<()> { - #[inline] - fn inner_patch( - patches: &SparseArray, - indices: &PrimitiveArray, - output: &mut [T], - ) -> VortexResult<()> { let taken_patches = take(patches.as_ref(), indices.as_ref())?; let taken_patches = SparseArray::try_from(taken_patches)?; @@ -153,63 +112,9 @@ fn patch_for_take_primitive( .for_each(|(idx, val)| { output[idx] = *val; }); - - Ok(()) - } - - // if we have a small number of relatively large batches, we gain by slicing and then patching inside the loop - // if we have a large number of relatively small batches, the overhead isn't worth it, and we're better off with a bulk patch - // roughly, if we have an average of less than 64 elements per batch, we prefer bulk patching - let prefer_bulk_patch = batch_count * BULK_PATCH_THRESHOLD > indices.len(); - if prefer_bulk_patch { - return inner_patch(patches, indices, output); - } - - let min_index = patches.min_index().unwrap_or_default(); - let max_index = patches.max_index().unwrap_or_default(); - - // Group indices into 1024-element chunks and relativise them to the beginning of each chunk - let chunked_indices = &indices - .maybe_null_slice::() - .iter() - .map(|i| { - i.to_usize() - .vortex_expect("index must be expressible as usize") - + offset - }) - .filter(|i| *i >= min_index && *i <= max_index) // short-circuit - .chunk_by(|idx| idx / 1024); - - for (chunk, offsets) in chunked_indices { - // NOTE: we need to subtract the array offset before slicing into the patches. - // This is because BitPackedArray is rounded to block boundaries, but patches - // is sliced exactly. - let patches_start = if chunk == 0 { - 0 - } else { - (chunk * 1024) - offset - }; - let patches_end = min((chunk + 1) * 1024 - offset, patches.len()); - let patches_slice = slice(patches.as_ref(), patches_start, patches_end)?; - let patches_slice = SparseArray::try_from(patches_slice)?; - if patches_slice.is_empty() { - continue; - } - - let min_index = patches_slice.min_index().unwrap_or_default(); - let max_index = patches_slice.max_index().unwrap_or_default(); - let offsets = offsets - .map(|i| (i % 1024) as u16) - .filter(|i| *i as usize >= min_index && *i as usize <= max_index) - .collect_vec(); - if offsets.is_empty() { - continue; - } - - inner_patch(&patches_slice, &PrimitiveArray::from(offsets), output)?; } - Ok(()) + Ok(output) } #[cfg(test)]