Skip to content

Commit

Permalink
Make entry point compute functions accept generic arguments (#861)
Browse files Browse the repository at this point in the history
This change saves on a lot of `&`, `into_array` and other unnecessary
array type wrangling
  • Loading branch information
AdamGS authored Sep 18, 2024
1 parent a1f997d commit 065d444
Show file tree
Hide file tree
Showing 33 changed files with 97 additions and 94 deletions.
8 changes: 4 additions & 4 deletions encodings/alp/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ impl TakeFn for ALPArray {
fn take(&self, indices: &Array) -> VortexResult<Array> {
// TODO(ngates): wrap up indices in an array that caches decompression?
Ok(Self::try_new(
take(&self.encoded(), indices)?,
take(self.encoded(), indices)?,
self.exponents(),
self.patches().map(|p| take(&p, indices)).transpose()?,
)?
Expand All @@ -73,7 +73,7 @@ impl TakeFn for ALPArray {
impl SliceFn for ALPArray {
fn slice(&self, start: usize, end: usize) -> VortexResult<Array> {
Ok(Self::try_new(
slice(&self.encoded(), start, end)?,
slice(self.encoded(), start, end)?,
self.exponents(),
self.patches().map(|p| slice(&p, start, end)).transpose()?,
)?
Expand All @@ -84,7 +84,7 @@ impl SliceFn for ALPArray {
impl FilterFn for ALPArray {
fn filter(&self, predicate: &Array) -> VortexResult<Array> {
Ok(Self::try_new(
filter(&self.encoded(), predicate)?,
filter(self.encoded(), predicate)?,
self.exponents(),
self.patches().map(|p| filter(&p, predicate)).transpose()?,
)?
Expand Down Expand Up @@ -134,7 +134,7 @@ where
match encoded {
Ok(encoded) => {
let s = ConstantArray::new(encoded, alp.len());
compare(&alp.encoded(), s.as_ref(), operator)
compare(alp.encoded(), s.as_ref(), operator)
}
Err(exception) => {
if let Some(patches) = alp.patches().as_ref() {
Expand Down
12 changes: 6 additions & 6 deletions encodings/datetime-parts/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ impl TakeFn for DateTimePartsArray {
fn take(&self, indices: &Array) -> VortexResult<Array> {
Ok(Self::try_new(
self.dtype().clone(),
take(&self.days(), indices)?,
take(&self.seconds(), indices)?,
take(&self.subsecond(), indices)?,
take(self.days(), indices)?,
take(self.seconds(), indices)?,
take(self.subsecond(), indices)?,
)?
.into_array())
}
Expand All @@ -40,9 +40,9 @@ impl SliceFn for DateTimePartsArray {
fn slice(&self, start: usize, stop: usize) -> VortexResult<Array> {
Ok(Self::try_new(
self.dtype().clone(),
slice(&self.days(), start, stop)?,
slice(&self.seconds(), start, stop)?,
slice(&self.subsecond(), start, stop)?,
slice(self.days(), start, stop)?,
slice(self.seconds(), start, stop)?,
slice(self.subsecond(), start, stop)?,
)?
.into_array())
}
Expand Down
4 changes: 2 additions & 2 deletions encodings/dict/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ impl TakeFn for DictArray {
// Dict
// codes: 0 0 1
// dict: a b c d e f g h
let codes = take(&self.codes(), indices)?;
let codes = take(self.codes(), indices)?;
Self::try_new(codes, self.values()).map(|a| a.into_array())
}
}

impl SliceFn for DictArray {
// TODO(robert): Add function to trim the dictionary
fn slice(&self, start: usize, stop: usize) -> VortexResult<Array> {
Self::try_new(slice(&self.codes(), start, stop)?, self.values()).map(|a| a.into_array())
Self::try_new(slice(self.codes(), start, stop)?, self.values()).map(|a| a.into_array())
}
}

Expand Down
2 changes: 1 addition & 1 deletion encodings/dict/src/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ impl ArrayTrait for DictArray {}

impl IntoCanonical for DictArray {
fn into_canonical(self) -> VortexResult<Canonical> {
take(&self.values(), &self.codes())?.into_canonical()
take(self.values(), self.codes())?.into_canonical()
}
}

Expand Down
8 changes: 2 additions & 6 deletions encodings/fastlanes/src/bitpacking/compute/search_sorted.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,12 +227,8 @@ mod test {
#[test]
fn search_sliced() {
let bitpacked = slice(
&BitPackedArray::encode(
&PrimitiveArray::from(vec![1u32, 2, 3, 4, 5]).into_array(),
2,
)
.unwrap()
.into_array(),
BitPackedArray::encode(PrimitiveArray::from(vec![1u32, 2, 3, 4, 5]).as_ref(), 2)
.unwrap(),
2,
4,
)
Expand Down
4 changes: 2 additions & 2 deletions encodings/fastlanes/src/bitpacking/compute/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ impl SliceFn for BitPackedArray {
let encoded_start = (block_start / 8) * self.bit_width() / self.ptype().byte_width();
let encoded_stop = (block_stop / 8) * self.bit_width() / self.ptype().byte_width();
Self::try_new_from_offset(
slice(&self.packed(), encoded_start, encoded_stop)?,
slice(self.packed(), encoded_start, encoded_stop)?,
self.validity().slice(start, stop)?,
self.patches()
.map(|p| slice(&p, start, stop))
Expand Down Expand Up @@ -171,7 +171,7 @@ mod test {
assert_eq!(patch_indices.len(), 1);

// Slicing drops the empty patches array.
let sliced = slice(&array.into_array(), 0, 64).unwrap();
let sliced = slice(array, 0, 64).unwrap();
let sliced_bp = BitPackedArray::try_from(sliced).unwrap();
assert!(sliced_bp.patches().is_none());
}
Expand Down
6 changes: 3 additions & 3 deletions encodings/fastlanes/src/for/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ impl ArrayCompute for FoRArray {
impl TakeFn for FoRArray {
fn take(&self, indices: &Array) -> VortexResult<Array> {
Self::try_new(
take(&self.encoded(), indices)?,
take(self.encoded(), indices)?,
self.reference().clone(),
self.shift(),
)
Expand All @@ -49,7 +49,7 @@ impl TakeFn for FoRArray {
impl FilterFn for FoRArray {
fn filter(&self, predicate: &Array) -> VortexResult<Array> {
Self::try_new(
filter(&self.encoded(), predicate)?,
filter(self.encoded(), predicate)?,
self.reference().clone(),
self.shift(),
)
Expand Down Expand Up @@ -81,7 +81,7 @@ impl ScalarAtFn for FoRArray {
impl SliceFn for FoRArray {
fn slice(&self, start: usize, stop: usize) -> VortexResult<Array> {
Self::try_new(
slice(&self.encoded(), start, stop)?,
slice(self.encoded(), start, stop)?,
self.reference().clone(),
self.shift(),
)
Expand Down
12 changes: 6 additions & 6 deletions encodings/fsst/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ impl SliceFn for FSSTArray {
self.dtype().clone(),
self.symbols(),
self.symbol_lengths(),
slice(&self.codes(), start, stop)?,
slice(&self.uncompressed_lengths(), start, stop)?,
slice(self.codes(), start, stop)?,
slice(self.uncompressed_lengths(), start, stop)?,
)?
.into_array())
}
Expand All @@ -48,8 +48,8 @@ impl TakeFn for FSSTArray {
self.dtype().clone(),
self.symbols(),
self.symbol_lengths(),
take(&self.codes(), indices)?,
take(&self.uncompressed_lengths(), indices)?,
take(self.codes(), indices)?,
take(self.uncompressed_lengths(), indices)?,
)?
.into_array())
}
Expand Down Expand Up @@ -81,8 +81,8 @@ impl FilterFn for FSSTArray {
self.dtype().clone(),
self.symbols(),
self.symbol_lengths(),
filter(&self.codes(), predicate)?,
filter(&self.uncompressed_lengths(), predicate)?,
filter(self.codes(), predicate)?,
filter(self.uncompressed_lengths(), predicate)?,
)?
.into_array())
}
Expand Down
7 changes: 3 additions & 4 deletions encodings/runend-bool/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,13 @@ mod test {
#[test]
fn take_bool() {
let arr = take(
&RunEndBoolArray::try_new(
RunEndBoolArray::try_new(
vec![2u32, 4, 5, 10].into_array(),
true,
Validity::NonNullable,
)
.unwrap()
.to_array(),
&vec![0, 0, 6, 4].into_array(),
.unwrap(),
vec![0, 0, 6, 4].into_array(),
)
.unwrap();

Expand Down
2 changes: 1 addition & 1 deletion encodings/runend-bool/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ impl SliceFn for RunEndBoolArray {
let slice_end = self.find_physical_index(stop)?;

Ok(Self::with_offset_and_size(
slice(&self.ends(), slice_begin, slice_end + 1)?,
slice(self.ends(), slice_begin, slice_end + 1)?,
value_at_index(slice_begin, self.start()),
self.validity().slice(slice_begin, slice_end + 1)?,
stop - start,
Expand Down
6 changes: 3 additions & 3 deletions encodings/runend/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ impl TakeFn for RunEndArray {
.map(|idx| *idx as u64)
.collect();
let physical_indices_array = PrimitiveArray::from(physical_indices).into_array();
let dense_values = take(&self.values(), &physical_indices_array)?;
let dense_values = take(self.values(), &physical_indices_array)?;

Ok(match self.validity() {
Validity::NonNullable => dense_values,
Expand Down Expand Up @@ -100,8 +100,8 @@ impl SliceFn for RunEndArray {
let slice_end = self.find_physical_index(stop)?;

Ok(Self::with_offset_and_size(
slice(&self.ends(), slice_begin, slice_end + 1)?,
slice(&self.values(), slice_begin, slice_end + 1)?,
slice(self.ends(), slice_begin, slice_end + 1)?,
slice(self.values(), slice_begin, slice_end + 1)?,
self.validity().slice(start, stop)?,
stop - start,
start + self.offset(),
Expand Down
2 changes: 1 addition & 1 deletion encodings/zigzag/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ impl ZigZagEncoded for u64 {

impl SliceFn for ZigZagArray {
fn slice(&self, start: usize, stop: usize) -> VortexResult<Array> {
Ok(Self::try_new(slice(&self.encoded(), start, stop)?)?.into_array())
Ok(Self::try_new(slice(self.encoded(), start, stop)?)?.into_array())
}
}

Expand Down
14 changes: 3 additions & 11 deletions vortex-array/src/array/bool/compute/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ mod test {
use crate::array::primitive::PrimitiveArray;
use crate::array::BoolArray;
use crate::compute::take;
use crate::IntoArray;

#[test]
fn take_nullable() {
Expand All @@ -39,17 +38,10 @@ mod test {
Some(false),
None,
Some(false),
])
.into_array();
]);

let b = BoolArray::try_from(
take(
&reference,
&PrimitiveArray::from(vec![0, 3, 4]).into_array(),
)
.unwrap(),
)
.unwrap();
let b = BoolArray::try_from(take(&reference, PrimitiveArray::from(vec![0, 3, 4])).unwrap())
.unwrap();
assert_eq!(
b.boolean_buffer(),
BoolArray::from_iter(vec![Some(false), None, Some(false)]).boolean_buffer()
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/array/chunked/canonical.rs
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ fn pack_varbin(chunks: &[Array], validity: Validity, dtype: &DType) -> VortexRes
offsets_arr.len() - 1,
))?;
let primitive_bytes =
slice(&chunk.bytes(), first_offset_value, last_offset_value)?.into_primitive()?;
slice(chunk.bytes(), first_offset_value, last_offset_value)?.into_primitive()?;
data_bytes.extend_from_slice(primitive_bytes.buffer());

let adjustment_from_previous = *offsets
Expand Down
8 changes: 4 additions & 4 deletions vortex-array/src/array/chunked/compute/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ fn filter_slices<'a>(
ChunkFilter::None => {}
// Slices => turn the slices into a boolean buffer.
ChunkFilter::Slices(slices) => {
result.push(filter(&chunk, &slices_to_predicate(slices, chunk.len()))?);
result.push(filter(&chunk, slices_to_predicate(slices, chunk.len()))?);
}
}
}
Expand Down Expand Up @@ -172,8 +172,8 @@ fn filter_indices<'a>(
.chunk(current_chunk_id)
.vortex_expect("find_chunk_idx must return valid chunk ID");
let filtered_chunk = take(
&chunk,
&PrimitiveArray::from(chunk_indices.clone()).into_array(),
chunk,
PrimitiveArray::from(chunk_indices.clone()).into_array(),
)?;
result.push(filtered_chunk);
}
Expand All @@ -192,7 +192,7 @@ fn filter_indices<'a>(
.vortex_expect("find_chunk_idx must return valid chunk ID");
let filtered_chunk = take(
&chunk,
&PrimitiveArray::from(chunk_indices.clone()).into_array(),
PrimitiveArray::from(chunk_indices.clone()).into_array(),
)?;
result.push(filtered_chunk);
}
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/array/chunked/compute/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ impl SliceFn for ChunkedArray {
})
.collect::<VortexResult<Vec<_>>>()?;
if let Some(c) = chunks.first_mut() {
*c = slice(c, offset_in_first_chunk, c.len())?;
*c = slice(&*c, offset_in_first_chunk, c.len())?;
}

if length_in_last_chunk == 0 {
chunks.pop();
} else if let Some(c) = chunks.last_mut() {
*c = slice(c, 0, length_in_last_chunk)?;
*c = slice(&*c, 0, length_in_last_chunk)?;
}

Self::try_new(chunks, self.dtype().clone()).map(|a| a.into_array())
Expand Down
8 changes: 4 additions & 4 deletions vortex-array/src/array/extension/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,12 +45,12 @@ impl MaybeCompareFn for ExtensionArray {
Scalar::new(self.storage().dtype().clone(), scalar_ext.value().clone()),
const_ext.len(),
);
compare(&self.storage(), const_storage.as_ref(), operator)
compare(self.storage(), const_storage, operator)
});
}

if let Ok(rhs_ext) = ExtensionArray::try_from(array) {
return Some(compare(&self.storage(), &rhs_ext.storage(), operator));
return Some(compare(self.storage(), rhs_ext.storage(), operator));
}

None
Expand All @@ -77,14 +77,14 @@ impl SliceFn for ExtensionArray {
fn slice(&self, start: usize, stop: usize) -> VortexResult<Array> {
Ok(Self::new(
self.ext_dtype().clone(),
slice(&self.storage(), start, stop)?,
slice(self.storage(), start, stop)?,
)
.into_array())
}
}

impl TakeFn for ExtensionArray {
fn take(&self, indices: &Array) -> VortexResult<Array> {
Ok(Self::new(self.ext_dtype().clone(), take(&self.storage(), indices)?).into_array())
Ok(Self::new(self.ext_dtype().clone(), take(self.storage(), indices)?).into_array())
}
}
5 changes: 2 additions & 3 deletions vortex-array/src/array/null/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,8 @@ mod test {
#[test]
fn test_take_nulls() {
let nulls = NullArray::new(10).into_array();
let taken =
NullArray::try_from(take(&nulls, &vec![0u64, 2, 4, 6, 8].into_array()).unwrap())
.unwrap();
let taken = NullArray::try_from(take(&nulls, vec![0u64, 2, 4, 6, 8].into_array()).unwrap())
.unwrap();

assert_eq!(taken.len(), 5);
assert!(matches!(
Expand Down
2 changes: 1 addition & 1 deletion vortex-array/src/array/sparse/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ impl FilterFn for SparseArray {

Ok(SparseArray::try_new(
PrimitiveArray::from(coordinate_indices).into_array(),
take(&self.values(), PrimitiveArray::from(value_indices).as_ref())?,
take(self.values(), PrimitiveArray::from(value_indices))?,
buffer.count_set_bits(),
self.fill_value().clone(),
)?
Expand Down
4 changes: 2 additions & 2 deletions vortex-array/src/array/sparse/compute/slice.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ impl SliceFn for SparseArray {
let index_end_index = self.search_index(stop)?.to_index();

Ok(Self::try_new_with_offset(
slice(&self.indices(), index_start_index, index_end_index)?,
slice(&self.values(), index_start_index, index_end_index)?,
slice(self.indices(), index_start_index, index_end_index)?,
slice(self.values(), index_start_index, index_end_index)?,
stop - start,
self.indices_offset() + start,
self.fill_value().clone(),
Expand Down
Loading

0 comments on commit 065d444

Please sign in to comment.