Skip to content

Commit

Permalink
update new implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
XiangpengHao committed Jul 12, 2024
1 parent df9af6b commit 0dbcc3e
Showing 1 changed file with 27 additions and 98 deletions.
125 changes: 27 additions & 98 deletions datafusion/physical-expr-common/src/binary_view_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,8 @@
//! StringArray / LargeStringArray / BinaryArray / LargeBinaryArray.
use ahash::RandomState;
use arrow::array::cast::AsArray;
use arrow::array::{
Array, ArrayRef, BinaryViewArray, BooleanBufferBuilder, GenericByteViewArray,
StringViewArray,
};
use arrow::buffer::{Buffer, NullBuffer, ScalarBuffer};
use arrow::array::{Array, ArrayBuilder, ArrayRef, GenericByteViewBuilder};
use arrow::datatypes::{BinaryViewType, ByteViewType, DataType, StringViewType};
use arrow_data::ByteView;
use datafusion_common::hash_utils::create_hashes;
use datafusion_common::utils::proxy::{RawTableAllocExt, VecAllocExt};
use std::fmt::Debug;
Expand Down Expand Up @@ -127,8 +122,8 @@ where
map: hashbrown::raw::RawTable<Entry<V>>,
/// Total size of the map in bytes
map_size: usize,
view_buffers: Vec<Buffer>,
views: Vec<u128>,

builder: GenericByteViewBuilder<BinaryViewType>,
/// random state used to generate hashes
random_state: RandomState,
/// buffer that stores hash values (reused across batches to save allocations)
Expand All @@ -137,10 +132,6 @@ where
/// NOTE null_index is the logical index in the final array, not the index
/// in the buffer
null: Option<(V, usize)>,

// the length of the input array. Used to determine if we want to gc the output array
// to avoid holding the view_buffers too long.
input_len: usize,
}

/// The size, in number of entries, of the initial hash table
Expand All @@ -155,12 +146,10 @@ where
output_type,
map: hashbrown::raw::RawTable::with_capacity(INITIAL_MAP_CAPACITY),
map_size: 0,
views: Vec::new(),
view_buffers: Vec::new(),
builder: GenericByteViewBuilder::new(),
random_state: RandomState::new(),
hashes_buffer: vec![],
null: None,
input_len: 0,
}
}

Expand Down Expand Up @@ -258,25 +247,18 @@ where
// step 2: insert each value into the set, if not already present
let values = values.as_byte_view::<B>();

self.input_len += values.len();

let buffer_offset = self.view_buffers.len();
self.view_buffers.extend_from_slice(values.data_buffers());

// Ensure lengths are equivalent
assert_eq!(values.len(), batch_hashes.len());

for (view_idx, (value, &hash)) in
values.iter().zip(batch_hashes.iter()).enumerate()
{
for (value, &hash) in values.iter().zip(batch_hashes.iter()) {
// handle null value
let Some(value) = value else {
let payload = if let Some(&(payload, _offset)) = self.null.as_ref() {
payload
} else {
let payload = make_payload_fn(None);
let null_index = self.views.len();
self.views.push(0);
let null_index = self.builder.len();
self.builder.append_null();
self.null = Some((payload, null_index));
payload
};
Expand All @@ -288,32 +270,13 @@ where
let value: &[u8] = value.as_ref();

let entry = self.map.get_mut(hash, |header| {
let v = unsafe { self.views.get_unchecked(header.view_idx) };
let v = self.builder.get_value(header.view_idx);

let len = *v as u32;
if len as usize != value.len() {
if v.len() != value.len() {
return false;
}

// We should probably change arrow-rs to provide a
// GenericByteViewArray::value_from_view() method
let b = if len <= 12 {
unsafe {
GenericByteViewArray::<BinaryViewType>::inline_value(
v,
len as usize,
)
}
} else {
let view = ByteView::from(*v);

let data = unsafe {
self.view_buffers.get_unchecked(view.buffer_index as usize)
};
let offset = view.offset as usize;
unsafe { data.get_unchecked(offset..offset + len as usize) }
};
b == value
v == value
});

let payload = if let Some(entry) = entry {
Expand All @@ -322,21 +285,14 @@ where
// no existing value, make a new one.
let payload = make_payload_fn(Some(value));

let inner_view_idx = self.views.len();
let inner_view_idx = self.builder.len();
let new_header = Entry {
view_idx: inner_view_idx,
payload,
};

let view = if value.len() <= 12 {
unsafe { *values.views().get_unchecked(view_idx) }
} else {
let v = unsafe { *values.views().get_unchecked(view_idx) };
let mut v = ByteView::from(v);
v.buffer_index += buffer_offset as u32;
v.as_u128()
};
self.views.push(view);
self.builder.append_value(value);

self.map
.insert_accounted(new_header, |_| hash, &mut self.map_size);
payload
Expand All @@ -352,40 +308,20 @@ where
/// The values are guaranteed to be returned in the same order in which
/// they were first seen.
pub fn into_state(self) -> ArrayRef {
// Only make a `NullBuffer` if there was a null value
let nulls = self.null.map(|(_payload, null_index)| {
let num_values = self.views.len();
single_null_buffer(num_values, null_index)
});

let len = self.views.len();
let b = ScalarBuffer::new(Buffer::from_vec(self.views), 0, len);
let mut builder = self.builder;
match self.output_type {
OutputType::BinaryView => {
// SAFETY: the offsets were constructed correctly
let mut array = unsafe {
BinaryViewArray::new_unchecked(b, self.view_buffers, nulls)
};
if array.len() < (self.input_len / 2) {
// arrow gc by default will deduplicate strings, it should not do that.
// todo: file a ticket to change it.
array = array.gc();
}
let array = builder.finish();

Arc::new(array)
}
OutputType::Utf8View => {
// SAFETY:
// 1. the offsets were constructed safely
//
// 2. we asserted the input arrays were all the correct type and
// we asserted the input arrays were all the correct type and
// thus since all the values that went in were valid (e.g. utf8)
// so are all the values that come out
let mut array = unsafe {
StringViewArray::new_unchecked(b, self.view_buffers, nulls)
};
if array.len() < (self.input_len / 2) {
array = array.gc();
}
let array = builder.finish();
let array = unsafe { array.to_string_view_unchecked() };
Arc::new(array)
}
}
Expand All @@ -409,19 +345,12 @@ where
/// Return the total size, in bytes, of memory used to store the data in
/// this set, not including `self`
pub fn size(&self) -> usize {
// view buffers are from upstream string view, not technically from us.
self.map_size + self.views.allocated_size() + self.hashes_buffer.allocated_size()
self.map_size
+ self.builder.allocated_size()
+ self.hashes_buffer.allocated_size()
}
}

/// Returns a `NullBuffer` with a single null value at the given index
fn single_null_buffer(num_values: usize, null_index: usize) -> NullBuffer {
let mut bool_builder = BooleanBufferBuilder::new(num_values);
bool_builder.append_n(num_values, true);
bool_builder.set_bit(null_index, false);
NullBuffer::from(bool_builder.finish())
}

impl<V> Debug for ArrowBytesViewMap<V>
where
V: Debug + PartialEq + Eq + Clone + Copy + Default,
Expand All @@ -430,7 +359,7 @@ where
f.debug_struct("ArrowBytesMap")
.field("map", &"<map>")
.field("map_size", &self.map_size)
.field("views", &self.views)
.field("view_builder", &self.builder)
.field("random_state", &self.random_state)
.field("hashes_buffer", &self.hashes_buffer)
.finish()
Expand All @@ -452,6 +381,7 @@ where

#[cfg(test)]
mod tests {
use arrow::array::{BinaryViewArray, GenericByteViewArray, StringViewArray};
use hashbrown::HashMap;

use super::*;
Expand Down Expand Up @@ -608,7 +538,7 @@ mod tests {
"BAR larger than 12 bytes.".repeat(1000),
"more unique.".repeat(1000),
"more unique2.".repeat(1000),
"BAZ".repeat(3000),
"FOO".repeat(3000),
]);
let total_strings2_len = strings2
.iter()
Expand All @@ -631,15 +561,14 @@ mod tests {
// inserting the same strings should not affect the size
set.insert(&values1);
assert_eq!(set.size(), size_after_values1);
assert_eq!(set.len(), 5);

// inserting the large strings should increase the reported size
set.insert(&values2);
let size_after_values2 = set.size();
assert!(size_after_values2 > size_after_values1);

// the consumed size should be less than the sum of the sizes of the strings
// bc the strings are deduplicated
assert!(size_after_values2 < total_strings1_len + total_strings2_len);
assert_eq!(set.len(), 10);
}

#[derive(Debug, PartialEq, Eq, Default, Clone, Copy)]
Expand Down

0 comments on commit 0dbcc3e

Please sign in to comment.