Skip to content

Commit

Permalink
Further take kernel cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed Aug 25, 2023
1 parent d9381c6 commit ca0079a
Showing 1 changed file with 25 additions and 35 deletions.
60 changes: 25 additions & 35 deletions arrow-select/src/take.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,19 @@
use std::sync::Arc;

use num::{One, Zero};

use arrow_array::builder::BufferBuilder;
use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::{
bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer,
ScalarBuffer,
bit_util, ArrowNativeType, BooleanBuffer, BooleanBufferBuilder, Buffer,
MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer,
};
use arrow_data::{ArrayData, ArrayDataBuilder};
use arrow_schema::{ArrowError, DataType, FieldRef};

use num::{One, Zero};

/// Take elements by index from [Array], creating a new [Array] from those indexes.
///
/// ```text
Expand Down Expand Up @@ -346,36 +346,31 @@ fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
) -> Result<GenericByteArray<T>, ArrowError> {
let data_len = indices.len();

let bytes_offset = (data_len + 1) * std::mem::size_of::<T::Offset>();
let mut offsets = MutableBuffer::new(bytes_offset);
let mut offsets = Vec::with_capacity(data_len + 1);
offsets.push(T::Offset::default());

let mut values = MutableBuffer::new(0);

let nulls;
if array.null_count() == 0 && indices.null_count() == 0 {
let nulls = if array.null_count() == 0 && indices.null_count() == 0 {
offsets.extend(indices.values().iter().map(|index| {
let s: &[u8] = array.value(index.as_usize()).as_ref();
values.extend_from_slice(s);
T::Offset::usize_as(values.len())
}));
nulls = None
None
} else if indices.null_count() == 0 {
let num_bytes = bit_util::ceil(data_len, 8);

let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
let null_slice = null_buf.as_slice_mut();
let mut null_buf = BooleanBufferBuilder::new(data_len);
null_buf.append_n(data_len, true);
offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
let index = index.as_usize();
if array.is_valid(index) {
let s: &[u8] = array.value(index).as_ref();
values.extend_from_slice(s.as_ref());
} else {
bit_util::unset_bit(null_slice, i);
null_buf.set_bit(i, false)
}
T::Offset::usize_as(values.len())
}));
nulls = Some(null_buf.into());
Some(null_buf.finish().into())
} else if array.null_count() == 0 {
offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
if indices.is_valid(i) {
Expand All @@ -384,12 +379,10 @@ fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
}
T::Offset::usize_as(values.len())
}));
nulls = indices.nulls().map(|b| b.inner().sliced());
indices.nulls().cloned()
} else {
let num_bytes = bit_util::ceil(data_len, 8);

let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true);
let null_slice = null_buf.as_slice_mut();
let mut null_buf = BooleanBufferBuilder::new(data_len);
null_buf.append_n(data_len, true);
offsets.extend(indices.values().iter().enumerate().map(|(i, index)| {
// check index is valid before using index. The value in
// NULL index slots may not be within bounds of array
Expand All @@ -398,25 +391,21 @@ fn take_bytes<T: ByteArrayType, IndexType: ArrowPrimitiveType>(
let s: &[u8] = array.value(index).as_ref();
values.extend_from_slice(s);
} else {
// set null bit
bit_util::unset_bit(null_slice, i);
null_buf.set_bit(i, false)
}
T::Offset::usize_as(values.len())
}));
nulls = Some(null_buf.into())
}
Some(null_buf.finish().into())
};

T::Offset::from_usize(values.len()).expect("offset overflow");

let array_data = ArrayData::builder(T::DATA_TYPE)
.len(data_len)
.add_buffer(offsets.into())
.add_buffer(values.into())
.null_bit_buffer(nulls);

let array_data = unsafe { array_data.build_unchecked() };

Ok(GenericByteArray::from(array_data))
// Safety: safe by construction
let array = unsafe {
let offsets = OffsetBuffer::new_unchecked(offsets.into());
GenericByteArray::<T>::new_unchecked(offsets, values.into(), nulls)
};
Ok(array)
}

/// `take` implementation for list arrays
Expand Down Expand Up @@ -753,10 +742,11 @@ to_indices_reinterpret!(Int64Type, UInt64Type);

#[cfg(test)]
mod tests {
use super::*;
use arrow_array::builder::*;
use arrow_schema::{Field, Fields, TimeUnit};

use super::*;

fn test_take_decimal_arrays(
data: Vec<Option<i128>>,
index: &UInt32Array,
Expand Down

0 comments on commit ca0079a

Please sign in to comment.