diff --git a/datafusion/physical-expr-common/src/binary_view_map.rs b/datafusion/physical-expr-common/src/binary_view_map.rs index c9b89d71145b..c030f78e2edd 100644 --- a/datafusion/physical-expr-common/src/binary_view_map.rs +++ b/datafusion/physical-expr-common/src/binary_view_map.rs @@ -19,13 +19,8 @@ //! StringArray / LargeStringArray / BinaryArray / LargeBinaryArray. use ahash::RandomState; use arrow::array::cast::AsArray; -use arrow::array::{ - Array, ArrayRef, BinaryViewArray, BooleanBufferBuilder, GenericByteViewArray, - StringViewArray, -}; -use arrow::buffer::{Buffer, NullBuffer, ScalarBuffer}; +use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder}; use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType}; -use arrow_data::ByteView; use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt}; use std::fmt::Debug; @@ -127,8 +122,8 @@ where map: hashbrown::raw::RawTable>, /// Total size of the map in bytes map_size: usize, - view_buffers: Vec, - views: Vec, + + builder: GenericByteViewBuilder, /// random state used to generate hashes random_state: RandomState, /// buffer that stores hash values (reused across batches to save allocations) @@ -137,10 +132,6 @@ where /// NOTE null_index is the logical index in the final array, not the index /// in the buffer null: Option<(V, usize)>, - - // the length of the input array. Used to determine if we want to gc the output array - // to avoid holding the view_buffers too long. - input_len: usize, } /// The size, in number of entries, of the initial hash table @@ -155,12 +146,10 @@ where output_type, map: hashbrown::raw::RawTable::with_capacity(INITIAL_MAP_CAPACITY), map_size: 0, - views: Vec::new(), - view_buffers: Vec::new(), + builder: GenericByteViewBuilder::new(), random_state: RandomState::new(), hashes_buffer: vec![], null: None, - input_len: 0, } } @@ -258,25 +247,18 @@ where // step 2: insert each value into the set, if not already present let values = values.as_byte_view::(); - self.input_len += values.len(); - - let buffer_offset = self.view_buffers.len(); - self.view_buffers.extend_from_slice(values.data_buffers()); - // Ensure lengths are equivalent assert_eq!(values.len(), batch_hashes.len()); - for (view_idx, (value, &hash)) in - values.iter().zip(batch_hashes.iter()).enumerate() - { + for (value, &hash) in values.iter().zip(batch_hashes.iter()) { // handle null value let Some(value) = value else { let payload = if let Some(&(payload, _offset)) = self.null.as_ref() { payload } else { let payload = make_payload_fn(None); - let null_index = self.views.len(); - self.views.push(0); + let null_index = self.builder.len(); + self.builder.append_null(); self.null = Some((payload, null_index)); payload }; @@ -288,32 +270,13 @@ where let value: &[u8] = value.as_ref(); let entry = self.map.get_mut(hash, |header| { - let v = unsafe { self.views.get_unchecked(header.view_idx) }; + let v = self.builder.get_value(header.view_idx); - let len = *v as u32; - if len as usize != value.len() { + if v.len() != value.len() { return false; } - // We should probably change arrow-rs to provide a - // GenericByteViewArray::value_from_view() method - let b = if len <= 12 { - unsafe { - GenericByteViewArray::::inline_value( - v, - len as usize, - ) - } - } else { - let view = ByteView::from(*v); - - let data = unsafe { - self.view_buffers.get_unchecked(view.buffer_index as usize) - }; - let offset = view.offset as usize; - unsafe { data.get_unchecked(offset..offset + len as usize) } - }; - b == value + v == value }); let payload = if let Some(entry) = entry { @@ -322,21 +285,14 @@ where // no existing value, make a new one. let payload = make_payload_fn(Some(value)); - let inner_view_idx = self.views.len(); + let inner_view_idx = self.builder.len(); let new_header = Entry { view_idx: inner_view_idx, payload, }; - let view = if value.len() <= 12 { - unsafe { *values.views().get_unchecked(view_idx) } - } else { - let v = unsafe { *values.views().get_unchecked(view_idx) }; - let mut v = ByteView::from(v); - v.buffer_index += buffer_offset as u32; - v.as_u128() - }; - self.views.push(view); + self.builder.append_value(value); + self.map .insert_accounted(new_header, |_| hash, &mut self.map_size); payload @@ -352,40 +308,20 @@ where /// The values are guaranteed to be returned in the same order in which /// they were first seen. pub fn into_state(self) -> ArrayRef { - // Only make a `NullBuffer` if there was a null value - let nulls = self.null.map(|(_payload, null_index)| { - let num_values = self.views.len(); - single_null_buffer(num_values, null_index) - }); - - let len = self.views.len(); - let b = ScalarBuffer::new(Buffer::from_vec(self.views), 0, len); + let mut builder = self.builder; match self.output_type { OutputType::BinaryView => { - // SAFETY: the offsets were constructed correctly - let mut array = unsafe { - BinaryViewArray::new_unchecked(b, self.view_buffers, nulls) - }; - if array.len() < (self.input_len / 2) { - // arrow gc by default will deduplicate strings, it should not do that. - // todo: file a ticket to change it. - array = array.gc(); - } + let array = builder.finish(); + Arc::new(array) } OutputType::Utf8View => { // SAFETY: - // 1. the offsets were constructed safely - // - // 2. we asserted the input arrays were all the correct type and + // we asserted the input arrays were all the correct type and // thus since all the values that went in were valid (e.g. utf8) // so are all the values that come out - let mut array = unsafe { - StringViewArray::new_unchecked(b, self.view_buffers, nulls) - }; - if array.len() < (self.input_len / 2) { - array = array.gc(); - } + let array = builder.finish(); + let array = unsafe { array.to_string_view_unchecked() }; Arc::new(array) } } @@ -409,19 +345,12 @@ where /// Return the total size, in bytes, of memory used to store the data in /// this set, not including `self` pub fn size(&self) -> usize { - // view buffers are from upstream string view, not technically from us. - self.map_size + self.views.allocated_size() + self.hashes_buffer.allocated_size() + self.map_size + + self.builder.allocated_size() + + self.hashes_buffer.allocated_size() } } -/// Returns a `NullBuffer` with a single null value at the given index -fn single_null_buffer(num_values: usize, null_index: usize) -> NullBuffer { - let mut bool_builder = BooleanBufferBuilder::new(num_values); - bool_builder.append_n(num_values, true); - bool_builder.set_bit(null_index, false); - NullBuffer::from(bool_builder.finish()) -} - impl Debug for ArrowBytesViewMap where V: Debug + PartialEq + Eq + Clone + Copy + Default, @@ -430,7 +359,7 @@ where f.debug_struct("ArrowBytesMap") .field("map", &"") .field("map_size", &self.map_size) - .field("views", &self.views) + .field("view_builder", &self.builder) .field("random_state", &self.random_state) .field("hashes_buffer", &self.hashes_buffer) .finish() @@ -452,6 +381,7 @@ where #[cfg(test)] mod tests { + use arrow::array::{BinaryViewArray, GenericByteViewArray, StringViewArray}; use hashbrown::HashMap; use super::*; @@ -608,7 +538,7 @@ mod tests { "BAR larger than 12 bytes.".repeat(1000), "more unique.".repeat(1000), "more unique2.".repeat(1000), - "BAZ".repeat(3000), + "FOO".repeat(3000), ]); let total_strings2_len = strings2 .iter() @@ -631,15 +561,14 @@ mod tests { // inserting the same strings should not affect the size set.insert(&values1); assert_eq!(set.size(), size_after_values1); + assert_eq!(set.len(), 5); // inserting the large strings should increase the reported size set.insert(&values2); let size_after_values2 = set.size(); assert!(size_after_values2 > size_after_values1); - // the consumed size should be less than the sum of the sizes of the strings - // bc the strings are deduplicated - assert!(size_after_values2 < total_strings1_len + total_strings2_len); + assert_eq!(set.len(), 10); } #[derive(Debug, PartialEq, Eq, Default, Clone, Copy)]