Skip to content

Commit

Permalink
Simplify/idiomize the way arrays return &Array (#826)
Browse files Browse the repository at this point in the history
per @robert3005 there used to be a lifetime issue with `AsArray`, but
that's not longer true and it doesn't actually serve any other purpose +
`AsRef` makes the code more rusty.
  • Loading branch information
AdamGS authored Sep 16, 2024
1 parent 9257acc commit 4c80613
Show file tree
Hide file tree
Showing 54 changed files with 226 additions and 222 deletions.
4 changes: 2 additions & 2 deletions encodings/alp/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ impl ALPArray {
}

pub fn encoded(&self) -> Array {
self.array()
self.as_ref()
.child(0, &self.metadata().encoded_dtype, self.len())
.vortex_expect("Missing encoded child in ALPArray")
}
Expand All @@ -94,7 +94,7 @@ impl ALPArray {

pub fn patches(&self) -> Option<Array> {
self.metadata().patches_dtype.as_ref().map(|dt| {
self.array().child(1, dt, self.len()).unwrap_or_else(|| {
self.as_ref().child(1, dt, self.len()).unwrap_or_else(|| {
vortex_panic!(
"Missing patches with present metadata flag; patches dtype: {}, patches_len: {}",
dt,
Expand Down
5 changes: 2 additions & 3 deletions encodings/alp/src/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,6 @@ mod tests {
use core::f64;

use vortex::compute::unary::scalar_at;
use vortex::AsArray;

use super::*;

Expand Down Expand Up @@ -182,11 +181,11 @@ mod tests {
assert_eq!(encoded.exponents(), Exponents { e: 3, f: 0 });

for idx in 0..3 {
let s = scalar_at(encoded.as_array_ref(), idx).unwrap();
let s = scalar_at(encoded.as_ref(), idx).unwrap();
assert!(s.is_valid());
}

let s = scalar_at(encoded.as_array_ref(), 4).unwrap();
let s = scalar_at(encoded.as_ref(), 4).unwrap();
assert!(s.is_null());

let _decoded = decompress(encoded).unwrap();
Expand Down
6 changes: 3 additions & 3 deletions encodings/alp/src/compute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,12 @@ where
match encoded {
Ok(encoded) => {
let s = ConstantArray::new(encoded, alp.len());
compare(&alp.encoded(), s.array(), operator)
compare(&alp.encoded(), s.as_ref(), operator)
}
Err(exception) => {
if let Some(patches) = alp.patches().as_ref() {
let s = ConstantArray::new(exception, alp.len());
compare(patches, s.array(), operator)
compare(patches, s.as_ref(), operator)
} else {
Ok(BoolArray::from_vec(vec![false; alp.len()], Validity::AllValid).into_array())
}
Expand Down Expand Up @@ -212,7 +212,7 @@ mod tests {
);

let r = encoded
.maybe_compare(other.array(), Operator::Eq)
.maybe_compare(other.as_ref(), Operator::Eq)
.unwrap()
.unwrap()
.into_bool()
Expand Down
17 changes: 8 additions & 9 deletions encodings/bytebool/src/compute/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,6 @@ impl FillForwardFn for ByteBoolArray {
mod tests {
use vortex::compute::unary::{scalar_at, scalar_at_unchecked};
use vortex::compute::{compare, slice, Operator};
use vortex::AsArray as _;
use vortex_scalar::ScalarValue;

use super::*;
Expand All @@ -146,18 +145,18 @@ mod tests {
let original = vec![Some(true), Some(true), None, Some(false), None];
let vortex_arr = ByteBoolArray::from(original.clone());

let sliced_arr = slice(vortex_arr.as_array_ref(), 1, 4).unwrap();
let sliced_arr = slice(vortex_arr.as_ref(), 1, 4).unwrap();
let sliced_arr = ByteBoolArray::try_from(sliced_arr).unwrap();

let s = scalar_at_unchecked(sliced_arr.as_array_ref(), 0);
let s = scalar_at_unchecked(sliced_arr.as_ref(), 0);
assert_eq!(s.into_value().as_bool().unwrap(), Some(true));

let s = scalar_at(sliced_arr.as_array_ref(), 1).unwrap();
let s = scalar_at(sliced_arr.as_ref(), 1).unwrap();
assert!(!sliced_arr.is_valid(1));
assert!(s.is_null());
assert_eq!(s.into_value().as_bool().unwrap(), None);

let s = scalar_at_unchecked(sliced_arr.as_array_ref(), 2);
let s = scalar_at_unchecked(sliced_arr.as_ref(), 2);
assert_eq!(s.into_value().as_bool().unwrap(), Some(false));
}

Expand All @@ -166,10 +165,10 @@ mod tests {
let lhs = ByteBoolArray::from(vec![true; 5]);
let rhs = ByteBoolArray::from(vec![true; 5]);

let arr = compare(lhs.as_array_ref(), rhs.as_array_ref(), Operator::Eq).unwrap();
let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();

for i in 0..arr.len() {
let s = scalar_at_unchecked(arr.as_array_ref(), i);
let s = scalar_at_unchecked(arr.as_ref(), i);
assert!(s.is_valid());
assert_eq!(s.value(), &ScalarValue::Bool(true));
}
Expand All @@ -180,7 +179,7 @@ mod tests {
let lhs = ByteBoolArray::from(vec![false; 5]);
let rhs = ByteBoolArray::from(vec![true; 5]);

let arr = compare(lhs.as_array_ref(), rhs.as_array_ref(), Operator::Eq).unwrap();
let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();

for i in 0..arr.len() {
let s = scalar_at(&arr, i).unwrap();
Expand All @@ -194,7 +193,7 @@ mod tests {
let lhs = ByteBoolArray::from(vec![true; 5]);
let rhs = ByteBoolArray::from(vec![Some(true), Some(true), Some(true), Some(false), None]);

let arr = compare(lhs.as_array_ref(), rhs.as_array_ref(), Operator::Eq).unwrap();
let arr = compare(lhs.as_ref(), rhs.as_ref(), Operator::Eq).unwrap();

for i in 0..3 {
let s = scalar_at(&arr, i).unwrap();
Expand Down
4 changes: 2 additions & 2 deletions encodings/bytebool/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl ByteBoolArray {
pub fn validity(&self) -> Validity {
self.metadata()
.validity
.to_validity(self.array().child(0, &Validity::DTYPE, self.len()))
.to_validity(self.as_ref().child(0, &Validity::DTYPE, self.len()))
}

pub fn try_new(buffer: Buffer, validity: Validity) -> VortexResult<Self> {
Expand Down Expand Up @@ -64,7 +64,7 @@ impl ByteBoolArray {
}

pub fn buffer(&self) -> &Buffer {
self.array()
self.as_ref()
.buffer()
.vortex_expect("ByteBoolArray is missing the underlying buffer")
}
Expand Down
4 changes: 2 additions & 2 deletions encodings/bytebool/src/stats.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use vortex::stats::{ArrayStatisticsCompute, Stat, StatsSet};
use vortex::{AsArray, IntoArrayVariant};
use vortex::IntoArrayVariant;
use vortex_error::VortexResult;

use super::ByteBoolArray;
Expand All @@ -11,7 +11,7 @@ impl ArrayStatisticsCompute for ByteBoolArray {
}

// TODO(adamgs): This is slightly wasteful and could be optimized in the future
let bools = self.as_array_ref().clone().into_bool()?;
let bools = self.as_ref().clone().into_bool()?;
bools.compute_statistics(stat)
}
}
Expand Down
6 changes: 3 additions & 3 deletions encodings/datetime-parts/src/array.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,19 @@ impl DateTimePartsArray {
}

pub fn days(&self) -> Array {
self.array()
self.as_ref()
.child(0, &self.metadata().days_dtype, self.len())
.vortex_expect("DatetimePartsArray missing days array")
}

pub fn seconds(&self) -> Array {
self.array()
self.as_ref()
.child(1, &self.metadata().seconds_dtype, self.len())
.vortex_expect("DatetimePartsArray missing seconds array")
}

pub fn subsecond(&self) -> Array {
self.array()
self.as_ref()
.child(2, &self.metadata().subseconds_dtype, self.len())
.vortex_expect("DatetimePartsArray missing subsecond array")
}
Expand Down
4 changes: 2 additions & 2 deletions encodings/dict/src/dict.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,14 @@ impl DictArray {

#[inline]
pub fn values(&self) -> Array {
self.array()
self.as_ref()
.child(0, self.dtype(), self.metadata().values_len)
.vortex_expect("DictArray is missing its values child array")
}

#[inline]
pub fn codes(&self) -> Array {
self.array()
self.as_ref()
.child(1, &self.metadata().codes_dtype, self.len())
.vortex_expect("DictArray is missing its codes child array")
}
Expand Down
24 changes: 12 additions & 12 deletions encodings/fastlanes/benches/bitpacking_take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@ fn bench_take(c: &mut Criterion) {
let uncompressed = PrimitiveArray::from(values.clone());

let packed = BitPackedArray::encode(
uncompressed.array(),
uncompressed.as_ref(),
find_best_bit_width(&uncompressed).unwrap(),
)
.unwrap();

let stratified_indices: PrimitiveArray = (0..10).map(|i| i * 10_000).collect::<Vec<_>>().into();
c.bench_function("take_10_stratified", |b| {
b.iter(|| black_box(take(packed.array(), stratified_indices.array()).unwrap()));
b.iter(|| black_box(take(packed.as_ref(), stratified_indices.as_ref()).unwrap()));
});

let contiguous_indices: PrimitiveArray = (0..10).collect::<Vec<_>>().into();
c.bench_function("take_10_contiguous", |b| {
b.iter(|| black_box(take(packed.array(), contiguous_indices.array()).unwrap()));
b.iter(|| black_box(take(packed.as_ref(), contiguous_indices.as_ref()).unwrap()));
});

let rng = thread_rng();
Expand All @@ -43,12 +43,12 @@ fn bench_take(c: &mut Criterion) {
.collect_vec()
.into();
c.bench_function("take_10K_random", |b| {
b.iter(|| black_box(take(packed.array(), random_indices.array()).unwrap()));
b.iter(|| black_box(take(packed.as_ref(), random_indices.as_ref()).unwrap()));
});

let contiguous_indices: PrimitiveArray = (0..10_000).collect::<Vec<_>>().into();
c.bench_function("take_10K_contiguous", |b| {
b.iter(|| black_box(take(packed.array(), contiguous_indices.array()).unwrap()));
b.iter(|| black_box(take(packed.as_ref(), contiguous_indices.as_ref()).unwrap()));
});
}

Expand All @@ -59,7 +59,7 @@ fn bench_patched_take(c: &mut Criterion) {

let uncompressed = PrimitiveArray::from(values.clone());
let packed = BitPackedArray::encode(
uncompressed.array(),
uncompressed.as_ref(),
find_best_bit_width(&uncompressed).unwrap(),
)
.unwrap();
Expand All @@ -74,12 +74,12 @@ fn bench_patched_take(c: &mut Criterion) {

let stratified_indices: PrimitiveArray = (0..10).map(|i| i * 10_000).collect::<Vec<_>>().into();
c.bench_function("patched_take_10_stratified", |b| {
b.iter(|| black_box(take(packed.array(), stratified_indices.array()).unwrap()));
b.iter(|| black_box(take(packed.as_ref(), stratified_indices.as_ref()).unwrap()));
});

let contiguous_indices: PrimitiveArray = (0..10).collect::<Vec<_>>().into();
c.bench_function("patched_take_10_contiguous", |b| {
b.iter(|| black_box(take(packed.array(), contiguous_indices.array()).unwrap()));
b.iter(|| black_box(take(packed.as_ref(), contiguous_indices.as_ref()).unwrap()));
});

let rng = thread_rng();
Expand All @@ -91,7 +91,7 @@ fn bench_patched_take(c: &mut Criterion) {
.collect_vec()
.into();
c.bench_function("patched_take_10K_random", |b| {
b.iter(|| black_box(take(packed.array(), random_indices.array()).unwrap()));
b.iter(|| black_box(take(packed.as_ref(), random_indices.as_ref()).unwrap()));
});

let not_patch_indices: PrimitiveArray = (0u32..num_exceptions)
Expand All @@ -100,7 +100,7 @@ fn bench_patched_take(c: &mut Criterion) {
.collect_vec()
.into();
c.bench_function("patched_take_10K_contiguous_not_patches", |b| {
b.iter(|| black_box(take(packed.array(), not_patch_indices.array()).unwrap()));
b.iter(|| black_box(take(packed.as_ref(), not_patch_indices.as_ref()).unwrap()));
});

let patch_indices: PrimitiveArray = (big_base2..big_base2 + num_exceptions)
Expand All @@ -109,7 +109,7 @@ fn bench_patched_take(c: &mut Criterion) {
.collect_vec()
.into();
c.bench_function("patched_take_10K_contiguous_patches", |b| {
b.iter(|| black_box(take(packed.array(), patch_indices.array()).unwrap()));
b.iter(|| black_box(take(packed.as_ref(), patch_indices.as_ref()).unwrap()));
});

// There are currently 2 magic parameters of note:
Expand All @@ -133,7 +133,7 @@ fn bench_patched_take(c: &mut Criterion) {
.collect_vec()
.into();
c.bench_function("patched_take_10K_adversarial", |b| {
b.iter(|| black_box(take(packed.array(), adversarial_indices.array()).unwrap()));
b.iter(|| black_box(take(packed.as_ref(), adversarial_indices.as_ref()).unwrap()));
});
}

Expand Down
4 changes: 2 additions & 2 deletions encodings/fastlanes/src/bitpacking/compress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ mod test {
let valid_values = (0..24).map(|v| v < 1 << 4).collect::<Vec<_>>();
let values =
PrimitiveArray::from_vec((0..24).collect::<Vec<_>>(), Validity::from(valid_values));
let compressed = BitPackedArray::encode(values.array(), 4).unwrap();
let compressed = BitPackedArray::encode(values.as_ref(), 4).unwrap();
assert!(compressed.patches().is_none());
assert_eq!(
(0..(1 << 4)).collect::<Vec<_>>(),
Expand Down Expand Up @@ -363,7 +363,7 @@ mod test {

fn compression_roundtrip(n: usize) {
let values = PrimitiveArray::from((0..n).map(|i| (i % 2047) as u16).collect::<Vec<_>>());
let compressed = BitPackedArray::encode(values.array(), 11).unwrap();
let compressed = BitPackedArray::encode(values.as_ref(), 11).unwrap();
let decompressed = compressed.to_array().into_primitive().unwrap();
assert_eq!(
decompressed.maybe_null_slice::<u16>(),
Expand Down
Loading

0 comments on commit 4c80613

Please sign in to comment.