Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] sparse array take golf #1163

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 35 additions & 130 deletions encodings/fastlanes/src/bitpacking/compute/take.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
use std::cmp::min;

use fastlanes::BitPacking;
use itertools::Itertools;
use vortex::aliases::hash_map::HashMap;
use vortex::array::{PrimitiveArray, SparseArray};
use vortex::compute::{slice, take, TakeFn};
use vortex::compute::{take, TakeFn};
use vortex::{Array, ArrayDType, IntoArray, IntoArrayVariant, IntoCanonical};
use vortex_dtype::{
match_each_integer_ptype, match_each_unsigned_integer_ptype, NativePType, PType,
Expand All @@ -16,7 +15,6 @@ use crate::{unpack_single_primitive, BitPackedArray};
// all 1024 elements takes ~8.8x as long as unpacking a single element on an M2 Macbook Air.
// see https://github.com/spiraldb/vortex/pull/190#issue-2223752833
const UNPACK_CHUNK_THRESHOLD: usize = 8;
const BULK_PATCH_THRESHOLD: usize = 64;

impl TakeFn for BitPackedArray {
fn take(&self, indices: &Array) -> VortexResult<Array> {
Expand Down Expand Up @@ -57,86 +55,47 @@ fn take_primitive<T: NativePType + BitPacking, I: NativePType>(
let packed = array.packed_slice::<T>();
let patches = array.patches().map(SparseArray::try_from).transpose()?;

// Group indices by 1024-element chunk, *without* allocating on the heap
let chunked_indices = &indices
.maybe_null_slice::<I>()
.iter()
.map(|i| {
i.to_usize()
.vortex_expect("index must be expressible as usize")
+ offset
})
.chunk_by(|idx| idx / 1024);
let mut chunk_counts: HashMap<usize, usize> = HashMap::new();
// let batch_count = chunk_counts.len();
for index in indices.maybe_null_slice::<I>() {
let value = index.to_usize().vortex_expect("x") + offset;
let entry = chunk_counts.entry(value / 1024).or_insert_with(|| 0_usize);
*entry += 1;
}
let mut unpacked_chunks: HashMap<usize, [T; 1024]> = HashMap::new();
for (chunk, count) in chunk_counts {
if count > UNPACK_CHUNK_THRESHOLD {
let chunk_size = 128 * bit_width / size_of::<T>();
let packed_chunk = &packed[chunk * chunk_size..][..chunk_size];
let mut unpacked = [T::zero(); 1024];
unsafe {
BitPacking::unchecked_unpack(bit_width, packed_chunk, &mut unpacked);
}
unpacked_chunks.insert(chunk, unpacked);
}
}

let mut output = Vec::with_capacity(indices.len());
let mut unpacked = [T::zero(); 1024];

let mut batch_count = 0_usize;
for (chunk, offsets) in chunked_indices {
batch_count += 1;
let chunk_size = 128 * bit_width / size_of::<T>();
let packed_chunk = &packed[chunk * chunk_size..][..chunk_size];

// array_chunks produced a fixed size array, doesn't heap allocate
let mut have_unpacked = false;
let mut offset_chunk_iter = offsets
// relativize indices to the start of the chunk
.map(|i| i % 1024)
.array_chunks::<UNPACK_CHUNK_THRESHOLD>();

// this loop only runs if we have at least UNPACK_CHUNK_THRESHOLD offsets
for offset_chunk in &mut offset_chunk_iter {
if !have_unpacked {
unsafe {
BitPacking::unchecked_unpack(bit_width, packed_chunk, &mut unpacked);
}
have_unpacked = true;
for index in indices.maybe_null_slice::<I>() {
let value = index.to_usize().vortex_expect("x") + offset;
let chunk = value / 1024;
let in_chunk_index = value % 1024;
match unpacked_chunks.get(&chunk) {
Some(unpacked) => {
output.push(unpacked[in_chunk_index]);
}

for index in offset_chunk {
output.push(unpacked[index]);
}
}

// if we have a remainder (i.e., < UNPACK_CHUNK_THRESHOLD leftover offsets), we need to handle it
if let Some(remainder) = offset_chunk_iter.into_remainder() {
if have_unpacked {
// we already bulk unpacked this chunk, so we can just push the remaining elements
for index in remainder {
output.push(unpacked[index]);
}
} else {
// we had fewer than UNPACK_CHUNK_THRESHOLD offsets in the first place,
// so we need to unpack each one individually
for index in remainder {
output.push(unsafe {
unpack_single_primitive::<T>(packed_chunk, bit_width, index)
});
}
None => {
let chunk_size = 128 * bit_width / size_of::<T>();
let packed_chunk = &packed[chunk * chunk_size..][..chunk_size];
output.push(unsafe {
unpack_single_primitive::<T>(packed_chunk, bit_width, in_chunk_index)
});
}
}
}

if let Some(ref patches) = patches {
patch_for_take_primitive::<T, I>(patches, indices, offset, batch_count, &mut output)?;
}

Ok(output)
}

fn patch_for_take_primitive<T: NativePType, I: NativePType>(
patches: &SparseArray,
indices: &PrimitiveArray,
offset: usize,
batch_count: usize,
output: &mut [T],
) -> VortexResult<()> {
#[inline]
fn inner_patch<T: NativePType>(
patches: &SparseArray,
indices: &PrimitiveArray,
output: &mut [T],
) -> VortexResult<()> {
let taken_patches = take(patches.as_ref(), indices.as_ref())?;
let taken_patches = SparseArray::try_from(taken_patches)?;

Expand All @@ -153,63 +112,9 @@ fn patch_for_take_primitive<T: NativePType, I: NativePType>(
.for_each(|(idx, val)| {
output[idx] = *val;
});

Ok(())
}

// if we have a small number of relatively large batches, we gain by slicing and then patching inside the loop
// if we have a large number of relatively small batches, the overhead isn't worth it, and we're better off with a bulk patch
// roughly, if we have an average of less than 64 elements per batch, we prefer bulk patching
let prefer_bulk_patch = batch_count * BULK_PATCH_THRESHOLD > indices.len();
if prefer_bulk_patch {
return inner_patch(patches, indices, output);
}

let min_index = patches.min_index().unwrap_or_default();
let max_index = patches.max_index().unwrap_or_default();

// Group indices into 1024-element chunks and relativise them to the beginning of each chunk
let chunked_indices = &indices
.maybe_null_slice::<I>()
.iter()
.map(|i| {
i.to_usize()
.vortex_expect("index must be expressible as usize")
+ offset
})
.filter(|i| *i >= min_index && *i <= max_index) // short-circuit
.chunk_by(|idx| idx / 1024);

for (chunk, offsets) in chunked_indices {
// NOTE: we need to subtract the array offset before slicing into the patches.
// This is because BitPackedArray is rounded to block boundaries, but patches
// is sliced exactly.
let patches_start = if chunk == 0 {
0
} else {
(chunk * 1024) - offset
};
let patches_end = min((chunk + 1) * 1024 - offset, patches.len());
let patches_slice = slice(patches.as_ref(), patches_start, patches_end)?;
let patches_slice = SparseArray::try_from(patches_slice)?;
if patches_slice.is_empty() {
continue;
}

let min_index = patches_slice.min_index().unwrap_or_default();
let max_index = patches_slice.max_index().unwrap_or_default();
let offsets = offsets
.map(|i| (i % 1024) as u16)
.filter(|i| *i as usize >= min_index && *i as usize <= max_index)
.collect_vec();
if offsets.is_empty() {
continue;
}

inner_patch(&patches_slice, &PrimitiveArray::from(offsets), output)?;
}

Ok(())
Ok(output)
}

#[cfg(test)]
Expand Down
Loading