Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-encode dictionaries in selection kernels (take / concat_batches) #3558

Merged
merged 24 commits into from
Sep 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions arrow-buffer/src/buffer/immutable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,14 @@ impl Buffer {
length,
})
}

/// Returns true if this [`Buffer`] is equal to `other`, using pointer comparisons
/// to determine buffer equality. This is cheaper than `PartialEq::eq` but may
/// return false when the arrays are logically equal
#[inline]
pub fn ptr_eq(&self, other: &Self) -> bool {
self.ptr == other.ptr && self.length == other.length
}
}

/// Creating a `Buffer` instance by copying the memory from a `AsRef<[u8]>` into a newly
Expand Down
8 changes: 8 additions & 0 deletions arrow-buffer/src/buffer/offset.rs
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,14 @@ impl<O: ArrowNativeType> OffsetBuffer<O> {
pub fn slice(&self, offset: usize, len: usize) -> Self {
Self(self.0.slice(offset, len.saturating_add(1)))
}

/// Returns true if this [`OffsetBuffer`] is equal to `other`, using pointer comparisons
/// to determine buffer equality. This is cheaper than `PartialEq::eq` but may
/// return false when the arrays are logically equal
#[inline]
pub fn ptr_eq(&self, other: &Self) -> bool {
self.0.ptr_eq(&other.0)
}
}

impl<T: ArrowNativeType> Deref for OffsetBuffer<T> {
Expand Down
8 changes: 8 additions & 0 deletions arrow-buffer/src/buffer/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ impl<T: ArrowNativeType> ScalarBuffer<T> {
pub fn into_inner(self) -> Buffer {
self.buffer
}

/// Returns true if this [`ScalarBuffer`] is equal to `other`, using pointer comparisons
/// to determine buffer equality. This is cheaper than `PartialEq::eq` but may
/// return false when the arrays are logically equal
#[inline]
pub fn ptr_eq(&self, other: &Self) -> bool {
self.buffer.ptr_eq(&other.buffer)
}
}

impl<T: ArrowNativeType> Deref for ScalarBuffer<T> {
Expand Down
1 change: 1 addition & 0 deletions arrow-select/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ arrow-data = { workspace = true }
arrow-schema = { workspace = true }
arrow-array = { workspace = true }
num = { version = "0.4", default-features = false, features = ["std"] }
ahash = { version = "0.8", default-features = false}

[features]
default = []
Expand Down
186 changes: 140 additions & 46 deletions arrow-select/src/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,20 @@
//! assert_eq!(arr.len(), 3);
//! ```

use crate::dictionary::{merge_dictionary_values, should_merge_dictionary_values};
use arrow_array::cast::AsArray;
use arrow_array::types::*;
use arrow_array::*;
use arrow_buffer::ArrowNativeType;
use arrow_buffer::{ArrowNativeType, BooleanBufferBuilder, NullBuffer};
use arrow_data::transform::{Capacities, MutableArrayData};
use arrow_schema::{ArrowError, DataType, SchemaRef};
use std::sync::Arc;

fn binary_capacity<T: ByteArrayType>(arrays: &[&dyn Array]) -> Capacities {
let mut item_capacity = 0;
let mut bytes_capacity = 0;
for array in arrays {
let a = array
.as_any()
.downcast_ref::<GenericByteArray<T>>()
.unwrap();
let a = array.as_bytes::<T>();

// Guaranteed to always have at least one element
let offsets = a.value_offsets();
Expand All @@ -54,6 +54,59 @@ fn binary_capacity<T: ByteArrayType>(arrays: &[&dyn Array]) -> Capacities {
Capacities::Binary(item_capacity, Some(bytes_capacity))
}

fn concat_dictionaries<K: ArrowDictionaryKeyType>(
arrays: &[&dyn Array],
) -> Result<ArrayRef, ArrowError> {
let mut output_len = 0;
let dictionaries: Vec<_> = arrays
.iter()
.map(|x| x.as_dictionary::<K>())
.inspect(|d| output_len += d.len())
.collect();

if !should_merge_dictionary_values::<K>(&dictionaries, output_len) {
return concat_fallback(arrays, Capacities::Array(output_len));
}

let merged = merge_dictionary_values(&dictionaries, None)?;

// Recompute keys
let mut key_values = Vec::with_capacity(output_len);

let mut has_nulls = false;
for (d, mapping) in dictionaries.iter().zip(merged.key_mappings) {
has_nulls |= d.null_count() != 0;
for key in d.keys().values() {
// Use get to safely handle nulls
key_values.push(mapping.get(key.as_usize()).copied().unwrap_or_default())
}
}

let nulls = has_nulls.then(|| {
let mut nulls = BooleanBufferBuilder::new(output_len);
for d in &dictionaries {
match d.nulls() {
Some(n) => nulls.append_buffer(n.inner()),
None => nulls.append_n(d.len(), true),
}
}
NullBuffer::new(nulls.finish())
});

let keys = PrimitiveArray::<K>::new(key_values.into(), nulls);
// Sanity check
assert_eq!(keys.len(), output_len);

let array = unsafe { DictionaryArray::new_unchecked(keys, merged.values) };
Ok(Arc::new(array))
}

macro_rules! dict_helper {
($t:ty, $arrays:expr) => {
return Ok(Arc::new(concat_dictionaries::<$t>($arrays)?) as _)
};
}

/// Concatenate multiple [Array] of the same type into a single [ArrayRef].
pub fn concat(arrays: &[&dyn Array]) -> Result<ArrayRef, ArrowError> {
if arrays.is_empty() {
Expand All @@ -78,9 +131,23 @@ pub fn concat(arrays: &[&dyn Array]) -> Result<ArrayRef, ArrowError> {
DataType::LargeUtf8 => binary_capacity::<LargeUtf8Type>(arrays),
DataType::Binary => binary_capacity::<BinaryType>(arrays),
DataType::LargeBinary => binary_capacity::<LargeBinaryType>(arrays),
DataType::Dictionary(k, _) => downcast_integer! {
k.as_ref() => (dict_helper, arrays),
_ => unreachable!("illegal dictionary key type {k}")
},
_ => Capacities::Array(arrays.iter().map(|a| a.len()).sum()),
};

concat_fallback(arrays, capacity)
}

/// Concatenates arrays using MutableArrayData
///
/// This will naively concatenate dictionaries
fn concat_fallback(
tustvold marked this conversation as resolved.
Show resolved Hide resolved
arrays: &[&dyn Array],
capacity: Capacities,
) -> Result<ArrayRef, ArrowError> {
let array_data: Vec<_> = arrays.iter().map(|a| a.to_data()).collect::<Vec<_>>();
let array_data = array_data.iter().collect();
let mut mutable = MutableArrayData::with_capacities(array_data, false, capacity);
Expand Down Expand Up @@ -140,6 +207,7 @@ pub fn concat_batches<'a>(
#[cfg(test)]
mod tests {
use super::*;
use arrow_array::builder::StringDictionaryBuilder;
use arrow_array::cast::AsArray;
use arrow_schema::{Field, Schema};
use std::sync::Arc;
Expand Down Expand Up @@ -468,29 +536,10 @@ mod tests {
}

fn collect_string_dictionary(
dictionary: &DictionaryArray<Int32Type>,
) -> Vec<Option<String>> {
let values = dictionary.values();
let values = values.as_any().downcast_ref::<StringArray>().unwrap();

dictionary
.keys()
.iter()
.map(|key| key.map(|key| values.value(key as _).to_string()))
.collect()
}

fn concat_dictionary(
input_1: DictionaryArray<Int32Type>,
input_2: DictionaryArray<Int32Type>,
) -> Vec<Option<String>> {
let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
let concat = concat
.as_any()
.downcast_ref::<DictionaryArray<Int32Type>>()
.unwrap();

collect_string_dictionary(concat)
array: &DictionaryArray<Int32Type>,
) -> Vec<Option<&str>> {
let concrete = array.downcast_dict::<StringArray>().unwrap();
concrete.into_iter().collect()
}

#[test]
Expand All @@ -509,11 +558,19 @@ mod tests {
"E",
]
.into_iter()
.map(|x| Some(x.to_string()))
.map(Some)
.collect();

let concat = concat_dictionary(input_1, input_2);
assert_eq!(concat, expected);
let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
let dictionary = concat.as_dictionary::<Int32Type>();
let actual = collect_string_dictionary(dictionary);
assert_eq!(actual, expected);

// Should have concatenated inputs together
assert_eq!(
dictionary.values().len(),
input_1.values().len() + input_2.values().len(),
)
}

#[test]
Expand All @@ -523,16 +580,45 @@ mod tests {
.into_iter()
.collect();
let input_2: DictionaryArray<Int32Type> = vec![None].into_iter().collect();
let expected = vec![
Some("foo".to_string()),
Some("bar".to_string()),
None,
Some("fiz".to_string()),
None,
];
let expected = vec![Some("foo"), Some("bar"), None, Some("fiz"), None];

let concat = concat_dictionary(input_1, input_2);
assert_eq!(concat, expected);
let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
let dictionary = concat.as_dictionary::<Int32Type>();
let actual = collect_string_dictionary(dictionary);
assert_eq!(actual, expected);

// Should have concatenated inputs together
assert_eq!(
dictionary.values().len(),
input_1.values().len() + input_2.values().len(),
)
}

#[test]
fn test_string_dictionary_merge() {
let mut builder = StringDictionaryBuilder::<Int32Type>::new();
for i in 0..20 {
builder.append(&i.to_string()).unwrap();
}
let input_1 = builder.finish();

let mut builder = StringDictionaryBuilder::<Int32Type>::new();
for i in 0..30 {
builder.append(&i.to_string()).unwrap();
}
let input_2 = builder.finish();

let expected: Vec<_> = (0..20).chain(0..30).map(|x| x.to_string()).collect();
let expected: Vec<_> = expected.iter().map(|x| Some(x.as_str())).collect();

let concat = concat(&[&input_1 as _, &input_2 as _]).unwrap();
let dictionary = concat.as_dictionary::<Int32Type>();
let actual = collect_string_dictionary(dictionary);
assert_eq!(actual, expected);

// Should have merged inputs together
// Not 30 as this is done on a best-effort basis
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

assert_eq!(dictionary.values().len(), 33)
}

#[test]
Expand All @@ -556,7 +642,7 @@ mod tests {
fn test_dictionary_concat_reuse() {
let array: DictionaryArray<Int8Type> =
vec!["a", "a", "b", "c"].into_iter().collect();
let copy: DictionaryArray<Int8Type> = array.to_data().into();
let copy: DictionaryArray<Int8Type> = array.clone();

// dictionary is "a", "b", "c"
assert_eq!(
Expand All @@ -567,11 +653,7 @@ mod tests {

// concatenate it with itself
let combined = concat(&[&copy as _, &array as _]).unwrap();

let combined = combined
.as_any()
.downcast_ref::<DictionaryArray<Int8Type>>()
.unwrap();
let combined = combined.as_dictionary::<Int8Type>();

assert_eq!(
combined.values(),
Expand Down Expand Up @@ -738,4 +820,16 @@ mod tests {
assert_eq!(data.buffers()[1].len(), 200);
assert_eq!(data.buffers()[1].capacity(), 256); // Nearest multiple of 64
}

#[test]
fn concat_sparse_nulls() {
let values = StringArray::from_iter_values((0..100).map(|x| x.to_string()));
let keys = Int32Array::from(vec![1; 10]);
let dict_a = DictionaryArray::new(keys, Arc::new(values));
let values = StringArray::new_null(0);
let keys = Int32Array::new_null(10);
let dict_b = DictionaryArray::new(keys, Arc::new(values));
let array = concat(&[&dict_a, &dict_b]).unwrap();
assert_eq!(array.null_count(), 10);
}
}
Loading
Loading