diff --git a/arrow-select/src/take.rs b/arrow-select/src/take.rs index 70b80e5878dd..c28e332eca15 100644 --- a/arrow-select/src/take.rs +++ b/arrow-select/src/take.rs @@ -19,19 +19,19 @@ use std::sync::Arc; +use num::{One, Zero}; + use arrow_array::builder::BufferBuilder; use arrow_array::cast::AsArray; use arrow_array::types::*; use arrow_array::*; use arrow_buffer::{ - bit_util, ArrowNativeType, BooleanBuffer, Buffer, MutableBuffer, NullBuffer, - ScalarBuffer, + bit_util, ArrowNativeType, BooleanBuffer, BooleanBufferBuilder, Buffer, + MutableBuffer, NullBuffer, OffsetBuffer, ScalarBuffer, }; -use arrow_data::{ArrayData, ArrayDataBuilder}; +use arrow_data::ArrayDataBuilder; use arrow_schema::{ArrowError, DataType, FieldRef}; -use num::{One, Zero}; - /// Take elements by index from [Array], creating a new [Array] from those indexes. /// /// ```text @@ -346,36 +346,31 @@ fn take_bytes( ) -> Result, ArrowError> { let data_len = indices.len(); - let bytes_offset = (data_len + 1) * std::mem::size_of::(); - let mut offsets = MutableBuffer::new(bytes_offset); + let mut offsets = Vec::with_capacity(data_len + 1); offsets.push(T::Offset::default()); let mut values = MutableBuffer::new(0); - - let nulls; - if array.null_count() == 0 && indices.null_count() == 0 { + let nulls = if array.null_count() == 0 && indices.null_count() == 0 { offsets.extend(indices.values().iter().map(|index| { let s: &[u8] = array.value(index.as_usize()).as_ref(); values.extend_from_slice(s); T::Offset::usize_as(values.len()) })); - nulls = None + None } else if indices.null_count() == 0 { - let num_bytes = bit_util::ceil(data_len, 8); - - let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); - let null_slice = null_buf.as_slice_mut(); + let mut null_buf = BooleanBufferBuilder::new(data_len); + null_buf.append_n(data_len, true); offsets.extend(indices.values().iter().enumerate().map(|(i, index)| { let index = index.as_usize(); if array.is_valid(index) { let s: &[u8] = array.value(index).as_ref(); values.extend_from_slice(s.as_ref()); } else { - bit_util::unset_bit(null_slice, i); + null_buf.set_bit(i, false) } T::Offset::usize_as(values.len()) })); - nulls = Some(null_buf.into()); + Some(null_buf.finish().into()) } else if array.null_count() == 0 { offsets.extend(indices.values().iter().enumerate().map(|(i, index)| { if indices.is_valid(i) { @@ -384,12 +379,10 @@ fn take_bytes( } T::Offset::usize_as(values.len()) })); - nulls = indices.nulls().map(|b| b.inner().sliced()); + indices.nulls().cloned() } else { - let num_bytes = bit_util::ceil(data_len, 8); - - let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, true); - let null_slice = null_buf.as_slice_mut(); + let mut null_buf = BooleanBufferBuilder::new(data_len); + null_buf.append_n(data_len, true); offsets.extend(indices.values().iter().enumerate().map(|(i, index)| { // check index is valid before using index. The value in // NULL index slots may not be within bounds of array @@ -398,25 +391,21 @@ fn take_bytes( let s: &[u8] = array.value(index).as_ref(); values.extend_from_slice(s); } else { - // set null bit - bit_util::unset_bit(null_slice, i); + null_buf.set_bit(i, false) } T::Offset::usize_as(values.len()) })); - nulls = Some(null_buf.into()) - } + Some(null_buf.finish().into()) + }; T::Offset::from_usize(values.len()).expect("offset overflow"); - let array_data = ArrayData::builder(T::DATA_TYPE) - .len(data_len) - .add_buffer(offsets.into()) - .add_buffer(values.into()) - .null_bit_buffer(nulls); - - let array_data = unsafe { array_data.build_unchecked() }; - - Ok(GenericByteArray::from(array_data)) + // Safety: safe by construction + let array = unsafe { + let offsets = OffsetBuffer::new_unchecked(offsets.into()); + GenericByteArray::::new_unchecked(offsets, values.into(), nulls) + }; + Ok(array) } /// `take` implementation for list arrays @@ -753,10 +742,11 @@ to_indices_reinterpret!(Int64Type, UInt64Type); #[cfg(test)] mod tests { - use super::*; use arrow_array::builder::*; use arrow_schema::{Field, Fields, TimeUnit}; + use super::*; + fn test_take_decimal_arrays( data: Vec>, index: &UInt32Array,