From af0e8a95ca60a231ee4e7665a14645db55a6b97a Mon Sep 17 00:00:00 2001 From: Jay Zhan Date: Mon, 29 Jan 2024 20:28:20 +0800 Subject: [PATCH] Optimize `COUNT( DISTINCT ...)` for strings (up to 9x faster) (#8849) * chkp Signed-off-by: jayzhan211 * chkp Signed-off-by: jayzhan211 * draft Signed-off-by: jayzhan211 * iter done Signed-off-by: jayzhan211 * short string test Signed-off-by: jayzhan211 * add test Signed-off-by: jayzhan211 * remove unused Signed-off-by: jayzhan211 * to_string directly Signed-off-by: jayzhan211 * rewrite evaluate Signed-off-by: jayzhan211 * return Vec Signed-off-by: jayzhan211 * fmt Signed-off-by: jayzhan211 * add more queries Signed-off-by: jayzhan211 * add group by query and rewrite evalute with state() Signed-off-by: jayzhan211 * move evaluate back Signed-off-by: jayzhan211 * upd test Signed-off-by: jayzhan211 * add row sort Signed-off-by: jayzhan211 * Update benchmarks/queries/clickbench/README.md * Rework set to avoid copies * Simplify offset construction * fmt * Improve comments * Improve comments * add fuzz test Signed-off-by: jayzhan211 * Add support for LargeStringArray * refine fuzz test * Add tests for size accounting * Split into new module * Remove use of Mutex * revert changes * Use reference rather than owned ArrayRef --------- Signed-off-by: jayzhan211 Co-authored-by: Andrew Lamb --- benchmarks/queries/clickbench/README.md | 3 +- datafusion-cli/Cargo.lock | 1 + .../fuzz_cases/distinct_count_string_fuzz.rs | 211 ++++++++ datafusion/core/tests/fuzz_cases/mod.rs | 1 + datafusion/physical-expr/Cargo.toml | 1 + .../mod.rs} | 46 +- .../src/aggregate/count_distinct/strings.rs | 490 ++++++++++++++++++ .../sqllogictest/test_files/aggregate.slt | 57 ++ .../sqllogictest/test_files/clickbench.slt | 3 + 9 files changed, 792 insertions(+), 21 deletions(-) create mode 100644 datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs rename datafusion/physical-expr/src/aggregate/{count_distinct.rs => count_distinct/mod.rs} (98%) create mode 100644 datafusion/physical-expr/src/aggregate/count_distinct/strings.rs diff --git a/benchmarks/queries/clickbench/README.md b/benchmarks/queries/clickbench/README.md index e03b7d519d91..ef540ccf9c91 100644 --- a/benchmarks/queries/clickbench/README.md +++ b/benchmarks/queries/clickbench/README.md @@ -29,7 +29,6 @@ SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DIST FROM hits; ``` - ### Q1: Data Exploration **Question**: "How many distinct "hit color", "browser country" and "language" are there in the dataset?" @@ -42,7 +41,7 @@ SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTI FROM hits; ``` -### Q2: Top 10 anaylsis +### Q2: Top 10 analysis **Question**: "Find the top 10 "browser country" by number of distinct "social network"s, including the distinct counts of "hit color", "browser language", diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index a718f7591a45..6b881e3105da 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1255,6 +1255,7 @@ dependencies = [ "blake3", "chrono", "datafusion-common", + "datafusion-execution", "datafusion-expr", "half", "hashbrown 0.14.3", diff --git a/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs new file mode 100644 index 000000000000..343a1756476f --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs @@ -0,0 +1,211 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Compare DistinctCount for string with naive HashSet and Short String Optimized HashSet + +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::record_batch::RecordBatch; +use arrow_array::{Array, GenericStringArray, OffsetSizeTrait, UInt32Array}; + +use arrow_array::cast::AsArray; +use datafusion::datasource::MemTable; +use rand::rngs::StdRng; +use rand::{thread_rng, Rng, SeedableRng}; +use std::collections::HashSet; +use tokio::task::JoinSet; + +use datafusion::prelude::{SessionConfig, SessionContext}; +use test_utils::stagger_batch; + +#[tokio::test(flavor = "multi_thread")] +async fn distinct_count_string_test() { + // max length of generated strings + let mut join_set = JoinSet::new(); + let mut rng = thread_rng(); + for null_pct in [0.0, 0.01, 0.1, 0.5] { + for _ in 0..100 { + let max_len = rng.gen_range(1..50); + let num_strings = rng.gen_range(1..100); + let num_distinct_strings = if num_strings > 1 { + rng.gen_range(1..num_strings) + } else { + num_strings + }; + let generator = BatchGenerator { + max_len, + num_strings, + num_distinct_strings, + null_pct, + rng: StdRng::from_seed(rng.gen()), + }; + join_set.spawn(async move { run_distinct_count_test(generator).await }); + } + } + while let Some(join_handle) = join_set.join_next().await { + // propagate errors + join_handle.unwrap(); + } +} + +/// Run COUNT DISTINCT using SQL and compare the result to computing the +/// distinct count using HashSet +async fn run_distinct_count_test(mut generator: BatchGenerator) { + let input = generator.make_input_batches(); + + let schema = input[0].schema(); + let session_config = SessionConfig::new().with_batch_size(50); + let ctx = SessionContext::new_with_config(session_config); + + // split input into two partitions + let partition_len = input.len() / 2; + let partitions = vec![ + input[0..partition_len].to_vec(), + input[partition_len..].to_vec(), + ]; + + let provider = MemTable::try_new(schema, partitions).unwrap(); + ctx.register_table("t", Arc::new(provider)).unwrap(); + // input has two columns, a and b. The result is the number of distinct + // values in each column. + // + // Note, we need at least two count distinct aggregates to trigger the + // count distinct aggregate. Otherwise, the optimizer will rewrite the + // `COUNT(DISTINCT a)` to `COUNT(*) from (SELECT DISTINCT a FROM t)` + let results = ctx + .sql("SELECT COUNT(DISTINCT a), COUNT(DISTINCT b) FROM t") + .await + .unwrap() + .collect() + .await + .unwrap(); + + // get all the strings from the first column of the result (distinct a) + let expected_a = extract_distinct_strings::(&input, 0).len(); + let result_a = extract_i64(&results, 0); + assert_eq!(expected_a, result_a); + + // get all the strings from the second column of the result (distinct b( + let expected_b = extract_distinct_strings::(&input, 1).len(); + let result_b = extract_i64(&results, 1); + assert_eq!(expected_b, result_b); +} + +/// Return all (non null) distinct strings from column col_idx +fn extract_distinct_strings( + results: &[RecordBatch], + col_idx: usize, +) -> Vec { + results + .iter() + .flat_map(|batch| { + let array = batch.column(col_idx).as_string::(); + // remove nulls via 'flatten' + array.iter().flatten().map(|s| s.to_string()) + }) + .collect::>() + .into_iter() + .collect() +} + +// extract the value from the Int64 column in col_idx in batch and return +// it as a usize +fn extract_i64(results: &[RecordBatch], col_idx: usize) -> usize { + assert_eq!(results.len(), 1); + let array = results[0] + .column(col_idx) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(array.len(), 1); + assert!(!array.is_null(0)); + array.value(0).try_into().unwrap() +} + +struct BatchGenerator { + //// The maximum length of the strings + max_len: usize, + /// the total number of strings in the output + num_strings: usize, + /// The number of distinct strings in the columns + num_distinct_strings: usize, + /// The percentage of nulls in the columns + null_pct: f64, + /// Random number generator + rng: StdRng, +} + +impl BatchGenerator { + /// Make batches of random strings with a random length columns "a" and "b": + /// + /// * "a" is a StringArray + /// * "b" is a LargeStringArray + fn make_input_batches(&mut self) -> Vec { + // use a random number generator to pick a random sized output + + let batch = RecordBatch::try_from_iter(vec![ + ("a", self.gen_data::()), + ("b", self.gen_data::()), + ]) + .unwrap(); + + stagger_batch(batch) + } + + /// Creates a StringArray or LargeStringArray with random strings according + /// to the parameters of the BatchGenerator + fn gen_data(&mut self) -> ArrayRef { + // table of strings from which to draw + let distinct_strings: GenericStringArray = (0..self.num_distinct_strings) + .map(|_| Some(random_string(&mut self.rng, self.max_len))) + .collect(); + + // pick num_strings randomly from the distinct string table + let indicies: UInt32Array = (0..self.num_strings) + .map(|_| { + if self.rng.gen::() < self.null_pct { + None + } else if self.num_distinct_strings > 1 { + let range = 1..(self.num_distinct_strings as u32); + Some(self.rng.gen_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let options = None; + arrow::compute::take(&distinct_strings, &indicies, options).unwrap() + } +} + +/// Return a string of random characters of length 1..=max_len +fn random_string(rng: &mut StdRng, max_len: usize) -> String { + // pick characters at random (not just ascii) + match max_len { + 0 => "".to_string(), + 1 => String::from(rng.gen::()), + _ => { + let len = rng.gen_range(1..=max_len); + rng.sample_iter::(rand::distributions::Standard) + .take(len) + .map(char::from) + .collect::() + } + } +} diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index 83ec928ae229..69241571b4af 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -16,6 +16,7 @@ // under the License. mod aggregate_fuzz; +mod distinct_count_string_fuzz; mod join_fuzz; mod merge_fuzz; mod sort_fuzz; diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index d237c68657a1..61eba042f939 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -54,6 +54,7 @@ blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } chrono = { workspace = true } datafusion-common = { workspace = true } +datafusion-execution = { workspace = true } datafusion-expr = { workspace = true } half = { version = "2.1", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs similarity index 98% rename from datafusion/physical-expr/src/aggregate/count_distinct.rs rename to datafusion/physical-expr/src/aggregate/count_distinct/mod.rs index ef1a248d5f82..891ef8588030 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs @@ -15,34 +15,37 @@ // specific language governing permissions and limitations // under the License. -use arrow::datatypes::{DataType, Field, TimeUnit}; -use arrow_array::types::{ - ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, - Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, - Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, - TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, - TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, -}; -use arrow_array::PrimitiveArray; +mod strings; use std::any::Any; use std::cmp::Eq; +use std::collections::HashSet; use std::fmt::Debug; use std::hash::Hash; use std::sync::Arc; use ahash::RandomState; use arrow::array::{Array, ArrayRef}; -use std::collections::HashSet; +use arrow::datatypes::{DataType, Field, TimeUnit}; +use arrow_array::types::{ + ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, + Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, + Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType, + TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, + TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, +}; +use arrow_array::PrimitiveArray; -use crate::aggregate::utils::{down_cast_any_ref, Hashable}; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; use datafusion_common::cast::{as_list_array, as_primitive_array}; use datafusion_common::utils::array_into_list_array; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; +use crate::aggregate::count_distinct::strings::StringDistinctCountAccumulator; +use crate::aggregate::utils::{down_cast_any_ref, Hashable}; +use crate::expressions::format_state_name; +use crate::{AggregateExpr, PhysicalExpr}; + type DistinctScalarValues = ScalarValue; /// Expression for a COUNT(DISTINCT) aggregation. @@ -61,10 +64,10 @@ impl DistinctCount { pub fn new( input_data_type: DataType, expr: Arc, - name: String, + name: impl Into, ) -> Self { Self { - name, + name: name.into(), state_data_type: input_data_type, expr, } @@ -152,6 +155,9 @@ impl AggregateExpr for DistinctCount { Float32 => float_distinct_count_accumulator!(Float32Type), Float64 => float_distinct_count_accumulator!(Float64Type), + Utf8 => Ok(Box::new(StringDistinctCountAccumulator::::new())), + LargeUtf8 => Ok(Box::new(StringDistinctCountAccumulator::::new())), + _ => Ok(Box::new(DistinctCountAccumulator { values: HashSet::default(), state_data_type: self.state_data_type.clone(), @@ -244,7 +250,7 @@ impl Accumulator for DistinctCountAccumulator { assert_eq!(states.len(), 1, "array_agg states must be singleton!"); let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?; for scalars in scalar_vec.into_iter() { - self.values.extend(scalars) + self.values.extend(scalars); } Ok(()) } @@ -440,9 +446,6 @@ where #[cfg(test)] mod tests { - use crate::expressions::NoOp; - - use super::*; use arrow::array::{ ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, @@ -454,10 +457,15 @@ mod tests { }; use arrow_array::Decimal256Array; use arrow_buffer::i256; + use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; use datafusion_common::internal_err; use datafusion_common::DataFusionError; + use crate::expressions::NoOp; + + use super::*; + macro_rules! state_to_vec_primitive { ($LIST:expr, $DATA_TYPE:ident) => {{ let arr = ScalarValue::raw_data($LIST).unwrap(); diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs b/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs new file mode 100644 index 000000000000..d7a9ea5c373d --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs @@ -0,0 +1,490 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Specialized implementation of `COUNT DISTINCT` for `StringArray` and `LargeStringArray` + +use ahash::RandomState; +use arrow_array::cast::AsArray; +use arrow_array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow_buffer::{BufferBuilder, OffsetBuffer, ScalarBuffer}; +use datafusion_common::cast::as_list_array; +use datafusion_common::hash_utils::create_hashes; +use datafusion_common::utils::array_into_list_array; +use datafusion_common::ScalarValue; +use datafusion_execution::memory_pool::proxy::RawTableAllocExt; +use datafusion_expr::Accumulator; +use std::fmt::Debug; +use std::mem; +use std::ops::Range; +use std::sync::Arc; + +#[derive(Debug)] +pub(super) struct StringDistinctCountAccumulator(SSOStringHashSet); +impl StringDistinctCountAccumulator { + pub(super) fn new() -> Self { + Self(SSOStringHashSet::::new()) + } +} + +impl Accumulator for StringDistinctCountAccumulator { + fn state(&mut self) -> datafusion_common::Result> { + // take the state out of the string set and replace with default + let set = std::mem::take(&mut self.0); + let arr = set.into_state(); + let list = Arc::new(array_into_list_array(arr)); + Ok(vec![ScalarValue::List(list)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { + if values.is_empty() { + return Ok(()); + } + + self.0.insert(&values[0]); + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!( + states.len(), + 1, + "count_distinct states must be single array" + ); + + let arr = as_list_array(&states[0])?; + arr.iter().try_for_each(|maybe_list| { + if let Some(list) = maybe_list { + self.0.insert(&list); + }; + Ok(()) + }) + } + + fn evaluate(&mut self) -> datafusion_common::Result { + Ok(ScalarValue::Int64(Some(self.0.len() as i64))) + } + + fn size(&self) -> usize { + // Size of accumulator + // + SSOStringHashSet size + std::mem::size_of_val(self) + self.0.size() + } +} + +/// Maximum size of a string that can be inlined in the hash table +const SHORT_STRING_LEN: usize = mem::size_of::(); + +/// Entry that is stored in a `SSOStringHashSet` that represents a string +/// that is either stored inline or in the buffer +/// +/// This helps the case where there are many short (less than 8 bytes) strings +/// that are the same (e.g. "MA", "CA", "NY", "TX", etc) +/// +/// ```text +/// ┌──────────────────┐ +/// │... │ +/// │TheQuickBrownFox │ +/// ─ ─ ─ ─ ─ ─ ─▶│... │ +/// │ │ │ +/// └──────────────────┘ +/// │ buffer of u8 +/// +/// │ +/// ┌────────────────┬───────────────┬───────────────┐ +/// Storing │ │ starting byte │ length, in │ +/// "TheQuickBrownFox" │ hash value │ offset in │ bytes (not │ +/// (long string) │ │ buffer │ characters) │ +/// └────────────────┴───────────────┴───────────────┘ +/// 8 bytes 8 bytes 4 or 8 +/// +/// +/// ┌───────────────┬─┬─┬─┬─┬─┬─┬─┬─┬───────────────┐ +/// Storing "foobar" │ │ │ │ │ │ │ │ │ │ length, in │ +/// (short string) │ hash value │?│?│f│o│o│b│a│r│ bytes (not │ +/// │ │ │ │ │ │ │ │ │ │ characters) │ +/// └───────────────┴─┴─┴─┴─┴─┴─┴─┴─┴───────────────┘ +/// 8 bytes 8 bytes 4 or 8 +/// ``` +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +struct SSOStringHeader { + /// hash of the string value (stored to avoid recomputing it in hash table + /// check) + hash: u64, + /// if len =< SHORT_STRING_LEN: the string data inlined + /// if len > SHORT_STRING_LEN, the offset of where the data starts + offset_or_inline: usize, + /// length of the string, in bytes + len: usize, +} + +impl SSOStringHeader { + /// returns self.offset..self.offset + self.len + fn range(&self) -> Range { + self.offset_or_inline..self.offset_or_inline + self.len + } +} + +/// HashSet optimized for storing `String` and `LargeString` values +/// and producing the final set as a GenericStringArray with minimal copies. +/// +/// Equivalent to `HashSet` but with better performance for arrow data. +struct SSOStringHashSet { + /// Underlying hash set for each distinct string + map: hashbrown::raw::RawTable, + /// Total size of the map in bytes + map_size: usize, + /// In progress arrow `Buffer` containing all string values + buffer: BufferBuilder, + /// Offsets into `buffer` for each distinct string value. These offsets + /// as used directly to create the final `GenericStringArray` + offsets: Vec, + /// random state used to generate hashes + random_state: RandomState, + /// buffer that stores hash values (reused across batches to save allocations) + hashes_buffer: Vec, +} + +impl Default for SSOStringHashSet { + fn default() -> Self { + Self::new() + } +} + +impl SSOStringHashSet { + fn new() -> Self { + Self { + map: hashbrown::raw::RawTable::new(), + map_size: 0, + buffer: BufferBuilder::new(0), + offsets: vec![O::default()], // first offset is always 0 + random_state: RandomState::new(), + hashes_buffer: vec![], + } + } + + fn insert(&mut self, values: &ArrayRef) { + // step 1: compute hashes for the strings + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(values.len(), 0); + create_hashes(&[values.clone()], &self.random_state, batch_hashes) + // hash is supported for all string types and create_hashes only + // returns errors for unsupported types + .unwrap(); + + // step 2: insert each string into the set, if not already present + let values = values.as_string::(); + + // Ensure lengths are equivalent (to guard unsafe values calls below) + assert_eq!(values.len(), batch_hashes.len()); + + for (value, &hash) in values.iter().zip(batch_hashes.iter()) { + // count distinct ignores nulls + let Some(value) = value else { + continue; + }; + + // from here on only use bytes (not str/chars) for value + let value = value.as_bytes(); + + // value is a "small" string + if value.len() <= SHORT_STRING_LEN { + let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x as usize); + + // is value is already present in the set? + let entry = self.map.get_mut(hash, |header| { + // compare value if hashes match + if header.len != value.len() { + return false; + } + // value is stored inline so no need to consult buffer + // (this is the "small string optimization") + inline == header.offset_or_inline + }); + + // if no existing entry, make a new one + if entry.is_none() { + // Put the small values into buffer and offsets so it appears + // the output array, but store the actual bytes inline for + // comparison + self.buffer.append_slice(value); + self.offsets.push(O::from_usize(self.buffer.len()).unwrap()); + let new_header = SSOStringHeader { + hash, + len: value.len(), + offset_or_inline: inline, + }; + self.map.insert_accounted( + new_header, + |header| header.hash, + &mut self.map_size, + ); + } + } + // value is not a "small" string + else { + // Check if the value is already present in the set + let entry = self.map.get_mut(hash, |header| { + // compare value if hashes match + if header.len != value.len() { + return false; + } + // Need to compare the bytes in the buffer + // SAFETY: buffer is only appended to, and we correctly inserted values and offsets + let existing_value = + unsafe { self.buffer.as_slice().get_unchecked(header.range()) }; + value == existing_value + }); + + // if no existing entry, make a new one + if entry.is_none() { + // Put the small values into buffer and offsets so it + // appears the output array, and store that offset + // so the bytes can be compared if needed + let offset = self.buffer.len(); // offset of start fof data + self.buffer.append_slice(value); + self.offsets.push(O::from_usize(self.buffer.len()).unwrap()); + + let new_header = SSOStringHeader { + hash, + len: value.len(), + offset_or_inline: offset, + }; + self.map.insert_accounted( + new_header, + |header| header.hash, + &mut self.map_size, + ); + } + } + } + } + + /// Converts this set into a `StringArray` or `LargeStringArray` with each + /// distinct string value without any copies + fn into_state(self) -> ArrayRef { + let Self { + map: _, + map_size: _, + offsets, + mut buffer, + random_state: _, + hashes_buffer: _, + } = self; + + let offsets: ScalarBuffer = offsets.into(); + let values = buffer.finish(); + let nulls = None; // count distinct ignores nulls so intermediate state never has nulls + + // SAFETY: all the values that went in were valid utf8 so are all the values that come out + let array = unsafe { + GenericStringArray::new_unchecked(OffsetBuffer::new(offsets), values, nulls) + }; + Arc::new(array) + } + + fn len(&self) -> usize { + self.map.len() + } + + /// Return the total size, in bytes, of memory used to store the data in + /// this set, not including `self` + fn size(&self) -> usize { + self.map_size + + self.buffer.capacity() * std::mem::size_of::() + + self.offsets.capacity() * std::mem::size_of::() + + self.hashes_buffer.capacity() * std::mem::size_of::() + } +} + +impl Debug for SSOStringHashSet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SSOStringHashSet") + .field("map", &"") + .field("map_size", &self.map_size) + .field("buffer", &self.buffer) + .field("random_state", &self.random_state) + .field("hashes_buffer", &self.hashes_buffer) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::ArrayRef; + use arrow_array::StringArray; + #[test] + fn string_set_empty() { + for values in [StringArray::new_null(0), StringArray::new_null(11)] { + let mut set = SSOStringHashSet::::new(); + let array: ArrayRef = Arc::new(values); + set.insert(&array); + assert_set(set, &[]); + } + } + + #[test] + fn string_set_basic_i32() { + test_string_set_basic::(); + } + #[test] + fn string_set_basic_i64() { + test_string_set_basic::(); + } + fn test_string_set_basic() { + // basic test for mixed small and large string values + let values = GenericStringArray::::from(vec![ + Some("a"), + Some("b"), + Some("CXCCCCCCCC"), // 10 bytes + Some(""), + Some("cbcxx"), // 5 bytes + None, + Some("AAAAAAAA"), // 8 bytes + Some("BBBBBQBBB"), // 9 bytes + Some("a"), + Some("cbcxx"), + Some("b"), + Some("cbcxx"), + Some(""), + None, + Some("BBBBBQBBB"), + Some("BBBBBQBBB"), + Some("AAAAAAAA"), + Some("CXCCCCCCCC"), + ]); + + let mut set = SSOStringHashSet::::new(); + let array: ArrayRef = Arc::new(values); + set.insert(&array); + assert_set( + set, + &[ + Some(""), + Some("AAAAAAAA"), + Some("BBBBBQBBB"), + Some("CXCCCCCCCC"), + Some("a"), + Some("b"), + Some("cbcxx"), + ], + ); + } + + #[test] + fn string_set_non_utf8_32() { + test_string_set_non_utf8::(); + } + #[test] + fn string_set_non_utf8_64() { + test_string_set_non_utf8::(); + } + fn test_string_set_non_utf8() { + // basic test for mixed small and large string values + let values = GenericStringArray::::from(vec![ + Some("a"), + Some("✨🔥"), + Some("🔥"), + Some("✨✨✨"), + Some("foobarbaz"), + Some("🔥"), + Some("✨🔥"), + ]); + + let mut set = SSOStringHashSet::::new(); + let array: ArrayRef = Arc::new(values); + set.insert(&array); + assert_set( + set, + &[ + Some("a"), + Some("foobarbaz"), + Some("✨✨✨"), + Some("✨🔥"), + Some("🔥"), + ], + ); + } + + // asserts that the set contains the expected strings + fn assert_set( + set: SSOStringHashSet, + expected: &[Option<&str>], + ) { + let strings = set.into_state(); + let strings = strings.as_string::(); + let mut state = strings.into_iter().collect::>(); + state.sort(); + assert_eq!(state, expected); + } + + // inserting strings into the set does not increase reported memoyr + #[test] + fn test_string_set_memory_usage() { + let strings1 = GenericStringArray::::from(vec![ + Some("a"), + Some("b"), + Some("CXCCCCCCCC"), // 10 bytes + Some("AAAAAAAA"), // 8 bytes + Some("BBBBBQBBB"), // 9 bytes + ]); + let total_strings1_len = strings1 + .iter() + .map(|s| s.map(|s| s.len()).unwrap_or(0)) + .sum::(); + let values1: ArrayRef = Arc::new(GenericStringArray::::from(strings1)); + + // Much larger strings in strings2 + let strings2 = GenericStringArray::::from(vec![ + "FOO".repeat(1000), + "BAR".repeat(2000), + "BAZ".repeat(3000), + ]); + let total_strings2_len = strings2 + .iter() + .map(|s| s.map(|s| s.len()).unwrap_or(0)) + .sum::(); + let values2: ArrayRef = Arc::new(GenericStringArray::::from(strings2)); + + let mut set = SSOStringHashSet::::new(); + let size_empty = set.size(); + + set.insert(&values1); + let size_after_values1 = set.size(); + assert!(size_empty < size_after_values1); + assert!( + size_after_values1 > total_strings1_len, + "expect {size_after_values1} to be more than {total_strings1_len}" + ); + assert!(size_after_values1 < total_strings1_len + total_strings2_len); + + // inserting the same strings should not affect the size + set.insert(&values1); + assert_eq!(set.size(), size_after_values1); + + // inserting the large strings should increase the reported size + set.insert(&values2); + let size_after_values2 = set.size(); + assert!(size_after_values2 > size_after_values1); + assert!(size_after_values2 > total_strings1_len + total_strings2_len); + } +} diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 5cd728c4344b..136fb39c673e 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3069,6 +3069,62 @@ select count(*) from (select count(*) a, count(*) b from (select 1)); ---- 1 +# Distinct Count for string +# (test for the specialized implementation of distinct count for strings) + +# UTF8 string matters for string to &[u8] conversion, add it to prevent regression +statement ok +create table distinct_count_string_table as values + (1, 'a', 'longstringtest_a', '台灣'), + (2, 'b', 'longstringtest_b1', '日本'), + (2, 'b', 'longstringtest_b2', '中國'), + (3, 'c', 'longstringtest_c1', '美國'), + (3, 'c', 'longstringtest_c2', '歐洲'), + (3, 'c', 'longstringtest_c3', '韓國') +; + +# run through update_batch +query IIII +select count(distinct column1), count(distinct column2), count(distinct column3), count(distinct column4) from distinct_count_string_table; +---- +3 3 6 6 + +# run through merge_batch +query IIII rowsort +select count(distinct column1), count(distinct column2), count(distinct column3), count(distinct column4) from distinct_count_string_table group by column1; +---- +1 1 1 1 +1 1 2 2 +1 1 3 3 + + +# test with long strings as well +statement ok +create table distinct_count_long_string_table as +SELECT column1, + arrow_cast(column2, 'LargeUtf8') as column2, + arrow_cast(column3, 'LargeUtf8') as column3, + arrow_cast(column4, 'LargeUtf8') as column4 +FROM distinct_count_string_table; + +# run through update_batch +query IIII +select count(distinct column1), count(distinct column2), count(distinct column3), count(distinct column4) from distinct_count_long_string_table; +---- +3 3 6 6 + +# run through merge_batch +query IIII rowsort +select count(distinct column1), count(distinct column2), count(distinct column3), count(distinct column4) from distinct_count_long_string_table group by column1; +---- +1 1 1 1 +1 1 2 2 +1 1 3 3 + +statement ok +drop table distinct_count_string_table; + + # rule `aggregate_statistics` should not optimize MIN/MAX to wrong values on empty relation statement ok @@ -3122,3 +3178,4 @@ NULL statement ok DROP TABLE t; + diff --git a/datafusion/sqllogictest/test_files/clickbench.slt b/datafusion/sqllogictest/test_files/clickbench.slt index 21befd78226e..b61bee670811 100644 --- a/datafusion/sqllogictest/test_files/clickbench.slt +++ b/datafusion/sqllogictest/test_files/clickbench.slt @@ -273,3 +273,6 @@ SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*) AS PageViews FROM hit query PI SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-14' AND "EventDate"::INT::DATE <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000; ---- + +query +drop table hits; \ No newline at end of file