Skip to content

Commit

Permalink
Optimize COUNT( DISTINCT ...) for strings (up to 9x faster) (#8849)
Browse files Browse the repository at this point in the history
* chkp

Signed-off-by: jayzhan211 <[email protected]>

* chkp

Signed-off-by: jayzhan211 <[email protected]>

* draft

Signed-off-by: jayzhan211 <[email protected]>

* iter done

Signed-off-by: jayzhan211 <[email protected]>

* short string test

Signed-off-by: jayzhan211 <[email protected]>

* add test

Signed-off-by: jayzhan211 <[email protected]>

* remove unused

Signed-off-by: jayzhan211 <[email protected]>

* to_string directly

Signed-off-by: jayzhan211 <[email protected]>

* rewrite evaluate

Signed-off-by: jayzhan211 <[email protected]>

* return Vec<String>

Signed-off-by: jayzhan211 <[email protected]>

* fmt

Signed-off-by: jayzhan211 <[email protected]>

* add more queries

Signed-off-by: jayzhan211 <[email protected]>

* add group by query and rewrite evalute with state()

Signed-off-by: jayzhan211 <[email protected]>

* move evaluate back

Signed-off-by: jayzhan211 <[email protected]>

* upd test

Signed-off-by: jayzhan211 <[email protected]>

* add row sort

Signed-off-by: jayzhan211 <[email protected]>

* Update benchmarks/queries/clickbench/README.md

* Rework set to avoid copies

* Simplify offset construction

* fmt

* Improve comments

* Improve comments

* add fuzz test

Signed-off-by: jayzhan211 <[email protected]>

* Add support for LargeStringArray

* refine fuzz test

* Add tests for size accounting

* Split into new module

* Remove use of Mutex

* revert changes

* Use reference rather than owned ArrayRef

---------

Signed-off-by: jayzhan211 <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
  • Loading branch information
jayzhan211 and alamb authored Jan 29, 2024
1 parent a57e270 commit af0e8a9
Show file tree
Hide file tree
Showing 9 changed files with 792 additions and 21 deletions.
3 changes: 1 addition & 2 deletions benchmarks/queries/clickbench/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@ SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DIST
FROM hits;
```


### Q1: Data Exploration

**Question**: "How many distinct "hit color", "browser country" and "language" are there in the dataset?"
Expand All @@ -42,7 +41,7 @@ SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTI
FROM hits;
```

### Q2: Top 10 anaylsis
### Q2: Top 10 analysis

**Question**: "Find the top 10 "browser country" by number of distinct "social network"s,
including the distinct counts of "hit color", "browser language",
Expand Down
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

211 changes: 211 additions & 0 deletions datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,211 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

//! Compare DistinctCount for string with naive HashSet and Short String Optimized HashSet
use std::sync::Arc;

use arrow::array::ArrayRef;
use arrow::record_batch::RecordBatch;
use arrow_array::{Array, GenericStringArray, OffsetSizeTrait, UInt32Array};

use arrow_array::cast::AsArray;
use datafusion::datasource::MemTable;
use rand::rngs::StdRng;
use rand::{thread_rng, Rng, SeedableRng};
use std::collections::HashSet;
use tokio::task::JoinSet;

use datafusion::prelude::{SessionConfig, SessionContext};
use test_utils::stagger_batch;

#[tokio::test(flavor = "multi_thread")]
async fn distinct_count_string_test() {
// max length of generated strings
let mut join_set = JoinSet::new();
let mut rng = thread_rng();
for null_pct in [0.0, 0.01, 0.1, 0.5] {
for _ in 0..100 {
let max_len = rng.gen_range(1..50);
let num_strings = rng.gen_range(1..100);
let num_distinct_strings = if num_strings > 1 {
rng.gen_range(1..num_strings)
} else {
num_strings
};
let generator = BatchGenerator {
max_len,
num_strings,
num_distinct_strings,
null_pct,
rng: StdRng::from_seed(rng.gen()),
};
join_set.spawn(async move { run_distinct_count_test(generator).await });
}
}
while let Some(join_handle) = join_set.join_next().await {
// propagate errors
join_handle.unwrap();
}
}

/// Run COUNT DISTINCT using SQL and compare the result to computing the
/// distinct count using HashSet<String>
async fn run_distinct_count_test(mut generator: BatchGenerator) {
let input = generator.make_input_batches();

let schema = input[0].schema();
let session_config = SessionConfig::new().with_batch_size(50);
let ctx = SessionContext::new_with_config(session_config);

// split input into two partitions
let partition_len = input.len() / 2;
let partitions = vec![
input[0..partition_len].to_vec(),
input[partition_len..].to_vec(),
];

let provider = MemTable::try_new(schema, partitions).unwrap();
ctx.register_table("t", Arc::new(provider)).unwrap();
// input has two columns, a and b. The result is the number of distinct
// values in each column.
//
// Note, we need at least two count distinct aggregates to trigger the
// count distinct aggregate. Otherwise, the optimizer will rewrite the
// `COUNT(DISTINCT a)` to `COUNT(*) from (SELECT DISTINCT a FROM t)`
let results = ctx
.sql("SELECT COUNT(DISTINCT a), COUNT(DISTINCT b) FROM t")
.await
.unwrap()
.collect()
.await
.unwrap();

// get all the strings from the first column of the result (distinct a)
let expected_a = extract_distinct_strings::<i32>(&input, 0).len();
let result_a = extract_i64(&results, 0);
assert_eq!(expected_a, result_a);

// get all the strings from the second column of the result (distinct b(
let expected_b = extract_distinct_strings::<i64>(&input, 1).len();
let result_b = extract_i64(&results, 1);
assert_eq!(expected_b, result_b);
}

/// Return all (non null) distinct strings from column col_idx
fn extract_distinct_strings<O: OffsetSizeTrait>(
results: &[RecordBatch],
col_idx: usize,
) -> Vec<String> {
results
.iter()
.flat_map(|batch| {
let array = batch.column(col_idx).as_string::<O>();
// remove nulls via 'flatten'
array.iter().flatten().map(|s| s.to_string())
})
.collect::<HashSet<_>>()
.into_iter()
.collect()
}

// extract the value from the Int64 column in col_idx in batch and return
// it as a usize
fn extract_i64(results: &[RecordBatch], col_idx: usize) -> usize {
assert_eq!(results.len(), 1);
let array = results[0]
.column(col_idx)
.as_any()
.downcast_ref::<arrow::array::Int64Array>()
.unwrap();
assert_eq!(array.len(), 1);
assert!(!array.is_null(0));
array.value(0).try_into().unwrap()
}

struct BatchGenerator {
//// The maximum length of the strings
max_len: usize,
/// the total number of strings in the output
num_strings: usize,
/// The number of distinct strings in the columns
num_distinct_strings: usize,
/// The percentage of nulls in the columns
null_pct: f64,
/// Random number generator
rng: StdRng,
}

impl BatchGenerator {
/// Make batches of random strings with a random length columns "a" and "b":
///
/// * "a" is a StringArray
/// * "b" is a LargeStringArray
fn make_input_batches(&mut self) -> Vec<RecordBatch> {
// use a random number generator to pick a random sized output

let batch = RecordBatch::try_from_iter(vec![
("a", self.gen_data::<i32>()),
("b", self.gen_data::<i64>()),
])
.unwrap();

stagger_batch(batch)
}

/// Creates a StringArray or LargeStringArray with random strings according
/// to the parameters of the BatchGenerator
fn gen_data<O: OffsetSizeTrait>(&mut self) -> ArrayRef {
// table of strings from which to draw
let distinct_strings: GenericStringArray<O> = (0..self.num_distinct_strings)
.map(|_| Some(random_string(&mut self.rng, self.max_len)))
.collect();

// pick num_strings randomly from the distinct string table
let indicies: UInt32Array = (0..self.num_strings)
.map(|_| {
if self.rng.gen::<f64>() < self.null_pct {
None
} else if self.num_distinct_strings > 1 {
let range = 1..(self.num_distinct_strings as u32);
Some(self.rng.gen_range(range))
} else {
Some(0)
}
})
.collect();

let options = None;
arrow::compute::take(&distinct_strings, &indicies, options).unwrap()
}
}

/// Return a string of random characters of length 1..=max_len
fn random_string(rng: &mut StdRng, max_len: usize) -> String {
// pick characters at random (not just ascii)
match max_len {
0 => "".to_string(),
1 => String::from(rng.gen::<char>()),
_ => {
let len = rng.gen_range(1..=max_len);
rng.sample_iter::<char, _>(rand::distributions::Standard)
.take(len)
.map(char::from)
.collect::<String>()
}
}
}
1 change: 1 addition & 0 deletions datafusion/core/tests/fuzz_cases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

mod aggregate_fuzz;
mod distinct_count_string_fuzz;
mod join_fuzz;
mod merge_fuzz;
mod sort_fuzz;
Expand Down
1 change: 1 addition & 0 deletions datafusion/physical-expr/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ blake2 = { version = "^0.10.2", optional = true }
blake3 = { version = "1.0", optional = true }
chrono = { workspace = true }
datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
datafusion-expr = { workspace = true }
half = { version = "2.1", default-features = false }
hashbrown = { version = "0.14", features = ["raw"] }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,34 +15,37 @@
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::{DataType, Field, TimeUnit};
use arrow_array::types::{
ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::PrimitiveArray;
mod strings;

use std::any::Any;
use std::cmp::Eq;
use std::collections::HashSet;
use std::fmt::Debug;
use std::hash::Hash;
use std::sync::Arc;

use ahash::RandomState;
use arrow::array::{Array, ArrayRef};
use std::collections::HashSet;
use arrow::datatypes::{DataType, Field, TimeUnit};
use arrow_array::types::{
ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::PrimitiveArray;

use crate::aggregate::utils::{down_cast_any_ref, Hashable};
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::cast::{as_list_array, as_primitive_array};
use datafusion_common::utils::array_into_list_array;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::Accumulator;

use crate::aggregate::count_distinct::strings::StringDistinctCountAccumulator;
use crate::aggregate::utils::{down_cast_any_ref, Hashable};
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};

type DistinctScalarValues = ScalarValue;

/// Expression for a COUNT(DISTINCT) aggregation.
Expand All @@ -61,10 +64,10 @@ impl DistinctCount {
pub fn new(
input_data_type: DataType,
expr: Arc<dyn PhysicalExpr>,
name: String,
name: impl Into<String>,
) -> Self {
Self {
name,
name: name.into(),
state_data_type: input_data_type,
expr,
}
Expand Down Expand Up @@ -152,6 +155,9 @@ impl AggregateExpr for DistinctCount {
Float32 => float_distinct_count_accumulator!(Float32Type),
Float64 => float_distinct_count_accumulator!(Float64Type),

Utf8 => Ok(Box::new(StringDistinctCountAccumulator::<i32>::new())),
LargeUtf8 => Ok(Box::new(StringDistinctCountAccumulator::<i64>::new())),

_ => Ok(Box::new(DistinctCountAccumulator {
values: HashSet::default(),
state_data_type: self.state_data_type.clone(),
Expand Down Expand Up @@ -244,7 +250,7 @@ impl Accumulator for DistinctCountAccumulator {
assert_eq!(states.len(), 1, "array_agg states must be singleton!");
let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
for scalars in scalar_vec.into_iter() {
self.values.extend(scalars)
self.values.extend(scalars);
}
Ok(())
}
Expand Down Expand Up @@ -440,9 +446,6 @@ where

#[cfg(test)]
mod tests {
use crate::expressions::NoOp;

use super::*;
use arrow::array::{
ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
Expand All @@ -454,10 +457,15 @@ mod tests {
};
use arrow_array::Decimal256Array;
use arrow_buffer::i256;

use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array};
use datafusion_common::internal_err;
use datafusion_common::DataFusionError;

use crate::expressions::NoOp;

use super::*;

macro_rules! state_to_vec_primitive {
($LIST:expr, $DATA_TYPE:ident) => {{
let arr = ScalarValue::raw_data($LIST).unwrap();
Expand Down
Loading

0 comments on commit af0e8a9

Please sign in to comment.