-
Notifications
You must be signed in to change notification settings - Fork 1.2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize COUNT( DISTINCT ...)
for strings (up to 9x faster)
#8849
Merged
Merged
Changes from all commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
9c44d04
chkp
jayzhan211 6cb8bbe
chkp
jayzhan211 9d662a7
draft
jayzhan211 1744cb3
iter done
jayzhan211 e3b0568
short string test
jayzhan211 12cf50c
add test
jayzhan211 4f9a3f0
remove unused
jayzhan211 626b1cb
to_string directly
jayzhan211 2e80cb7
rewrite evaluate
jayzhan211 d2d1d6d
return Vec<String>
jayzhan211 ebb8726
fmt
jayzhan211 98a9cd1
add more queries
jayzhan211 07831fa
add group by query and rewrite evalute with state()
jayzhan211 62c8084
move evaluate back
jayzhan211 e3b65c8
upd test
jayzhan211 3f0e9a9
add row sort
jayzhan211 4bc483a
Merge remote-tracking branch 'apache/main' into bytes-distinctcount
alamb 0475687
Update benchmarks/queries/clickbench/README.md
alamb a764e99
Rework set to avoid copies
alamb bde49c6
Merge branch 'bytes-distinctcount' of github.com:jayzhan211/arrow-dat…
alamb a101b62
Simplify offset construction
alamb 0f2fa02
fmt
alamb 489e130
Improve comments
alamb c39988a
Improve comments
alamb 0e33b12
add fuzz test
jayzhan211 b3bcc68
Add support for LargeStringArray
alamb d7efcf6
Merge branch 'bytes-distinctcount' of github.com:jayzhan211/arrow-dat…
alamb a80b39c
refine fuzz test
alamb 3e9289a
Add tests for size accounting
alamb 7b9d067
Split into new module
alamb d405744
Merge remote-tracking branch 'apache/main' into bytes-distinctcount
alamb 3a6a066
Remove use of Mutex
alamb f177aed
Merge remote-tracking branch 'apache/main' into bytes-distinctcount
alamb 8640907
revert changes
alamb 214ba5b
Merge remote-tracking branch 'apache/main' into bytes-distinctcount
alamb 1e10b9c
Merge remote-tracking branch 'apache/main' into bytes-distinctcount
alamb f5e268d
Use reference rather than owned ArrayRef
alamb File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
211 changes: 211 additions & 0 deletions
211
datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,211 @@ | ||
// Licensed to the Apache Software Foundation (ASF) under one | ||
// or more contributor license agreements. See the NOTICE file | ||
// distributed with this work for additional information | ||
// regarding copyright ownership. The ASF licenses this file | ||
// to you under the Apache License, Version 2.0 (the | ||
// "License"); you may not use this file except in compliance | ||
// with the License. You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, | ||
// software distributed under the License is distributed on an | ||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
// KIND, either express or implied. See the License for the | ||
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
//! Compare DistinctCount for string with naive HashSet and Short String Optimized HashSet | ||
|
||
use std::sync::Arc; | ||
|
||
use arrow::array::ArrayRef; | ||
use arrow::record_batch::RecordBatch; | ||
use arrow_array::{Array, GenericStringArray, OffsetSizeTrait, UInt32Array}; | ||
|
||
use arrow_array::cast::AsArray; | ||
use datafusion::datasource::MemTable; | ||
use rand::rngs::StdRng; | ||
use rand::{thread_rng, Rng, SeedableRng}; | ||
use std::collections::HashSet; | ||
use tokio::task::JoinSet; | ||
|
||
use datafusion::prelude::{SessionConfig, SessionContext}; | ||
use test_utils::stagger_batch; | ||
|
||
#[tokio::test(flavor = "multi_thread")] | ||
async fn distinct_count_string_test() { | ||
// max length of generated strings | ||
let mut join_set = JoinSet::new(); | ||
let mut rng = thread_rng(); | ||
for null_pct in [0.0, 0.01, 0.1, 0.5] { | ||
for _ in 0..100 { | ||
let max_len = rng.gen_range(1..50); | ||
let num_strings = rng.gen_range(1..100); | ||
let num_distinct_strings = if num_strings > 1 { | ||
rng.gen_range(1..num_strings) | ||
} else { | ||
num_strings | ||
}; | ||
let generator = BatchGenerator { | ||
max_len, | ||
num_strings, | ||
num_distinct_strings, | ||
null_pct, | ||
rng: StdRng::from_seed(rng.gen()), | ||
}; | ||
join_set.spawn(async move { run_distinct_count_test(generator).await }); | ||
} | ||
} | ||
while let Some(join_handle) = join_set.join_next().await { | ||
// propagate errors | ||
join_handle.unwrap(); | ||
} | ||
} | ||
|
||
/// Run COUNT DISTINCT using SQL and compare the result to computing the | ||
/// distinct count using HashSet<String> | ||
async fn run_distinct_count_test(mut generator: BatchGenerator) { | ||
let input = generator.make_input_batches(); | ||
|
||
let schema = input[0].schema(); | ||
let session_config = SessionConfig::new().with_batch_size(50); | ||
let ctx = SessionContext::new_with_config(session_config); | ||
|
||
// split input into two partitions | ||
let partition_len = input.len() / 2; | ||
let partitions = vec![ | ||
input[0..partition_len].to_vec(), | ||
input[partition_len..].to_vec(), | ||
]; | ||
|
||
let provider = MemTable::try_new(schema, partitions).unwrap(); | ||
ctx.register_table("t", Arc::new(provider)).unwrap(); | ||
// input has two columns, a and b. The result is the number of distinct | ||
// values in each column. | ||
// | ||
// Note, we need at least two count distinct aggregates to trigger the | ||
// count distinct aggregate. Otherwise, the optimizer will rewrite the | ||
// `COUNT(DISTINCT a)` to `COUNT(*) from (SELECT DISTINCT a FROM t)` | ||
let results = ctx | ||
.sql("SELECT COUNT(DISTINCT a), COUNT(DISTINCT b) FROM t") | ||
.await | ||
.unwrap() | ||
.collect() | ||
.await | ||
.unwrap(); | ||
|
||
// get all the strings from the first column of the result (distinct a) | ||
let expected_a = extract_distinct_strings::<i32>(&input, 0).len(); | ||
let result_a = extract_i64(&results, 0); | ||
assert_eq!(expected_a, result_a); | ||
|
||
// get all the strings from the second column of the result (distinct b( | ||
let expected_b = extract_distinct_strings::<i64>(&input, 1).len(); | ||
let result_b = extract_i64(&results, 1); | ||
assert_eq!(expected_b, result_b); | ||
} | ||
|
||
/// Return all (non null) distinct strings from column col_idx | ||
fn extract_distinct_strings<O: OffsetSizeTrait>( | ||
results: &[RecordBatch], | ||
col_idx: usize, | ||
) -> Vec<String> { | ||
results | ||
.iter() | ||
.flat_map(|batch| { | ||
let array = batch.column(col_idx).as_string::<O>(); | ||
// remove nulls via 'flatten' | ||
array.iter().flatten().map(|s| s.to_string()) | ||
}) | ||
.collect::<HashSet<_>>() | ||
.into_iter() | ||
.collect() | ||
} | ||
|
||
// extract the value from the Int64 column in col_idx in batch and return | ||
// it as a usize | ||
fn extract_i64(results: &[RecordBatch], col_idx: usize) -> usize { | ||
assert_eq!(results.len(), 1); | ||
let array = results[0] | ||
.column(col_idx) | ||
.as_any() | ||
.downcast_ref::<arrow::array::Int64Array>() | ||
.unwrap(); | ||
assert_eq!(array.len(), 1); | ||
assert!(!array.is_null(0)); | ||
array.value(0).try_into().unwrap() | ||
} | ||
|
||
struct BatchGenerator { | ||
//// The maximum length of the strings | ||
max_len: usize, | ||
/// the total number of strings in the output | ||
num_strings: usize, | ||
/// The number of distinct strings in the columns | ||
num_distinct_strings: usize, | ||
/// The percentage of nulls in the columns | ||
null_pct: f64, | ||
/// Random number generator | ||
rng: StdRng, | ||
} | ||
|
||
impl BatchGenerator { | ||
/// Make batches of random strings with a random length columns "a" and "b": | ||
/// | ||
/// * "a" is a StringArray | ||
/// * "b" is a LargeStringArray | ||
fn make_input_batches(&mut self) -> Vec<RecordBatch> { | ||
// use a random number generator to pick a random sized output | ||
|
||
let batch = RecordBatch::try_from_iter(vec![ | ||
("a", self.gen_data::<i32>()), | ||
("b", self.gen_data::<i64>()), | ||
]) | ||
.unwrap(); | ||
|
||
stagger_batch(batch) | ||
} | ||
|
||
/// Creates a StringArray or LargeStringArray with random strings according | ||
/// to the parameters of the BatchGenerator | ||
fn gen_data<O: OffsetSizeTrait>(&mut self) -> ArrayRef { | ||
// table of strings from which to draw | ||
let distinct_strings: GenericStringArray<O> = (0..self.num_distinct_strings) | ||
.map(|_| Some(random_string(&mut self.rng, self.max_len))) | ||
.collect(); | ||
|
||
// pick num_strings randomly from the distinct string table | ||
let indicies: UInt32Array = (0..self.num_strings) | ||
.map(|_| { | ||
if self.rng.gen::<f64>() < self.null_pct { | ||
None | ||
} else if self.num_distinct_strings > 1 { | ||
let range = 1..(self.num_distinct_strings as u32); | ||
Some(self.rng.gen_range(range)) | ||
} else { | ||
Some(0) | ||
} | ||
}) | ||
.collect(); | ||
|
||
let options = None; | ||
arrow::compute::take(&distinct_strings, &indicies, options).unwrap() | ||
} | ||
} | ||
|
||
/// Return a string of random characters of length 1..=max_len | ||
fn random_string(rng: &mut StdRng, max_len: usize) -> String { | ||
// pick characters at random (not just ascii) | ||
match max_len { | ||
0 => "".to_string(), | ||
1 => String::from(rng.gen::<char>()), | ||
_ => { | ||
let len = rng.gen_range(1..=max_len); | ||
rng.sample_iter::<char, _>(rand::distributions::Standard) | ||
.take(len) | ||
.map(char::from) | ||
.collect::<String>() | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,34 +15,37 @@ | |
// specific language governing permissions and limitations | ||
// under the License. | ||
|
||
use arrow::datatypes::{DataType, Field, TimeUnit}; | ||
use arrow_array::types::{ | ||
ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, | ||
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, | ||
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, | ||
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, | ||
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, | ||
}; | ||
use arrow_array::PrimitiveArray; | ||
mod strings; | ||
|
||
use std::any::Any; | ||
use std::cmp::Eq; | ||
use std::collections::HashSet; | ||
use std::fmt::Debug; | ||
use std::hash::Hash; | ||
use std::sync::Arc; | ||
|
||
use ahash::RandomState; | ||
use arrow::array::{Array, ArrayRef}; | ||
use std::collections::HashSet; | ||
use arrow::datatypes::{DataType, Field, TimeUnit}; | ||
use arrow_array::types::{ | ||
ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, | ||
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, | ||
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, | ||
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, | ||
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, | ||
}; | ||
use arrow_array::PrimitiveArray; | ||
|
||
use crate::aggregate::utils::{down_cast_any_ref, Hashable}; | ||
use crate::expressions::format_state_name; | ||
use crate::{AggregateExpr, PhysicalExpr}; | ||
use datafusion_common::cast::{as_list_array, as_primitive_array}; | ||
use datafusion_common::utils::array_into_list_array; | ||
use datafusion_common::{Result, ScalarValue}; | ||
use datafusion_expr::Accumulator; | ||
|
||
use crate::aggregate::count_distinct::strings::StringDistinctCountAccumulator; | ||
use crate::aggregate::utils::{down_cast_any_ref, Hashable}; | ||
use crate::expressions::format_state_name; | ||
use crate::{AggregateExpr, PhysicalExpr}; | ||
|
||
type DistinctScalarValues = ScalarValue; | ||
|
||
/// Expression for a COUNT(DISTINCT) aggregation. | ||
|
@@ -61,10 +64,10 @@ impl DistinctCount { | |
pub fn new( | ||
input_data_type: DataType, | ||
expr: Arc<dyn PhysicalExpr>, | ||
name: String, | ||
name: impl Into<String>, | ||
) -> Self { | ||
Self { | ||
name, | ||
name: name.into(), | ||
state_data_type: input_data_type, | ||
expr, | ||
} | ||
|
@@ -152,6 +155,9 @@ impl AggregateExpr for DistinctCount { | |
Float32 => float_distinct_count_accumulator!(Float32Type), | ||
Float64 => float_distinct_count_accumulator!(Float64Type), | ||
|
||
Utf8 => Ok(Box::new(StringDistinctCountAccumulator::<i32>::new())), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The key contribution in this PR is to add these specialized accumulators |
||
LargeUtf8 => Ok(Box::new(StringDistinctCountAccumulator::<i64>::new())), | ||
|
||
_ => Ok(Box::new(DistinctCountAccumulator { | ||
values: HashSet::default(), | ||
state_data_type: self.state_data_type.clone(), | ||
|
@@ -244,7 +250,7 @@ impl Accumulator for DistinctCountAccumulator { | |
assert_eq!(states.len(), 1, "array_agg states must be singleton!"); | ||
let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?; | ||
for scalars in scalar_vec.into_iter() { | ||
self.values.extend(scalars) | ||
self.values.extend(scalars); | ||
} | ||
Ok(()) | ||
} | ||
|
@@ -440,9 +446,6 @@ where | |
|
||
#[cfg(test)] | ||
mod tests { | ||
use crate::expressions::NoOp; | ||
|
||
use super::*; | ||
use arrow::array::{ | ||
ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, | ||
Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, | ||
|
@@ -454,10 +457,15 @@ mod tests { | |
}; | ||
use arrow_array::Decimal256Array; | ||
use arrow_buffer::i256; | ||
|
||
use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; | ||
use datafusion_common::internal_err; | ||
use datafusion_common::DataFusionError; | ||
|
||
use crate::expressions::NoOp; | ||
|
||
use super::*; | ||
|
||
macro_rules! state_to_vec_primitive { | ||
($LIST:expr, $DATA_TYPE:ident) => {{ | ||
let arr = ScalarValue::raw_data($LIST).unwrap(); | ||
|
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needed to use RawTableAlloc trait