Skip to content

Commit

Permalink
fix: RunEndBool array take respects validity (#1684)
Browse files Browse the repository at this point in the history
  • Loading branch information
robert3005 authored Dec 13, 2024
1 parent 18268bc commit 65e0571
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 12 deletions.
40 changes: 34 additions & 6 deletions encodings/runend-bool/src/compute/mod.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
mod invert;

use arrow_buffer::BooleanBuffer;
use vortex_array::array::BoolArray;
use vortex_array::compute::{slice, ComputeVTable, InvertFn, ScalarAtFn, SliceFn, TakeFn};
use vortex_array::variants::PrimitiveArrayTrait;
use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant, ToArrayData};
use vortex_array::{ArrayDType, ArrayData, ArrayLen, IntoArrayData, IntoArrayVariant};
use vortex_dtype::match_each_integer_ptype;
use vortex_error::{vortex_bail, VortexResult};
use vortex_scalar::Scalar;
Expand Down Expand Up @@ -53,10 +54,15 @@ impl TakeFn<RunEndBoolArray> for RunEndBoolEncoding {
.collect::<VortexResult<Vec<_>>>()?
});
let start = array.start();
Ok(
BoolArray::from_iter(physical_indices.iter().map(|&it| value_at_index(it, start)))
.to_array(),
BoolArray::try_new(
BooleanBuffer::from_iter(
physical_indices
.into_iter()
.map(|it| value_at_index(it, start)),
),
array.validity().take(indices)?,
)
.map(|a| a.into_array())
}
}

Expand Down Expand Up @@ -90,9 +96,11 @@ impl SliceFn<RunEndBoolArray> for RunEndBoolEncoding {

#[cfg(test)]
mod tests {
use vortex_array::compute::{scalar_at, slice};
use arrow_buffer::BooleanBuffer;
use vortex_array::array::PrimitiveArray;
use vortex_array::compute::{scalar_at, slice, take};
use vortex_array::validity::Validity;
use vortex_array::{ArrayLen, IntoArrayData};
use vortex_array::{ArrayDType, ArrayLen, IntoArrayData, IntoArrayVariant};
use vortex_dtype::Nullability;
use vortex_scalar::Scalar;

Expand Down Expand Up @@ -124,4 +132,24 @@ mod tests {
Scalar::bool(false, Nullability::Nullable)
);
}

#[test]
fn take_nullable() {
let re_array = RunEndBoolArray::try_new(
vec![7_u64, 10].into_array(),
false,
Validity::from(BooleanBuffer::from(vec![
false, false, true, true, true, true, true, true, false, false,
])),
)
.unwrap();

let taken = take(&re_array, PrimitiveArray::from(vec![6, 9])).unwrap();
let taken_bool = taken.into_bool().unwrap();
assert_eq!(taken_bool.dtype(), re_array.dtype());
assert_eq!(
taken_bool.boolean_buffer(),
BooleanBuffer::from(vec![false, true])
);
}
}
3 changes: 1 addition & 2 deletions vortex-array/src/array/bool/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::sync::Arc;

use arrow_array::BooleanArray;
use arrow_buffer::{BooleanBufferBuilder, MutableBuffer};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use vortex_buffer::Buffer;
use vortex_dtype::{DType, Nullability};
Expand Down Expand Up @@ -129,7 +128,7 @@ impl BoolArray {
first_byte_bit_offset,
}),
Some(Buffer::from(inner)),
validity.into_array().into_iter().collect_vec().into(),
validity.into_array().into_iter().collect(),
StatsSet::default(),
)?
.try_into()
Expand Down
16 changes: 12 additions & 4 deletions vortex-array/src/compute/take.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use log::info;
use vortex_error::{vortex_bail, vortex_err, VortexError, VortexResult};

use crate::encoding::Encoding;
Expand Down Expand Up @@ -72,17 +71,26 @@ pub fn take(
// If TakeFn defined for the encoding, delegate to TakeFn.
// If we know from stats that indices are all valid, we can avoid all bounds checks.
if let Some(take_fn) = array.encoding().take_fn() {
return if checked_indices {
let result = if checked_indices {
// SAFETY: indices are all inbounds per stats.
// TODO(aduffy): this means stats must be trusted, can still trigger UB if stats are bad.
unsafe { take_fn.take_unchecked(array, indices) }
} else {
take_fn.take(array, indices)
};
}?;
if array.dtype() != result.dtype() {
vortex_bail!(
"TakeFn {} changed array dtype from {} to {}",
array.encoding().id(),
array.dtype(),
result.dtype()
);
}
return Ok(result);
}

// Otherwise, flatten and try again.
info!("TakeFn not implemented for {}, flattening", array);
log::debug!("No take implementation found for {}", array.encoding().id());
let canonical = array.clone().into_canonical()?.into_array();
let canonical_take_fn = canonical
.encoding()
Expand Down

0 comments on commit 65e0571

Please sign in to comment.