From 9c44d04b3062074961be81f901e15a5519909bb2 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 10 Jan 2024 20:39:30 +0800 Subject: [PATCH 01/30] chkp Signed-off-by: jayzhan211 --- datafusion/core/tests/data/distinct_count.csv | 11 ++ .../physical-expr/src/aggregate/build_in.rs | 13 ++- .../src/aggregate/count_distinct.rs | 100 +++++++++++++++++- 3 files changed, 117 insertions(+), 7 deletions(-) create mode 100644 datafusion/core/tests/data/distinct_count.csv diff --git a/datafusion/core/tests/data/distinct_count.csv b/datafusion/core/tests/data/distinct_count.csv new file mode 100644 index 000000000000..e9a65ceee4aa --- /dev/null +++ b/datafusion/core/tests/data/distinct_count.csv @@ -0,0 +1,11 @@ +c1,c2,c3 +1,20,0 +2,20,1 +3,10,2 +4,10,3 +5,30,4 +6,30,5 +7,30,6 +8,30,7 +9,30,8 +10,10,9 \ No newline at end of file diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index c40f0db19405..d913eef272cf 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -59,11 +59,14 @@ pub fn create_aggregate_expr( (AggregateFunction::Count, false) => Arc::new( expressions::Count::new_with_multiple_exprs(input_phy_exprs, name, data_type), ), - (AggregateFunction::Count, true) => Arc::new(expressions::DistinctCount::new( - data_type, - input_phy_exprs[0].clone(), - name, - )), + (AggregateFunction::Count, true) => { + println!("go to distinct count"); + Arc::new(expressions::DistinctCount::new( + data_type, + input_phy_exprs[0].clone(), + name, + )) + } (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( input_phy_exprs[0].clone(), name, diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 021c33fb94a7..6a0d7c832560 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -24,7 +24,10 @@ use arrow_array::types::{ TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use arrow_array::PrimitiveArray; +use arrow_buffer::BufferBuilder; +use hashbrown::hash_map::DefaultHashBuilder; +use core::slice::SlicePattern; use std::any::Any; use std::cmp::Eq; use std::fmt::Debug; @@ -33,7 +36,7 @@ use std::sync::Arc; use ahash::RandomState; use arrow::array::{Array, ArrayRef}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; use crate::aggregate::utils::{down_cast_any_ref, Hashable}; use crate::expressions::format_state_name; @@ -215,10 +218,12 @@ impl Accumulator for DistinctCountAccumulator { fn state(&self) -> Result> { let scalars = self.values.iter().cloned().collect::>(); let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); + println!("state: {:?}", arr); Ok(vec![ScalarValue::List(arr)]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + println!("ub values: {:?}", values); if values.is_empty() { return Ok(()); } @@ -233,18 +238,21 @@ impl Accumulator for DistinctCountAccumulator { let scalar = ScalarValue::try_from_array(arr, index)?; self.values.insert(scalar); } + // println!("self.values: {:?}", self.values); Ok(()) }) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + println!("mb values: {:?}", states); if states.is_empty() { return Ok(()); } assert_eq!(states.len(), 1, "array_agg states must be singleton!"); let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?; for scalars in scalar_vec.into_iter() { - self.values.extend(scalars) + self.values.extend(scalars); + // println!("self.values: {:?}", self.values); } Ok(()) } @@ -438,6 +446,94 @@ where } } +// #[derive(Debug)] +// struct StringDistinctCountAccumulator +// where +// T: ArrowPrimitiveType + Send, +// T::Native: Eq + Hash, +// { +// values: HashSet, +// } + +// impl NativeDistinctCountAccumulator +// where +// T: ArrowPrimitiveType + Send, +// T::Native: Eq + Hash, +// { +// fn new() -> Self { +// Self { +// values: HashSet::default(), +// } +// } +// } + +// Short String Optimizated HashSet for String +// Equivalent to HashSet but with better memory usage (Speed unsure) +struct SSOStringHashSet { + // header: u128 + // short string: length(4bytes) + data(12bytes) + // long string: length(4bytes) + prefix(4bytes) + offset(8bytes) + header_set: HashSet, + // map + long_string_map: HashMap, + buffer: BufferBuilder, +} + +impl SSOStringHashSet { + fn insert(&mut self, value: &str) { + let value_len = value.len(); + if value_len <= 12 { + let mut short_string_header = 0u128; + short_string_header |= (value_len << 96) as u128; + short_string_header |= value + .as_bytes() + .iter() + .fold(0u128, |acc, &x| acc << 8 | x as u128); + self.header_set.insert(short_string_header); + } else { + // 1) hash the string w/o 4 bytes prefix + // 2) check if the hash exists in the map + // 3) if exists, insert the offset into the header + // 4) if not exists, insert the hash and offset into the map + + let mut long_string_header = 0u128; + long_string_header |= (value_len << 96) as u128; + long_string_header |= (value + .as_bytes() + .iter() + .take(4) + .fold(0u128, |acc, &x| acc << 8 | x as u128) + << 64) as u128; + + let suffix = value + .as_bytes() + .iter() + .skip(4) + .collect::>(); + + // NYI hash_bytes: hash &[u8] to u64, similar to hashbrown `make_hash` for &[u8] + let hashed_suffix = hash_bytes(suffix); + if let Some(offset) = self.long_string_map.get(&hashed_suffix) { + long_string_header |= *offset as u128; + } else { + let offset = self.buffer.len(); + self.long_string_map.insert(hashed_suffix, offset as u64); + long_string_header |= offset as u128; + // convert suffix: Vec<&u8> to &[u8] + self.buffer.append_slice(suffix); + } + + self.header_set.insert(long_string_header); + + } + } +} + +struct HashSetSSOString { + // values: HashSet, + short_string_set: HashSet, +} + #[cfg(test)] mod tests { use crate::expressions::NoOp; From 6cb8bbe504189dae39349230dfc11c6cd3f1db6b Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 13 Jan 2024 09:18:28 +0800 Subject: [PATCH 02/30] chkp Signed-off-by: jayzhan211 --- .../src/aggregate/count_distinct.rs | 57 +++++++++++++------ 1 file changed, 39 insertions(+), 18 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 6a0d7c832560..0d6b09f4d5fc 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -25,13 +25,14 @@ use arrow_array::types::{ }; use arrow_array::PrimitiveArray; use arrow_buffer::BufferBuilder; -use hashbrown::hash_map::DefaultHashBuilder; +use hashbrown::HashMap; use core::slice::SlicePattern; use std::any::Any; use std::cmp::Eq; use std::fmt::Debug; use std::hash::Hash; +use std::mem; use std::sync::Arc; use ahash::RandomState; @@ -467,34 +468,54 @@ where // } // } +const LEN: usize = mem::size_of::(); + + +#[derive(Debug, PartialEq, Eq, Hash)] +struct SSOStringHeader { + len: usize, + offset_or_inline: usize, +} + +impl SSOStringHeader { + fn evaluate(&self) -> (bool, usize) { + // short string + if self.len <= LEN { + (true, self.offset_or_inline) + } else { + (false, self.offset_or_inline) + } + } +} + +/// The size, in number of groups, of the initial hash table +const INITIAL_CAPACITY: usize = 128; +/// The size, in bytes, of the string data +const INITIAL_BUFFER_CAPACITY: usize = 8 * 1024; + // Short String Optimizated HashSet for String // Equivalent to HashSet but with better memory usage (Speed unsure) struct SSOStringHashSet { - // header: u128 - // short string: length(4bytes) + data(12bytes) - // long string: length(4bytes) + prefix(4bytes) + offset(8bytes) - header_set: HashSet, - // map - long_string_map: HashMap, + header_set: HashSet, + long_string_map: hashbrown::HashMap, buffer: BufferBuilder, + state: RandomState, } impl SSOStringHashSet { fn insert(&mut self, value: &str) { let value_len = value.len(); - if value_len <= 12 { - let mut short_string_header = 0u128; - short_string_header |= (value_len << 96) as u128; - short_string_header |= value - .as_bytes() - .iter() - .fold(0u128, |acc, &x| acc << 8 | x as u128); + if value_len <= LEN { + let inline = value.as_bytes().iter().fold(0usize, |acc, &x| acc << 8 | x as usize); + let short_string_header = SSOStringHeader { + len: value_len, + offset_or_inline: inline, + }; self.header_set.insert(short_string_header); } else { - // 1) hash the string w/o 4 bytes prefix - // 2) check if the hash exists in the map - // 3) if exists, insert the offset into the header - // 4) if not exists, insert the hash and offset into the map + let hash = s + let value_bytes = value.as_bytes(); + self.long_string_map.raw_entry_mut() let mut long_string_header = 0u128; long_string_header |= (value_len << 96) as u128; From 9d662a776ce0699ed541879328ae0af11a6da0a1 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 13 Jan 2024 11:44:39 +0800 Subject: [PATCH 03/30] draft Signed-off-by: jayzhan211 --- datafusion/physical-expr/Cargo.toml | 1 + .../src/aggregate/count_distinct.rs | 174 +++++++++--------- 2 files changed, 93 insertions(+), 82 deletions(-) diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index d237c68657a1..75ccbed83929 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -55,6 +55,7 @@ blake3 = { version = "1.0", optional = true } chrono = { workspace = true } datafusion-common = { workspace = true } datafusion-expr = { workspace = true } +datafusion-execution = { workspace = true } half = { version = "2.1", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } hex = { version = "0.4", optional = true } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 0d6b09f4d5fc..a7cf204dc30a 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -25,9 +25,7 @@ use arrow_array::types::{ }; use arrow_array::PrimitiveArray; use arrow_buffer::BufferBuilder; -use hashbrown::HashMap; -use core::slice::SlicePattern; use std::any::Any; use std::cmp::Eq; use std::fmt::Debug; @@ -37,7 +35,7 @@ use std::sync::Arc; use ahash::RandomState; use arrow::array::{Array, ArrayRef}; -use std::collections::{HashMap, HashSet}; +use std::collections::HashSet; use crate::aggregate::utils::{down_cast_any_ref, Hashable}; use crate::expressions::format_state_name; @@ -46,6 +44,7 @@ use datafusion_common::cast::{as_list_array, as_primitive_array}; use datafusion_common::utils::array_into_list_array; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; +use datafusion_execution::memory_pool::proxy::RawTableAllocExt; type DistinctScalarValues = ScalarValue; @@ -447,114 +446,125 @@ where } } -// #[derive(Debug)] -// struct StringDistinctCountAccumulator -// where -// T: ArrowPrimitiveType + Send, -// T::Native: Eq + Hash, -// { -// values: HashSet, -// } - -// impl NativeDistinctCountAccumulator -// where -// T: ArrowPrimitiveType + Send, -// T::Native: Eq + Hash, -// { -// fn new() -> Self { -// Self { -// values: HashSet::default(), -// } -// } -// } - -const LEN: usize = mem::size_of::(); - - -#[derive(Debug, PartialEq, Eq, Hash)] -struct SSOStringHeader { - len: usize, - offset_or_inline: usize, +#[derive(Debug)] +struct StringDistinctCountAccumulator(SSOStringHashSet); +impl StringDistinctCountAccumulator { + fn new() -> Self { + Self(SSOStringHashSet::new()) + } } -impl SSOStringHeader { - fn evaluate(&self) -> (bool, usize) { - // short string - if self.len <= LEN { - (true, self.offset_or_inline) - } else { - (false, self.offset_or_inline) - } +impl Accumulator for StringDistinctCountAccumulator { + fn state(&self) -> Result> { + // let arr = Arc::new(PrimitiveArray::::from_iter_values( + // self.values.iter().map(|v| v.0), + // )) as ArrayRef; + // let list = Arc::new(array_into_list_array(arr)) as ArrayRef; + // Ok(vec![ScalarValue::List(list)]) + todo!() + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + todo!() + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + todo!() + } + + fn evaluate(&self) -> Result { + todo!() + } + + fn size(&self) -> usize { + todo!() } } -/// The size, in number of groups, of the initial hash table -const INITIAL_CAPACITY: usize = 128; -/// The size, in bytes, of the string data -const INITIAL_BUFFER_CAPACITY: usize = 8 * 1024; +const SHORT_STRING_LEN: usize = mem::size_of::(); + +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +struct SSOStringHeader { + /// hash of the string value (used when resizing table) + hash: u64, + + len: usize, + offset_or_inline: usize, +} // Short String Optimizated HashSet for String -// Equivalent to HashSet but with better memory usage (Speed unsure) +// Equivalent to HashSet but with better memory usage +#[derive(Default)] struct SSOStringHashSet { - header_set: HashSet, - long_string_map: hashbrown::HashMap, + header_set: HashSet, + long_string_map: hashbrown::raw::RawTable, + map_size: usize, buffer: BufferBuilder, state: RandomState, } impl SSOStringHashSet { + fn new() -> Self { + Self::default() + } + + fn with_capacities() -> Self { + todo!("with_capacities") + } + fn insert(&mut self, value: &str) { let value_len = value.len(); - if value_len <= LEN { - let inline = value.as_bytes().iter().fold(0usize, |acc, &x| acc << 8 | x as usize); + let value_bytes = value.as_bytes(); + + if value_len <= SHORT_STRING_LEN { + let inline = value_bytes.iter().fold(0usize, |acc, &x| acc << 8 | x as usize); let short_string_header = SSOStringHeader { + // no need for short string cases + hash: 0, len: value_len, offset_or_inline: inline, }; self.header_set.insert(short_string_header); } else { - let hash = s - let value_bytes = value.as_bytes(); - self.long_string_map.raw_entry_mut() - - let mut long_string_header = 0u128; - long_string_header |= (value_len << 96) as u128; - long_string_header |= (value - .as_bytes() - .iter() - .take(4) - .fold(0u128, |acc, &x| acc << 8 | x as u128) - << 64) as u128; + let hash = self.state.hash_one(value_bytes); - let suffix = value - .as_bytes() - .iter() - .skip(4) - .collect::>(); - - // NYI hash_bytes: hash &[u8] to u64, similar to hashbrown `make_hash` for &[u8] - let hashed_suffix = hash_bytes(suffix); - if let Some(offset) = self.long_string_map.get(&hashed_suffix) { - long_string_header |= *offset as u128; - } else { - let offset = self.buffer.len(); - self.long_string_map.insert(hashed_suffix, offset as u64); - long_string_header |= offset as u128; - // convert suffix: Vec<&u8> to &[u8] - self.buffer.append_slice(suffix); - } + let entry = self.long_string_map.get_mut(hash, |header| { + // if hash matches, check if the bytes match + let offset = header.offset_or_inline; + let len = header.len; + + // SAFETY: buffer is only appended to, and we correctly inserted values + let existing_value = unsafe { self.buffer.as_slice().get_unchecked(offset..offset + len) }; - self.header_set.insert(long_string_header); + value_bytes == existing_value + }); + if entry.is_none() { + let offset = self.buffer.len(); + self.buffer.append_slice(value_bytes); + let header = SSOStringHeader { + hash, + len: value_len, + offset_or_inline: offset, + }; + self.long_string_map.insert_accounted(header, |header| header.hash, &mut self.map_size); + self.header_set.insert(header); + } } } } -struct HashSetSSOString { - // values: HashSet, - short_string_set: HashSet, +impl Debug for SSOStringHashSet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SSOStringHashSet") + .field("header_set", &self.header_set) + // TODO: Print long_string_map + .field("map_size", &self.map_size) + .field("buffer", &self.buffer) + .field("state", &self.state) + .finish() + } } - #[cfg(test)] mod tests { use crate::expressions::NoOp; From 1744cb3f4398d1816ab64c7c956c2e3dcad01792 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 13 Jan 2024 14:28:20 +0800 Subject: [PATCH 04/30] iter done Signed-off-by: jayzhan211 --- .../src/aggregate/count_distinct.rs | 87 ++++++++++++++++--- 1 file changed, 76 insertions(+), 11 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index a7cf204dc30a..fdfe38ef651c 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -23,7 +23,7 @@ use arrow_array::types::{ TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; -use arrow_array::PrimitiveArray; +use arrow_array::{PrimitiveArray, StringArray}; use arrow_buffer::BufferBuilder; use std::any::Any; @@ -40,7 +40,7 @@ use std::collections::HashSet; use crate::aggregate::utils::{down_cast_any_ref, Hashable}; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; -use datafusion_common::cast::{as_list_array, as_primitive_array}; +use datafusion_common::cast::{as_list_array, as_primitive_array, as_string_array}; use datafusion_common::utils::array_into_list_array; use datafusion_common::{Result, ScalarValue}; use datafusion_expr::Accumulator; @@ -456,24 +456,53 @@ impl StringDistinctCountAccumulator { impl Accumulator for StringDistinctCountAccumulator { fn state(&self) -> Result> { - // let arr = Arc::new(PrimitiveArray::::from_iter_values( - // self.values.iter().map(|v| v.0), - // )) as ArrayRef; - // let list = Arc::new(array_into_list_array(arr)) as ArrayRef; - // Ok(vec![ScalarValue::List(list)]) - todo!() + let arr = StringArray::from_iter_values(self.0.iter()); + let list = Arc::new(array_into_list_array(Arc::new(arr))); + Ok(vec![ScalarValue::List(list)]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - todo!() + if values.is_empty() { + return Ok(()); + } + + let arr = as_string_array(&values[0])?; + arr.iter().for_each(|value| { + if let Some(value) = value { + self.0.insert(value); + } + }); + + Ok(()) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - todo!() + if states.is_empty() { + return Ok(()); + } + assert_eq!( + states.len(), + 1, + "count_distinct states must be single array" + ); + + let arr = as_list_array(&states[0])?; + arr.iter().try_for_each(|maybe_list| { + if let Some(list) = maybe_list { + let list = as_string_array(&list)?; + + list.iter().for_each(|value| { + if let Some(value) = value { + self.0.insert(value); + } + }) + }; + Ok(()) + }) } fn evaluate(&self) -> Result { - todo!() + Ok(ScalarValue::Int64(Some(self.0.len() as i64))) } fn size(&self) -> usize { @@ -492,6 +521,16 @@ struct SSOStringHeader { offset_or_inline: usize, } +impl SSOStringHeader { + fn evaluate(&self) -> (bool, usize) { + if self.len <= SHORT_STRING_LEN { + (true, self.offset_or_inline) + } else { + (false, self.offset_or_inline) + } + } +} + // Short String Optimizated HashSet for String // Equivalent to HashSet but with better memory usage #[derive(Default)] @@ -552,6 +591,32 @@ impl SSOStringHashSet { } } } + + fn iter(&self) -> impl Iterator + '_ { + self.header_set.iter().map(|header| { + let (is_short, offset_or_inline) = header.evaluate(); + if is_short { + let mut inline = offset_or_inline; + // Convert usize to String + let mut bytes = [0u8; SHORT_STRING_LEN]; + for i in (0..SHORT_STRING_LEN).rev() { + bytes[i] = (inline & 0xFF) as u8; + inline >>= 8; + } + // SAFETY: StringDistinctCountAccumulator only inserts valid utf8 strings + unsafe { std::str::from_utf8_unchecked(&bytes) }.to_string() + } else { + let offset = offset_or_inline; + let len = header.len; + // SAFETY: buffer is only appended to, and we correctly inserted values + unsafe { std::str::from_utf8_unchecked(self.buffer.as_slice().get_unchecked(offset..offset + len)) }.to_string() + } + }) + } + + fn len(&self) -> usize { + self.header_set.len() + } } impl Debug for SSOStringHashSet { From e3b0568186addbcf3a9ad48ac5764793ed42a175 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 13 Jan 2024 14:48:23 +0800 Subject: [PATCH 05/30] short string test Signed-off-by: jayzhan211 --- .../src/aggregate/count_distinct.rs | 48 ++++++++++++------- .../sqllogictest/test_files/aggregate.slt | 25 ++++++++++ 2 files changed, 56 insertions(+), 17 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index fdfe38ef651c..3b09fded38ad 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -43,8 +43,8 @@ use crate::{AggregateExpr, PhysicalExpr}; use datafusion_common::cast::{as_list_array, as_primitive_array, as_string_array}; use datafusion_common::utils::array_into_list_array; use datafusion_common::{Result, ScalarValue}; -use datafusion_expr::Accumulator; use datafusion_execution::memory_pool::proxy::RawTableAllocExt; +use datafusion_expr::Accumulator; type DistinctScalarValues = ScalarValue; @@ -155,6 +155,8 @@ impl AggregateExpr for DistinctCount { Float32 => float_distinct_count_accumulator!(Float32Type), Float64 => float_distinct_count_accumulator!(Float64Type), + Utf8 => Ok(Box::new(StringDistinctCountAccumulator::new())), + _ => Ok(Box::new(DistinctCountAccumulator { values: HashSet::default(), state_data_type: self.state_data_type.clone(), @@ -218,12 +220,10 @@ impl Accumulator for DistinctCountAccumulator { fn state(&self) -> Result> { let scalars = self.values.iter().cloned().collect::>(); let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type); - println!("state: {:?}", arr); Ok(vec![ScalarValue::List(arr)]) } fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - println!("ub values: {:?}", values); if values.is_empty() { return Ok(()); } @@ -238,13 +238,11 @@ impl Accumulator for DistinctCountAccumulator { let scalar = ScalarValue::try_from_array(arr, index)?; self.values.insert(scalar); } - // println!("self.values: {:?}", self.values); Ok(()) }) } fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - println!("mb values: {:?}", states); if states.is_empty() { return Ok(()); } @@ -252,7 +250,6 @@ impl Accumulator for DistinctCountAccumulator { let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?; for scalars in scalar_vec.into_iter() { self.values.extend(scalars); - // println!("self.values: {:?}", self.values); } Ok(()) } @@ -490,7 +487,7 @@ impl Accumulator for StringDistinctCountAccumulator { arr.iter().try_for_each(|maybe_list| { if let Some(list) = maybe_list { let list = as_string_array(&list)?; - + list.iter().for_each(|value| { if let Some(value) = value { self.0.insert(value); @@ -506,7 +503,9 @@ impl Accumulator for StringDistinctCountAccumulator { } fn size(&self) -> usize { - todo!() + // Size of accumulator + // + SSOStringHashSet size + std::mem::size_of_val(self) + self.0.size() } } @@ -547,16 +546,14 @@ impl SSOStringHashSet { Self::default() } - fn with_capacities() -> Self { - todo!("with_capacities") - } - fn insert(&mut self, value: &str) { let value_len = value.len(); let value_bytes = value.as_bytes(); if value_len <= SHORT_STRING_LEN { - let inline = value_bytes.iter().fold(0usize, |acc, &x| acc << 8 | x as usize); + let inline = value_bytes + .iter() + .fold(0usize, |acc, &x| acc << 8 | x as usize); let short_string_header = SSOStringHeader { // no need for short string cases hash: 0, @@ -571,9 +568,10 @@ impl SSOStringHashSet { // if hash matches, check if the bytes match let offset = header.offset_or_inline; let len = header.len; - + // SAFETY: buffer is only appended to, and we correctly inserted values - let existing_value = unsafe { self.buffer.as_slice().get_unchecked(offset..offset + len) }; + let existing_value = + unsafe { self.buffer.as_slice().get_unchecked(offset..offset + len) }; value_bytes == existing_value }); @@ -586,7 +584,11 @@ impl SSOStringHashSet { len: value_len, offset_or_inline: offset, }; - self.long_string_map.insert_accounted(header, |header| header.hash, &mut self.map_size); + self.long_string_map.insert_accounted( + header, + |header| header.hash, + &mut self.map_size, + ); self.header_set.insert(header); } } @@ -609,7 +611,12 @@ impl SSOStringHashSet { let offset = offset_or_inline; let len = header.len; // SAFETY: buffer is only appended to, and we correctly inserted values - unsafe { std::str::from_utf8_unchecked(self.buffer.as_slice().get_unchecked(offset..offset + len)) }.to_string() + unsafe { + std::str::from_utf8_unchecked( + self.buffer.as_slice().get_unchecked(offset..offset + len), + ) + } + .to_string() } }) } @@ -617,6 +624,13 @@ impl SSOStringHashSet { fn len(&self) -> usize { self.header_set.len() } + + // NEED HELPED + fn size(&self) -> usize { + self.header_set.len() * mem::size_of::() + + self.map_size + + self.buffer.len() + } } impl Debug for SSOStringHashSet { diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 50cdebd054a7..c0418cb85895 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3260,3 +3260,28 @@ query I select count(*) from (select count(*) a, count(*) b from (select 1)); ---- 1 + +# Distinct Count for string + +statement ok +create table distinct_count_string_table as values + (1, 'a'), + (2, 'b'), + (2, 'c') +; + +# run through update_batch +query II +select count(distinct column1), count(distinct column2) from distinct_count_string_table; +---- +2 3 + +# run through merge_batch +query II rowsort +select count(distinct column1), count(distinct column2) from distinct_count_string_table group by column1; +---- +1 1 +1 2 + +statement ok +drop table distinct_count_string_table; From 12cf50c3da39516b06fc561571ca423c9a3e15bb Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 13 Jan 2024 14:53:55 +0800 Subject: [PATCH 06/30] add test Signed-off-by: jayzhan211 --- .../physical-expr/src/aggregate/build_in.rs | 13 ++++------ .../sqllogictest/test_files/aggregate.slt | 24 +++++++++++-------- 2 files changed, 19 insertions(+), 18 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/build_in.rs b/datafusion/physical-expr/src/aggregate/build_in.rs index d913eef272cf..c40f0db19405 100644 --- a/datafusion/physical-expr/src/aggregate/build_in.rs +++ b/datafusion/physical-expr/src/aggregate/build_in.rs @@ -59,14 +59,11 @@ pub fn create_aggregate_expr( (AggregateFunction::Count, false) => Arc::new( expressions::Count::new_with_multiple_exprs(input_phy_exprs, name, data_type), ), - (AggregateFunction::Count, true) => { - println!("go to distinct count"); - Arc::new(expressions::DistinctCount::new( - data_type, - input_phy_exprs[0].clone(), - name, - )) - } + (AggregateFunction::Count, true) => Arc::new(expressions::DistinctCount::new( + data_type, + input_phy_exprs[0].clone(), + name, + )), (AggregateFunction::Grouping, _) => Arc::new(expressions::Grouping::new( input_phy_exprs[0].clone(), name, diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index c0418cb85895..0f81257f94fd 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3265,23 +3265,27 @@ select count(*) from (select count(*) a, count(*) b from (select 1)); statement ok create table distinct_count_string_table as values - (1, 'a'), - (2, 'b'), - (2, 'c') + (1, 'a', 'longstringtest_a'), + (2, 'b', 'longstringtest_b1'), + (2, 'b', 'longstringtest_b2'), + (3, 'c', 'longstringtest_c1'), + (3, 'c', 'longstringtest_c2'), + (3, 'c', 'longstringtest_c3') ; # run through update_batch -query II -select count(distinct column1), count(distinct column2) from distinct_count_string_table; +query III +select count(distinct column1), count(distinct column2), count(distinct column3) from distinct_count_string_table; ---- -2 3 +3 3 6 # run through merge_batch -query II rowsort -select count(distinct column1), count(distinct column2) from distinct_count_string_table group by column1; +query III rowsort +select count(distinct column1), count(distinct column2), count(distinct column3) from distinct_count_string_table group by column1; ---- -1 1 -1 2 +1 1 1 +1 1 2 +1 1 3 statement ok drop table distinct_count_string_table; From 4f9a3f02ea2cec9bea50c72460615f5b810b6f9f Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 13 Jan 2024 15:02:00 +0800 Subject: [PATCH 07/30] remove unused Signed-off-by: jayzhan211 --- datafusion-cli/Cargo.lock | 1 + datafusion/core/tests/data/distinct_count.csv | 11 ----------- datafusion/physical-expr/Cargo.toml | 2 +- 3 files changed, 2 insertions(+), 12 deletions(-) delete mode 100644 datafusion/core/tests/data/distinct_count.csv diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 5663e736dbd8..77fd4b7eb68c 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1255,6 +1255,7 @@ dependencies = [ "blake3", "chrono", "datafusion-common", + "datafusion-execution", "datafusion-expr", "half", "hashbrown 0.14.3", diff --git a/datafusion/core/tests/data/distinct_count.csv b/datafusion/core/tests/data/distinct_count.csv deleted file mode 100644 index e9a65ceee4aa..000000000000 --- a/datafusion/core/tests/data/distinct_count.csv +++ /dev/null @@ -1,11 +0,0 @@ -c1,c2,c3 -1,20,0 -2,20,1 -3,10,2 -4,10,3 -5,30,4 -6,30,5 -7,30,6 -8,30,7 -9,30,8 -10,10,9 \ No newline at end of file diff --git a/datafusion/physical-expr/Cargo.toml b/datafusion/physical-expr/Cargo.toml index 75ccbed83929..61eba042f939 100644 --- a/datafusion/physical-expr/Cargo.toml +++ b/datafusion/physical-expr/Cargo.toml @@ -54,8 +54,8 @@ blake2 = { version = "^0.10.2", optional = true } blake3 = { version = "1.0", optional = true } chrono = { workspace = true } datafusion-common = { workspace = true } -datafusion-expr = { workspace = true } datafusion-execution = { workspace = true } +datafusion-expr = { workspace = true } half = { version = "2.1", default-features = false } hashbrown = { version = "0.14", features = ["raw"] } hex = { version = "0.4", optional = true } From 626b1cb5100e23f18b599cc71f311b7c0b958927 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 13 Jan 2024 15:41:45 +0800 Subject: [PATCH 08/30] to_string directly Signed-off-by: jayzhan211 --- .../physical-expr/src/aggregate/count_distinct.rs | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 3b09fded38ad..3c517b6b2b29 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -598,15 +598,7 @@ impl SSOStringHashSet { self.header_set.iter().map(|header| { let (is_short, offset_or_inline) = header.evaluate(); if is_short { - let mut inline = offset_or_inline; - // Convert usize to String - let mut bytes = [0u8; SHORT_STRING_LEN]; - for i in (0..SHORT_STRING_LEN).rev() { - bytes[i] = (inline & 0xFF) as u8; - inline >>= 8; - } - // SAFETY: StringDistinctCountAccumulator only inserts valid utf8 strings - unsafe { std::str::from_utf8_unchecked(&bytes) }.to_string() + offset_or_inline.to_string() } else { let offset = offset_or_inline; let len = header.len; From 2e80cb721d4417ce98010506405e2dde9902b827 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 13 Jan 2024 15:54:08 +0800 Subject: [PATCH 09/30] rewrite evaluate Signed-off-by: jayzhan211 --- .../src/aggregate/count_distinct.rs | 32 ++++++++----------- 1 file changed, 13 insertions(+), 19 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 3c517b6b2b29..3a3538ffcd95 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -521,11 +521,18 @@ struct SSOStringHeader { } impl SSOStringHeader { - fn evaluate(&self) -> (bool, usize) { + fn evaluate(&self, buffer: &[u8]) -> String { if self.len <= SHORT_STRING_LEN { - (true, self.offset_or_inline) + self.offset_or_inline.to_string() } else { - (false, self.offset_or_inline) + let offset = self.offset_or_inline; + // SAFETY: buffer is only appended to, and we correctly inserted values + unsafe { + std::str::from_utf8_unchecked( + buffer.get_unchecked(offset..offset + self.len), + ) + } + .to_string() } } } @@ -595,22 +602,9 @@ impl SSOStringHashSet { } fn iter(&self) -> impl Iterator + '_ { - self.header_set.iter().map(|header| { - let (is_short, offset_or_inline) = header.evaluate(); - if is_short { - offset_or_inline.to_string() - } else { - let offset = offset_or_inline; - let len = header.len; - // SAFETY: buffer is only appended to, and we correctly inserted values - unsafe { - std::str::from_utf8_unchecked( - self.buffer.as_slice().get_unchecked(offset..offset + len), - ) - } - .to_string() - } - }) + self.header_set + .iter() + .map(|header| header.evaluate(self.buffer.as_slice())) } fn len(&self) -> usize { From d2d1d6d1a5411a4a19db5a79705eac12ab9c75da Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 13 Jan 2024 15:55:48 +0800 Subject: [PATCH 10/30] return Vec Signed-off-by: jayzhan211 --- datafusion/physical-expr/src/aggregate/count_distinct.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 3a3538ffcd95..a34d60cc83ad 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -601,10 +601,10 @@ impl SSOStringHashSet { } } - fn iter(&self) -> impl Iterator + '_ { + fn iter(&self) -> Vec { self.header_set .iter() - .map(|header| header.evaluate(self.buffer.as_slice())) + .map(|header| header.evaluate(self.buffer.as_slice())).collect() } fn len(&self) -> usize { From ebb8726134d82f9e5e942d5334fb75f10a92cd1d Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Sat, 13 Jan 2024 15:56:01 +0800 Subject: [PATCH 11/30] fmt Signed-off-by: jayzhan211 --- datafusion/physical-expr/src/aggregate/count_distinct.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index a34d60cc83ad..cc320644c399 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -604,7 +604,8 @@ impl SSOStringHashSet { fn iter(&self) -> Vec { self.header_set .iter() - .map(|header| header.evaluate(self.buffer.as_slice())).collect() + .map(|header| header.evaluate(self.buffer.as_slice())) + .collect() } fn len(&self) -> usize { From 98a9cd1a7b4f8ee632643b1b70592a96da9e8259 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Tue, 16 Jan 2024 19:49:26 +0800 Subject: [PATCH 12/30] add more queries Signed-off-by: jayzhan211 --- benchmarks/queries/clickbench/README.md | 17 +++++++++++++++++ benchmarks/queries/clickbench/extended.sql | 4 +++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/benchmarks/queries/clickbench/README.md b/benchmarks/queries/clickbench/README.md index d5105afd4832..9e41077ad785 100644 --- a/benchmarks/queries/clickbench/README.md +++ b/benchmarks/queries/clickbench/README.md @@ -28,6 +28,23 @@ SELECT FROM hits; ``` +### Q1 +Query to test distinct count for String. Three of them are all small string (length either 1 or 2). + +```sql +SELECT + COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") +FROM hits; +``` + +### Q2 +Query to test distinct count for String. "URL" has length greater than 8 + +```sql +SELECT + COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "URL") +FROM hits; +``` diff --git a/benchmarks/queries/clickbench/extended.sql b/benchmarks/queries/clickbench/extended.sql index 82c0266af61a..d52e7c0861f0 100644 --- a/benchmarks/queries/clickbench/extended.sql +++ b/benchmarks/queries/clickbench/extended.sql @@ -1 +1,3 @@ -SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; \ No newline at end of file +SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; +SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; +SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "URL") FROM hits; \ No newline at end of file From 07831faac7d15dbafae4a317a24f9594d8d0dcc4 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 17 Jan 2024 20:18:40 +0800 Subject: [PATCH 13/30] add group by query and rewrite evalute with state() Signed-off-by: jayzhan211 --- benchmarks/queries/clickbench/README.md | 8 ++- benchmarks/queries/clickbench/extended.sql | 2 +- .../src/aggregate/count_distinct.rs | 68 +++++++++++-------- 3 files changed, 46 insertions(+), 32 deletions(-) diff --git a/benchmarks/queries/clickbench/README.md b/benchmarks/queries/clickbench/README.md index 9e41077ad785..69aa580c46db 100644 --- a/benchmarks/queries/clickbench/README.md +++ b/benchmarks/queries/clickbench/README.md @@ -29,6 +29,7 @@ FROM hits; ``` ### Q1 +Models initial Data exploration, to understand some statistics of data. Query to test distinct count for String. Three of them are all small string (length either 1 or 2). ```sql @@ -38,12 +39,13 @@ FROM hits; ``` ### Q2 -Query to test distinct count for String. "URL" has length greater than 8 +Models initial Data exploration, to understand some statistics of data. +Extend with `group by` from Q1 ```sql SELECT - COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "URL") -FROM hits; + "BrowserCountry", COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") +FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; ``` diff --git a/benchmarks/queries/clickbench/extended.sql b/benchmarks/queries/clickbench/extended.sql index d52e7c0861f0..5972a175b1e0 100644 --- a/benchmarks/queries/clickbench/extended.sql +++ b/benchmarks/queries/clickbench/extended.sql @@ -1,3 +1,3 @@ SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; -SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "URL") FROM hits; \ No newline at end of file +SELECT "BrowserCountry", COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; \ No newline at end of file diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index cc320644c399..9285d1876a75 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -24,7 +24,7 @@ use arrow_array::types::{ TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use arrow_array::{PrimitiveArray, StringArray}; -use arrow_buffer::BufferBuilder; +use arrow_buffer::{BufferBuilder, MutableBuffer, OffsetBuffer}; use std::any::Any; use std::cmp::Eq; @@ -453,7 +453,7 @@ impl StringDistinctCountAccumulator { impl Accumulator for StringDistinctCountAccumulator { fn state(&self) -> Result> { - let arr = StringArray::from_iter_values(self.0.iter()); + let arr = self.0.state(); let list = Arc::new(array_into_list_array(Arc::new(arr))); Ok(vec![ScalarValue::List(list)]) } @@ -515,37 +515,28 @@ const SHORT_STRING_LEN: usize = mem::size_of::(); struct SSOStringHeader { /// hash of the string value (used when resizing table) hash: u64, - + /// length of the string len: usize, + /// short strings are stored inline, long strings are stored in the buffer offset_or_inline: usize, } -impl SSOStringHeader { - fn evaluate(&self, buffer: &[u8]) -> String { - if self.len <= SHORT_STRING_LEN { - self.offset_or_inline.to_string() - } else { - let offset = self.offset_or_inline; - // SAFETY: buffer is only appended to, and we correctly inserted values - unsafe { - std::str::from_utf8_unchecked( - buffer.get_unchecked(offset..offset + self.len), - ) - } - .to_string() - } - } -} - // Short String Optimizated HashSet for String // Equivalent to HashSet but with better memory usage #[derive(Default)] struct SSOStringHashSet { + /// Core of the HashSet, it stores both the short and long string headers header_set: HashSet, + /// Used to check if the long string already exists long_string_map: hashbrown::raw::RawTable, + /// Total size of the map in bytes map_size: usize, + /// Buffer containing all long strings buffer: BufferBuilder, + /// The random state used to generate hashes state: RandomState, + /// Used for capacity calculation, equivalent to the sum of all string lengths + size_hint: usize, } impl SSOStringHashSet { @@ -555,6 +546,7 @@ impl SSOStringHashSet { fn insert(&mut self, value: &str) { let value_len = value.len(); + self.size_hint += value_len; let value_bytes = value.as_bytes(); if value_len <= SHORT_STRING_LEN { @@ -562,8 +554,7 @@ impl SSOStringHashSet { .iter() .fold(0usize, |acc, &x| acc << 8 | x as usize); let short_string_header = SSOStringHeader { - // no need for short string cases - hash: 0, + hash: 0, // no need for short string cases len: value_len, offset_or_inline: inline, }; @@ -601,18 +592,39 @@ impl SSOStringHashSet { } } - fn iter(&self) -> Vec { - self.header_set - .iter() - .map(|header| header.evaluate(self.buffer.as_slice())) - .collect() + // Returns a StringArray with the current state of the set + fn state(&self) -> StringArray { + let mut offsets = Vec::with_capacity(self.size_hint + 1); + offsets.push(0); + + let mut values = MutableBuffer::new(0); + let buffer = self.buffer.as_slice(); + + for header in self.header_set.iter() { + let s = if header.len <= SHORT_STRING_LEN { + let inline = header.offset_or_inline; + // convert usize to &[u8] + let ptr = &inline as *const usize as *const u8; + let len = std::mem::size_of::(); + unsafe { std::slice::from_raw_parts(ptr, len) } + } else { + let offset = header.offset_or_inline; + // SAFETY: buffer is only appended to, and we correctly inserted values + unsafe { buffer.get_unchecked(offset..offset + header.len) } + }; + + values.extend_from_slice(s); + offsets.push(values.len() as i32); + } + + let value_offsets = OffsetBuffer::::new(offsets.into()); + StringArray::new(value_offsets, values.into(), None) } fn len(&self) -> usize { self.header_set.len() } - // NEED HELPED fn size(&self) -> usize { self.header_set.len() * mem::size_of::() + self.map_size From 62c80849bddc44a285c4ff0a379ca3079b24e89e Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 17 Jan 2024 21:23:30 +0800 Subject: [PATCH 14/30] move evaluate back Signed-off-by: jayzhan211 --- .../src/aggregate/count_distinct.rs | 32 +++++++++++-------- .../sqllogictest/test_files/clickbench.slt | 3 ++ 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 9285d1876a75..48ed7a301707 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -521,6 +521,23 @@ struct SSOStringHeader { offset_or_inline: usize, } +impl SSOStringHeader { + fn evaluate(&self, buffer: &[u8]) -> String { + if self.len <= SHORT_STRING_LEN { + self.offset_or_inline.to_string() + } else { + let offset = self.offset_or_inline; + // SAFETY: buffer is only appended to, and we correctly inserted values + unsafe { + std::str::from_utf8_unchecked( + buffer.get_unchecked(offset..offset + self.len), + ) + } + .to_string() + } + } +} + // Short String Optimizated HashSet for String // Equivalent to HashSet but with better memory usage #[derive(Default)] @@ -601,19 +618,8 @@ impl SSOStringHashSet { let buffer = self.buffer.as_slice(); for header in self.header_set.iter() { - let s = if header.len <= SHORT_STRING_LEN { - let inline = header.offset_or_inline; - // convert usize to &[u8] - let ptr = &inline as *const usize as *const u8; - let len = std::mem::size_of::(); - unsafe { std::slice::from_raw_parts(ptr, len) } - } else { - let offset = header.offset_or_inline; - // SAFETY: buffer is only appended to, and we correctly inserted values - unsafe { buffer.get_unchecked(offset..offset + header.len) } - }; - - values.extend_from_slice(s); + let s = header.evaluate(buffer); + values.extend_from_slice(s.as_bytes()); offsets.push(values.len() as i32); } diff --git a/datafusion/sqllogictest/test_files/clickbench.slt b/datafusion/sqllogictest/test_files/clickbench.slt index 21befd78226e..b61bee670811 100644 --- a/datafusion/sqllogictest/test_files/clickbench.slt +++ b/datafusion/sqllogictest/test_files/clickbench.slt @@ -273,3 +273,6 @@ SELECT "WindowClientWidth", "WindowClientHeight", COUNT(*) AS PageViews FROM hit query PI SELECT DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) AS M, COUNT(*) AS PageViews FROM hits WHERE "CounterID" = 62 AND "EventDate"::INT::DATE >= '2013-07-14' AND "EventDate"::INT::DATE <= '2013-07-15' AND "IsRefresh" = 0 AND "DontCountHits" = 0 GROUP BY DATE_TRUNC('minute', to_timestamp_seconds("EventTime")) ORDER BY DATE_TRUNC('minute', M) LIMIT 10 OFFSET 1000; ---- + +query +drop table hits; \ No newline at end of file From e3b65c80fffff26a05f872331c943b56de521eea Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 17 Jan 2024 21:56:17 +0800 Subject: [PATCH 15/30] upd test Signed-off-by: jayzhan211 --- .../sqllogictest/test_files/aggregate.slt | 29 ++++++++++--------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 0f81257f94fd..d81b65b0ed0d 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3263,29 +3263,30 @@ select count(*) from (select count(*) a, count(*) b from (select 1)); # Distinct Count for string +# UTF8 string matters for string to &[u8] conversion, add it to prevent regression statement ok create table distinct_count_string_table as values - (1, 'a', 'longstringtest_a'), - (2, 'b', 'longstringtest_b1'), - (2, 'b', 'longstringtest_b2'), - (3, 'c', 'longstringtest_c1'), - (3, 'c', 'longstringtest_c2'), - (3, 'c', 'longstringtest_c3') + (1, 'a', 'longstringtest_a', '台灣'), + (2, 'b', 'longstringtest_b1', '日本'), + (2, 'b', 'longstringtest_b2', '中國'), + (3, 'c', 'longstringtest_c1', '美國'), + (3, 'c', 'longstringtest_c2', '歐洲'), + (3, 'c', 'longstringtest_c3', '韓國') ; # run through update_batch -query III -select count(distinct column1), count(distinct column2), count(distinct column3) from distinct_count_string_table; +query IIII +select count(distinct column1), count(distinct column2), count(distinct column3), count(distinct column4) from distinct_count_string_table; ---- -3 3 6 +3 3 6 6 # run through merge_batch -query III rowsort -select count(distinct column1), count(distinct column2), count(distinct column3) from distinct_count_string_table group by column1; +query IIII +select count(distinct column1), count(distinct column2), count(distinct column3), count(distinct column4) from distinct_count_string_table group by column1; ---- -1 1 1 -1 1 2 -1 1 3 +1 1 1 1 +1 1 2 2 +1 1 3 3 statement ok drop table distinct_count_string_table; From 3f0e9a9aaaab31a19512ed64d27015af1fb027cf Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Wed, 17 Jan 2024 22:16:12 +0800 Subject: [PATCH 16/30] add row sort Signed-off-by: jayzhan211 --- datafusion/sqllogictest/test_files/aggregate.slt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index d81b65b0ed0d..5351f3d3f871 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3281,7 +3281,7 @@ select count(distinct column1), count(distinct column2), count(distinct column3) 3 3 6 6 # run through merge_batch -query IIII +query IIII rowsort select count(distinct column1), count(distinct column2), count(distinct column3), count(distinct column4) from distinct_count_string_table group by column1; ---- 1 1 1 1 From 0475687e578c4e96ae7ffd0d555fc52db6e1d1f3 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 20 Jan 2024 05:50:18 -0500 Subject: [PATCH 17/30] Update benchmarks/queries/clickbench/README.md --- benchmarks/queries/clickbench/README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/benchmarks/queries/clickbench/README.md b/benchmarks/queries/clickbench/README.md index 69aa580c46db..00699d90a3fe 100644 --- a/benchmarks/queries/clickbench/README.md +++ b/benchmarks/queries/clickbench/README.md @@ -30,6 +30,7 @@ FROM hits; ### Q1 Models initial Data exploration, to understand some statistics of data. +Models initial Data exploration, to understand some statistics of data. Query to test distinct count for String. Three of them are all small string (length either 1 or 2). ```sql From a764e997e968f6747b78f227f1423c43cdc97d35 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 20 Jan 2024 07:42:13 -0500 Subject: [PATCH 18/30] Rework set to avoid copies --- .../src/aggregate/count_distinct.rs | 412 +++++++++++++----- 1 file changed, 291 insertions(+), 121 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 48ed7a301707..fe7c9dcd2ba7 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -15,7 +15,19 @@ // specific language governing permissions and limitations // under the License. +use std::any::Any; +use std::cmp::Eq; +use std::collections::HashSet; +use std::fmt::Debug; +use std::hash::Hash; +use std::mem; +use std::ops::Range; +use std::sync::{Arc, Mutex}; + +use ahash::RandomState; +use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field, TimeUnit}; +use arrow_array::cast::AsArray; use arrow_array::types::{ ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, @@ -24,28 +36,19 @@ use arrow_array::types::{ TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; use arrow_array::{PrimitiveArray, StringArray}; -use arrow_buffer::{BufferBuilder, MutableBuffer, OffsetBuffer}; +use arrow_buffer::{BufferBuilder, OffsetBuffer, ScalarBuffer}; -use std::any::Any; -use std::cmp::Eq; -use std::fmt::Debug; -use std::hash::Hash; -use std::mem; -use std::sync::Arc; - -use ahash::RandomState; -use arrow::array::{Array, ArrayRef}; -use std::collections::HashSet; - -use crate::aggregate::utils::{down_cast_any_ref, Hashable}; -use crate::expressions::format_state_name; -use crate::{AggregateExpr, PhysicalExpr}; -use datafusion_common::cast::{as_list_array, as_primitive_array, as_string_array}; +use datafusion_common::cast::{as_list_array, as_primitive_array}; +use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::array_into_list_array; use datafusion_common::{Result, ScalarValue}; use datafusion_execution::memory_pool::proxy::RawTableAllocExt; use datafusion_expr::Accumulator; +use crate::aggregate::utils::{down_cast_any_ref, Hashable}; +use crate::expressions::format_state_name; +use crate::{AggregateExpr, PhysicalExpr}; + type DistinctScalarValues = ScalarValue; /// Expression for a COUNT(DISTINCT) aggregation. @@ -444,16 +447,22 @@ where } #[derive(Debug)] -struct StringDistinctCountAccumulator(SSOStringHashSet); +struct StringDistinctCountAccumulator(Mutex); impl StringDistinctCountAccumulator { fn new() -> Self { - Self(SSOStringHashSet::new()) + Self(Mutex::new(SSOStringHashSet::new())) } } impl Accumulator for StringDistinctCountAccumulator { fn state(&self) -> Result> { - let arr = self.0.state(); + // TODO this should not need a lock/clone (should make + // `Accumulator::state` take a mutable reference) + let mut lk = self.0.lock().unwrap(); + let set: &mut SSOStringHashSet = &mut lk; + // take the state out of the string set and replace with default + let set = std::mem::take(set); + let arr = set.into_state(); let list = Arc::new(array_into_list_array(Arc::new(arr))); Ok(vec![ScalarValue::List(list)]) } @@ -463,12 +472,7 @@ impl Accumulator for StringDistinctCountAccumulator { return Ok(()); } - let arr = as_string_array(&values[0])?; - arr.iter().for_each(|value| { - if let Some(value) = value { - self.0.insert(value); - } - }); + self.0.lock().unwrap().insert(values[0].clone()); Ok(()) } @@ -486,74 +490,64 @@ impl Accumulator for StringDistinctCountAccumulator { let arr = as_list_array(&states[0])?; arr.iter().try_for_each(|maybe_list| { if let Some(list) = maybe_list { - let list = as_string_array(&list)?; - - list.iter().for_each(|value| { - if let Some(value) = value { - self.0.insert(value); - } - }) + self.0.lock().unwrap().insert(list); }; Ok(()) }) } fn evaluate(&self) -> Result { - Ok(ScalarValue::Int64(Some(self.0.len() as i64))) + Ok(ScalarValue::Int64( + Some(self.0.lock().unwrap().len() as i64), + )) } fn size(&self) -> usize { // Size of accumulator // + SSOStringHashSet size - std::mem::size_of_val(self) + self.0.size() + std::mem::size_of_val(self) + self.0.lock().unwrap().size() } } +/// Maximum size of a string that can be inlined in the hash table const SHORT_STRING_LEN: usize = mem::size_of::(); +/// Entry that is stored in the actual hash table #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] struct SSOStringHeader { - /// hash of the string value (used when resizing table) + /// hash of the string value (stored to avoid recomputing it when checking) + /// TODO can we simply recreate when needed hash: u64, - /// length of the string + /// length of the string, in bytes len: usize, - /// short strings are stored inline, long strings are stored in the buffer + /// if len =< SHORT_STRING_LEN: the string data inlined + /// if len > SHORT_STRING_LEN, the offset offset_or_inline: usize, } +impl SSOStringHeader {} + impl SSOStringHeader { - fn evaluate(&self, buffer: &[u8]) -> String { - if self.len <= SHORT_STRING_LEN { - self.offset_or_inline.to_string() - } else { - let offset = self.offset_or_inline; - // SAFETY: buffer is only appended to, and we correctly inserted values - unsafe { - std::str::from_utf8_unchecked( - buffer.get_unchecked(offset..offset + self.len), - ) - } - .to_string() - } + /// returns self.offset..self.offset + self.len + fn range(&self) -> Range { + self.offset_or_inline..self.offset_or_inline + self.len } } -// Short String Optimizated HashSet for String +// Short String Optimized HashSet for String // Equivalent to HashSet but with better memory usage #[derive(Default)] struct SSOStringHashSet { - /// Core of the HashSet, it stores both the short and long string headers - header_set: HashSet, - /// Used to check if the long string already exists - long_string_map: hashbrown::raw::RawTable, - /// Total size of the map in bytes + /// Store entries for each distinct string + map: hashbrown::raw::RawTable, + /// Total size of the map in bytes (TODO) map_size: usize, /// Buffer containing all long strings buffer: BufferBuilder, /// The random state used to generate hashes - state: RandomState, - /// Used for capacity calculation, equivalent to the sum of all string lengths - size_hint: usize, + random_state: RandomState, + // buffer to be reused to store hashes + hashes_buffer: Vec, } impl SSOStringHashSet { @@ -561,99 +555,183 @@ impl SSOStringHashSet { Self::default() } - fn insert(&mut self, value: &str) { - let value_len = value.len(); - self.size_hint += value_len; - let value_bytes = value.as_bytes(); + fn insert(&mut self, values: ArrayRef) { + // step 1: compute hashes for the strings + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(values.len(), 0); + create_hashes(&[values.clone()], &self.random_state, batch_hashes) + // hash is supported for all string types and create_hashes only + // returns errors for unsupported types + .unwrap(); - if value_len <= SHORT_STRING_LEN { - let inline = value_bytes - .iter() - .fold(0usize, |acc, &x| acc << 8 | x as usize); - let short_string_header = SSOStringHeader { - hash: 0, // no need for short string cases - len: value_len, - offset_or_inline: inline, - }; - self.header_set.insert(short_string_header); - } else { - let hash = self.state.hash_one(value_bytes); + // TODO make this generic (to support large strings) + let values = values.as_string::(); - let entry = self.long_string_map.get_mut(hash, |header| { - // if hash matches, check if the bytes match - let offset = header.offset_or_inline; - let len = header.len; + // step 2: insert each string into the set, if not already present - // SAFETY: buffer is only appended to, and we correctly inserted values - let existing_value = - unsafe { self.buffer.as_slice().get_unchecked(offset..offset + len) }; + // Assert for unsafe values call + assert_eq!(values.len(), batch_hashes.len()); - value_bytes == existing_value - }); + for (value, &hash) in values.iter().zip(batch_hashes.iter()) { + // count distinct ignores nulls + let Some(value) = value else { + continue; + }; + + // from here on only use bytes (not str/chars) for value + let value = value.as_bytes(); - if entry.is_none() { - let offset = self.buffer.len(); - self.buffer.append_slice(value_bytes); - let header = SSOStringHeader { - hash, - len: value_len, - offset_or_inline: offset, - }; - self.long_string_map.insert_accounted( - header, - |header| header.hash, - &mut self.map_size, - ); - self.header_set.insert(header); + if value.len() <= SHORT_STRING_LEN { + let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x as usize); + + // Check if the value is already present in the set + let entry = self.map.get_mut(hash, |header| { + // if hash matches, must also compare the values + if header.len != value.len() { + return false; + } + inline == header.offset_or_inline + }); + + // Insert an entry for this value if it is not present + if entry.is_none() { + let new_header = SSOStringHeader { + hash, + len: value.len(), + offset_or_inline: inline, + }; + self.map.insert_accounted( + new_header, + |header| header.hash, + &mut self.map_size, + ); + } + } + // handle large strings + else { + // Check if the value is already present in the set + let entry = self.map.get_mut(hash, |header| { + // if hash matches, must also compare the values + if header.len != value.len() { + return false; + } + // SAFETY: buffer is only appended to, and we correctly inserted values + let existing_value = + unsafe { self.buffer.as_slice().get_unchecked(header.range()) }; + value == existing_value + }); + + // Insert the value if it is not present + if entry.is_none() { + // long strings are stored as a length/offset into the buffer + let offset = self.buffer.len(); + self.buffer.append_slice(value); + let new_header = SSOStringHeader { + hash, + len: value.len(), + offset_or_inline: offset, + }; + self.map.insert_accounted( + new_header, + |header| header.hash, + &mut self.map_size, + ); + } } } } - // Returns a StringArray with the current state of the set - fn state(&self) -> StringArray { - let mut offsets = Vec::with_capacity(self.size_hint + 1); - offsets.push(0); + /// Converts this set into a StringArray of the distinct string values + fn into_state(self) -> StringArray { + // The map contains entries that have offsets in some arbitrary order + // but the buffer contains the actual strings in the order they were inserted + // so we need to build offsets for the strings in the buffer in order + // then append short strings, if any, and then build the StringArray + // TODO a picture would be nice here + let Self { + map, + map_size: _, + mut buffer, + random_state: _, + hashes_buffer: _, + } = self; + + // Sort all headers so that long strings come first, in offset order + // followed by short strings ordered by value + let mut headers = map.into_iter().collect::>(); + headers.sort_unstable_by(|a, b| { + if a.len <= SHORT_STRING_LEN && b.len <= SHORT_STRING_LEN { + // both are short strings, compare the inlined values + a.offset_or_inline.cmp(&b.offset_or_inline) + } else if a.len <= SHORT_STRING_LEN { + // a is a short string, b is a long string + // (long strings sort before short strings) + std::cmp::Ordering::Greater + } else if b.len <= SHORT_STRING_LEN { + // a is a long string, b is a short string + // (long strings sort before short strings) + std::cmp::Ordering::Less + } else { + // both are long strings, sort by offsets + a.offset_or_inline.cmp(&b.offset_or_inline) + } + }); - let mut values = MutableBuffer::new(0); - let buffer = self.buffer.as_slice(); + // create offsets for the long strings + let offsets: ScalarBuffer<_> = std::iter::once(0) + .chain(headers.into_iter().map(|header| { + if header.len > SHORT_STRING_LEN { + // long strings are already stored in the buffer, so take + // offset directly + (header.offset_or_inline + header.len) as i32 + } else { + // short strings are inlined, so append their bytes to the + // buffer now + // a string like {10, 20, 30} was stored as [30, 20, 10] + // so need to reverse here + // todo maybe we could cast directly to *u8 and avoid this shifting / finagling + for i in 0..header.len { + let shift = 8 * (header.len - i - 1); + let mask = 0xffusize << shift; + let v = ((header.offset_or_inline & mask) >> shift) as u8; + buffer.append(v); + } + buffer.len() as i32 + } + })) + .collect(); - for header in self.header_set.iter() { - let s = header.evaluate(buffer); - values.extend_from_slice(s.as_bytes()); - offsets.push(values.len() as i32); - } + // get the values and reset self.buffer + let values = buffer.finish(); - let value_offsets = OffsetBuffer::::new(offsets.into()); - StringArray::new(value_offsets, values.into(), None) + let nulls = None; // count distinct ignores nulls + // todo could use unchecked to avoid utf8 validation + StringArray::new(OffsetBuffer::new(offsets), values, nulls) } fn len(&self) -> usize { - self.header_set.len() + self.map.len() } fn size(&self) -> usize { - self.header_set.len() * mem::size_of::() - + self.map_size - + self.buffer.len() + self.map_size + self.buffer.len() } } impl Debug for SSOStringHashSet { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SSOStringHashSet") - .field("header_set", &self.header_set) - // TODO: Print long_string_map + .field("map", &"") .field("map_size", &self.map_size) .field("buffer", &self.buffer) - .field("state", &self.state) + .field("random_state", &self.random_state) + .field("hashes_buffer", &self.hashes_buffer) .finish() } } #[cfg(test)] mod tests { - use crate::expressions::NoOp; - - use super::*; use arrow::array::{ ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array, @@ -665,10 +743,15 @@ mod tests { }; use arrow_array::Decimal256Array; use arrow_buffer::i256; + use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; use datafusion_common::internal_err; use datafusion_common::DataFusionError; + use crate::expressions::NoOp; + + use super::*; + macro_rules! state_to_vec_primitive { ($LIST:expr, $DATA_TYPE:ident) => {{ let arr = ScalarValue::raw_data($LIST).unwrap(); @@ -1055,4 +1138,91 @@ mod tests { assert_eq!(result, ScalarValue::Int64(Some(2))); Ok(()) } + #[test] + fn string_set_empty() { + for values in [StringArray::new_null(0), StringArray::new_null(11)] { + let mut set = SSOStringHashSet::new(); + set.insert(Arc::new(values)); + assert_set(set, &[]); + } + } + + #[test] + fn string_set_basic() { + // basic test for mixed small and large string values + let values = StringArray::from(vec![ + Some("a"), + Some("b"), + Some("CXCCCCCCCC"), // 10 bytes + Some(""), + Some("cbcxx"), // 5 bytes + None, + Some("AAAAAAAA"), // 8 bytes + Some("BBBBBQBBB"), // 9 bytes + Some("a"), + Some("cbcxx"), + Some("b"), + Some("cbcxx"), + Some(""), + None, + Some("BBBBBQBBB"), + Some("BBBBBQBBB"), + Some("AAAAAAAA"), + Some("CXCCCCCCCC"), + ]); + + let mut set = SSOStringHashSet::new(); + set.insert(Arc::new(values)); + assert_set( + set, + &[ + Some(""), + Some("AAAAAAAA"), + Some("BBBBBQBBB"), + Some("CXCCCCCCCC"), + Some("a"), + Some("b"), + Some("cbcxx"), + ], + ); + } + + #[test] + fn string_set_non_utf8() { + // basic test for mixed small and large string values + let values = StringArray::from(vec![ + Some("a"), + Some("✨🔥"), + Some("🔥"), + Some("✨✨✨"), + Some("foobarbaz"), + Some("🔥"), + Some("✨🔥"), + ]); + + let mut set = SSOStringHashSet::new(); + set.insert(Arc::new(values)); + assert_set( + set, + &[ + Some("a"), + Some("foobarbaz"), + Some("✨✨✨"), + Some("✨🔥"), + Some("🔥"), + ], + ); + } + + // asserts that the set contains the expected strings + fn assert_set(set: SSOStringHashSet, expected: &[Option<&str>]) { + let strings = set.into_state(); + let mut state = strings.into_iter().collect::>(); + state.sort(); + assert_eq!(state, expected); + } + + // TODO fuzz testing + + // inserting strings into the set does not increase reported memoyr } From a101b62557d0c0aaa0402256cc4ca249995d5b04 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 20 Jan 2024 12:44:10 -0500 Subject: [PATCH 19/30] Simplify offset construction --- .../src/aggregate/count_distinct.rs | 87 ++++++++----------- 1 file changed, 35 insertions(+), 52 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index fe7c9dcd2ba7..0a7d1efb3f11 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -521,7 +521,7 @@ struct SSOStringHeader { /// length of the string, in bytes len: usize, /// if len =< SHORT_STRING_LEN: the string data inlined - /// if len > SHORT_STRING_LEN, the offset + /// if len > SHORT_STRING_LEN, the offset of where the data starts offset_or_inline: usize, } @@ -536,23 +536,38 @@ impl SSOStringHeader { // Short String Optimized HashSet for String // Equivalent to HashSet but with better memory usage -#[derive(Default)] struct SSOStringHashSet { /// Store entries for each distinct string map: hashbrown::raw::RawTable, /// Total size of the map in bytes (TODO) map_size: usize, - /// Buffer containing all long strings + /// Buffer containing all string values buffer: BufferBuilder, + /// offsets into buffer of the distinct values. These are the same offsets + /// as are used for a GenericStringArray + offsets: Vec, /// The random state used to generate hashes random_state: RandomState, // buffer to be reused to store hashes hashes_buffer: Vec, } +impl Default for SSOStringHashSet { + fn default() -> Self { + Self::new() + } +} + impl SSOStringHashSet { fn new() -> Self { - Self::default() + Self { + map: hashbrown::raw::RawTable::new(), + map_size: 0, + buffer: BufferBuilder::new(0), + offsets: vec![0], // first offset is always 0 + random_state: RandomState::new(), + hashes_buffer: vec![], + } } fn insert(&mut self, values: ArrayRef) { @@ -596,6 +611,12 @@ impl SSOStringHashSet { // Insert an entry for this value if it is not present if entry.is_none() { + // Put the small values into buffer and output so it appears + // the output array, but store the actual bytes inline + self.buffer.append_slice(value); + self.offsets.push(self.buffer.len() as i32); + + // store the actual value inline let new_header = SSOStringHeader { hash, len: value.len(), @@ -619,14 +640,17 @@ impl SSOStringHashSet { // SAFETY: buffer is only appended to, and we correctly inserted values let existing_value = unsafe { self.buffer.as_slice().get_unchecked(header.range()) }; + value == existing_value }); // Insert the value if it is not present if entry.is_none() { // long strings are stored as a length/offset into the buffer - let offset = self.buffer.len(); + let offset = self.buffer.len(); // offset of start fof data self.buffer.append_slice(value); + self.offsets.push(self.buffer.len() as i32); + let new_header = SSOStringHeader { hash, len: value.len(), @@ -650,64 +674,23 @@ impl SSOStringHashSet { // then append short strings, if any, and then build the StringArray // TODO a picture would be nice here let Self { - map, + map: _, map_size: _, + offsets, mut buffer, random_state: _, hashes_buffer: _, } = self; - // Sort all headers so that long strings come first, in offset order - // followed by short strings ordered by value - let mut headers = map.into_iter().collect::>(); - headers.sort_unstable_by(|a, b| { - if a.len <= SHORT_STRING_LEN && b.len <= SHORT_STRING_LEN { - // both are short strings, compare the inlined values - a.offset_or_inline.cmp(&b.offset_or_inline) - } else if a.len <= SHORT_STRING_LEN { - // a is a short string, b is a long string - // (long strings sort before short strings) - std::cmp::Ordering::Greater - } else if b.len <= SHORT_STRING_LEN { - // a is a long string, b is a short string - // (long strings sort before short strings) - std::cmp::Ordering::Less - } else { - // both are long strings, sort by offsets - a.offset_or_inline.cmp(&b.offset_or_inline) - } - }); - - // create offsets for the long strings - let offsets: ScalarBuffer<_> = std::iter::once(0) - .chain(headers.into_iter().map(|header| { - if header.len > SHORT_STRING_LEN { - // long strings are already stored in the buffer, so take - // offset directly - (header.offset_or_inline + header.len) as i32 - } else { - // short strings are inlined, so append their bytes to the - // buffer now - // a string like {10, 20, 30} was stored as [30, 20, 10] - // so need to reverse here - // todo maybe we could cast directly to *u8 and avoid this shifting / finagling - for i in 0..header.len { - let shift = 8 * (header.len - i - 1); - let mask = 0xffusize << shift; - let v = ((header.offset_or_inline & mask) >> shift) as u8; - buffer.append(v); - } - buffer.len() as i32 - } - })) - .collect(); + // Add any + let offsets: ScalarBuffer<_> = offsets.into(); // get the values and reset self.buffer let values = buffer.finish(); let nulls = None; // count distinct ignores nulls - // todo could use unchecked to avoid utf8 validation - StringArray::new(OffsetBuffer::new(offsets), values, nulls) + // SAFETY: all the values that went in are coming + unsafe { StringArray::new_unchecked(OffsetBuffer::new(offsets), values, nulls) } } fn len(&self) -> usize { From 0f2fa02542883001d58680daf4156bb41bc36ba8 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sat, 20 Jan 2024 15:32:35 -0500 Subject: [PATCH 20/30] fmt --- datafusion/physical-expr/src/aggregate/count_distinct.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 0a7d1efb3f11..70edf6102541 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -689,7 +689,7 @@ impl SSOStringHashSet { let values = buffer.finish(); let nulls = None; // count distinct ignores nulls - // SAFETY: all the values that went in are coming + // SAFETY: all the values that went in are coming unsafe { StringArray::new_unchecked(OffsetBuffer::new(offsets), values, nulls) } } From 489e1306fa7df1ec70f4ddbe594e66c3bae36b93 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 21 Jan 2024 11:56:12 -0500 Subject: [PATCH 21/30] Improve comments --- .../src/aggregate/count_distinct.rs | 49 +++++++++++++------ 1 file changed, 34 insertions(+), 15 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 70edf6102541..344fd41b89f1 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -512,17 +512,44 @@ impl Accumulator for StringDistinctCountAccumulator { /// Maximum size of a string that can be inlined in the hash table const SHORT_STRING_LEN: usize = mem::size_of::(); -/// Entry that is stored in the actual hash table +/// Entry that is stored in a `SSOStringHashSet` that represents a string +/// that is either stored inline or in the buffer +/// +/// ```text +/// ┌──────────────────┐ +/// │... │ +/// │TheQuickBrownFox │ +/// ─ ─ ─ ─ ─ ─ ─▶│... │ +/// │ │ │ +/// └──────────────────┘ +/// │ buffer of u8 +/// +/// │ +/// ┌────────────────┬───────────────┬───────────────┐ +/// Storing │ │ starting byte │ length, in │ +/// "TheQuickBrownFox" │ hash value │ offset in │ bytes (not │ +/// (long string) │ │ buffer │ characters) │ +/// └────────────────┴───────────────┴───────────────┘ +/// 8 bytes 8 bytes 4 or 8 +/// +/// +/// ┌───────────────┬─┬─┬─┬─┬─┬─┬─┬─┬───────────────┐ +/// Storing "foobar" │ │ │ │ │ │ │ │ │ │ length, in │ +/// (short string) │ hash value │?│?│f│o│o│b│a│r│ bytes (not │ +/// │ │ │ │ │ │ │ │ │ │ characters) │ +/// └───────────────┴─┴─┴─┴─┴─┴─┴─┴─┴───────────────┘ +/// 8 bytes 8 bytes 4 or 8 +/// ``` #[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] struct SSOStringHeader { - /// hash of the string value (stored to avoid recomputing it when checking) - /// TODO can we simply recreate when needed + /// hash of the string value (stored to avoid recomputing it in hash table + /// check) hash: u64, - /// length of the string, in bytes - len: usize, /// if len =< SHORT_STRING_LEN: the string data inlined /// if len > SHORT_STRING_LEN, the offset of where the data starts offset_or_inline: usize, + /// length of the string, in bytes + len: usize, } impl SSOStringHeader {} @@ -668,11 +695,6 @@ impl SSOStringHashSet { /// Converts this set into a StringArray of the distinct string values fn into_state(self) -> StringArray { - // The map contains entries that have offsets in some arbitrary order - // but the buffer contains the actual strings in the order they were inserted - // so we need to build offsets for the strings in the buffer in order - // then append short strings, if any, and then build the StringArray - // TODO a picture would be nice here let Self { map: _, map_size: _, @@ -682,14 +704,11 @@ impl SSOStringHashSet { hashes_buffer: _, } = self; - // Add any let offsets: ScalarBuffer<_> = offsets.into(); - - // get the values and reset self.buffer let values = buffer.finish(); + let nulls = None; // count distinct ignores nulls so intermediate state never has nulls - let nulls = None; // count distinct ignores nulls - // SAFETY: all the values that went in are coming + // SAFETY: all the values that went in were valid utf8 so are all the values that come out unsafe { StringArray::new_unchecked(OffsetBuffer::new(offsets), values, nulls) } } From c39988a61e25c46316966d9ce50c1201d00420aa Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 21 Jan 2024 12:17:03 -0500 Subject: [PATCH 22/30] Improve comments --- .../src/aggregate/count_distinct.rs | 30 +++++++++++-------- 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 344fd41b89f1..7a0f431c2525 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -515,6 +515,9 @@ const SHORT_STRING_LEN: usize = mem::size_of::(); /// Entry that is stored in a `SSOStringHashSet` that represents a string /// that is either stored inline or in the buffer /// +/// This helps the case where there are many short (less than 8 bytes) strings +/// that are the same (e.g. "MA", "CA", "NY", "TX", etc) +/// /// ```text /// ┌──────────────────┐ /// │... │ @@ -629,21 +632,22 @@ impl SSOStringHashSet { // Check if the value is already present in the set let entry = self.map.get_mut(hash, |header| { - // if hash matches, must also compare the values + // compare value if hashes match if header.len != value.len() { return false; } + // value is stored inline so no need to consult buffer + // (this is the "small string optimization") here inline == header.offset_or_inline }); - // Insert an entry for this value if it is not present + // if no existing entry, make a new one if entry.is_none() { - // Put the small values into buffer and output so it appears - // the output array, but store the actual bytes inline + // Put the small values into buffer and offsets so it appears + // the output array, but store the actual bytes inline for + // comparison self.buffer.append_slice(value); self.offsets.push(self.buffer.len() as i32); - - // store the actual value inline let new_header = SSOStringHeader { hash, len: value.len(), @@ -656,24 +660,26 @@ impl SSOStringHashSet { ); } } - // handle large strings + // value is not a "small" string else { // Check if the value is already present in the set let entry = self.map.get_mut(hash, |header| { - // if hash matches, must also compare the values + // compare value if hashes match if header.len != value.len() { return false; } - // SAFETY: buffer is only appended to, and we correctly inserted values + // Need to compare the bytes in the buffer + // SAFETY: buffer is only appended to, and we correctly inserted values and offsets let existing_value = unsafe { self.buffer.as_slice().get_unchecked(header.range()) }; - value == existing_value }); - // Insert the value if it is not present + // if no existing entry, make a new one if entry.is_none() { - // long strings are stored as a length/offset into the buffer + // Put the small values into buffer and offsets so it + // appears the output array, and store that offset + // so the bytes can be compared if needed let offset = self.buffer.len(); // offset of start fof data self.buffer.append_slice(value); self.offsets.push(self.buffer.len() as i32); From 0e33b12880d4125a0a46d7a9cfc9b5f1e719f8d2 Mon Sep 17 00:00:00 2001 From: jayzhan211 Date: Mon, 22 Jan 2024 09:02:31 +0800 Subject: [PATCH 23/30] add fuzz test Signed-off-by: jayzhan211 --- .../fuzz_cases/distinct_count_string_fuzz.rs | 323 ++++++++++++++++++ datafusion/core/tests/fuzz_cases/mod.rs | 1 + .../src/aggregate/count_distinct.rs | 6 +- 3 files changed, 326 insertions(+), 4 deletions(-) create mode 100644 datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs diff --git a/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs new file mode 100644 index 000000000000..038deddb483f --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs @@ -0,0 +1,323 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Compare DistinctCount for string with naive HashSet and Short String Optimized HashSet + +use std::any::Any; +use std::sync::Arc; + +use arrow::array::ArrayRef; +use arrow::datatypes::DataType; +use arrow::datatypes::Field; +use arrow::record_batch::RecordBatch; +use arrow_array::cast::{as_list_array, as_string_array}; +use arrow_array::{ListArray, StringArray}; +use datafusion::physical_expr::aggregate::utils::down_cast_any_ref; +use datafusion::physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion_common::utils::array_into_list_array; +use datafusion_common::Result; +use datafusion_common::ScalarValue; +use datafusion_expr::Accumulator; +use datafusion_physical_expr::expressions::format_state_name; +use datafusion_physical_expr::PhysicalExpr; + +use rand::rngs::StdRng; +use rand::{Rng, SeedableRng}; +use std::collections::HashSet; + +use datafusion::physical_plan::memory::MemoryExec; +use datafusion::physical_plan::{collect, ExecutionPlan}; +use datafusion::prelude::{SessionConfig, SessionContext}; +use datafusion_physical_expr::expressions::{col, DistinctCount}; +use datafusion_physical_expr::AggregateExpr; + +#[tokio::test(flavor = "multi_thread", worker_threads = 8)] +async fn distinct_count_string_test() { + // max length of generated strings + let max_lens = [4, 8, 16, 32, 64, 128, 256, 512]; + let n = 300; + // number of rows in each batch + let row_lens = [8, 16, 32, 64, 128, 256, 512, 1024]; + for row_len in row_lens { + let mut handles = Vec::new(); + for i in 0..n { + let test_idx = i % max_lens.len(); + let max_len = max_lens[test_idx]; + let job = tokio::spawn(run_distinct_count_test(make_staggered_batches( + max_len, row_len, i as u64, + ))); + handles.push(job); + } + for job in handles { + job.await.unwrap(); + } + } +} + +/// Perform batch and streaming aggregation with same input +/// and verify outputs of `AggregateExec` with pipeline breaking stream `GroupedHashAggregateStream` +/// and non-pipeline breaking stream `BoundedAggregateStream` produces same result. +async fn run_distinct_count_test(input1: Vec) { + let schema = input1[0].schema(); + let session_config = SessionConfig::new().with_batch_size(50); + let ctx = SessionContext::new_with_config(session_config); + + let control_group_source = + Arc::new(MemoryExec::try_new(&[input1.clone()], schema.clone(), None).unwrap()); + + let experimental_source = + Arc::new(MemoryExec::try_new(&[input1.clone()], schema.clone(), None).unwrap()); + + let distinct_count_expr = vec![ + Arc::new(DistinctCount::new( + DataType::Utf8, + col("a", &schema).unwrap(), + "distinct_count1", + )) as Arc, + Arc::new(DistinctCount::new( + DataType::Utf8, + col("b", &schema).unwrap(), + "distinct_count2", + )) as Arc, + ]; + + let expr = vec![(col("c", &schema).unwrap(), "c".to_string())]; + let group_by = PhysicalGroupBy::new_single(expr); + + let mode = AggregateMode::FinalPartitioned; + let filter_expr = vec![None; distinct_count_expr.len()]; + + let distinct_count_control_group = Arc::new( + AggregateExec::try_new( + mode, + group_by.clone(), + distinct_count_expr, + filter_expr.clone(), + experimental_source, + schema.clone(), + ) + .unwrap(), + ) as Arc; + + let distinct_count_expr = vec![ + Arc::new(DistinctCountForTest::new( + DataType::Utf8, + col("a", &schema).unwrap(), + "distinct_count1", + )) as Arc, + Arc::new(DistinctCountForTest::new( + DataType::Utf8, + col("b", &schema).unwrap(), + "distinct_count2", + )) as Arc, + ]; + + let distinct_count_experimental_group = Arc::new( + AggregateExec::try_new( + mode, + group_by, + distinct_count_expr, + filter_expr, + control_group_source, + schema, + ) + .unwrap(), + ) as Arc; + + let task_ctx = ctx.task_ctx(); + let collected_control_group = + collect(distinct_count_experimental_group.clone(), task_ctx.clone()) + .await + .unwrap(); + + let collected_experimental_group = + collect(distinct_count_control_group.clone(), task_ctx.clone()) + .await + .unwrap(); + + assert_eq!(collected_control_group, collected_experimental_group); +} + +fn make_staggered_batches( + max_len: usize, + row_len: usize, + random_seed: u64, +) -> Vec { + // use a random number generator to pick a random sized output + let mut rng = StdRng::seed_from_u64(random_seed); + + fn gen_data(rng: &mut StdRng, row_len: usize, max_len: usize) -> ListArray { + let data: Vec = (0..row_len) + .map(|_| { + let len = rng.gen_range(0..max_len) + 1; + "a".repeat(len) + }) + .collect(); + array_into_list_array(Arc::new(StringArray::from(data))) + } + + let inputa = gen_data(&mut rng, row_len, max_len); + let inputb = gen_data(&mut rng, row_len, max_len); + let input_groupby = gen_data(&mut rng, row_len, max_len); + let batch = RecordBatch::try_from_iter(vec![ + ("a", Arc::new(inputa) as ArrayRef), + ("b", Arc::new(inputb) as ArrayRef), + ("c", Arc::new(input_groupby) as ArrayRef), // column for group by + ]) + .unwrap(); + + vec![batch] +} + +/// Expression for a COUNT(DISTINCT) aggregation. +#[derive(Debug)] +pub struct DistinctCountForTest { + /// Column name + name: String, + /// The DataType used to hold the state for each input + state_data_type: DataType, + /// The input arguments + expr: Arc, +} + +impl DistinctCountForTest { + /// Create a new COUNT(DISTINCT) aggregate function. + pub fn new( + input_data_type: DataType, + expr: Arc, + name: impl Into, + ) -> Self { + Self { + name: name.into(), + state_data_type: input_data_type, + expr, + } + } +} +impl AggregateExpr for DistinctCountForTest { + /// Return a reference to Any that can be used for downcasting + fn as_any(&self) -> &dyn Any { + self + } + + fn field(&self) -> Result { + Ok(Field::new(&self.name, DataType::Int64, true)) + } + + fn state_fields(&self) -> Result> { + Ok(vec![Field::new_list( + format_state_name(&self.name, "count distinct"), + Field::new("item", self.state_data_type.clone(), true), + false, + )]) + } + + fn expressions(&self) -> Vec> { + vec![self.expr.clone()] + } + + fn create_accumulator(&self) -> Result> { + match &self.state_data_type { + DataType::Utf8 => Ok(Box::new(StringDistinctCountAccumulatorForTest::new())), + _ => panic!( + "Unsupported type for COUNT(DISTINCT): {:?}", + self.state_data_type + ), + } + } + + fn name(&self) -> &str { + &self.name + } +} + +impl PartialEq for DistinctCountForTest { + fn eq(&self, other: &dyn Any) -> bool { + down_cast_any_ref(other) + .downcast_ref::() + .map(|x| { + self.name == x.name + && self.state_data_type == x.state_data_type + && self.expr.eq(&x.expr) + }) + .unwrap_or(false) + } +} + +#[derive(Debug)] +struct StringDistinctCountAccumulatorForTest(HashSet); +impl StringDistinctCountAccumulatorForTest { + fn new() -> Self { + Self(HashSet::new()) + } +} + +impl Accumulator for StringDistinctCountAccumulatorForTest { + fn state(&self) -> Result> { + let arr = + StringArray::from(self.0.iter().map(|s| s.as_str()).collect::>()); + let list = Arc::new(array_into_list_array(Arc::new(arr))); + Ok(vec![ScalarValue::List(list)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { + if values.is_empty() { + return Ok(()); + } + + let array = as_string_array(&values[0]); + for v in array.iter().flatten() { + self.0.insert(v.to_string()); + } + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!( + states.len(), + 1, + "count_distinct states must be single array" + ); + + let arr = as_list_array(&states[0]); + arr.iter().try_for_each(|maybe_list| { + if let Some(list) = maybe_list { + let array = as_string_array(&list); + for v in array.iter().flatten() { + self.0.insert(v.to_string()); + } + }; + Ok(()) + }) + } + + fn evaluate(&self) -> Result { + Ok(ScalarValue::Int64(Some(self.0.len() as i64))) + } + + fn size(&self) -> usize { + // Size of accumulator + // + SSOStringHashSet size + std::mem::size_of_val(self) + self.0.capacity() * std::mem::size_of::() + } +} diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index 83ec928ae229..69241571b4af 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -16,6 +16,7 @@ // under the License. mod aggregate_fuzz; +mod distinct_count_string_fuzz; mod join_fuzz; mod merge_fuzz; mod sort_fuzz; diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 344fd41b89f1..677a14b8b77c 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -67,10 +67,10 @@ impl DistinctCount { pub fn new( input_data_type: DataType, expr: Arc, - name: String, + name: impl Into, ) -> Self { Self { - name, + name: name.into(), state_data_type: input_data_type, expr, } @@ -552,8 +552,6 @@ struct SSOStringHeader { len: usize, } -impl SSOStringHeader {} - impl SSOStringHeader { /// returns self.offset..self.offset + self.len fn range(&self) -> Range { From b3bcc68dfc93f4ee31da1a338a5447b98dccdab4 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 22 Jan 2024 07:07:25 -0500 Subject: [PATCH 24/30] Add support for LargeStringArray --- .../src/aggregate/count_distinct.rs | 124 +++++++++++------- .../sqllogictest/test_files/aggregate.slt | 25 ++++ 2 files changed, 102 insertions(+), 47 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 7a0f431c2525..3ae7eb3aafc4 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -35,7 +35,7 @@ use arrow_array::types::{ TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; -use arrow_array::{PrimitiveArray, StringArray}; +use arrow_array::{GenericStringArray, OffsetSizeTrait, PrimitiveArray}; use arrow_buffer::{BufferBuilder, OffsetBuffer, ScalarBuffer}; use datafusion_common::cast::{as_list_array, as_primitive_array}; @@ -158,7 +158,8 @@ impl AggregateExpr for DistinctCount { Float32 => float_distinct_count_accumulator!(Float32Type), Float64 => float_distinct_count_accumulator!(Float64Type), - Utf8 => Ok(Box::new(StringDistinctCountAccumulator::new())), + Utf8 => Ok(Box::new(StringDistinctCountAccumulator::::new())), + LargeUtf8 => Ok(Box::new(StringDistinctCountAccumulator::::new())), _ => Ok(Box::new(DistinctCountAccumulator { values: HashSet::default(), @@ -447,23 +448,24 @@ where } #[derive(Debug)] -struct StringDistinctCountAccumulator(Mutex); -impl StringDistinctCountAccumulator { +struct StringDistinctCountAccumulator(Mutex>); +impl StringDistinctCountAccumulator { fn new() -> Self { - Self(Mutex::new(SSOStringHashSet::new())) + Self(Mutex::new(SSOStringHashSet::::new())) } } -impl Accumulator for StringDistinctCountAccumulator { +impl Accumulator for StringDistinctCountAccumulator { fn state(&self) -> Result> { // TODO this should not need a lock/clone (should make // `Accumulator::state` take a mutable reference) + // see https://github.com/apache/arrow-datafusion/pull/8925 let mut lk = self.0.lock().unwrap(); - let set: &mut SSOStringHashSet = &mut lk; + let set: &mut SSOStringHashSet<_> = &mut lk; // take the state out of the string set and replace with default let set = std::mem::take(set); let arr = set.into_state(); - let list = Arc::new(array_into_list_array(Arc::new(arr))); + let list = Arc::new(array_into_list_array(arr)); Ok(vec![ScalarValue::List(list)]) } @@ -564,37 +566,39 @@ impl SSOStringHeader { } } -// Short String Optimized HashSet for String -// Equivalent to HashSet but with better memory usage -struct SSOStringHashSet { - /// Store entries for each distinct string +/// HashSet optimized for storing `String` and `LargeString` values +/// and producing the final set as a GenericStringArray with minimal copies. +/// +/// Equivalent to `HashSet` but with better performance for arrow data. +struct SSOStringHashSet { + /// Underlying hash set for each distinct string map: hashbrown::raw::RawTable, - /// Total size of the map in bytes (TODO) + /// Total size of the map in bytes map_size: usize, - /// Buffer containing all string values + /// In progress arrow `Buffer` containing all string values buffer: BufferBuilder, - /// offsets into buffer of the distinct values. These are the same offsets - /// as are used for a GenericStringArray - offsets: Vec, - /// The random state used to generate hashes + /// Offsets into `buffer` for each distinct string value. These offsets + /// as used directly to create the final `GenericStringArray` + offsets: Vec, + /// random state used to generate hashes random_state: RandomState, - // buffer to be reused to store hashes + /// buffer that stores hash values (reused across batches to save allocations) hashes_buffer: Vec, } -impl Default for SSOStringHashSet { +impl Default for SSOStringHashSet { fn default() -> Self { Self::new() } } -impl SSOStringHashSet { +impl SSOStringHashSet { fn new() -> Self { Self { map: hashbrown::raw::RawTable::new(), map_size: 0, buffer: BufferBuilder::new(0), - offsets: vec![0], // first offset is always 0 + offsets: vec![O::default()], // first offset is always 0 random_state: RandomState::new(), hashes_buffer: vec![], } @@ -610,12 +614,10 @@ impl SSOStringHashSet { // returns errors for unsupported types .unwrap(); - // TODO make this generic (to support large strings) - let values = values.as_string::(); - // step 2: insert each string into the set, if not already present + let values = values.as_string::(); - // Assert for unsafe values call + // Ensure lengths are equivalent (to guard unsafe values calls below) assert_eq!(values.len(), batch_hashes.len()); for (value, &hash) in values.iter().zip(batch_hashes.iter()) { @@ -627,27 +629,28 @@ impl SSOStringHashSet { // from here on only use bytes (not str/chars) for value let value = value.as_bytes(); + // value is a "small" string if value.len() <= SHORT_STRING_LEN { let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x as usize); - // Check if the value is already present in the set + // is value is already present in the set? let entry = self.map.get_mut(hash, |header| { // compare value if hashes match if header.len != value.len() { return false; } // value is stored inline so no need to consult buffer - // (this is the "small string optimization") here + // (this is the "small string optimization") inline == header.offset_or_inline }); // if no existing entry, make a new one if entry.is_none() { - // Put the small values into buffer and offsets so it appears + // Put the small values into buffer and offsets so it appears // the output array, but store the actual bytes inline for // comparison self.buffer.append_slice(value); - self.offsets.push(self.buffer.len() as i32); + self.offsets.push(O::from_usize(self.buffer.len()).unwrap()); let new_header = SSOStringHeader { hash, len: value.len(), @@ -682,7 +685,7 @@ impl SSOStringHashSet { // so the bytes can be compared if needed let offset = self.buffer.len(); // offset of start fof data self.buffer.append_slice(value); - self.offsets.push(self.buffer.len() as i32); + self.offsets.push(O::from_usize(self.buffer.len()).unwrap()); let new_header = SSOStringHeader { hash, @@ -699,8 +702,9 @@ impl SSOStringHashSet { } } - /// Converts this set into a StringArray of the distinct string values - fn into_state(self) -> StringArray { + /// Converts this set into a `StringArray` or `LargeStringArray` with each + /// distinct string value without any copies + fn into_state(self) -> ArrayRef { let Self { map: _, map_size: _, @@ -710,24 +714,32 @@ impl SSOStringHashSet { hashes_buffer: _, } = self; - let offsets: ScalarBuffer<_> = offsets.into(); + let offsets: ScalarBuffer = offsets.into(); let values = buffer.finish(); let nulls = None; // count distinct ignores nulls so intermediate state never has nulls // SAFETY: all the values that went in were valid utf8 so are all the values that come out - unsafe { StringArray::new_unchecked(OffsetBuffer::new(offsets), values, nulls) } + let array = unsafe { + GenericStringArray::new_unchecked(OffsetBuffer::new(offsets), values, nulls) + }; + Arc::new(array) } fn len(&self) -> usize { self.map.len() } + /// Return the total size, in bytes, of memory used to store the data in + /// this set, not including `self` fn size(&self) -> usize { - self.map_size + self.buffer.len() + self.map_size + + self.buffer.capacity() * std::mem::size_of::() + + self.offsets.capacity() * std::mem::size_of::() + + self.hashes_buffer.capacity() * std::mem::size_of::() } } -impl Debug for SSOStringHashSet { +impl Debug for SSOStringHashSet { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("SSOStringHashSet") .field("map", &"") @@ -749,7 +761,7 @@ mod tests { Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; - use arrow_array::Decimal256Array; + use arrow_array::{Decimal256Array, StringArray}; use arrow_buffer::i256; use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; @@ -1149,16 +1161,23 @@ mod tests { #[test] fn string_set_empty() { for values in [StringArray::new_null(0), StringArray::new_null(11)] { - let mut set = SSOStringHashSet::new(); - set.insert(Arc::new(values)); + let mut set = SSOStringHashSet::::new(); + set.insert(Arc::new(values.clone())); assert_set(set, &[]); } } #[test] - fn string_set_basic() { + fn string_set_basic_i32() { + test_string_set_basic::(); + } + #[test] + fn string_set_basic_i64() { + test_string_set_basic::(); + } + fn test_string_set_basic() { // basic test for mixed small and large string values - let values = StringArray::from(vec![ + let values = GenericStringArray::::from(vec![ Some("a"), Some("b"), Some("CXCCCCCCCC"), // 10 bytes @@ -1179,7 +1198,7 @@ mod tests { Some("CXCCCCCCCC"), ]); - let mut set = SSOStringHashSet::new(); + let mut set = SSOStringHashSet::::new(); set.insert(Arc::new(values)); assert_set( set, @@ -1196,9 +1215,16 @@ mod tests { } #[test] - fn string_set_non_utf8() { + fn string_set_non_utf8_32() { + test_string_set_non_utf8::(); + } + #[test] + fn string_set_non_utf8_64() { + test_string_set_non_utf8::(); + } + fn test_string_set_non_utf8() { // basic test for mixed small and large string values - let values = StringArray::from(vec![ + let values = GenericStringArray::::from(vec![ Some("a"), Some("✨🔥"), Some("🔥"), @@ -1208,7 +1234,7 @@ mod tests { Some("✨🔥"), ]); - let mut set = SSOStringHashSet::new(); + let mut set = SSOStringHashSet::::new(); set.insert(Arc::new(values)); assert_set( set, @@ -1223,8 +1249,12 @@ mod tests { } // asserts that the set contains the expected strings - fn assert_set(set: SSOStringHashSet, expected: &[Option<&str>]) { + fn assert_set( + set: SSOStringHashSet, + expected: &[Option<&str>], + ) { let strings = set.into_state(); + let strings = strings.as_string::(); let mut state = strings.into_iter().collect::>(); state.sort(); assert_eq!(state, expected); diff --git a/datafusion/sqllogictest/test_files/aggregate.slt b/datafusion/sqllogictest/test_files/aggregate.slt index 62253603761a..4cfa2f3483ed 100644 --- a/datafusion/sqllogictest/test_files/aggregate.slt +++ b/datafusion/sqllogictest/test_files/aggregate.slt @@ -3262,6 +3262,7 @@ select count(*) from (select count(*) a, count(*) b from (select 1)); 1 # Distinct Count for string +# (test for the specialized implementation of distinct count for strings) # UTF8 string matters for string to &[u8] conversion, add it to prevent regression statement ok @@ -3288,6 +3289,30 @@ select count(distinct column1), count(distinct column2), count(distinct column3) 1 1 2 2 1 1 3 3 + +# test with long strings as well +statement ok +create table distinct_count_long_string_table as +SELECT column1, + arrow_cast(column2, 'LargeUtf8') as column2, + arrow_cast(column3, 'LargeUtf8') as column3, + arrow_cast(column4, 'LargeUtf8') as column4 +FROM distinct_count_string_table; + +# run through update_batch +query IIII +select count(distinct column1), count(distinct column2), count(distinct column3), count(distinct column4) from distinct_count_long_string_table; +---- +3 3 6 6 + +# run through merge_batch +query IIII rowsort +select count(distinct column1), count(distinct column2), count(distinct column3), count(distinct column4) from distinct_count_long_string_table group by column1; +---- +1 1 1 1 +1 1 2 2 +1 1 3 3 + statement ok drop table distinct_count_string_table; From a80b39cd25dbdaaba2adca974215aae1b03d1303 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 22 Jan 2024 09:28:45 -0500 Subject: [PATCH 25/30] refine fuzz test --- .../fuzz_cases/distinct_count_string_fuzz.rs | 420 +++++++----------- 1 file changed, 154 insertions(+), 266 deletions(-) diff --git a/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs index 038deddb483f..343a1756476f 100644 --- a/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/distinct_count_string_fuzz.rs @@ -17,307 +17,195 @@ //! Compare DistinctCount for string with naive HashSet and Short String Optimized HashSet -use std::any::Any; use std::sync::Arc; use arrow::array::ArrayRef; -use arrow::datatypes::DataType; -use arrow::datatypes::Field; use arrow::record_batch::RecordBatch; -use arrow_array::cast::{as_list_array, as_string_array}; -use arrow_array::{ListArray, StringArray}; -use datafusion::physical_expr::aggregate::utils::down_cast_any_ref; -use datafusion::physical_plan::aggregates::{ - AggregateExec, AggregateMode, PhysicalGroupBy, -}; -use datafusion_common::utils::array_into_list_array; -use datafusion_common::Result; -use datafusion_common::ScalarValue; -use datafusion_expr::Accumulator; -use datafusion_physical_expr::expressions::format_state_name; -use datafusion_physical_expr::PhysicalExpr; +use arrow_array::{Array, GenericStringArray, OffsetSizeTrait, UInt32Array}; +use arrow_array::cast::AsArray; +use datafusion::datasource::MemTable; use rand::rngs::StdRng; -use rand::{Rng, SeedableRng}; +use rand::{thread_rng, Rng, SeedableRng}; use std::collections::HashSet; +use tokio::task::JoinSet; -use datafusion::physical_plan::memory::MemoryExec; -use datafusion::physical_plan::{collect, ExecutionPlan}; use datafusion::prelude::{SessionConfig, SessionContext}; -use datafusion_physical_expr::expressions::{col, DistinctCount}; -use datafusion_physical_expr::AggregateExpr; +use test_utils::stagger_batch; -#[tokio::test(flavor = "multi_thread", worker_threads = 8)] +#[tokio::test(flavor = "multi_thread")] async fn distinct_count_string_test() { // max length of generated strings - let max_lens = [4, 8, 16, 32, 64, 128, 256, 512]; - let n = 300; - // number of rows in each batch - let row_lens = [8, 16, 32, 64, 128, 256, 512, 1024]; - for row_len in row_lens { - let mut handles = Vec::new(); - for i in 0..n { - let test_idx = i % max_lens.len(); - let max_len = max_lens[test_idx]; - let job = tokio::spawn(run_distinct_count_test(make_staggered_batches( - max_len, row_len, i as u64, - ))); - handles.push(job); - } - for job in handles { - job.await.unwrap(); + let mut join_set = JoinSet::new(); + let mut rng = thread_rng(); + for null_pct in [0.0, 0.01, 0.1, 0.5] { + for _ in 0..100 { + let max_len = rng.gen_range(1..50); + let num_strings = rng.gen_range(1..100); + let num_distinct_strings = if num_strings > 1 { + rng.gen_range(1..num_strings) + } else { + num_strings + }; + let generator = BatchGenerator { + max_len, + num_strings, + num_distinct_strings, + null_pct, + rng: StdRng::from_seed(rng.gen()), + }; + join_set.spawn(async move { run_distinct_count_test(generator).await }); } } + while let Some(join_handle) = join_set.join_next().await { + // propagate errors + join_handle.unwrap(); + } } -/// Perform batch and streaming aggregation with same input -/// and verify outputs of `AggregateExec` with pipeline breaking stream `GroupedHashAggregateStream` -/// and non-pipeline breaking stream `BoundedAggregateStream` produces same result. -async fn run_distinct_count_test(input1: Vec) { - let schema = input1[0].schema(); +/// Run COUNT DISTINCT using SQL and compare the result to computing the +/// distinct count using HashSet +async fn run_distinct_count_test(mut generator: BatchGenerator) { + let input = generator.make_input_batches(); + + let schema = input[0].schema(); let session_config = SessionConfig::new().with_batch_size(50); let ctx = SessionContext::new_with_config(session_config); - let control_group_source = - Arc::new(MemoryExec::try_new(&[input1.clone()], schema.clone(), None).unwrap()); - - let experimental_source = - Arc::new(MemoryExec::try_new(&[input1.clone()], schema.clone(), None).unwrap()); - - let distinct_count_expr = vec![ - Arc::new(DistinctCount::new( - DataType::Utf8, - col("a", &schema).unwrap(), - "distinct_count1", - )) as Arc, - Arc::new(DistinctCount::new( - DataType::Utf8, - col("b", &schema).unwrap(), - "distinct_count2", - )) as Arc, - ]; - - let expr = vec![(col("c", &schema).unwrap(), "c".to_string())]; - let group_by = PhysicalGroupBy::new_single(expr); - - let mode = AggregateMode::FinalPartitioned; - let filter_expr = vec![None; distinct_count_expr.len()]; - - let distinct_count_control_group = Arc::new( - AggregateExec::try_new( - mode, - group_by.clone(), - distinct_count_expr, - filter_expr.clone(), - experimental_source, - schema.clone(), - ) - .unwrap(), - ) as Arc; - - let distinct_count_expr = vec![ - Arc::new(DistinctCountForTest::new( - DataType::Utf8, - col("a", &schema).unwrap(), - "distinct_count1", - )) as Arc, - Arc::new(DistinctCountForTest::new( - DataType::Utf8, - col("b", &schema).unwrap(), - "distinct_count2", - )) as Arc, + // split input into two partitions + let partition_len = input.len() / 2; + let partitions = vec![ + input[0..partition_len].to_vec(), + input[partition_len..].to_vec(), ]; - let distinct_count_experimental_group = Arc::new( - AggregateExec::try_new( - mode, - group_by, - distinct_count_expr, - filter_expr, - control_group_source, - schema, - ) - .unwrap(), - ) as Arc; - - let task_ctx = ctx.task_ctx(); - let collected_control_group = - collect(distinct_count_experimental_group.clone(), task_ctx.clone()) - .await - .unwrap(); - - let collected_experimental_group = - collect(distinct_count_control_group.clone(), task_ctx.clone()) - .await - .unwrap(); - - assert_eq!(collected_control_group, collected_experimental_group); + let provider = MemTable::try_new(schema, partitions).unwrap(); + ctx.register_table("t", Arc::new(provider)).unwrap(); + // input has two columns, a and b. The result is the number of distinct + // values in each column. + // + // Note, we need at least two count distinct aggregates to trigger the + // count distinct aggregate. Otherwise, the optimizer will rewrite the + // `COUNT(DISTINCT a)` to `COUNT(*) from (SELECT DISTINCT a FROM t)` + let results = ctx + .sql("SELECT COUNT(DISTINCT a), COUNT(DISTINCT b) FROM t") + .await + .unwrap() + .collect() + .await + .unwrap(); + + // get all the strings from the first column of the result (distinct a) + let expected_a = extract_distinct_strings::(&input, 0).len(); + let result_a = extract_i64(&results, 0); + assert_eq!(expected_a, result_a); + + // get all the strings from the second column of the result (distinct b( + let expected_b = extract_distinct_strings::(&input, 1).len(); + let result_b = extract_i64(&results, 1); + assert_eq!(expected_b, result_b); } -fn make_staggered_batches( - max_len: usize, - row_len: usize, - random_seed: u64, -) -> Vec { - // use a random number generator to pick a random sized output - let mut rng = StdRng::seed_from_u64(random_seed); - - fn gen_data(rng: &mut StdRng, row_len: usize, max_len: usize) -> ListArray { - let data: Vec = (0..row_len) - .map(|_| { - let len = rng.gen_range(0..max_len) + 1; - "a".repeat(len) - }) - .collect(); - array_into_list_array(Arc::new(StringArray::from(data))) - } - - let inputa = gen_data(&mut rng, row_len, max_len); - let inputb = gen_data(&mut rng, row_len, max_len); - let input_groupby = gen_data(&mut rng, row_len, max_len); - let batch = RecordBatch::try_from_iter(vec![ - ("a", Arc::new(inputa) as ArrayRef), - ("b", Arc::new(inputb) as ArrayRef), - ("c", Arc::new(input_groupby) as ArrayRef), // column for group by - ]) - .unwrap(); - - vec![batch] +/// Return all (non null) distinct strings from column col_idx +fn extract_distinct_strings( + results: &[RecordBatch], + col_idx: usize, +) -> Vec { + results + .iter() + .flat_map(|batch| { + let array = batch.column(col_idx).as_string::(); + // remove nulls via 'flatten' + array.iter().flatten().map(|s| s.to_string()) + }) + .collect::>() + .into_iter() + .collect() } -/// Expression for a COUNT(DISTINCT) aggregation. -#[derive(Debug)] -pub struct DistinctCountForTest { - /// Column name - name: String, - /// The DataType used to hold the state for each input - state_data_type: DataType, - /// The input arguments - expr: Arc, +// extract the value from the Int64 column in col_idx in batch and return +// it as a usize +fn extract_i64(results: &[RecordBatch], col_idx: usize) -> usize { + assert_eq!(results.len(), 1); + let array = results[0] + .column(col_idx) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(array.len(), 1); + assert!(!array.is_null(0)); + array.value(0).try_into().unwrap() } -impl DistinctCountForTest { - /// Create a new COUNT(DISTINCT) aggregate function. - pub fn new( - input_data_type: DataType, - expr: Arc, - name: impl Into, - ) -> Self { - Self { - name: name.into(), - state_data_type: input_data_type, - expr, - } - } +struct BatchGenerator { + //// The maximum length of the strings + max_len: usize, + /// the total number of strings in the output + num_strings: usize, + /// The number of distinct strings in the columns + num_distinct_strings: usize, + /// The percentage of nulls in the columns + null_pct: f64, + /// Random number generator + rng: StdRng, } -impl AggregateExpr for DistinctCountForTest { - /// Return a reference to Any that can be used for downcasting - fn as_any(&self) -> &dyn Any { - self - } - fn field(&self) -> Result { - Ok(Field::new(&self.name, DataType::Int64, true)) - } - - fn state_fields(&self) -> Result> { - Ok(vec![Field::new_list( - format_state_name(&self.name, "count distinct"), - Field::new("item", self.state_data_type.clone(), true), - false, - )]) - } - - fn expressions(&self) -> Vec> { - vec![self.expr.clone()] - } - - fn create_accumulator(&self) -> Result> { - match &self.state_data_type { - DataType::Utf8 => Ok(Box::new(StringDistinctCountAccumulatorForTest::new())), - _ => panic!( - "Unsupported type for COUNT(DISTINCT): {:?}", - self.state_data_type - ), - } - } - - fn name(&self) -> &str { - &self.name - } -} +impl BatchGenerator { + /// Make batches of random strings with a random length columns "a" and "b": + /// + /// * "a" is a StringArray + /// * "b" is a LargeStringArray + fn make_input_batches(&mut self) -> Vec { + // use a random number generator to pick a random sized output + + let batch = RecordBatch::try_from_iter(vec![ + ("a", self.gen_data::()), + ("b", self.gen_data::()), + ]) + .unwrap(); + + stagger_batch(batch) + } + + /// Creates a StringArray or LargeStringArray with random strings according + /// to the parameters of the BatchGenerator + fn gen_data(&mut self) -> ArrayRef { + // table of strings from which to draw + let distinct_strings: GenericStringArray = (0..self.num_distinct_strings) + .map(|_| Some(random_string(&mut self.rng, self.max_len))) + .collect(); -impl PartialEq for DistinctCountForTest { - fn eq(&self, other: &dyn Any) -> bool { - down_cast_any_ref(other) - .downcast_ref::() - .map(|x| { - self.name == x.name - && self.state_data_type == x.state_data_type - && self.expr.eq(&x.expr) + // pick num_strings randomly from the distinct string table + let indicies: UInt32Array = (0..self.num_strings) + .map(|_| { + if self.rng.gen::() < self.null_pct { + None + } else if self.num_distinct_strings > 1 { + let range = 1..(self.num_distinct_strings as u32); + Some(self.rng.gen_range(range)) + } else { + Some(0) + } }) - .unwrap_or(false) - } -} + .collect(); -#[derive(Debug)] -struct StringDistinctCountAccumulatorForTest(HashSet); -impl StringDistinctCountAccumulatorForTest { - fn new() -> Self { - Self(HashSet::new()) + let options = None; + arrow::compute::take(&distinct_strings, &indicies, options).unwrap() } } -impl Accumulator for StringDistinctCountAccumulatorForTest { - fn state(&self) -> Result> { - let arr = - StringArray::from(self.0.iter().map(|s| s.as_str()).collect::>()); - let list = Arc::new(array_into_list_array(Arc::new(arr))); - Ok(vec![ScalarValue::List(list)]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - let array = as_string_array(&values[0]); - for v in array.iter().flatten() { - self.0.insert(v.to_string()); - } - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); +/// Return a string of random characters of length 1..=max_len +fn random_string(rng: &mut StdRng, max_len: usize) -> String { + // pick characters at random (not just ascii) + match max_len { + 0 => "".to_string(), + 1 => String::from(rng.gen::()), + _ => { + let len = rng.gen_range(1..=max_len); + rng.sample_iter::(rand::distributions::Standard) + .take(len) + .map(char::from) + .collect::() } - assert_eq!( - states.len(), - 1, - "count_distinct states must be single array" - ); - - let arr = as_list_array(&states[0]); - arr.iter().try_for_each(|maybe_list| { - if let Some(list) = maybe_list { - let array = as_string_array(&list); - for v in array.iter().flatten() { - self.0.insert(v.to_string()); - } - }; - Ok(()) - }) - } - - fn evaluate(&self) -> Result { - Ok(ScalarValue::Int64(Some(self.0.len() as i64))) - } - - fn size(&self) -> usize { - // Size of accumulator - // + SSOStringHashSet size - std::mem::size_of_val(self) + self.0.capacity() * std::mem::size_of::() } } From 3e9289aefb5492091a6c96913f283bc0259b14c3 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 22 Jan 2024 09:43:51 -0500 Subject: [PATCH 26/30] Add tests for size accounting --- .../src/aggregate/count_distinct.rs | 51 ++++++++++++++++++- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct.rs index 89ab4474f1ae..6044421451b5 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct.rs @@ -1258,7 +1258,54 @@ mod tests { assert_eq!(state, expected); } - // TODO fuzz testing - // inserting strings into the set does not increase reported memoyr + #[test] + fn test_string_set_memory_usage() { + let strings1 = GenericStringArray::::from(vec![ + Some("a"), + Some("b"), + Some("CXCCCCCCCC"), // 10 bytes + Some("AAAAAAAA"), // 8 bytes + Some("BBBBBQBBB"), // 9 bytes + ]); + let total_strings1_len = strings1 + .iter() + .map(|s| s.map(|s| s.len()).unwrap_or(0)) + .sum::(); + let values1: ArrayRef = Arc::new(GenericStringArray::::from(strings1)); + + // Much larger strings in strings2 + let strings2 = GenericStringArray::::from(vec![ + "FOO".repeat(1000), + "BAR".repeat(2000), + "BAZ".repeat(3000), + ]); + let total_strings2_len = strings2 + .iter() + .map(|s| s.map(|s| s.len()).unwrap_or(0)) + .sum::(); + let values2: ArrayRef = Arc::new(GenericStringArray::::from(strings2)); + + let mut set = SSOStringHashSet::::new(); + let size_empty = set.size(); + + set.insert(values1.clone()); + let size_after_values1 = set.size(); + assert!(size_empty < size_after_values1); + assert!( + size_after_values1 > total_strings1_len, + "expect {size_after_values1} to be more than {total_strings1_len}" + ); + assert!(size_after_values1 < total_strings1_len + total_strings2_len); + + // inserting the same strings should not affect the size + set.insert(values1.clone()); + assert_eq!(set.size(), size_after_values1); + + // inserting the large strings should increase the reported size + set.insert(values2); + let size_after_values2 = set.size(); + assert!(size_after_values2 > size_after_values1); + assert!(size_after_values2 > total_strings1_len + total_strings2_len); + } } From 7b9d067e8beeda0ff7fb4ab8f168bcb05bd1811d Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Mon, 22 Jan 2024 09:51:35 -0500 Subject: [PATCH 27/30] Split into new module --- .../mod.rs} | 468 +---------------- .../src/aggregate/count_distinct/strings.rs | 496 ++++++++++++++++++ 2 files changed, 502 insertions(+), 462 deletions(-) rename datafusion/physical-expr/src/aggregate/{count_distinct.rs => count_distinct/mod.rs} (61%) create mode 100644 datafusion/physical-expr/src/aggregate/count_distinct/strings.rs diff --git a/datafusion/physical-expr/src/aggregate/count_distinct.rs b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs similarity index 61% rename from datafusion/physical-expr/src/aggregate/count_distinct.rs rename to datafusion/physical-expr/src/aggregate/count_distinct/mod.rs index 6044421451b5..38d5eea4fab9 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct/mod.rs @@ -15,19 +15,18 @@ // specific language governing permissions and limitations // under the License. +mod strings; + use std::any::Any; use std::cmp::Eq; use std::collections::HashSet; use std::fmt::Debug; use std::hash::Hash; -use std::mem; -use std::ops::Range; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use ahash::RandomState; use arrow::array::{Array, ArrayRef}; use arrow::datatypes::{DataType, Field, TimeUnit}; -use arrow_array::cast::AsArray; use arrow_array::types::{ ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, @@ -35,16 +34,14 @@ use arrow_array::types::{ TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; -use arrow_array::{GenericStringArray, OffsetSizeTrait, PrimitiveArray}; -use arrow_buffer::{BufferBuilder, OffsetBuffer, ScalarBuffer}; +use arrow_array::PrimitiveArray; use datafusion_common::cast::{as_list_array, as_primitive_array}; -use datafusion_common::hash_utils::create_hashes; use datafusion_common::utils::array_into_list_array; use datafusion_common::{Result, ScalarValue}; -use datafusion_execution::memory_pool::proxy::RawTableAllocExt; use datafusion_expr::Accumulator; +use crate::aggregate::count_distinct::strings::StringDistinctCountAccumulator; use crate::aggregate::utils::{down_cast_any_ref, Hashable}; use crate::expressions::format_state_name; use crate::{AggregateExpr, PhysicalExpr}; @@ -447,307 +444,6 @@ where } } -#[derive(Debug)] -struct StringDistinctCountAccumulator(Mutex>); -impl StringDistinctCountAccumulator { - fn new() -> Self { - Self(Mutex::new(SSOStringHashSet::::new())) - } -} - -impl Accumulator for StringDistinctCountAccumulator { - fn state(&self) -> Result> { - // TODO this should not need a lock/clone (should make - // `Accumulator::state` take a mutable reference) - // see https://github.com/apache/arrow-datafusion/pull/8925 - let mut lk = self.0.lock().unwrap(); - let set: &mut SSOStringHashSet<_> = &mut lk; - // take the state out of the string set and replace with default - let set = std::mem::take(set); - let arr = set.into_state(); - let list = Arc::new(array_into_list_array(arr)); - Ok(vec![ScalarValue::List(list)]) - } - - fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> { - if values.is_empty() { - return Ok(()); - } - - self.0.lock().unwrap().insert(values[0].clone()); - - Ok(()) - } - - fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> { - if states.is_empty() { - return Ok(()); - } - assert_eq!( - states.len(), - 1, - "count_distinct states must be single array" - ); - - let arr = as_list_array(&states[0])?; - arr.iter().try_for_each(|maybe_list| { - if let Some(list) = maybe_list { - self.0.lock().unwrap().insert(list); - }; - Ok(()) - }) - } - - fn evaluate(&self) -> Result { - Ok(ScalarValue::Int64( - Some(self.0.lock().unwrap().len() as i64), - )) - } - - fn size(&self) -> usize { - // Size of accumulator - // + SSOStringHashSet size - std::mem::size_of_val(self) + self.0.lock().unwrap().size() - } -} - -/// Maximum size of a string that can be inlined in the hash table -const SHORT_STRING_LEN: usize = mem::size_of::(); - -/// Entry that is stored in a `SSOStringHashSet` that represents a string -/// that is either stored inline or in the buffer -/// -/// This helps the case where there are many short (less than 8 bytes) strings -/// that are the same (e.g. "MA", "CA", "NY", "TX", etc) -/// -/// ```text -/// ┌──────────────────┐ -/// │... │ -/// │TheQuickBrownFox │ -/// ─ ─ ─ ─ ─ ─ ─▶│... │ -/// │ │ │ -/// └──────────────────┘ -/// │ buffer of u8 -/// -/// │ -/// ┌────────────────┬───────────────┬───────────────┐ -/// Storing │ │ starting byte │ length, in │ -/// "TheQuickBrownFox" │ hash value │ offset in │ bytes (not │ -/// (long string) │ │ buffer │ characters) │ -/// └────────────────┴───────────────┴───────────────┘ -/// 8 bytes 8 bytes 4 or 8 -/// -/// -/// ┌───────────────┬─┬─┬─┬─┬─┬─┬─┬─┬───────────────┐ -/// Storing "foobar" │ │ │ │ │ │ │ │ │ │ length, in │ -/// (short string) │ hash value │?│?│f│o│o│b│a│r│ bytes (not │ -/// │ │ │ │ │ │ │ │ │ │ characters) │ -/// └───────────────┴─┴─┴─┴─┴─┴─┴─┴─┴───────────────┘ -/// 8 bytes 8 bytes 4 or 8 -/// ``` -#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] -struct SSOStringHeader { - /// hash of the string value (stored to avoid recomputing it in hash table - /// check) - hash: u64, - /// if len =< SHORT_STRING_LEN: the string data inlined - /// if len > SHORT_STRING_LEN, the offset of where the data starts - offset_or_inline: usize, - /// length of the string, in bytes - len: usize, -} - -impl SSOStringHeader { - /// returns self.offset..self.offset + self.len - fn range(&self) -> Range { - self.offset_or_inline..self.offset_or_inline + self.len - } -} - -/// HashSet optimized for storing `String` and `LargeString` values -/// and producing the final set as a GenericStringArray with minimal copies. -/// -/// Equivalent to `HashSet` but with better performance for arrow data. -struct SSOStringHashSet { - /// Underlying hash set for each distinct string - map: hashbrown::raw::RawTable, - /// Total size of the map in bytes - map_size: usize, - /// In progress arrow `Buffer` containing all string values - buffer: BufferBuilder, - /// Offsets into `buffer` for each distinct string value. These offsets - /// as used directly to create the final `GenericStringArray` - offsets: Vec, - /// random state used to generate hashes - random_state: RandomState, - /// buffer that stores hash values (reused across batches to save allocations) - hashes_buffer: Vec, -} - -impl Default for SSOStringHashSet { - fn default() -> Self { - Self::new() - } -} - -impl SSOStringHashSet { - fn new() -> Self { - Self { - map: hashbrown::raw::RawTable::new(), - map_size: 0, - buffer: BufferBuilder::new(0), - offsets: vec![O::default()], // first offset is always 0 - random_state: RandomState::new(), - hashes_buffer: vec![], - } - } - - fn insert(&mut self, values: ArrayRef) { - // step 1: compute hashes for the strings - let batch_hashes = &mut self.hashes_buffer; - batch_hashes.clear(); - batch_hashes.resize(values.len(), 0); - create_hashes(&[values.clone()], &self.random_state, batch_hashes) - // hash is supported for all string types and create_hashes only - // returns errors for unsupported types - .unwrap(); - - // step 2: insert each string into the set, if not already present - let values = values.as_string::(); - - // Ensure lengths are equivalent (to guard unsafe values calls below) - assert_eq!(values.len(), batch_hashes.len()); - - for (value, &hash) in values.iter().zip(batch_hashes.iter()) { - // count distinct ignores nulls - let Some(value) = value else { - continue; - }; - - // from here on only use bytes (not str/chars) for value - let value = value.as_bytes(); - - // value is a "small" string - if value.len() <= SHORT_STRING_LEN { - let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x as usize); - - // is value is already present in the set? - let entry = self.map.get_mut(hash, |header| { - // compare value if hashes match - if header.len != value.len() { - return false; - } - // value is stored inline so no need to consult buffer - // (this is the "small string optimization") - inline == header.offset_or_inline - }); - - // if no existing entry, make a new one - if entry.is_none() { - // Put the small values into buffer and offsets so it appears - // the output array, but store the actual bytes inline for - // comparison - self.buffer.append_slice(value); - self.offsets.push(O::from_usize(self.buffer.len()).unwrap()); - let new_header = SSOStringHeader { - hash, - len: value.len(), - offset_or_inline: inline, - }; - self.map.insert_accounted( - new_header, - |header| header.hash, - &mut self.map_size, - ); - } - } - // value is not a "small" string - else { - // Check if the value is already present in the set - let entry = self.map.get_mut(hash, |header| { - // compare value if hashes match - if header.len != value.len() { - return false; - } - // Need to compare the bytes in the buffer - // SAFETY: buffer is only appended to, and we correctly inserted values and offsets - let existing_value = - unsafe { self.buffer.as_slice().get_unchecked(header.range()) }; - value == existing_value - }); - - // if no existing entry, make a new one - if entry.is_none() { - // Put the small values into buffer and offsets so it - // appears the output array, and store that offset - // so the bytes can be compared if needed - let offset = self.buffer.len(); // offset of start fof data - self.buffer.append_slice(value); - self.offsets.push(O::from_usize(self.buffer.len()).unwrap()); - - let new_header = SSOStringHeader { - hash, - len: value.len(), - offset_or_inline: offset, - }; - self.map.insert_accounted( - new_header, - |header| header.hash, - &mut self.map_size, - ); - } - } - } - } - - /// Converts this set into a `StringArray` or `LargeStringArray` with each - /// distinct string value without any copies - fn into_state(self) -> ArrayRef { - let Self { - map: _, - map_size: _, - offsets, - mut buffer, - random_state: _, - hashes_buffer: _, - } = self; - - let offsets: ScalarBuffer = offsets.into(); - let values = buffer.finish(); - let nulls = None; // count distinct ignores nulls so intermediate state never has nulls - - // SAFETY: all the values that went in were valid utf8 so are all the values that come out - let array = unsafe { - GenericStringArray::new_unchecked(OffsetBuffer::new(offsets), values, nulls) - }; - Arc::new(array) - } - - fn len(&self) -> usize { - self.map.len() - } - - /// Return the total size, in bytes, of memory used to store the data in - /// this set, not including `self` - fn size(&self) -> usize { - self.map_size - + self.buffer.capacity() * std::mem::size_of::() - + self.offsets.capacity() * std::mem::size_of::() - + self.hashes_buffer.capacity() * std::mem::size_of::() - } -} - -impl Debug for SSOStringHashSet { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("SSOStringHashSet") - .field("map", &"") - .field("map_size", &self.map_size) - .field("buffer", &self.buffer) - .field("random_state", &self.random_state) - .field("hashes_buffer", &self.hashes_buffer) - .finish() - } -} #[cfg(test)] mod tests { use arrow::array::{ @@ -759,7 +455,7 @@ mod tests { Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type, }; - use arrow_array::{Decimal256Array, StringArray}; + use arrow_array::Decimal256Array; use arrow_buffer::i256; use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array}; @@ -1156,156 +852,4 @@ mod tests { assert_eq!(result, ScalarValue::Int64(Some(2))); Ok(()) } - #[test] - fn string_set_empty() { - for values in [StringArray::new_null(0), StringArray::new_null(11)] { - let mut set = SSOStringHashSet::::new(); - set.insert(Arc::new(values.clone())); - assert_set(set, &[]); - } - } - - #[test] - fn string_set_basic_i32() { - test_string_set_basic::(); - } - #[test] - fn string_set_basic_i64() { - test_string_set_basic::(); - } - fn test_string_set_basic() { - // basic test for mixed small and large string values - let values = GenericStringArray::::from(vec![ - Some("a"), - Some("b"), - Some("CXCCCCCCCC"), // 10 bytes - Some(""), - Some("cbcxx"), // 5 bytes - None, - Some("AAAAAAAA"), // 8 bytes - Some("BBBBBQBBB"), // 9 bytes - Some("a"), - Some("cbcxx"), - Some("b"), - Some("cbcxx"), - Some(""), - None, - Some("BBBBBQBBB"), - Some("BBBBBQBBB"), - Some("AAAAAAAA"), - Some("CXCCCCCCCC"), - ]); - - let mut set = SSOStringHashSet::::new(); - set.insert(Arc::new(values)); - assert_set( - set, - &[ - Some(""), - Some("AAAAAAAA"), - Some("BBBBBQBBB"), - Some("CXCCCCCCCC"), - Some("a"), - Some("b"), - Some("cbcxx"), - ], - ); - } - - #[test] - fn string_set_non_utf8_32() { - test_string_set_non_utf8::(); - } - #[test] - fn string_set_non_utf8_64() { - test_string_set_non_utf8::(); - } - fn test_string_set_non_utf8() { - // basic test for mixed small and large string values - let values = GenericStringArray::::from(vec![ - Some("a"), - Some("✨🔥"), - Some("🔥"), - Some("✨✨✨"), - Some("foobarbaz"), - Some("🔥"), - Some("✨🔥"), - ]); - - let mut set = SSOStringHashSet::::new(); - set.insert(Arc::new(values)); - assert_set( - set, - &[ - Some("a"), - Some("foobarbaz"), - Some("✨✨✨"), - Some("✨🔥"), - Some("🔥"), - ], - ); - } - - // asserts that the set contains the expected strings - fn assert_set( - set: SSOStringHashSet, - expected: &[Option<&str>], - ) { - let strings = set.into_state(); - let strings = strings.as_string::(); - let mut state = strings.into_iter().collect::>(); - state.sort(); - assert_eq!(state, expected); - } - - // inserting strings into the set does not increase reported memoyr - #[test] - fn test_string_set_memory_usage() { - let strings1 = GenericStringArray::::from(vec![ - Some("a"), - Some("b"), - Some("CXCCCCCCCC"), // 10 bytes - Some("AAAAAAAA"), // 8 bytes - Some("BBBBBQBBB"), // 9 bytes - ]); - let total_strings1_len = strings1 - .iter() - .map(|s| s.map(|s| s.len()).unwrap_or(0)) - .sum::(); - let values1: ArrayRef = Arc::new(GenericStringArray::::from(strings1)); - - // Much larger strings in strings2 - let strings2 = GenericStringArray::::from(vec![ - "FOO".repeat(1000), - "BAR".repeat(2000), - "BAZ".repeat(3000), - ]); - let total_strings2_len = strings2 - .iter() - .map(|s| s.map(|s| s.len()).unwrap_or(0)) - .sum::(); - let values2: ArrayRef = Arc::new(GenericStringArray::::from(strings2)); - - let mut set = SSOStringHashSet::::new(); - let size_empty = set.size(); - - set.insert(values1.clone()); - let size_after_values1 = set.size(); - assert!(size_empty < size_after_values1); - assert!( - size_after_values1 > total_strings1_len, - "expect {size_after_values1} to be more than {total_strings1_len}" - ); - assert!(size_after_values1 < total_strings1_len + total_strings2_len); - - // inserting the same strings should not affect the size - set.insert(values1.clone()); - assert_eq!(set.size(), size_after_values1); - - // inserting the large strings should increase the reported size - set.insert(values2); - let size_after_values2 = set.size(); - assert!(size_after_values2 > size_after_values1); - assert!(size_after_values2 > total_strings1_len + total_strings2_len); - } } diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs b/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs new file mode 100644 index 000000000000..3af3c4486346 --- /dev/null +++ b/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs @@ -0,0 +1,496 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Specialized implementation of `COUNT DISTINCT` for `StringArray` and `LargeStringArray` + +use ahash::RandomState; +use arrow_array::cast::AsArray; +use arrow_array::{Array, ArrayRef, GenericStringArray, OffsetSizeTrait}; +use arrow_buffer::{BufferBuilder, OffsetBuffer, ScalarBuffer}; +use datafusion_common::cast::as_list_array; +use datafusion_common::hash_utils::create_hashes; +use datafusion_common::utils::array_into_list_array; +use datafusion_common::ScalarValue; +use datafusion_execution::memory_pool::proxy::RawTableAllocExt; +use datafusion_expr::Accumulator; +use std::fmt::Debug; +use std::mem; +use std::ops::Range; +use std::sync::{Arc, Mutex}; + +#[derive(Debug)] +pub(super) struct StringDistinctCountAccumulator( + Mutex>, +); +impl StringDistinctCountAccumulator { + pub(super) fn new() -> Self { + Self(Mutex::new(SSOStringHashSet::::new())) + } +} + +impl Accumulator for StringDistinctCountAccumulator { + fn state(&self) -> datafusion_common::Result> { + // TODO this should not need a lock/clone (should make + // `Accumulator::state` take a mutable reference) + // see https://github.com/apache/arrow-datafusion/pull/8925 + let mut lk = self.0.lock().unwrap(); + let set: &mut SSOStringHashSet<_> = &mut lk; + // take the state out of the string set and replace with default + let set = std::mem::take(set); + let arr = set.into_state(); + let list = Arc::new(array_into_list_array(arr)); + Ok(vec![ScalarValue::List(list)]) + } + + fn update_batch(&mut self, values: &[ArrayRef]) -> datafusion_common::Result<()> { + if values.is_empty() { + return Ok(()); + } + + self.0.lock().unwrap().insert(values[0].clone()); + + Ok(()) + } + + fn merge_batch(&mut self, states: &[ArrayRef]) -> datafusion_common::Result<()> { + if states.is_empty() { + return Ok(()); + } + assert_eq!( + states.len(), + 1, + "count_distinct states must be single array" + ); + + let arr = as_list_array(&states[0])?; + arr.iter().try_for_each(|maybe_list| { + if let Some(list) = maybe_list { + self.0.lock().unwrap().insert(list); + }; + Ok(()) + }) + } + + fn evaluate(&self) -> datafusion_common::Result { + Ok(ScalarValue::Int64( + Some(self.0.lock().unwrap().len() as i64), + )) + } + + fn size(&self) -> usize { + // Size of accumulator + // + SSOStringHashSet size + std::mem::size_of_val(self) + self.0.lock().unwrap().size() + } +} + +/// Maximum size of a string that can be inlined in the hash table +const SHORT_STRING_LEN: usize = mem::size_of::(); + +/// Entry that is stored in a `SSOStringHashSet` that represents a string +/// that is either stored inline or in the buffer +/// +/// This helps the case where there are many short (less than 8 bytes) strings +/// that are the same (e.g. "MA", "CA", "NY", "TX", etc) +/// +/// ```text +/// ┌──────────────────┐ +/// │... │ +/// │TheQuickBrownFox │ +/// ─ ─ ─ ─ ─ ─ ─▶│... │ +/// │ │ │ +/// └──────────────────┘ +/// │ buffer of u8 +/// +/// │ +/// ┌────────────────┬───────────────┬───────────────┐ +/// Storing │ │ starting byte │ length, in │ +/// "TheQuickBrownFox" │ hash value │ offset in │ bytes (not │ +/// (long string) │ │ buffer │ characters) │ +/// └────────────────┴───────────────┴───────────────┘ +/// 8 bytes 8 bytes 4 or 8 +/// +/// +/// ┌───────────────┬─┬─┬─┬─┬─┬─┬─┬─┬───────────────┐ +/// Storing "foobar" │ │ │ │ │ │ │ │ │ │ length, in │ +/// (short string) │ hash value │?│?│f│o│o│b│a│r│ bytes (not │ +/// │ │ │ │ │ │ │ │ │ │ characters) │ +/// └───────────────┴─┴─┴─┴─┴─┴─┴─┴─┴───────────────┘ +/// 8 bytes 8 bytes 4 or 8 +/// ``` +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)] +struct SSOStringHeader { + /// hash of the string value (stored to avoid recomputing it in hash table + /// check) + hash: u64, + /// if len =< SHORT_STRING_LEN: the string data inlined + /// if len > SHORT_STRING_LEN, the offset of where the data starts + offset_or_inline: usize, + /// length of the string, in bytes + len: usize, +} + +impl SSOStringHeader { + /// returns self.offset..self.offset + self.len + fn range(&self) -> Range { + self.offset_or_inline..self.offset_or_inline + self.len + } +} + +/// HashSet optimized for storing `String` and `LargeString` values +/// and producing the final set as a GenericStringArray with minimal copies. +/// +/// Equivalent to `HashSet` but with better performance for arrow data. +struct SSOStringHashSet { + /// Underlying hash set for each distinct string + map: hashbrown::raw::RawTable, + /// Total size of the map in bytes + map_size: usize, + /// In progress arrow `Buffer` containing all string values + buffer: BufferBuilder, + /// Offsets into `buffer` for each distinct string value. These offsets + /// as used directly to create the final `GenericStringArray` + offsets: Vec, + /// random state used to generate hashes + random_state: RandomState, + /// buffer that stores hash values (reused across batches to save allocations) + hashes_buffer: Vec, +} + +impl Default for SSOStringHashSet { + fn default() -> Self { + Self::new() + } +} + +impl SSOStringHashSet { + fn new() -> Self { + Self { + map: hashbrown::raw::RawTable::new(), + map_size: 0, + buffer: BufferBuilder::new(0), + offsets: vec![O::default()], // first offset is always 0 + random_state: RandomState::new(), + hashes_buffer: vec![], + } + } + + fn insert(&mut self, values: ArrayRef) { + // step 1: compute hashes for the strings + let batch_hashes = &mut self.hashes_buffer; + batch_hashes.clear(); + batch_hashes.resize(values.len(), 0); + create_hashes(&[values.clone()], &self.random_state, batch_hashes) + // hash is supported for all string types and create_hashes only + // returns errors for unsupported types + .unwrap(); + + // step 2: insert each string into the set, if not already present + let values = values.as_string::(); + + // Ensure lengths are equivalent (to guard unsafe values calls below) + assert_eq!(values.len(), batch_hashes.len()); + + for (value, &hash) in values.iter().zip(batch_hashes.iter()) { + // count distinct ignores nulls + let Some(value) = value else { + continue; + }; + + // from here on only use bytes (not str/chars) for value + let value = value.as_bytes(); + + // value is a "small" string + if value.len() <= SHORT_STRING_LEN { + let inline = value.iter().fold(0usize, |acc, &x| acc << 8 | x as usize); + + // is value is already present in the set? + let entry = self.map.get_mut(hash, |header| { + // compare value if hashes match + if header.len != value.len() { + return false; + } + // value is stored inline so no need to consult buffer + // (this is the "small string optimization") + inline == header.offset_or_inline + }); + + // if no existing entry, make a new one + if entry.is_none() { + // Put the small values into buffer and offsets so it appears + // the output array, but store the actual bytes inline for + // comparison + self.buffer.append_slice(value); + self.offsets.push(O::from_usize(self.buffer.len()).unwrap()); + let new_header = SSOStringHeader { + hash, + len: value.len(), + offset_or_inline: inline, + }; + self.map.insert_accounted( + new_header, + |header| header.hash, + &mut self.map_size, + ); + } + } + // value is not a "small" string + else { + // Check if the value is already present in the set + let entry = self.map.get_mut(hash, |header| { + // compare value if hashes match + if header.len != value.len() { + return false; + } + // Need to compare the bytes in the buffer + // SAFETY: buffer is only appended to, and we correctly inserted values and offsets + let existing_value = + unsafe { self.buffer.as_slice().get_unchecked(header.range()) }; + value == existing_value + }); + + // if no existing entry, make a new one + if entry.is_none() { + // Put the small values into buffer and offsets so it + // appears the output array, and store that offset + // so the bytes can be compared if needed + let offset = self.buffer.len(); // offset of start fof data + self.buffer.append_slice(value); + self.offsets.push(O::from_usize(self.buffer.len()).unwrap()); + + let new_header = SSOStringHeader { + hash, + len: value.len(), + offset_or_inline: offset, + }; + self.map.insert_accounted( + new_header, + |header| header.hash, + &mut self.map_size, + ); + } + } + } + } + + /// Converts this set into a `StringArray` or `LargeStringArray` with each + /// distinct string value without any copies + fn into_state(self) -> ArrayRef { + let Self { + map: _, + map_size: _, + offsets, + mut buffer, + random_state: _, + hashes_buffer: _, + } = self; + + let offsets: ScalarBuffer = offsets.into(); + let values = buffer.finish(); + let nulls = None; // count distinct ignores nulls so intermediate state never has nulls + + // SAFETY: all the values that went in were valid utf8 so are all the values that come out + let array = unsafe { + GenericStringArray::new_unchecked(OffsetBuffer::new(offsets), values, nulls) + }; + Arc::new(array) + } + + fn len(&self) -> usize { + self.map.len() + } + + /// Return the total size, in bytes, of memory used to store the data in + /// this set, not including `self` + fn size(&self) -> usize { + self.map_size + + self.buffer.capacity() * std::mem::size_of::() + + self.offsets.capacity() * std::mem::size_of::() + + self.hashes_buffer.capacity() * std::mem::size_of::() + } +} + +impl Debug for SSOStringHashSet { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SSOStringHashSet") + .field("map", &"") + .field("map_size", &self.map_size) + .field("buffer", &self.buffer) + .field("random_state", &self.random_state) + .field("hashes_buffer", &self.hashes_buffer) + .finish() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use arrow::array::ArrayRef; + use arrow_array::StringArray; + #[test] + fn string_set_empty() { + for values in [StringArray::new_null(0), StringArray::new_null(11)] { + let mut set = SSOStringHashSet::::new(); + set.insert(Arc::new(values.clone())); + assert_set(set, &[]); + } + } + + #[test] + fn string_set_basic_i32() { + test_string_set_basic::(); + } + #[test] + fn string_set_basic_i64() { + test_string_set_basic::(); + } + fn test_string_set_basic() { + // basic test for mixed small and large string values + let values = GenericStringArray::::from(vec![ + Some("a"), + Some("b"), + Some("CXCCCCCCCC"), // 10 bytes + Some(""), + Some("cbcxx"), // 5 bytes + None, + Some("AAAAAAAA"), // 8 bytes + Some("BBBBBQBBB"), // 9 bytes + Some("a"), + Some("cbcxx"), + Some("b"), + Some("cbcxx"), + Some(""), + None, + Some("BBBBBQBBB"), + Some("BBBBBQBBB"), + Some("AAAAAAAA"), + Some("CXCCCCCCCC"), + ]); + + let mut set = SSOStringHashSet::::new(); + set.insert(Arc::new(values)); + assert_set( + set, + &[ + Some(""), + Some("AAAAAAAA"), + Some("BBBBBQBBB"), + Some("CXCCCCCCCC"), + Some("a"), + Some("b"), + Some("cbcxx"), + ], + ); + } + + #[test] + fn string_set_non_utf8_32() { + test_string_set_non_utf8::(); + } + #[test] + fn string_set_non_utf8_64() { + test_string_set_non_utf8::(); + } + fn test_string_set_non_utf8() { + // basic test for mixed small and large string values + let values = GenericStringArray::::from(vec![ + Some("a"), + Some("✨🔥"), + Some("🔥"), + Some("✨✨✨"), + Some("foobarbaz"), + Some("🔥"), + Some("✨🔥"), + ]); + + let mut set = SSOStringHashSet::::new(); + set.insert(Arc::new(values)); + assert_set( + set, + &[ + Some("a"), + Some("foobarbaz"), + Some("✨✨✨"), + Some("✨🔥"), + Some("🔥"), + ], + ); + } + + // asserts that the set contains the expected strings + fn assert_set( + set: SSOStringHashSet, + expected: &[Option<&str>], + ) { + let strings = set.into_state(); + let strings = strings.as_string::(); + let mut state = strings.into_iter().collect::>(); + state.sort(); + assert_eq!(state, expected); + } + + // inserting strings into the set does not increase reported memoyr + #[test] + fn test_string_set_memory_usage() { + let strings1 = GenericStringArray::::from(vec![ + Some("a"), + Some("b"), + Some("CXCCCCCCCC"), // 10 bytes + Some("AAAAAAAA"), // 8 bytes + Some("BBBBBQBBB"), // 9 bytes + ]); + let total_strings1_len = strings1 + .iter() + .map(|s| s.map(|s| s.len()).unwrap_or(0)) + .sum::(); + let values1: ArrayRef = Arc::new(GenericStringArray::::from(strings1)); + + // Much larger strings in strings2 + let strings2 = GenericStringArray::::from(vec![ + "FOO".repeat(1000), + "BAR".repeat(2000), + "BAZ".repeat(3000), + ]); + let total_strings2_len = strings2 + .iter() + .map(|s| s.map(|s| s.len()).unwrap_or(0)) + .sum::(); + let values2: ArrayRef = Arc::new(GenericStringArray::::from(strings2)); + + let mut set = SSOStringHashSet::::new(); + let size_empty = set.size(); + + set.insert(values1.clone()); + let size_after_values1 = set.size(); + assert!(size_empty < size_after_values1); + assert!( + size_after_values1 > total_strings1_len, + "expect {size_after_values1} to be more than {total_strings1_len}" + ); + assert!(size_after_values1 < total_strings1_len + total_strings2_len); + + // inserting the same strings should not affect the size + set.insert(values1.clone()); + assert_eq!(set.size(), size_after_values1); + + // inserting the large strings should increase the reported size + set.insert(values2); + let size_after_values2 = set.size(); + assert!(size_after_values2 > size_after_values1); + assert!(size_after_values2 > total_strings1_len + total_strings2_len); + } +} From 3a6a06623eae825f972b6fbcd480bd50632267a1 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Wed, 24 Jan 2024 07:05:39 -0500 Subject: [PATCH 28/30] Remove use of Mutex --- .../src/aggregate/count_distinct/strings.rs | 29 +++++++------------ 1 file changed, 10 insertions(+), 19 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs b/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs index 3af3c4486346..3994f0112967 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs @@ -30,27 +30,20 @@ use datafusion_expr::Accumulator; use std::fmt::Debug; use std::mem; use std::ops::Range; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; #[derive(Debug)] -pub(super) struct StringDistinctCountAccumulator( - Mutex>, -); +pub(super) struct StringDistinctCountAccumulator(SSOStringHashSet); impl StringDistinctCountAccumulator { pub(super) fn new() -> Self { - Self(Mutex::new(SSOStringHashSet::::new())) + Self(SSOStringHashSet::::new()) } } impl Accumulator for StringDistinctCountAccumulator { - fn state(&self) -> datafusion_common::Result> { - // TODO this should not need a lock/clone (should make - // `Accumulator::state` take a mutable reference) - // see https://github.com/apache/arrow-datafusion/pull/8925 - let mut lk = self.0.lock().unwrap(); - let set: &mut SSOStringHashSet<_> = &mut lk; + fn state(&mut self) -> datafusion_common::Result> { // take the state out of the string set and replace with default - let set = std::mem::take(set); + let set = std::mem::take(&mut self.0); let arr = set.into_state(); let list = Arc::new(array_into_list_array(arr)); Ok(vec![ScalarValue::List(list)]) @@ -61,7 +54,7 @@ impl Accumulator for StringDistinctCountAccumulator { return Ok(()); } - self.0.lock().unwrap().insert(values[0].clone()); + self.0.insert(values[0].clone()); Ok(()) } @@ -79,22 +72,20 @@ impl Accumulator for StringDistinctCountAccumulator { let arr = as_list_array(&states[0])?; arr.iter().try_for_each(|maybe_list| { if let Some(list) = maybe_list { - self.0.lock().unwrap().insert(list); + self.0.insert(list); }; Ok(()) }) } - fn evaluate(&self) -> datafusion_common::Result { - Ok(ScalarValue::Int64( - Some(self.0.lock().unwrap().len() as i64), - )) + fn evaluate(&mut self) -> datafusion_common::Result { + Ok(ScalarValue::Int64(Some(self.0.len() as i64))) } fn size(&self) -> usize { // Size of accumulator // + SSOStringHashSet size - std::mem::size_of_val(self) + self.0.lock().unwrap().size() + std::mem::size_of_val(self) + self.0.size() } } From 8640907927b67fa2759fb3f903172ceafb45bf1e Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Thu, 25 Jan 2024 07:35:38 -0500 Subject: [PATCH 29/30] revert changes --- benchmarks/queries/clickbench/README.md | 23 +--------------------- benchmarks/queries/clickbench/extended.sql | 2 +- 2 files changed, 2 insertions(+), 23 deletions(-) diff --git a/benchmarks/queries/clickbench/README.md b/benchmarks/queries/clickbench/README.md index e3ada9858978..ef540ccf9c91 100644 --- a/benchmarks/queries/clickbench/README.md +++ b/benchmarks/queries/clickbench/README.md @@ -29,27 +29,6 @@ SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DIST FROM hits; ``` -### Q1 -Models initial Data exploration, to understand some statistics of data. -Models initial Data exploration, to understand some statistics of data. -Query to test distinct count for String. Three of them are all small string (length either 1 or 2). - -```sql -SELECT - COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") -FROM hits; -``` - -### Q2 -Models initial Data exploration, to understand some statistics of data. -Extend with `group by` from Q1 - -```sql -SELECT - "BrowserCountry", COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") -FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; -``` - ### Q1: Data Exploration **Question**: "How many distinct "hit color", "browser country" and "language" are there in the dataset?" @@ -62,7 +41,7 @@ SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTI FROM hits; ``` -### Q2: Top 10 anaylsis +### Q2: Top 10 analysis **Question**: "Find the top 10 "browser country" by number of distinct "social network"s, including the distinct counts of "hit color", "browser language", diff --git a/benchmarks/queries/clickbench/extended.sql b/benchmarks/queries/clickbench/extended.sql index bcba4fee9277..0a2999fceb49 100644 --- a/benchmarks/queries/clickbench/extended.sql +++ b/benchmarks/queries/clickbench/extended.sql @@ -1,3 +1,3 @@ SELECT COUNT(DISTINCT "SearchPhrase"), COUNT(DISTINCT "MobilePhone"), COUNT(DISTINCT "MobilePhoneModel") FROM hits; SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage") FROM hits; -SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; +SELECT "BrowserCountry", COUNT(DISTINCT "SocialNetwork"), COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "SocialAction") FROM hits GROUP BY 1 ORDER BY 2 DESC LIMIT 10; \ No newline at end of file From f5e268d07a3bb3d7d244775dd4ef3ee0b6380829 Mon Sep 17 00:00:00 2001 From: Andrew Lamb Date: Sun, 28 Jan 2024 07:40:13 -0500 Subject: [PATCH 30/30] Use reference rather than owned ArrayRef --- .../src/aggregate/count_distinct/strings.rs | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs b/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs index 3994f0112967..d7a9ea5c373d 100644 --- a/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs +++ b/datafusion/physical-expr/src/aggregate/count_distinct/strings.rs @@ -54,7 +54,7 @@ impl Accumulator for StringDistinctCountAccumulator { return Ok(()); } - self.0.insert(values[0].clone()); + self.0.insert(&values[0]); Ok(()) } @@ -72,7 +72,7 @@ impl Accumulator for StringDistinctCountAccumulator { let arr = as_list_array(&states[0])?; arr.iter().try_for_each(|maybe_list| { if let Some(list) = maybe_list { - self.0.insert(list); + self.0.insert(&list); }; Ok(()) }) @@ -180,7 +180,7 @@ impl SSOStringHashSet { } } - fn insert(&mut self, values: ArrayRef) { + fn insert(&mut self, values: &ArrayRef) { // step 1: compute hashes for the strings let batch_hashes = &mut self.hashes_buffer; batch_hashes.clear(); @@ -336,7 +336,8 @@ mod tests { fn string_set_empty() { for values in [StringArray::new_null(0), StringArray::new_null(11)] { let mut set = SSOStringHashSet::::new(); - set.insert(Arc::new(values.clone())); + let array: ArrayRef = Arc::new(values); + set.insert(&array); assert_set(set, &[]); } } @@ -373,7 +374,8 @@ mod tests { ]); let mut set = SSOStringHashSet::::new(); - set.insert(Arc::new(values)); + let array: ArrayRef = Arc::new(values); + set.insert(&array); assert_set( set, &[ @@ -409,7 +411,8 @@ mod tests { ]); let mut set = SSOStringHashSet::::new(); - set.insert(Arc::new(values)); + let array: ArrayRef = Arc::new(values); + set.insert(&array); assert_set( set, &[ @@ -465,7 +468,7 @@ mod tests { let mut set = SSOStringHashSet::::new(); let size_empty = set.size(); - set.insert(values1.clone()); + set.insert(&values1); let size_after_values1 = set.size(); assert!(size_empty < size_after_values1); assert!( @@ -475,11 +478,11 @@ mod tests { assert!(size_after_values1 < total_strings1_len + total_strings2_len); // inserting the same strings should not affect the size - set.insert(values1.clone()); + set.insert(&values1); assert_eq!(set.size(), size_after_values1); // inserting the large strings should increase the reported size - set.insert(values2); + set.insert(&values2); let size_after_values2 = set.size(); assert!(size_after_values2 > size_after_values1); assert!(size_after_values2 > total_strings1_len + total_strings2_len);