Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize COUNT( DISTINCT ...) for strings (up to 9x faster) #8849

Merged
merged 37 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from 35 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
9c44d04
chkp
jayzhan211 Jan 10, 2024
6cb8bbe
chkp
jayzhan211 Jan 13, 2024
9d662a7
draft
jayzhan211 Jan 13, 2024
1744cb3
iter done
jayzhan211 Jan 13, 2024
e3b0568
short string test
jayzhan211 Jan 13, 2024
12cf50c
add test
jayzhan211 Jan 13, 2024
4f9a3f0
remove unused
jayzhan211 Jan 13, 2024
626b1cb
to_string directly
jayzhan211 Jan 13, 2024
2e80cb7
rewrite evaluate
jayzhan211 Jan 13, 2024
d2d1d6d
return Vec<String>
jayzhan211 Jan 13, 2024
ebb8726
fmt
jayzhan211 Jan 13, 2024
98a9cd1
add more queries
jayzhan211 Jan 16, 2024
07831fa
add group by query and rewrite evalute with state()
jayzhan211 Jan 17, 2024
62c8084
move evaluate back
jayzhan211 Jan 17, 2024
e3b65c8
upd test
jayzhan211 Jan 17, 2024
3f0e9a9
add row sort
jayzhan211 Jan 17, 2024
4bc483a
Merge remote-tracking branch 'apache/main' into bytes-distinctcount
alamb Jan 20, 2024
0475687
Update benchmarks/queries/clickbench/README.md
alamb Jan 20, 2024
a764e99
Rework set to avoid copies
alamb Jan 20, 2024
bde49c6
Merge branch 'bytes-distinctcount' of github.com:jayzhan211/arrow-dat…
alamb Jan 20, 2024
a101b62
Simplify offset construction
alamb Jan 20, 2024
0f2fa02
fmt
alamb Jan 20, 2024
489e130
Improve comments
alamb Jan 21, 2024
c39988a
Improve comments
alamb Jan 21, 2024
0e33b12
add fuzz test
jayzhan211 Jan 22, 2024
b3bcc68
Add support for LargeStringArray
alamb Jan 22, 2024
d7efcf6
Merge branch 'bytes-distinctcount' of github.com:jayzhan211/arrow-dat…
alamb Jan 22, 2024
a80b39c
refine fuzz test
alamb Jan 22, 2024
3e9289a
Add tests for size accounting
alamb Jan 22, 2024
7b9d067
Split into new module
alamb Jan 22, 2024
d405744
Merge remote-tracking branch 'apache/main' into bytes-distinctcount
alamb Jan 24, 2024
3a6a066
Remove use of Mutex
alamb Jan 24, 2024
f177aed
Merge remote-tracking branch 'apache/main' into bytes-distinctcount
alamb Jan 25, 2024
8640907
revert changes
alamb Jan 25, 2024
214ba5b
Merge remote-tracking branch 'apache/main' into bytes-distinctcount
alamb Jan 27, 2024
1e10b9c
Merge remote-tracking branch 'apache/main' into bytes-distinctcount
alamb Jan 28, 2024
f5e268d
Use reference rather than owned ArrayRef
alamb Jan 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions benchmarks/queries/clickbench/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DIST
FROM hits;
```


### Q1: Data Exploration

**Question**: "How many distinct "hit color", "browser country" and "language" are there in the dataset?"
Expand All @@ -42,7 +41,7 @@ SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTI
FROM hits;
```

### Q2: Top 10 anaylsis
### Q2: Top 10 analysis

**Question**: "Find the top 10 "browser country" by number of distinct "social network"s,
including the distinct counts of "hit color", "browser language",
Expand Down
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

211 changes: 211 additions & 0 deletions datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Compare DistinctCount for string with naive HashSet and Short String Optimized HashSet

use std::sync::Arc;

use arrow::array::ArrayRef;
use arrow::record_batch::RecordBatch;
use arrow_array::{Array, GenericStringArray, OffsetSizeTrait, UInt32Array};

use arrow_array::cast::AsArray;
use datafusion::datasource::MemTable;
use rand::rngs::StdRng;
use rand::{thread_rng, Rng, SeedableRng};
use std::collections::HashSet;
use tokio::task::JoinSet;

use datafusion::prelude::{SessionConfig, SessionContext};
use test_utils::stagger_batch;

#[tokio::test(flavor = "multi_thread")]
async fn distinct_count_string_test() {
// max length of generated strings
let mut join_set = JoinSet::new();
let mut rng = thread_rng();
for null_pct in [0.0, 0.01, 0.1, 0.5] {
for _ in 0..100 {
let max_len = rng.gen_range(1..50);
let num_strings = rng.gen_range(1..100);
let num_distinct_strings = if num_strings > 1 {
rng.gen_range(1..num_strings)
} else {
num_strings
};
let generator = BatchGenerator {
max_len,
num_strings,
num_distinct_strings,
null_pct,
rng: StdRng::from_seed(rng.gen()),
};
join_set.spawn(async move { run_distinct_count_test(generator).await });
}
}
while let Some(join_handle) = join_set.join_next().await {
// propagate errors
join_handle.unwrap();
}
}

/// Run COUNT DISTINCT using SQL and compare the result to computing the
/// distinct count using HashSet<String>
async fn run_distinct_count_test(mut generator: BatchGenerator) {
let input = generator.make_input_batches();

let schema = input[0].schema();
let session_config = SessionConfig::new().with_batch_size(50);
let ctx = SessionContext::new_with_config(session_config);

// split input into two partitions
let partition_len = input.len() / 2;
let partitions = vec![
input[0..partition_len].to_vec(),
input[partition_len..].to_vec(),
];

let provider = MemTable::try_new(schema, partitions).unwrap();
ctx.register_table("t", Arc::new(provider)).unwrap();
// input has two columns, a and b. The result is the number of distinct
// values in each column.
//
// Note, we need at least two count distinct aggregates to trigger the
// count distinct aggregate. Otherwise, the optimizer will rewrite the
// `COUNT(DISTINCT a)` to `COUNT(*) from (SELECT DISTINCT a FROM t)`
let results = ctx
.sql("SELECT COUNT(DISTINCT a), COUNT(DISTINCT b) FROM t")
.await
.unwrap()
.collect()
.await
.unwrap();

// get all the strings from the first column of the result (distinct a)
let expected_a = extract_distinct_strings::<i32>(&input, 0).len();
let result_a = extract_i64(&results, 0);
assert_eq!(expected_a, result_a);

// get all the strings from the second column of the result (distinct b(
let expected_b = extract_distinct_strings::<i64>(&input, 1).len();
let result_b = extract_i64(&results, 1);
assert_eq!(expected_b, result_b);
}

/// Return all (non null) distinct strings from column col_idx
fn extract_distinct_strings<O: OffsetSizeTrait>(
results: &[RecordBatch],
col_idx: usize,
) -> Vec<String> {
results
.iter()
.flat_map(|batch| {
let array = batch.column(col_idx).as_string::<O>();
// remove nulls via 'flatten'
array.iter().flatten().map(|s| s.to_string())
})
.collect::<HashSet<_>>()
.into_iter()
.collect()
}

// extract the value from the Int64 column in col_idx in batch and return
// it as a usize
fn extract_i64(results: &[RecordBatch], col_idx: usize) -> usize {
assert_eq!(results.len(), 1);
let array = results[0]
.column(col_idx)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.unwrap();
assert_eq!(array.len(), 1);
assert!(!array.is_null(0));
array.value(0).try_into().unwrap()
}

struct BatchGenerator {
//// The maximum length of the strings
max_len: usize,
/// the total number of strings in the output
num_strings: usize,
/// The number of distinct strings in the columns
num_distinct_strings: usize,
/// The percentage of nulls in the columns
null_pct: f64,
/// Random number generator
rng: StdRng,
}

impl BatchGenerator {
/// Make batches of random strings with a random length columns "a" and "b":
///
/// * "a" is a StringArray
/// * "b" is a LargeStringArray
fn make_input_batches(&mut self) -> Vec<RecordBatch> {
// use a random number generator to pick a random sized output

let batch = RecordBatch::try_from_iter(vec![
("a", self.gen_data::<i32>()),
("b", self.gen_data::<i64>()),
])
.unwrap();

stagger_batch(batch)
}

/// Creates a StringArray or LargeStringArray with random strings according
/// to the parameters of the BatchGenerator
fn gen_data<O: OffsetSizeTrait>(&mut self) -> ArrayRef {
// table of strings from which to draw
let distinct_strings: GenericStringArray<O> = (0..self.num_distinct_strings)
.map(|_| Some(random_string(&mut self.rng, self.max_len)))
.collect();

// pick num_strings randomly from the distinct string table
let indicies: UInt32Array = (0..self.num_strings)
.map(|_| {
if self.rng.gen::<f64>() < self.null_pct {
None
} else if self.num_distinct_strings > 1 {
let range = 1..(self.num_distinct_strings as u32);
Some(self.rng.gen_range(range))
} else {
Some(0)
}
})
.collect();

let options = None;
arrow::compute::take(&distinct_strings, &indicies, options).unwrap()
}
}

/// Return a string of random characters of length 1..=max_len
fn random_string(rng: &mut StdRng, max_len: usize) -> String {
// pick characters at random (not just ascii)
match max_len {
0 => "".to_string(),
1 => String::from(rng.gen::<char>()),
_ => {
let len = rng.gen_range(1..=max_len);
rng.sample_iter::<char, _>(rand::distributions::Standard)
.take(len)
.map(char::from)
.collect::<String>()
}
}
}
1 change: 1 addition & 0 deletions datafusion/core/tests/fuzz_cases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

mod aggregate_fuzz;
mod distinct_count_string_fuzz;
mod join_fuzz;
mod merge_fuzz;
mod sort_fuzz;
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ blake2 = { version = "^0.10.2", optional = true }
blake3 = { version = "1.0", optional = true }
chrono = { workspace = true }
datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to use RawTableAlloc trait

datafusion-expr = { workspace = true }
half = { version = "2.1", default-features = false }
hashbrown = { version = "0.14", features = ["raw"] }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,37 @@
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::{DataType, Field, TimeUnit};
use arrow_array::types::{
ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::PrimitiveArray;
mod strings;

use std::any::Any;
use std::cmp::Eq;
use std::collections::HashSet;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::Arc;

use ahash::RandomState;
use arrow::array::{Array, ArrayRef};
use std::collections::HashSet;
use arrow::datatypes::{DataType, Field, TimeUnit};
use arrow_array::types::{
ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::PrimitiveArray;

use crate::aggregate::utils::{down_cast_any_ref, Hashable};
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::cast::{as_list_array, as_primitive_array};
use datafusion_common::utils::array_into_list_array;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Accumulator;

use crate::aggregate::count_distinct::strings::StringDistinctCountAccumulator;
use crate::aggregate::utils::{down_cast_any_ref, Hashable};
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};

type DistinctScalarValues = ScalarValue;

/// Expression for a COUNT(DISTINCT) aggregation.
Expand All @@ -61,10 +64,10 @@ impl DistinctCount {
pub fn new(
input_data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
name: String,
name: impl Into<String>,
) -> Self {
Self {
name,
name: name.into(),
state_data_type: input_data_type,
expr,
}
Expand Down Expand Up @@ -152,6 +155,9 @@ impl AggregateExpr for DistinctCount {
Float32 => float_distinct_count_accumulator!(Float32Type),
Float64 => float_distinct_count_accumulator!(Float64Type),

Utf8 => Ok(Box::new(StringDistinctCountAccumulator::<i32>::new())),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key contribution in this PR is to add these specialized accumulators

LargeUtf8 => Ok(Box::new(StringDistinctCountAccumulator::<i64>::new())),

_ => Ok(Box::new(DistinctCountAccumulator {
values: HashSet::default(),
state_data_type: self.state_data_type.clone(),
Expand Down Expand Up @@ -244,7 +250,7 @@ impl Accumulator for DistinctCountAccumulator {
assert_eq!(states.len(), 1, "array_agg states must be singleton!");
let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
for scalars in scalar_vec.into_iter() {
self.values.extend(scalars)
self.values.extend(scalars);
}
Ok(())
}
Expand Down Expand Up @@ -440,9 +446,6 @@ where

#[cfg(test)]
mod tests {
use crate::expressions::NoOp;

use super::*;
use arrow::array::{
ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
Expand All @@ -454,10 +457,15 @@ mod tests {
};
use arrow_array::Decimal256Array;
use arrow_buffer::i256;

use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array};
use datafusion_common::internal_err;
use datafusion_common::DataFusionError;

use crate::expressions::NoOp;

use super::*;

macro_rules! state_to_vec_primitive {
($LIST:expr, $DATA_TYPE:ident) => {{
let arr = ScalarValue::raw_data($LIST).unwrap();
Expand Down
Loading
Loading