-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize COUNT( DISTINCT ...)
for strings (up to 9x faster)
#8849
Conversation
|
||
const SHORT_STRING_LEN: usize = mem::size_of::<usize>(); | ||
|
||
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allow Copy since they are all native types
Thanks @jayzhan211 -- looks basically on the right track. Is there any chance you can run some sort of benchmark on this code? My thinking is that we should get benchmark results showing that the idea actually improves performance before spending too much time polishing I looked at ClickBench and I don't actually think there are any queries that do Q8 looks like it should be helped SELECT COUNT(DISTINCT "SearchPhrase") FROM hits; however, I am pretty sure datfusion rewrites this query to avoid the distinct with Maybe you could try manually runing a query that can't be rewritten (throw un multiple ❯ SELECT
COUNT(DISTINCT "SearchPhrase"),
COUNT(DISTINCT "MobilePhone"),
COUNT(DISTINCT "MobilePhoneModel")
FROM 'hits.parquet';
+-------------------------------------------+------------------------------------------+-----------------------------------------------+
| COUNT(DISTINCT hits.parquet.SearchPhrase) | COUNT(DISTINCT hits.parquet.MobilePhone) | COUNT(DISTINCT hits.parquet.MobilePhoneModel) |
+-------------------------------------------+------------------------------------------+-----------------------------------------------+
| 6019103 | 44 | 166 |
+-------------------------------------------+------------------------------------------+-----------------------------------------------+ |
I test with injecting string array manually to find out if SSO is better than simple HashSet Compare with n=1e5
Mostly non distinct case
All distinct case
I thought the more the small string is the more performance gains, but it shows that the more the long string is the better TLDR testing file// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use arrow::datatypes::{DataType, Field, TimeUnit};
use arrow_array::types::{
ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::{PrimitiveArray, StringArray};
use arrow_buffer::BufferBuilder;
use chrono::format;
use rand::Rng;
use std::any::Any;
use std::cmp::Eq;
use std::fmt::Debug;
use std::hash::Hash;
use std::mem;
use std::sync::Arc;
use ahash::RandomState;
use arrow::array::{Array, ArrayRef};
use std::collections::HashSet;
use crate::aggregate::utils::{down_cast_any_ref, Hashable};
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::cast::{as_list_array, as_primitive_array, as_string_array};
use datafusion_common::utils::array_into_list_array;
use datafusion_common::{Result, ScalarValue};
use datafusion_execution::memory_pool::proxy::RawTableAllocExt;
use datafusion_expr::Accumulator;
type DistinctScalarValues = ScalarValue;
/// Expression for a COUNT(DISTINCT) aggregation.
#[derive(Debug)]
pub struct DistinctCount {
/// Column name
name: String,
/// The DataType used to hold the state for each input
state_data_type: DataType,
/// The input arguments
expr: Arc<dyn PhysicalExpr>,
}
impl DistinctCount {
/// Create a new COUNT(DISTINCT) aggregate function.
pub fn new(
input_data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
name: String,
) -> Self {
Self {
name,
state_data_type: input_data_type,
expr,
}
}
}
macro_rules! native_distinct_count_accumulator {
($TYPE:ident) => {{
Ok(Box::new(NativeDistinctCountAccumulator::<$TYPE>::new()))
}};
}
macro_rules! float_distinct_count_accumulator {
($TYPE:ident) => {{
Ok(Box::new(FloatDistinctCountAccumulator::<$TYPE>::new()))
}};
}
impl AggregateExpr for DistinctCount {
/// Return a reference to Any that can be used for downcasting
fn as_any(&self) -> &dyn Any {
self
}
fn field(&self) -> Result<Field> {
Ok(Field::new(&self.name, DataType::Int64, true))
}
fn state_fields(&self) -> Result<Vec<Field>> {
Ok(vec![Field::new_list(
format_state_name(&self.name, "count distinct"),
Field::new("item", self.state_data_type.clone(), true),
false,
)])
}
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
vec![self.expr.clone()]
}
fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
use DataType::*;
use TimeUnit::*;
match &self.state_data_type {
Int8 => native_distinct_count_accumulator!(Int8Type),
Int16 => native_distinct_count_accumulator!(Int16Type),
Int32 => native_distinct_count_accumulator!(Int32Type),
Int64 => native_distinct_count_accumulator!(Int64Type),
UInt8 => native_distinct_count_accumulator!(UInt8Type),
UInt16 => native_distinct_count_accumulator!(UInt16Type),
UInt32 => native_distinct_count_accumulator!(UInt32Type),
UInt64 => native_distinct_count_accumulator!(UInt64Type),
Decimal128(_, _) => native_distinct_count_accumulator!(Decimal128Type),
Decimal256(_, _) => native_distinct_count_accumulator!(Decimal256Type),
Date32 => native_distinct_count_accumulator!(Date32Type),
Date64 => native_distinct_count_accumulator!(Date64Type),
Time32(Millisecond) => {
native_distinct_count_accumulator!(Time32MillisecondType)
}
Time32(Second) => {
native_distinct_count_accumulator!(Time32SecondType)
}
Time64(Microsecond) => {
native_distinct_count_accumulator!(Time64MicrosecondType)
}
Time64(Nanosecond) => {
native_distinct_count_accumulator!(Time64NanosecondType)
}
Timestamp(Microsecond, _) => {
native_distinct_count_accumulator!(TimestampMicrosecondType)
}
Timestamp(Millisecond, _) => {
native_distinct_count_accumulator!(TimestampMillisecondType)
}
Timestamp(Nanosecond, _) => {
native_distinct_count_accumulator!(TimestampNanosecondType)
}
Timestamp(Second, _) => {
native_distinct_count_accumulator!(TimestampSecondType)
}
Float16 => float_distinct_count_accumulator!(Float16Type),
Float32 => float_distinct_count_accumulator!(Float32Type),
Float64 => float_distinct_count_accumulator!(Float64Type),
Utf8 => Ok(Box::new(StringDistinctCountAccumulator::new())),
_ => Ok(Box::new(DistinctCountAccumulator {
values: HashSet::default(),
state_data_type: self.state_data_type.clone(),
})),
}
}
fn name(&self) -> &str {
&self.name
}
}
impl PartialEq<dyn Any> for DistinctCount {
fn eq(&self, other: &dyn Any) -> bool {
down_cast_any_ref(other)
.downcast_ref::<Self>()
.map(|x| {
self.name == x.name
&& self.state_data_type == x.state_data_type
&& self.expr.eq(&x.expr)
})
.unwrap_or(false)
}
}
#[derive(Debug)]
struct DistinctCountAccumulator {
values: HashSet<DistinctScalarValues, RandomState>,
state_data_type: DataType,
}
impl DistinctCountAccumulator {
// calculating the size for fixed length values, taking first batch size * number of batches
// This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types
fn fixed_size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<DistinctScalarValues>() * self.values.capacity())
+ self
.values
.iter()
.next()
.map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals))
.unwrap_or(0)
+ std::mem::size_of::<DataType>()
}
// calculates the size as accurate as possible, call to this method is expensive
fn full_size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<DistinctScalarValues>() * self.values.capacity())
+ self
.values
.iter()
.map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals))
.sum::<usize>()
+ std::mem::size_of::<DataType>()
}
}
fn get_vec_string(n: usize) -> Vec<String> {
// Create a vector and generate random strings
let mut rng = rand::thread_rng();
let random_strings: Vec<String> = (0..n)
.map(|_| {
let random_char = match rng.gen_range(0..3) {
0 => "a",
1 => "b",
2 => "cccccccc",
_ => unreachable!(), // This should never happen
};
random_char.to_string()
})
.collect();
random_strings
}
fn get_distinct_string(n: usize) -> Vec<String> {
let distinct_strings: Vec<String> = (1..=n)
.map(|i| {
format!("{}{}", i.to_string(), "aaaaaaaa")
// if i < n / 2 {
// i.to_string()
// } else {
// }
})
.collect();
distinct_strings
// let mut rng = rand::thread_rng();
// let random_strings: Vec<String> = (0..n)
// .map(|_| {
// let random_char = match rng.gen_range(0..usize::MAX) {
// x => x.to_string(),
// };
// random_char.to_string()
// })
// .collect();
// random_strings
}
impl Accumulator for DistinctCountAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
let scalars = self.values.iter().cloned().collect::<Vec<_>>();
let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type);
Ok(vec![ScalarValue::List(arr)])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
let arr = &values[0];
if arr.data_type() == &DataType::Null {
return Ok(());
}
(0..arr.len()).try_for_each(|index| {
if !arr.is_null(index) {
let scalar = ScalarValue::try_from_array(arr, index)?;
self.values.insert(scalar);
}
Ok(())
})
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(states.len(), 1, "array_agg states must be singleton!");
let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
for scalars in scalar_vec.into_iter() {
self.values.extend(scalars);
}
Ok(())
}
fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
}
fn size(&self) -> usize {
match &self.state_data_type {
DataType::Boolean | DataType::Null => self.fixed_size(),
d if d.is_primitive() => self.fixed_size(),
_ => self.full_size(),
}
}
}
#[derive(Debug)]
struct NativeDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
T::Native: Eq + Hash,
{
values: HashSet<T::Native, RandomState>,
}
impl<T> NativeDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
T::Native: Eq + Hash,
{
fn new() -> Self {
Self {
values: HashSet::default(),
}
}
}
impl<T> Accumulator for NativeDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send + Debug,
T::Native: Eq + Hash,
{
fn state(&self) -> Result<Vec<ScalarValue>> {
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
self.values.iter().cloned(),
)) as ArrayRef;
let list = Arc::new(array_into_list_array(arr));
Ok(vec![ScalarValue::List(list)])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
let arr = as_primitive_array::<T>(&values[0])?;
arr.iter().for_each(|value| {
if let Some(value) = value {
self.values.insert(value);
}
});
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(
states.len(),
1,
"count_distinct states must be single array"
);
let arr = as_list_array(&states[0])?;
arr.iter().try_for_each(|maybe_list| {
if let Some(list) = maybe_list {
let list = as_primitive_array::<T>(&list)?;
self.values.extend(list.values())
};
Ok(())
})
}
fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
}
fn size(&self) -> usize {
let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX)
/ 7)
.next_power_of_two();
// Size of accumulator
// + size of entry * number of buckets
// + 1 byte for each bucket
// + fixed size of HashSet
std::mem::size_of_val(self)
+ std::mem::size_of::<T::Native>() * estimated_buckets
+ estimated_buckets
+ std::mem::size_of_val(&self.values)
}
}
#[derive(Debug)]
struct FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
{
values: HashSet<Hashable<T::Native>, RandomState>,
}
impl<T> FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
{
fn new() -> Self {
Self {
values: HashSet::default(),
}
}
}
impl<T> Accumulator for FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send + Debug,
{
fn state(&self) -> Result<Vec<ScalarValue>> {
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
self.values.iter().map(|v| v.0),
)) as ArrayRef;
let list = Arc::new(array_into_list_array(arr));
Ok(vec![ScalarValue::List(list)])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
let arr = as_primitive_array::<T>(&values[0])?;
arr.iter().for_each(|value| {
if let Some(value) = value {
self.values.insert(Hashable(value));
}
});
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(
states.len(),
1,
"count_distinct states must be single array"
);
let arr = as_list_array(&states[0])?;
arr.iter().try_for_each(|maybe_list| {
if let Some(list) = maybe_list {
let list = as_primitive_array::<T>(&list)?;
self.values
.extend(list.values().iter().map(|v| Hashable(*v)));
};
Ok(())
})
}
fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
}
fn size(&self) -> usize {
let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX)
/ 7)
.next_power_of_two();
// Size of accumulator
// + size of entry * number of buckets
// + 1 byte for each bucket
// + fixed size of HashSet
std::mem::size_of_val(self)
+ std::mem::size_of::<T::Native>() * estimated_buckets
+ estimated_buckets
+ std::mem::size_of_val(&self.values)
}
}
#[derive(Debug)]
struct StringDistinctCountAccumulator2(HashSet<String>);
impl StringDistinctCountAccumulator2 {
fn new() -> Self {
Self(HashSet::new())
}
}
impl Accumulator for StringDistinctCountAccumulator2 {
fn state(&self) -> Result<Vec<ScalarValue>> {
let arr = StringArray::from_iter_values(self.0.iter());
let list = Arc::new(array_into_list_array(Arc::new(arr)));
Ok(vec![ScalarValue::List(list)])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
let vs = get_distinct_string(100000);
// let vs = get_vec_string(100000);
let arr = &StringArray::from_iter_values(vs);
// let arr = as_string_array(&values[0])?;
arr.iter().for_each(|value| {
if let Some(value) = value {
self.0.insert(value.to_string());
}
});
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(
states.len(),
1,
"count_distinct states must be single array"
);
let arr = as_list_array(&states[0])?;
arr.iter().try_for_each(|maybe_list| {
if let Some(list) = maybe_list {
let list = as_string_array(&list)?;
list.iter().for_each(|value| {
if let Some(value) = value {
self.0.insert(value.to_string());
}
})
};
Ok(())
})
}
fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.0.len() as i64)))
}
fn size(&self) -> usize {
// Size of accumulator
// + SSOStringHashSet size
std::mem::size_of_val(self) + 0
}
}
#[derive(Debug)]
struct StringDistinctCountAccumulator(SSOStringHashSet);
impl StringDistinctCountAccumulator {
fn new() -> Self {
Self(SSOStringHashSet::new())
}
}
impl Accumulator for StringDistinctCountAccumulator {
fn state(&self) -> Result<Vec<ScalarValue>> {
let arr = StringArray::from_iter_values(self.0.iter());
let list = Arc::new(array_into_list_array(Arc::new(arr)));
Ok(vec![ScalarValue::List(list)])
}
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}
let vs = get_distinct_string(100000);
// let vs = get_vec_string(100000);
let arr = &StringArray::from_iter_values(vs);
// let arr = as_string_array(&values[0])?;
arr.iter().for_each(|value| {
if let Some(value) = value {
self.0.insert(value);
}
});
Ok(())
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(
states.len(),
1,
"count_distinct states must be single array"
);
let arr = as_list_array(&states[0])?;
arr.iter().try_for_each(|maybe_list| {
if let Some(list) = maybe_list {
let list = as_string_array(&list)?;
list.iter().for_each(|value| {
if let Some(value) = value {
self.0.insert(value);
}
})
};
Ok(())
})
}
fn evaluate(&self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.0.len() as i64)))
}
fn size(&self) -> usize {
// Size of accumulator
// + SSOStringHashSet size
std::mem::size_of_val(self) + self.0.size()
}
}
const SHORT_STRING_LEN: usize = mem::size_of::<usize>();
#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
struct SSOStringHeader {
/// hash of the string value (used when resizing table)
hash: u64,
len: usize,
offset_or_inline: usize,
}
impl SSOStringHeader {
fn evaluate(&self, buffer: &[u8]) -> String {
if self.len <= SHORT_STRING_LEN {
self.offset_or_inline.to_string()
} else {
let offset = self.offset_or_inline;
// SAFETY: buffer is only appended to, and we correctly inserted values
unsafe {
std::str::from_utf8_unchecked(
buffer.get_unchecked(offset..offset + self.len),
)
}
.to_string()
}
}
}
// Short String Optimizated HashSet for String
// Equivalent to HashSet<String> but with better memory usage
#[derive(Default)]
struct SSOStringHashSet {
header_set: HashSet<SSOStringHeader>,
long_string_map: hashbrown::raw::RawTable<SSOStringHeader>,
map_size: usize,
buffer: BufferBuilder<u8>,
state: RandomState,
}
impl SSOStringHashSet {
fn new() -> Self {
Self::default()
}
fn insert(&mut self, value: &str) {
let value_len = value.len();
let value_bytes = value.as_bytes();
if value_len <= SHORT_STRING_LEN {
let inline = value_bytes
.iter()
.fold(0usize, |acc, &x| acc << 8 | x as usize);
let short_string_header = SSOStringHeader {
// no need for short string cases
hash: 0,
len: value_len,
offset_or_inline: inline,
};
self.header_set.insert(short_string_header);
} else {
let hash = self.state.hash_one(value_bytes);
let entry = self.long_string_map.get_mut(hash, |header| {
// if hash matches, check if the bytes match
let offset = header.offset_or_inline;
let len = header.len;
// SAFETY: buffer is only appended to, and we correctly inserted values
let existing_value =
unsafe { self.buffer.as_slice().get_unchecked(offset..offset + len) };
value_bytes == existing_value
});
if entry.is_none() {
let offset = self.buffer.len();
self.buffer.append_slice(value_bytes);
let header = SSOStringHeader {
hash,
len: value_len,
offset_or_inline: offset,
};
self.long_string_map.insert_accounted(
header,
|header| header.hash,
&mut self.map_size,
);
self.header_set.insert(header);
}
}
}
fn iter(&self) -> Vec<String> {
self.header_set
.iter()
.map(|header| header.evaluate(self.buffer.as_slice()))
.collect()
}
fn len(&self) -> usize {
self.header_set.len()
}
// NEED HELPED
fn size(&self) -> usize {
self.header_set.len() * mem::size_of::<SSOStringHeader>()
+ self.map_size
+ self.buffer.len()
}
}
impl Debug for SSOStringHashSet {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SSOStringHashSet")
.field("header_set", &self.header_set)
// TODO: Print long_string_map
.field("map_size", &self.map_size)
.field("buffer", &self.buffer)
.field("state", &self.state)
.finish()
}
}
#[cfg(test)]
mod tests {
use crate::expressions::NoOp;
use super::*;
use arrow::array::{
ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
};
use arrow::datatypes::DataType;
use arrow::datatypes::{
Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::Decimal256Array;
use arrow_buffer::i256;
use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array};
use datafusion_common::internal_err;
use datafusion_common::DataFusionError;
macro_rules! state_to_vec_primitive {
($LIST:expr, $DATA_TYPE:ident) => {{
let arr = ScalarValue::raw_data($LIST).unwrap();
let list_arr = as_list_array(&arr).unwrap();
let arr = list_arr.values();
let arr = as_primitive_array::<$DATA_TYPE>(arr)?;
arr.values().iter().cloned().collect::<Vec<_>>()
}};
}
macro_rules! test_count_distinct_update_batch_numeric {
($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
let values: Vec<Option<$PRIM_TYPE>> = vec![
Some(1),
Some(1),
None,
Some(3),
Some(2),
None,
Some(2),
Some(3),
Some(1),
];
let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef];
let (states, result) = run_update_batch(&arrays)?;
let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE);
state_vec.sort();
assert_eq!(states.len(), 1);
assert_eq!(state_vec, vec![1, 2, 3]);
assert_eq!(result, ScalarValue::Int64(Some(3)));
Ok(())
}};
}
fn state_to_vec_bool(sv: &ScalarValue) -> Result<Vec<bool>> {
let arr = ScalarValue::raw_data(sv)?;
let list_arr = as_list_array(&arr)?;
let arr = list_arr.values();
let bool_arr = as_boolean_array(arr)?;
Ok(bool_arr.iter().flatten().collect())
}
fn run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>, ScalarValue)> {
let agg = DistinctCount::new(
arrays[0].data_type().clone(),
Arc::new(NoOp::new()),
String::from("__col_name__"),
);
let mut accum = agg.create_accumulator()?;
accum.update_batch(arrays)?;
Ok((accum.state()?, accum.evaluate()?))
}
fn run_update(
data_types: &[DataType],
rows: &[Vec<ScalarValue>],
) -> Result<(Vec<ScalarValue>, ScalarValue)> {
let agg = DistinctCount::new(
data_types[0].clone(),
Arc::new(NoOp::new()),
String::from("__col_name__"),
);
let mut accum = agg.create_accumulator()?;
let cols = (0..rows[0].len())
.map(|i| {
rows.iter()
.map(|inner| inner[i].clone())
.collect::<Vec<ScalarValue>>()
})
.collect::<Vec<_>>();
let arrays: Vec<ArrayRef> = cols
.iter()
.map(|c| ScalarValue::iter_to_array(c.clone()))
.collect::<Result<Vec<ArrayRef>>>()?;
accum.update_batch(&arrays)?;
Ok((accum.state()?, accum.evaluate()?))
}
// Used trait to create associated constant for f32 and f64
trait SubNormal: 'static {
const SUBNORMAL: Self;
}
impl SubNormal for f64 {
const SUBNORMAL: Self = 1.0e-308_f64;
}
impl SubNormal for f32 {
const SUBNORMAL: Self = 1.0e-38_f32;
}
macro_rules! test_count_distinct_update_batch_floating_point {
($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
let values: Vec<Option<$PRIM_TYPE>> = vec![
Some(<$PRIM_TYPE>::INFINITY),
Some(<$PRIM_TYPE>::NAN),
Some(1.0),
Some(<$PRIM_TYPE as SubNormal>::SUBNORMAL),
Some(1.0),
Some(<$PRIM_TYPE>::INFINITY),
None,
Some(3.0),
Some(-4.5),
Some(2.0),
None,
Some(2.0),
Some(3.0),
Some(<$PRIM_TYPE>::NEG_INFINITY),
Some(1.0),
Some(<$PRIM_TYPE>::NAN),
Some(<$PRIM_TYPE>::NEG_INFINITY),
];
let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef];
let (states, result) = run_update_batch(&arrays)?;
let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE);
dbg!(&state_vec);
state_vec.sort_by(|a, b| match (a, b) {
(lhs, rhs) => lhs.total_cmp(rhs),
});
let nan_idx = state_vec.len() - 1;
assert_eq!(states.len(), 1);
assert_eq!(
&state_vec[..nan_idx],
vec![
<$PRIM_TYPE>::NEG_INFINITY,
-4.5,
<$PRIM_TYPE as SubNormal>::SUBNORMAL,
1.0,
2.0,
3.0,
<$PRIM_TYPE>::INFINITY
]
);
assert!(state_vec[nan_idx].is_nan());
assert_eq!(result, ScalarValue::Int64(Some(8)));
Ok(())
}};
}
macro_rules! test_count_distinct_update_batch_bigint {
($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
let values: Vec<Option<$PRIM_TYPE>> = vec![
Some(i256::from(1)),
Some(i256::from(1)),
None,
Some(i256::from(3)),
Some(i256::from(2)),
None,
Some(i256::from(2)),
Some(i256::from(3)),
Some(i256::from(1)),
];
let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef];
let (states, result) = run_update_batch(&arrays)?;
let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE);
state_vec.sort();
assert_eq!(states.len(), 1);
assert_eq!(state_vec, vec![i256::from(1), i256::from(2), i256::from(3)]);
assert_eq!(result, ScalarValue::Int64(Some(3)));
Ok(())
}};
}
#[test]
fn count_distinct_update_batch_i8() -> Result<()> {
test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8)
}
#[test]
fn count_distinct_update_batch_i16() -> Result<()> {
test_count_distinct_update_batch_numeric!(Int16Array, Int16Type, i16)
}
#[test]
fn count_distinct_update_batch_i32() -> Result<()> {
test_count_distinct_update_batch_numeric!(Int32Array, Int32Type, i32)
}
#[test]
fn count_distinct_update_batch_i64() -> Result<()> {
test_count_distinct_update_batch_numeric!(Int64Array, Int64Type, i64)
}
#[test]
fn count_distinct_update_batch_u8() -> Result<()> {
test_count_distinct_update_batch_numeric!(UInt8Array, UInt8Type, u8)
}
#[test]
fn count_distinct_update_batch_u16() -> Result<()> {
test_count_distinct_update_batch_numeric!(UInt16Array, UInt16Type, u16)
}
#[test]
fn count_distinct_update_batch_u32() -> Result<()> {
test_count_distinct_update_batch_numeric!(UInt32Array, UInt32Type, u32)
}
#[test]
fn count_distinct_update_batch_u64() -> Result<()> {
test_count_distinct_update_batch_numeric!(UInt64Array, UInt64Type, u64)
}
#[test]
fn count_distinct_update_batch_f32() -> Result<()> {
test_count_distinct_update_batch_floating_point!(Float32Array, Float32Type, f32)
}
#[test]
fn count_distinct_update_batch_f64() -> Result<()> {
test_count_distinct_update_batch_floating_point!(Float64Array, Float64Type, f64)
}
#[test]
fn count_distinct_update_batch_i256() -> Result<()> {
test_count_distinct_update_batch_bigint!(Decimal256Array, Decimal256Type, i256)
}
#[test]
fn count_distinct_update_batch_boolean() -> Result<()> {
let get_count = |data: BooleanArray| -> Result<(Vec<bool>, i64)> {
let arrays = vec![Arc::new(data) as ArrayRef];
let (states, result) = run_update_batch(&arrays)?;
let mut state_vec = state_to_vec_bool(&states[0])?;
state_vec.sort();
let count = match result {
ScalarValue::Int64(c) => c.ok_or_else(|| {
DataFusionError::Internal("Found None count".to_string())
}),
scalar => {
internal_err!("Found non int64 scalar value from count: {scalar}")
}
}?;
Ok((state_vec, count))
};
let zero_count_values = BooleanArray::from(Vec::<bool>::new());
let one_count_values = BooleanArray::from(vec![false, false]);
let one_count_values_with_null =
BooleanArray::from(vec![Some(true), Some(true), None, None]);
let two_count_values = BooleanArray::from(vec![true, false, true, false, true]);
let two_count_values_with_null = BooleanArray::from(vec![
Some(true),
Some(false),
None,
None,
Some(true),
Some(false),
]);
assert_eq!(get_count(zero_count_values)?, (Vec::<bool>::new(), 0));
assert_eq!(get_count(one_count_values)?, (vec![false], 1));
assert_eq!(get_count(one_count_values_with_null)?, (vec![true], 1));
assert_eq!(get_count(two_count_values)?, (vec![false, true], 2));
assert_eq!(
get_count(two_count_values_with_null)?,
(vec![false, true], 2)
);
Ok(())
}
#[test]
fn count_distinct_update_batch_all_nulls() -> Result<()> {
let arrays = vec![Arc::new(Int32Array::from(
vec![None, None, None, None] as Vec<Option<i32>>
)) as ArrayRef];
let (states, result) = run_update_batch(&arrays)?;
let state_vec = state_to_vec_primitive!(&states[0], Int32Type);
assert_eq!(states.len(), 1);
assert!(state_vec.is_empty());
assert_eq!(result, ScalarValue::Int64(Some(0)));
Ok(())
}
#[test]
fn count_distinct_update_batch_empty() -> Result<()> {
let arrays = vec![Arc::new(Int32Array::from(vec![0_i32; 0])) as ArrayRef];
let (states, result) = run_update_batch(&arrays)?;
let state_vec = state_to_vec_primitive!(&states[0], Int32Type);
assert_eq!(states.len(), 1);
assert!(state_vec.is_empty());
assert_eq!(result, ScalarValue::Int64(Some(0)));
Ok(())
}
#[test]
fn count_distinct_update() -> Result<()> {
let (states, result) = run_update(
&[DataType::Int32],
&[
vec![ScalarValue::Int32(Some(-1))],
vec![ScalarValue::Int32(Some(5))],
vec![ScalarValue::Int32(Some(-1))],
vec![ScalarValue::Int32(Some(5))],
vec![ScalarValue::Int32(Some(-1))],
vec![ScalarValue::Int32(Some(-1))],
vec![ScalarValue::Int32(Some(2))],
],
)?;
assert_eq!(states.len(), 1);
assert_eq!(result, ScalarValue::Int64(Some(3)));
let (states, result) = run_update(
&[DataType::UInt64],
&[
vec![ScalarValue::UInt64(Some(1))],
vec![ScalarValue::UInt64(Some(5))],
vec![ScalarValue::UInt64(Some(1))],
vec![ScalarValue::UInt64(Some(5))],
vec![ScalarValue::UInt64(Some(1))],
vec![ScalarValue::UInt64(Some(1))],
vec![ScalarValue::UInt64(Some(2))],
],
)?;
assert_eq!(states.len(), 1);
assert_eq!(result, ScalarValue::Int64(Some(3)));
Ok(())
}
#[test]
fn count_distinct_update_with_nulls() -> Result<()> {
let (states, result) = run_update(
&[DataType::Int32],
&[
// None of these updates contains a None, so these are accumulated.
vec![ScalarValue::Int32(Some(-1))],
vec![ScalarValue::Int32(Some(-1))],
vec![ScalarValue::Int32(Some(-2))],
// Each of these updates contains at least one None, so these
// won't be accumulated.
vec![ScalarValue::Int32(Some(-1))],
vec![ScalarValue::Int32(None)],
vec![ScalarValue::Int32(None)],
],
)?;
assert_eq!(states.len(), 1);
assert_eq!(result, ScalarValue::Int64(Some(2)));
let (states, result) = run_update(
&[DataType::UInt64],
&[
// None of these updates contains a None, so these are accumulated.
vec![ScalarValue::UInt64(Some(1))],
vec![ScalarValue::UInt64(Some(1))],
vec![ScalarValue::UInt64(Some(2))],
// Each of these updates contains at least one None, so these
// won't be accumulated.
vec![ScalarValue::UInt64(Some(1))],
vec![ScalarValue::UInt64(None)],
vec![ScalarValue::UInt64(None)],
],
)?;
assert_eq!(states.len(), 1);
assert_eq!(result, ScalarValue::Int64(Some(2)));
Ok(())
}
} |
It seems to me this means that maybe the small string optimization is unnecessary at this time given it doesn't seem to make a significant different to performance 🤔 Maybe we could simplify the code ? |
If the number of rows is large > 1e6, then the speed gains of short strings is larger than seconds (5s faster for n=1e6) |
hits.parquet data is not large enough
where they are either len 1 or 2. This does not show difference. Hashset
SSO
URL length is mostly > 8. Improve from 11s to 9s HashSet
SSO
|
Nice! I'll add this query to the "extended" benchmark I am working on in #8861 |
@jayzhan211 -- I plan to review this PR more carefully over the next day or two I am hoping that we can then use the same structure for #7064 |
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
@jayzhan211 -- Thank you for 0e33b12 |
I have an idea for this test, let me push it up and see what you think Update: 3e9289a |
Ok, it think this PR is now blocked on the following two PRs:
Once those are merged I think we can merge up from this branch, run some final benchmark numbers, and get it reviewed Thanks again @jayzhan211 |
Thanks, @alamb. It seems I misunderstand both the goal of fuzz test and memory accounting test 😅 |
I think this was my bad for not explaining it well. |
Ok, now I am just waiting on #8950 to merge and then I'll run the benchmarks and I think this PR will be ready for review |
COUNT( DISTINCT ...)
for stringsCOUNT( DISTINCT ...)
for strings (up to 9x faster)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this PR is now ready for review. The benchmarks are looking very nice
Thanks again for this great teamwork @jayzhan211
Since I wrote a bunch of this PR I think another committer should also approve prior to merging it
@@ -54,6 +54,7 @@ blake2 = { version = "^0.10.2", optional = true } | |||
blake3 = { version = "1.0", optional = true } | |||
chrono = { workspace = true } | |||
datafusion-common = { workspace = true } | |||
datafusion-execution = { workspace = true } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed to use RawTableAlloc trait
@@ -152,6 +155,9 @@ impl AggregateExpr for DistinctCount { | |||
Float32 => float_distinct_count_accumulator!(Float32Type), | |||
Float64 => float_distinct_count_accumulator!(Float64Type), | |||
|
|||
Utf8 => Ok(Box::new(StringDistinctCountAccumulator::<i32>::new())), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The key contribution in this PR is to add these specialized accumulators
/// Maximum size of a string that can be inlined in the hash table | ||
const SHORT_STRING_LEN: usize = mem::size_of::<usize>(); | ||
|
||
/// Entry that is stored in a `SSOStringHashSet` that represents a string |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This explains the core change in this PR and how things work
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very nice!
return Ok(()); | ||
} | ||
|
||
self.0.insert(values[0].clone()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: insert
should be able to take a reference right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you are right -- I did this in f5e268d
Thank you
Thank you for the review @thinkharderdev 🙏 |
🚀 |
Note for reviewers this PR has around 200 lines of code, and the rest is testing
This PR is a collaboration between @jayzhan211 and @alamb
Which issue does this PR close?
Part of #5472
Follow up on #8721
Rationale for this change
Speed up queries that include multiple
COUNT DISTINCT
s forString
orLargeString
What changes are included in this PR?
Implement a specialized Accumulator for
COUNT DISTINCT
that avoids copying string data or allocating individual stringsAre these changes tested?
Benchmark results:
Clickbench Extended
Admittedly these benchmarks were chosen to highlight this particular change but I am still feeling pretty good with 9x faster query . See Docs for more details about what these tests are
Entire Clickbench (basically the same)
Entire TPCH_1 (basically the same)
Benchmark tpch_mem.json
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Query ┃ main_base ┃ bytes-distinctcount ┃ Change ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ QQuery 1 │ 219.29ms │ 208.36ms │ no change │
│ QQuery 2 │ 47.19ms │ 45.30ms │ no change │
│ QQuery 3 │ 79.84ms │ 78.33ms │ no change │
│ QQuery 4 │ 75.31ms │ 76.34ms │ no change │
│ QQuery 5 │ 125.54ms │ 126.23ms │ no change │
│ QQuery 6 │ 16.68ms │ 16.15ms │ no change │
│ QQuery 7 │ 337.79ms │ 320.61ms │ +1.05x faster │
│ QQuery 8 │ 82.13ms │ 80.68ms │ no change │
│ QQuery 9 │ 127.99ms │ 128.50ms │ no change │
│ QQuery 10 │ 158.88ms │ 155.37ms │ no change │
│ QQuery 11 │ 33.85ms │ 33.83ms │ no change │
│ QQuery 12 │ 72.23ms │ 70.90ms │ no change │
│ QQuery 13 │ 83.32ms │ 85.77ms │ no change │
│ QQuery 14 │ 26.51ms │ 26.13ms │ no change │
│ QQuery 15 │ 62.08ms │ 60.46ms │ no change │
│ QQuery 16 │ 48.43ms │ 45.78ms │ +1.06x faster │
│ QQuery 17 │ 166.20ms │ 161.16ms │ no change │
│ QQuery 18 │ 471.68ms │ 465.24ms │ no change │
│ QQuery 19 │ 65.11ms │ 65.86ms │ no change │
│ QQuery 20 │ 117.70ms │ 117.06ms │ no change │
│ QQuery 21 │ 368.86ms │ 363.66ms │ no change │
│ QQuery 22 │ 29.75ms │ 29.44ms │ no change │
└──────────────┴───────────┴─────────────────────┴───────────────┘
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃ Benchmark Summary ┃ ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│ Total Time (main_base) │ 2816.38ms │
│ Total Time (bytes-distinctcount) │ 2761.14ms │
│ Average Time (main_base) │ 128.02ms │
│ Average Time (bytes-distinctcount) │ 125.51ms │
│ Queries Faster │ 2 │
│ Queries Slower │ 0 │
│ Queries with No Change │ 20 │
└────────────────────────────────────┴───────────┘