Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split count_distinct.rs into separate modules #9087

Merged
merged 5 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 41 additions & 216 deletions datafusion/physical-expr/src/aggregate/count_distinct/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,39 +15,36 @@
// specific language governing permissions and limitations
// under the License.

mod native;
mod strings;

use std::any::Any;
use std::cmp::Eq;
use std::collections::HashSet;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::Arc;

use ahash::RandomState;
use arrow::array::{Array, ArrayRef};
use arrow::datatypes::{DataType, Field, TimeUnit};
use arrow_array::types::{
ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float16Type, Float32Type,
Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, Time32MillisecondType,
Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::PrimitiveArray;

use datafusion_common::cast::{as_list_array, as_primitive_array};
use datafusion_common::utils::array_into_list_array;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Accumulator;

use crate::aggregate::count_distinct::native::{
FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator,
};
use crate::aggregate::count_distinct::strings::StringDistinctCountAccumulator;
use crate::aggregate::utils::{down_cast_any_ref, Hashable};
use crate::aggregate::utils::down_cast_any_ref;
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};

type DistinctScalarValues = ScalarValue;

/// Expression for a COUNT(DISTINCT) aggregation.
#[derive(Debug)]
pub struct DistinctCount {
Expand Down Expand Up @@ -101,46 +98,46 @@ impl AggregateExpr for DistinctCount {
use TimeUnit::*;

Ok(match &self.state_data_type {
Int8 => Box::new(NativeDistinctCountAccumulator::<Int8Type>::new()),
Int16 => Box::new(NativeDistinctCountAccumulator::<Int16Type>::new()),
Int32 => Box::new(NativeDistinctCountAccumulator::<Int32Type>::new()),
Int64 => Box::new(NativeDistinctCountAccumulator::<Int64Type>::new()),
UInt8 => Box::new(NativeDistinctCountAccumulator::<UInt8Type>::new()),
UInt16 => Box::new(NativeDistinctCountAccumulator::<UInt16Type>::new()),
UInt32 => Box::new(NativeDistinctCountAccumulator::<UInt32Type>::new()),
UInt64 => Box::new(NativeDistinctCountAccumulator::<UInt64Type>::new()),
Int8 => Box::new(PrimitiveDistinctCountAccumulator::<Int8Type>::new()),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed

Int16 => Box::new(PrimitiveDistinctCountAccumulator::<Int16Type>::new()),
Int32 => Box::new(PrimitiveDistinctCountAccumulator::<Int32Type>::new()),
Int64 => Box::new(PrimitiveDistinctCountAccumulator::<Int64Type>::new()),
UInt8 => Box::new(PrimitiveDistinctCountAccumulator::<UInt8Type>::new()),
UInt16 => Box::new(PrimitiveDistinctCountAccumulator::<UInt16Type>::new()),
UInt32 => Box::new(PrimitiveDistinctCountAccumulator::<UInt32Type>::new()),
UInt64 => Box::new(PrimitiveDistinctCountAccumulator::<UInt64Type>::new()),
Decimal128(_, _) => {
Box::new(NativeDistinctCountAccumulator::<Decimal128Type>::new())
Box::new(PrimitiveDistinctCountAccumulator::<Decimal128Type>::new())
}
Decimal256(_, _) => {
Box::new(NativeDistinctCountAccumulator::<Decimal256Type>::new())
Box::new(PrimitiveDistinctCountAccumulator::<Decimal256Type>::new())
}

Date32 => Box::new(NativeDistinctCountAccumulator::<Date32Type>::new()),
Date64 => Box::new(NativeDistinctCountAccumulator::<Date64Type>::new()),
Time32(Millisecond) => {
Box::new(NativeDistinctCountAccumulator::<Time32MillisecondType>::new())
}
Date32 => Box::new(PrimitiveDistinctCountAccumulator::<Date32Type>::new()),
Date64 => Box::new(PrimitiveDistinctCountAccumulator::<Date64Type>::new()),
Time32(Millisecond) => Box::new(PrimitiveDistinctCountAccumulator::<
Time32MillisecondType,
>::new()),
Time32(Second) => {
Box::new(NativeDistinctCountAccumulator::<Time32SecondType>::new())
}
Time64(Microsecond) => {
Box::new(NativeDistinctCountAccumulator::<Time64MicrosecondType>::new())
Box::new(PrimitiveDistinctCountAccumulator::<Time32SecondType>::new())
}
Time64(Microsecond) => Box::new(PrimitiveDistinctCountAccumulator::<
Time64MicrosecondType,
>::new()),
Time64(Nanosecond) => {
Box::new(NativeDistinctCountAccumulator::<Time64NanosecondType>::new())
Box::new(PrimitiveDistinctCountAccumulator::<Time64NanosecondType>::new())
}
Timestamp(Microsecond, _) => Box::new(NativeDistinctCountAccumulator::<
Timestamp(Microsecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
TimestampMicrosecondType,
>::new()),
Timestamp(Millisecond, _) => Box::new(NativeDistinctCountAccumulator::<
Timestamp(Millisecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
TimestampMillisecondType,
>::new()),
Timestamp(Nanosecond, _) => {
Box::new(NativeDistinctCountAccumulator::<TimestampNanosecondType>::new())
}
Timestamp(Nanosecond, _) => Box::new(PrimitiveDistinctCountAccumulator::<
TimestampNanosecondType,
>::new()),
Timestamp(Second, _) => {
Box::new(NativeDistinctCountAccumulator::<TimestampSecondType>::new())
Box::new(PrimitiveDistinctCountAccumulator::<TimestampSecondType>::new())
}

Float16 => Box::new(FloatDistinctCountAccumulator::<Float16Type>::new()),
Expand Down Expand Up @@ -175,9 +172,13 @@ impl PartialEq<dyn Any> for DistinctCount {
}
}

/// General purpose distinct accumulator that works for any DataType by using
/// [`ScalarValue`]. Some types have specialized accumulators that are (much)
/// more efficient such as [`PrimitiveDistinctCountAccumulator`] and
/// [`StringDistinctCountAccumulator`]
#[derive(Debug)]
struct DistinctCountAccumulator {
values: HashSet<DistinctScalarValues, RandomState>,
values: HashSet<ScalarValue, RandomState>,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was typedefd to

type DistinctScalarValues = ScalarValue;

Which serves no purpose that I can tell other than obscuring what this code is doing

state_data_type: DataType,
}

Expand All @@ -186,7 +187,7 @@ impl DistinctCountAccumulator {
// This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types
fn fixed_size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<DistinctScalarValues>() * self.values.capacity())
+ (std::mem::size_of::<ScalarValue>() * self.values.capacity())
+ self
.values
.iter()
Expand All @@ -199,7 +200,7 @@ impl DistinctCountAccumulator {
// calculates the size as accurate as possible, call to this method is expensive
fn full_size(&self) -> usize {
std::mem::size_of_val(self)
+ (std::mem::size_of::<DistinctScalarValues>() * self.values.capacity())
+ (std::mem::size_of::<ScalarValue>() * self.values.capacity())
+ self
.values
.iter()
Expand Down Expand Up @@ -260,182 +261,6 @@ impl Accumulator for DistinctCountAccumulator {
}
}

#[derive(Debug)]
struct NativeDistinctCountAccumulator<T>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved to native.rs

where
T: ArrowPrimitiveType + Send,
T::Native: Eq + Hash,
{
values: HashSet<T::Native, RandomState>,
}

impl<T> NativeDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
T::Native: Eq + Hash,
{
fn new() -> Self {
Self {
values: HashSet::default(),
}
}
}

impl<T> Accumulator for NativeDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send + Debug,
T::Native: Eq + Hash,
{
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
self.values.iter().cloned(),
)) as ArrayRef;
let list = Arc::new(array_into_list_array(arr));
Ok(vec![ScalarValue::List(list)])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}

let arr = as_primitive_array::<T>(&values[0])?;
arr.iter().for_each(|value| {
if let Some(value) = value {
self.values.insert(value);
}
});

Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(
states.len(),
1,
"count_distinct states must be single array"
);

let arr = as_list_array(&states[0])?;
arr.iter().try_for_each(|maybe_list| {
if let Some(list) = maybe_list {
let list = as_primitive_array::<T>(&list)?;
self.values.extend(list.values())
};
Ok(())
})
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
}

fn size(&self) -> usize {
let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX)
/ 7)
.next_power_of_two();

// Size of accumulator
// + size of entry * number of buckets
// + 1 byte for each bucket
// + fixed size of HashSet
std::mem::size_of_val(self)
+ std::mem::size_of::<T::Native>() * estimated_buckets
+ estimated_buckets
+ std::mem::size_of_val(&self.values)
}
}

#[derive(Debug)]
struct FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
{
values: HashSet<Hashable<T::Native>, RandomState>,
}

impl<T> FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send,
{
fn new() -> Self {
Self {
values: HashSet::default(),
}
}
}

impl<T> Accumulator for FloatDistinctCountAccumulator<T>
where
T: ArrowPrimitiveType + Send + Debug,
{
fn state(&mut self) -> Result<Vec<ScalarValue>> {
let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
self.values.iter().map(|v| v.0),
)) as ArrayRef;
let list = Arc::new(array_into_list_array(arr));
Ok(vec![ScalarValue::List(list)])
}

fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
if values.is_empty() {
return Ok(());
}

let arr = as_primitive_array::<T>(&values[0])?;
arr.iter().for_each(|value| {
if let Some(value) = value {
self.values.insert(Hashable(value));
}
});

Ok(())
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
if states.is_empty() {
return Ok(());
}
assert_eq!(
states.len(),
1,
"count_distinct states must be single array"
);

let arr = as_list_array(&states[0])?;
arr.iter().try_for_each(|maybe_list| {
if let Some(list) = maybe_list {
let list = as_primitive_array::<T>(&list)?;
self.values
.extend(list.values().iter().map(|v| Hashable(*v)));
};
Ok(())
})
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
}

fn size(&self) -> usize {
let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX)
/ 7)
.next_power_of_two();

// Size of accumulator
// + size of entry * number of buckets
// + 1 byte for each bucket
// + fixed size of HashSet
std::mem::size_of_val(self)
+ std::mem::size_of::<T::Native>() * estimated_buckets
+ estimated_buckets
+ std::mem::size_of_val(&self.values)
}
}

#[cfg(test)]
mod tests {
use arrow::array::{
Expand Down
Loading
Loading