Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize COUNT( DISTINCT ...) for strings (up to 9x faster) #8849

Merged
merged 37 commits into from
Jan 29, 2024

Conversation

jayzhan211
Copy link
Contributor

@jayzhan211 jayzhan211 commented Jan 13, 2024

Note for reviewers this PR has around 200 lines of code, and the rest is testing

This PR is a collaboration between @jayzhan211 and @alamb

Which issue does this PR close?

Part of #5472
Follow up on #8721

Rationale for this change

Speed up queries that include multiple COUNT DISTINCTs for String or LargeString

What changes are included in this PR?

Implement a specialized Accumulator for COUNT DISTINCT that avoids copying string data or allocating individual strings

Are these changes tested?

  1. New unit tests
  2. New fuzz test
  3. New sqllogictests

Benchmark results:

Clickbench Extended

Admittedly these benchmarks were chosen to highlight this particular change but I am still feeling pretty good with 9x faster query :bowtie:. See Docs for more details about what these tests are

--------------------
Benchmark clickbench_extended.json
--------------------
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Query        ┃  main_base ┃ bytes-distinctcount ┃        Change ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ QQuery 0     │ 31810.05ms │           3506.66ms │ +9.07x faster │
│ QQuery 1     │  4567.84ms │           1509.76ms │ +3.03x faster │
│ QQuery 2     │  9525.14ms │           3482.35ms │ +2.74x faster │
└──────────────┴────────────┴─────────────────────┴───────────────┘
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┓
┃ Benchmark Summary                  ┃            ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━┩
│ Total Time (main_base)             │ 45903.03ms │
│ Total Time (bytes-distinctcount)   │  8498.77ms │
│ Average Time (main_base)           │ 15301.01ms │
│ Average Time (bytes-distinctcount) │  2832.92ms │
│ Queries Faster                     │          3 │
│ Queries Slower                     │          0 │
│ Queries with No Change             │          0 │
└────────────────────────────────────┴────────────┘
Entire Clickbench (basically the same)

--------------------
Benchmark clickbench_1.json
--------------------
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Query        ┃  main_base ┃ bytes-distinctcount ┃        Change ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ QQuery 0     │     0.95ms │              0.95ms │     no change │
│ QQuery 1     │    92.20ms │             91.78ms │     no change │
│ QQuery 2     │   192.76ms │            188.63ms │     no change │
│ QQuery 3     │   209.18ms │            205.88ms │     no change │
│ QQuery 4     │  2005.31ms │           2029.24ms │     no change │
│ QQuery 5     │  2665.90ms │           2660.46ms │     no change │
│ QQuery 6     │    82.57ms │             80.39ms │     no change │
│ QQuery 7     │    95.49ms │             95.38ms │     no change │
│ QQuery 8     │  3065.70ms │           3025.37ms │     no change │
│ QQuery 9     │  2256.66ms │           2229.14ms │     no change │
│ QQuery 10    │   807.43ms │            806.34ms │     no change │
│ QQuery 11    │   873.79ms │            875.82ms │     no change │
│ QQuery 12    │  2449.75ms │           2442.61ms │     no change │
│ QQuery 13    │  4760.71ms │           4639.70ms │     no change │
│ QQuery 14    │  2709.37ms │           2641.73ms │     no change │
│ QQuery 15    │  2276.90ms │           2205.92ms │     no change │
│ QQuery 16    │  5532.90ms │           5476.04ms │     no change │
│ QQuery 17    │  5487.97ms │           5326.65ms │     no change │
│ QQuery 18    │ 11310.75ms │          10917.21ms │     no change │
│ QQuery 19    │   163.58ms │            155.14ms │ +1.05x faster │
│ QQuery 20    │  2596.50ms │           2583.46ms │     no change │
│ QQuery 21    │  3348.95ms │           3334.02ms │     no change │
│ QQuery 22    │  9301.68ms │           9143.42ms │     no change │
│ QQuery 23    │ 21843.06ms │          21409.51ms │     no change │
│ QQuery 24    │  1342.49ms │           1315.66ms │     no change │
│ QQuery 25    │  1159.28ms │           1122.84ms │     no change │
│ QQuery 26    │  1461.98ms │           1438.08ms │     no change │
│ QQuery 27    │  3878.11ms │           3787.23ms │     no change │
│ QQuery 28    │ 30404.48ms │          30145.97ms │     no change │
│ QQuery 29    │  1030.54ms │           1029.82ms │     no change │
│ QQuery 30    │  2459.54ms │           2344.20ms │     no change │
│ QQuery 31    │  3139.26ms │           3048.40ms │     no change │
│ QQuery 32    │ 16329.44ms │          15284.43ms │ +1.07x faster │
│ QQuery 33    │ 12134.41ms │          11993.85ms │     no change │
│ QQuery 34    │ 12840.20ms │          12805.82ms │     no change │
│ QQuery 35    │  3701.97ms │           3576.27ms │     no change │
│ QQuery 36    │   417.23ms │            396.03ms │ +1.05x faster │
│ QQuery 37    │   245.99ms │            249.50ms │     no change │
│ QQuery 38    │   189.14ms │            185.00ms │     no change │
│ QQuery 39    │  1080.01ms │           1075.50ms │     no change │
│ QQuery 40    │    85.94ms │             87.76ms │     no change │
│ QQuery 41    │    81.70ms │             80.47ms │     no change │
│ QQuery 42    │    89.83ms │             90.62ms │     no change │
└──────────────┴────────────┴─────────────────────┴───────────────┘

┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━┓
┃ Benchmark Summary                  ┃             ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━┩
│ Total Time (main_base)             │ 176201.59ms │
│ Total Time (bytes-distinctcount)   │ 172622.23ms │
│ Average Time (main_base)           │   4097.71ms │
│ Average Time (bytes-distinctcount) │   4014.47ms │
│ Queries Faster                     │           3 │
│ Queries Slower                     │           0 │
│ Queries with No Change             │          40 │
└────────────────────────────────────┴─────────────┘

Entire TPCH_1 (basically the same)


Benchmark tpch_mem.json

┏━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Query ┃ main_base ┃ bytes-distinctcount ┃ Change ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ QQuery 1 │ 219.29ms │ 208.36ms │ no change │
│ QQuery 2 │ 47.19ms │ 45.30ms │ no change │
│ QQuery 3 │ 79.84ms │ 78.33ms │ no change │
│ QQuery 4 │ 75.31ms │ 76.34ms │ no change │
│ QQuery 5 │ 125.54ms │ 126.23ms │ no change │
│ QQuery 6 │ 16.68ms │ 16.15ms │ no change │
│ QQuery 7 │ 337.79ms │ 320.61ms │ +1.05x faster │
│ QQuery 8 │ 82.13ms │ 80.68ms │ no change │
│ QQuery 9 │ 127.99ms │ 128.50ms │ no change │
│ QQuery 10 │ 158.88ms │ 155.37ms │ no change │
│ QQuery 11 │ 33.85ms │ 33.83ms │ no change │
│ QQuery 12 │ 72.23ms │ 70.90ms │ no change │
│ QQuery 13 │ 83.32ms │ 85.77ms │ no change │
│ QQuery 14 │ 26.51ms │ 26.13ms │ no change │
│ QQuery 15 │ 62.08ms │ 60.46ms │ no change │
│ QQuery 16 │ 48.43ms │ 45.78ms │ +1.06x faster │
│ QQuery 17 │ 166.20ms │ 161.16ms │ no change │
│ QQuery 18 │ 471.68ms │ 465.24ms │ no change │
│ QQuery 19 │ 65.11ms │ 65.86ms │ no change │
│ QQuery 20 │ 117.70ms │ 117.06ms │ no change │
│ QQuery 21 │ 368.86ms │ 363.66ms │ no change │
│ QQuery 22 │ 29.75ms │ 29.44ms │ no change │
└──────────────┴───────────┴─────────────────────┴───────────────┘
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃ Benchmark Summary ┃ ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━┩
│ Total Time (main_base) │ 2816.38ms │
│ Total Time (bytes-distinctcount) │ 2761.14ms │
│ Average Time (main_base) │ 128.02ms │
│ Average Time (bytes-distinctcount) │ 125.51ms │
│ Queries Faster │ 2 │
│ Queries Slower │ 0 │
│ Queries with No Change │ 20 │
└────────────────────────────────────┴───────────┘


</p>
</details> 

## Are there any user-facing changes?

Faster queries 🚀 

@github-actions github-actions bot added physical-expr Physical Expressions core Core DataFusion crate sqllogictest SQL Logic Tests (.slt) and removed core Core DataFusion crate labels Jan 13, 2024

const SHORT_STRING_LEN: usize = mem::size_of::<usize>();

#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allow Copy since they are all native types

@alamb
Copy link
Contributor

alamb commented Jan 13, 2024

Thanks @jayzhan211 -- looks basically on the right track. Is there any chance you can run some sort of benchmark on this code? My thinking is that we should get benchmark results showing that the idea actually improves performance before spending too much time polishing

I looked at ClickBench and I don't actually think there are any queries that do COUNT(distinct <utf8>)

Q8 looks like it should be helped

SELECT COUNT(DISTINCT "SearchPhrase") FROM hits;

however, I am pretty sure datfusion rewrites this query to avoid the distinct with SELECT COUNT(..) GROUP BY "SearchPhrase"

Maybe you could try manually runing a query that can't be rewritten (throw un multiple DISTINCTs) such as

SELECT
  COUNT(DISTINCT "SearchPhrase"),
  COUNT(DISTINCT "MobilePhone"),
  COUNT(DISTINCT "MobilePhoneModel")
FROM 'hits.parquet';
+-------------------------------------------+------------------------------------------+-----------------------------------------------+
| COUNT(DISTINCT hits.parquet.SearchPhrase) | COUNT(DISTINCT hits.parquet.MobilePhone) | COUNT(DISTINCT hits.parquet.MobilePhoneModel) |
+-------------------------------------------+------------------------------------------+-----------------------------------------------+
| 6019103                                   | 44                                       | 166                                           |
+-------------------------------------------+------------------------------------------+-----------------------------------------------+

@jayzhan211
Copy link
Contributor Author

It looks pretty nice.

SELECT
  COUNT(DISTINCT "SearchPhrase"),
  COUNT(DISTINCT "MobilePhone"),
  COUNT(DISTINCT "MobilePhoneModel")
Screenshot 2024-01-14 at 2 23 23 PM

@jayzhan211
Copy link
Contributor Author

Run with simple HashSet has the similar result... I think I need to test with lots of short string data
Screenshot 2024-01-14 at 2 46 39 PM

@jayzhan211
Copy link
Contributor Author

jayzhan211 commented Jan 14, 2024

I test with injecting string array manually to find out if SSO is better than simple HashSet

Compare with n=1e5

  1. hashset 2. short string

Mostly non distinct case

  • all short, 13k vs 13k ms
  • half short, 16k vs 13k ms
  • all long, 18k vs 14k ms

All distinct case

  • all short, 20k vs 19k ms
  • half short, 26k vs 23k ms
  • all long, 31k vs 24k ms

I thought the more the small string is the more performance gains, but it shows that the more the long string is the better
🤔
Also, I thought the more distinct long strings are the less difference between SSO vs HashSet, but that has not been the case.

TLDR
I think we can move on with SSO HashSet

testing file
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

use arrow::datatypes::{DataType, Field, TimeUnit};
use arrow_array::types::{
    ArrowPrimitiveType, Date32Type, Date64Type, Decimal128Type, Decimal256Type,
    Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type,
    Time32MillisecondType, Time32SecondType, Time64MicrosecondType, Time64NanosecondType,
    TimestampMicrosecondType, TimestampMillisecondType, TimestampNanosecondType,
    TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
};
use arrow_array::{PrimitiveArray, StringArray};
use arrow_buffer::BufferBuilder;
use chrono::format;
use rand::Rng;

use std::any::Any;
use std::cmp::Eq;
use std::fmt::Debug;
use std::hash::Hash;
use std::mem;
use std::sync::Arc;

use ahash::RandomState;
use arrow::array::{Array, ArrayRef};
use std::collections::HashSet;

use crate::aggregate::utils::{down_cast_any_ref, Hashable};
use crate::expressions::format_state_name;
use crate::{AggregateExpr, PhysicalExpr};
use datafusion_common::cast::{as_list_array, as_primitive_array, as_string_array};
use datafusion_common::utils::array_into_list_array;
use datafusion_common::{Result, ScalarValue};
use datafusion_execution::memory_pool::proxy::RawTableAllocExt;
use datafusion_expr::Accumulator;

type DistinctScalarValues = ScalarValue;

/// Expression for a COUNT(DISTINCT) aggregation.
#[derive(Debug)]
pub struct DistinctCount {
    /// Column name
    name: String,
    /// The DataType used to hold the state for each input
    state_data_type: DataType,
    /// The input arguments
    expr: Arc<dyn PhysicalExpr>,
}

impl DistinctCount {
    /// Create a new COUNT(DISTINCT) aggregate function.
    pub fn new(
        input_data_type: DataType,
        expr: Arc<dyn PhysicalExpr>,
        name: String,
    ) -> Self {
        Self {
            name,
            state_data_type: input_data_type,
            expr,
        }
    }
}

macro_rules! native_distinct_count_accumulator {
    ($TYPE:ident) => {{
        Ok(Box::new(NativeDistinctCountAccumulator::<$TYPE>::new()))
    }};
}

macro_rules! float_distinct_count_accumulator {
    ($TYPE:ident) => {{
        Ok(Box::new(FloatDistinctCountAccumulator::<$TYPE>::new()))
    }};
}

impl AggregateExpr for DistinctCount {
    /// Return a reference to Any that can be used for downcasting
    fn as_any(&self) -> &dyn Any {
        self
    }

    fn field(&self) -> Result<Field> {
        Ok(Field::new(&self.name, DataType::Int64, true))
    }

    fn state_fields(&self) -> Result<Vec<Field>> {
        Ok(vec![Field::new_list(
            format_state_name(&self.name, "count distinct"),
            Field::new("item", self.state_data_type.clone(), true),
            false,
        )])
    }

    fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
        vec![self.expr.clone()]
    }

    fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
        use DataType::*;
        use TimeUnit::*;

        match &self.state_data_type {
            Int8 => native_distinct_count_accumulator!(Int8Type),
            Int16 => native_distinct_count_accumulator!(Int16Type),
            Int32 => native_distinct_count_accumulator!(Int32Type),
            Int64 => native_distinct_count_accumulator!(Int64Type),
            UInt8 => native_distinct_count_accumulator!(UInt8Type),
            UInt16 => native_distinct_count_accumulator!(UInt16Type),
            UInt32 => native_distinct_count_accumulator!(UInt32Type),
            UInt64 => native_distinct_count_accumulator!(UInt64Type),
            Decimal128(_, _) => native_distinct_count_accumulator!(Decimal128Type),
            Decimal256(_, _) => native_distinct_count_accumulator!(Decimal256Type),

            Date32 => native_distinct_count_accumulator!(Date32Type),
            Date64 => native_distinct_count_accumulator!(Date64Type),
            Time32(Millisecond) => {
                native_distinct_count_accumulator!(Time32MillisecondType)
            }
            Time32(Second) => {
                native_distinct_count_accumulator!(Time32SecondType)
            }
            Time64(Microsecond) => {
                native_distinct_count_accumulator!(Time64MicrosecondType)
            }
            Time64(Nanosecond) => {
                native_distinct_count_accumulator!(Time64NanosecondType)
            }
            Timestamp(Microsecond, _) => {
                native_distinct_count_accumulator!(TimestampMicrosecondType)
            }
            Timestamp(Millisecond, _) => {
                native_distinct_count_accumulator!(TimestampMillisecondType)
            }
            Timestamp(Nanosecond, _) => {
                native_distinct_count_accumulator!(TimestampNanosecondType)
            }
            Timestamp(Second, _) => {
                native_distinct_count_accumulator!(TimestampSecondType)
            }

            Float16 => float_distinct_count_accumulator!(Float16Type),
            Float32 => float_distinct_count_accumulator!(Float32Type),
            Float64 => float_distinct_count_accumulator!(Float64Type),

            Utf8 => Ok(Box::new(StringDistinctCountAccumulator::new())),

            _ => Ok(Box::new(DistinctCountAccumulator {
                values: HashSet::default(),
                state_data_type: self.state_data_type.clone(),
            })),
        }
    }

    fn name(&self) -> &str {
        &self.name
    }
}

impl PartialEq<dyn Any> for DistinctCount {
    fn eq(&self, other: &dyn Any) -> bool {
        down_cast_any_ref(other)
            .downcast_ref::<Self>()
            .map(|x| {
                self.name == x.name
                    && self.state_data_type == x.state_data_type
                    && self.expr.eq(&x.expr)
            })
            .unwrap_or(false)
    }
}

#[derive(Debug)]
struct DistinctCountAccumulator {
    values: HashSet<DistinctScalarValues, RandomState>,
    state_data_type: DataType,
}

impl DistinctCountAccumulator {
    // calculating the size for fixed length values, taking first batch size * number of batches
    // This method is faster than .full_size(), however it is not suitable for variable length values like strings or complex types
    fn fixed_size(&self) -> usize {
        std::mem::size_of_val(self)
            + (std::mem::size_of::<DistinctScalarValues>() * self.values.capacity())
            + self
                .values
                .iter()
                .next()
                .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals))
                .unwrap_or(0)
            + std::mem::size_of::<DataType>()
    }

    // calculates the size as accurate as possible, call to this method is expensive
    fn full_size(&self) -> usize {
        std::mem::size_of_val(self)
            + (std::mem::size_of::<DistinctScalarValues>() * self.values.capacity())
            + self
                .values
                .iter()
                .map(|vals| ScalarValue::size(vals) - std::mem::size_of_val(vals))
                .sum::<usize>()
            + std::mem::size_of::<DataType>()
    }
}

fn get_vec_string(n: usize) -> Vec<String> {
    // Create a vector and generate random strings
    let mut rng = rand::thread_rng();
    let random_strings: Vec<String> = (0..n)
        .map(|_| {
            let random_char = match rng.gen_range(0..3) {
                0 => "a",
                1 => "b",
                2 => "cccccccc",
                _ => unreachable!(), // This should never happen
            };
            random_char.to_string()
        })
        .collect();
    random_strings 
}

fn get_distinct_string(n: usize) -> Vec<String> {
    let distinct_strings: Vec<String> = (1..=n)
        .map(|i| {
            format!("{}{}", i.to_string(), "aaaaaaaa")
            // if i < n / 2 {
            //     i.to_string()
            // } else {
            // }

        })
        .collect();
    distinct_strings
    // let mut rng = rand::thread_rng();
    // let random_strings: Vec<String> = (0..n)
    //     .map(|_| {
    //         let random_char = match rng.gen_range(0..usize::MAX) {
    //             x => x.to_string(),
    //         };
    //         random_char.to_string()
    //     })
    //     .collect();
    // random_strings 
}

impl Accumulator for DistinctCountAccumulator {
    fn state(&self) -> Result<Vec<ScalarValue>> {
        let scalars = self.values.iter().cloned().collect::<Vec<_>>();
        let arr = ScalarValue::new_list(scalars.as_slice(), &self.state_data_type);
        Ok(vec![ScalarValue::List(arr)])
    }

    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
        if values.is_empty() {
            return Ok(());
        }

        let arr = &values[0];
        if arr.data_type() == &DataType::Null {
            return Ok(());
        }

        (0..arr.len()).try_for_each(|index| {
            if !arr.is_null(index) {
                let scalar = ScalarValue::try_from_array(arr, index)?;
                self.values.insert(scalar);
            }
            Ok(())
        })
    }

    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
        if states.is_empty() {
            return Ok(());
        }
        assert_eq!(states.len(), 1, "array_agg states must be singleton!");
        let scalar_vec = ScalarValue::convert_array_to_scalar_vec(&states[0])?;
        for scalars in scalar_vec.into_iter() {
            self.values.extend(scalars);
        }
        Ok(())
    }

    fn evaluate(&self) -> Result<ScalarValue> {
        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
    }

    fn size(&self) -> usize {
        match &self.state_data_type {
            DataType::Boolean | DataType::Null => self.fixed_size(),
            d if d.is_primitive() => self.fixed_size(),
            _ => self.full_size(),
        }
    }
}

#[derive(Debug)]
struct NativeDistinctCountAccumulator<T>
where
    T: ArrowPrimitiveType + Send,
    T::Native: Eq + Hash,
{
    values: HashSet<T::Native, RandomState>,
}

impl<T> NativeDistinctCountAccumulator<T>
where
    T: ArrowPrimitiveType + Send,
    T::Native: Eq + Hash,
{
    fn new() -> Self {
        Self {
            values: HashSet::default(),
        }
    }
}

impl<T> Accumulator for NativeDistinctCountAccumulator<T>
where
    T: ArrowPrimitiveType + Send + Debug,
    T::Native: Eq + Hash,
{
    fn state(&self) -> Result<Vec<ScalarValue>> {
        let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
            self.values.iter().cloned(),
        )) as ArrayRef;
        let list = Arc::new(array_into_list_array(arr));
        Ok(vec![ScalarValue::List(list)])
    }

    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
        if values.is_empty() {
            return Ok(());
        }

        let arr = as_primitive_array::<T>(&values[0])?;
        arr.iter().for_each(|value| {
            if let Some(value) = value {
                self.values.insert(value);
            }
        });

        Ok(())
    }

    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
        if states.is_empty() {
            return Ok(());
        }
        assert_eq!(
            states.len(),
            1,
            "count_distinct states must be single array"
        );

        let arr = as_list_array(&states[0])?;
        arr.iter().try_for_each(|maybe_list| {
            if let Some(list) = maybe_list {
                let list = as_primitive_array::<T>(&list)?;
                self.values.extend(list.values())
            };
            Ok(())
        })
    }

    fn evaluate(&self) -> Result<ScalarValue> {
        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
    }

    fn size(&self) -> usize {
        let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX)
            / 7)
        .next_power_of_two();

        // Size of accumulator
        // + size of entry * number of buckets
        // + 1 byte for each bucket
        // + fixed size of HashSet
        std::mem::size_of_val(self)
            + std::mem::size_of::<T::Native>() * estimated_buckets
            + estimated_buckets
            + std::mem::size_of_val(&self.values)
    }
}

#[derive(Debug)]
struct FloatDistinctCountAccumulator<T>
where
    T: ArrowPrimitiveType + Send,
{
    values: HashSet<Hashable<T::Native>, RandomState>,
}

impl<T> FloatDistinctCountAccumulator<T>
where
    T: ArrowPrimitiveType + Send,
{
    fn new() -> Self {
        Self {
            values: HashSet::default(),
        }
    }
}

impl<T> Accumulator for FloatDistinctCountAccumulator<T>
where
    T: ArrowPrimitiveType + Send + Debug,
{
    fn state(&self) -> Result<Vec<ScalarValue>> {
        let arr = Arc::new(PrimitiveArray::<T>::from_iter_values(
            self.values.iter().map(|v| v.0),
        )) as ArrayRef;
        let list = Arc::new(array_into_list_array(arr));
        Ok(vec![ScalarValue::List(list)])
    }

    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
        if values.is_empty() {
            return Ok(());
        }

        let arr = as_primitive_array::<T>(&values[0])?;
        arr.iter().for_each(|value| {
            if let Some(value) = value {
                self.values.insert(Hashable(value));
            }
        });

        Ok(())
    }

    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
        if states.is_empty() {
            return Ok(());
        }
        assert_eq!(
            states.len(),
            1,
            "count_distinct states must be single array"
        );

        let arr = as_list_array(&states[0])?;
        arr.iter().try_for_each(|maybe_list| {
            if let Some(list) = maybe_list {
                let list = as_primitive_array::<T>(&list)?;
                self.values
                    .extend(list.values().iter().map(|v| Hashable(*v)));
            };
            Ok(())
        })
    }

    fn evaluate(&self) -> Result<ScalarValue> {
        Ok(ScalarValue::Int64(Some(self.values.len() as i64)))
    }

    fn size(&self) -> usize {
        let estimated_buckets = (self.values.len().checked_mul(8).unwrap_or(usize::MAX)
            / 7)
        .next_power_of_two();

        // Size of accumulator
        // + size of entry * number of buckets
        // + 1 byte for each bucket
        // + fixed size of HashSet
        std::mem::size_of_val(self)
            + std::mem::size_of::<T::Native>() * estimated_buckets
            + estimated_buckets
            + std::mem::size_of_val(&self.values)
    }
}

#[derive(Debug)]
struct StringDistinctCountAccumulator2(HashSet<String>);
impl StringDistinctCountAccumulator2 {
    fn new() -> Self {
        Self(HashSet::new())
    }
}

impl Accumulator for StringDistinctCountAccumulator2 {
    fn state(&self) -> Result<Vec<ScalarValue>> {
        let arr = StringArray::from_iter_values(self.0.iter());
        let list = Arc::new(array_into_list_array(Arc::new(arr)));
        Ok(vec![ScalarValue::List(list)])
    }

    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
        if values.is_empty() {
            return Ok(());
        }


        let vs = get_distinct_string(100000);
        // let vs = get_vec_string(100000);
        let arr = &StringArray::from_iter_values(vs);
        // let arr = as_string_array(&values[0])?;
        arr.iter().for_each(|value| {
            if let Some(value) = value {
                self.0.insert(value.to_string());
            }
        });

        Ok(())
    }

    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
        if states.is_empty() {
            return Ok(());
        }
        assert_eq!(
            states.len(),
            1,
            "count_distinct states must be single array"
        );

        let arr = as_list_array(&states[0])?;
        arr.iter().try_for_each(|maybe_list| {
            if let Some(list) = maybe_list {
                let list = as_string_array(&list)?;

                list.iter().for_each(|value| {
                    if let Some(value) = value {
                        self.0.insert(value.to_string());
                    }
                })
            };
            Ok(())
        })
    }

    fn evaluate(&self) -> Result<ScalarValue> {
        Ok(ScalarValue::Int64(Some(self.0.len() as i64)))
    }

    fn size(&self) -> usize {
        // Size of accumulator
        // + SSOStringHashSet size
        std::mem::size_of_val(self) + 0
    }
}

#[derive(Debug)]
struct StringDistinctCountAccumulator(SSOStringHashSet);
impl StringDistinctCountAccumulator {
    fn new() -> Self {
        Self(SSOStringHashSet::new())
    }
}

impl Accumulator for StringDistinctCountAccumulator {
    fn state(&self) -> Result<Vec<ScalarValue>> {
        let arr = StringArray::from_iter_values(self.0.iter());
        let list = Arc::new(array_into_list_array(Arc::new(arr)));
        Ok(vec![ScalarValue::List(list)])
    }

    fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
        if values.is_empty() {
            return Ok(());
        }

        let vs = get_distinct_string(100000);
        // let vs = get_vec_string(100000);
        let arr = &StringArray::from_iter_values(vs);
        // let arr = as_string_array(&values[0])?;
        arr.iter().for_each(|value| {
            if let Some(value) = value {
                self.0.insert(value);
            }
        });

        Ok(())
    }

    fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
        if states.is_empty() {
            return Ok(());
        }
        assert_eq!(
            states.len(),
            1,
            "count_distinct states must be single array"
        );

        let arr = as_list_array(&states[0])?;
        arr.iter().try_for_each(|maybe_list| {
            if let Some(list) = maybe_list {
                let list = as_string_array(&list)?;

                list.iter().for_each(|value| {
                    if let Some(value) = value {
                        self.0.insert(value);
                    }
                })
            };
            Ok(())
        })
    }

    fn evaluate(&self) -> Result<ScalarValue> {
        Ok(ScalarValue::Int64(Some(self.0.len() as i64)))
    }

    fn size(&self) -> usize {
        // Size of accumulator
        // + SSOStringHashSet size
        std::mem::size_of_val(self) + self.0.size()
    }
}

const SHORT_STRING_LEN: usize = mem::size_of::<usize>();

#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy)]
struct SSOStringHeader {
    /// hash of the string value (used when resizing table)
    hash: u64,

    len: usize,
    offset_or_inline: usize,
}

impl SSOStringHeader {
    fn evaluate(&self, buffer: &[u8]) -> String {
        if self.len <= SHORT_STRING_LEN {
            self.offset_or_inline.to_string()
        } else {
            let offset = self.offset_or_inline;
            // SAFETY: buffer is only appended to, and we correctly inserted values
            unsafe {
                std::str::from_utf8_unchecked(
                    buffer.get_unchecked(offset..offset + self.len),
                )
            }
            .to_string()
        }
    }
}

// Short String Optimizated HashSet for String
// Equivalent to HashSet<String> but with better memory usage
#[derive(Default)]
struct SSOStringHashSet {
    header_set: HashSet<SSOStringHeader>,
    long_string_map: hashbrown::raw::RawTable<SSOStringHeader>,
    map_size: usize,
    buffer: BufferBuilder<u8>,
    state: RandomState,
}

impl SSOStringHashSet {
    fn new() -> Self {
        Self::default()
    }

    fn insert(&mut self, value: &str) {
        let value_len = value.len();
        let value_bytes = value.as_bytes();

        if value_len <= SHORT_STRING_LEN {
            let inline = value_bytes
                .iter()
                .fold(0usize, |acc, &x| acc << 8 | x as usize);
            let short_string_header = SSOStringHeader {
                // no need for short string cases
                hash: 0,
                len: value_len,
                offset_or_inline: inline,
            };
            self.header_set.insert(short_string_header);
        } else {
            let hash = self.state.hash_one(value_bytes);

            let entry = self.long_string_map.get_mut(hash, |header| {
                // if hash matches, check if the bytes match
                let offset = header.offset_or_inline;
                let len = header.len;

                // SAFETY: buffer is only appended to, and we correctly inserted values
                let existing_value =
                    unsafe { self.buffer.as_slice().get_unchecked(offset..offset + len) };

                value_bytes == existing_value
            });

            if entry.is_none() {
                let offset = self.buffer.len();
                self.buffer.append_slice(value_bytes);
                let header = SSOStringHeader {
                    hash,
                    len: value_len,
                    offset_or_inline: offset,
                };
                self.long_string_map.insert_accounted(
                    header,
                    |header| header.hash,
                    &mut self.map_size,
                );
                self.header_set.insert(header);
            }
        }
    }

    fn iter(&self) -> Vec<String> {
        self.header_set
            .iter()
            .map(|header| header.evaluate(self.buffer.as_slice()))
            .collect()
    }

    fn len(&self) -> usize {
        self.header_set.len()
    }

    // NEED HELPED
    fn size(&self) -> usize {
        self.header_set.len() * mem::size_of::<SSOStringHeader>()
            + self.map_size
            + self.buffer.len()
    }
}

impl Debug for SSOStringHashSet {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("SSOStringHashSet")
            .field("header_set", &self.header_set)
            // TODO: Print long_string_map
            .field("map_size", &self.map_size)
            .field("buffer", &self.buffer)
            .field("state", &self.state)
            .finish()
    }
}
#[cfg(test)]
mod tests {
    use crate::expressions::NoOp;

    use super::*;
    use arrow::array::{
        ArrayRef, BooleanArray, Float32Array, Float64Array, Int16Array, Int32Array,
        Int64Array, Int8Array, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
    };
    use arrow::datatypes::DataType;
    use arrow::datatypes::{
        Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type,
        UInt32Type, UInt64Type, UInt8Type,
    };
    use arrow_array::Decimal256Array;
    use arrow_buffer::i256;
    use datafusion_common::cast::{as_boolean_array, as_list_array, as_primitive_array};
    use datafusion_common::internal_err;
    use datafusion_common::DataFusionError;

    macro_rules! state_to_vec_primitive {
        ($LIST:expr, $DATA_TYPE:ident) => {{
            let arr = ScalarValue::raw_data($LIST).unwrap();
            let list_arr = as_list_array(&arr).unwrap();
            let arr = list_arr.values();
            let arr = as_primitive_array::<$DATA_TYPE>(arr)?;
            arr.values().iter().cloned().collect::<Vec<_>>()
        }};
    }

    macro_rules! test_count_distinct_update_batch_numeric {
        ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
            let values: Vec<Option<$PRIM_TYPE>> = vec![
                Some(1),
                Some(1),
                None,
                Some(3),
                Some(2),
                None,
                Some(2),
                Some(3),
                Some(1),
            ];

            let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef];

            let (states, result) = run_update_batch(&arrays)?;

            let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE);
            state_vec.sort();

            assert_eq!(states.len(), 1);
            assert_eq!(state_vec, vec![1, 2, 3]);
            assert_eq!(result, ScalarValue::Int64(Some(3)));

            Ok(())
        }};
    }

    fn state_to_vec_bool(sv: &ScalarValue) -> Result<Vec<bool>> {
        let arr = ScalarValue::raw_data(sv)?;
        let list_arr = as_list_array(&arr)?;
        let arr = list_arr.values();
        let bool_arr = as_boolean_array(arr)?;
        Ok(bool_arr.iter().flatten().collect())
    }

    fn run_update_batch(arrays: &[ArrayRef]) -> Result<(Vec<ScalarValue>, ScalarValue)> {
        let agg = DistinctCount::new(
            arrays[0].data_type().clone(),
            Arc::new(NoOp::new()),
            String::from("__col_name__"),
        );

        let mut accum = agg.create_accumulator()?;
        accum.update_batch(arrays)?;

        Ok((accum.state()?, accum.evaluate()?))
    }

    fn run_update(
        data_types: &[DataType],
        rows: &[Vec<ScalarValue>],
    ) -> Result<(Vec<ScalarValue>, ScalarValue)> {
        let agg = DistinctCount::new(
            data_types[0].clone(),
            Arc::new(NoOp::new()),
            String::from("__col_name__"),
        );

        let mut accum = agg.create_accumulator()?;

        let cols = (0..rows[0].len())
            .map(|i| {
                rows.iter()
                    .map(|inner| inner[i].clone())
                    .collect::<Vec<ScalarValue>>()
            })
            .collect::<Vec<_>>();

        let arrays: Vec<ArrayRef> = cols
            .iter()
            .map(|c| ScalarValue::iter_to_array(c.clone()))
            .collect::<Result<Vec<ArrayRef>>>()?;

        accum.update_batch(&arrays)?;

        Ok((accum.state()?, accum.evaluate()?))
    }

    // Used trait to create associated constant for f32 and f64
    trait SubNormal: 'static {
        const SUBNORMAL: Self;
    }

    impl SubNormal for f64 {
        const SUBNORMAL: Self = 1.0e-308_f64;
    }

    impl SubNormal for f32 {
        const SUBNORMAL: Self = 1.0e-38_f32;
    }

    macro_rules! test_count_distinct_update_batch_floating_point {
        ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
            let values: Vec<Option<$PRIM_TYPE>> = vec![
                Some(<$PRIM_TYPE>::INFINITY),
                Some(<$PRIM_TYPE>::NAN),
                Some(1.0),
                Some(<$PRIM_TYPE as SubNormal>::SUBNORMAL),
                Some(1.0),
                Some(<$PRIM_TYPE>::INFINITY),
                None,
                Some(3.0),
                Some(-4.5),
                Some(2.0),
                None,
                Some(2.0),
                Some(3.0),
                Some(<$PRIM_TYPE>::NEG_INFINITY),
                Some(1.0),
                Some(<$PRIM_TYPE>::NAN),
                Some(<$PRIM_TYPE>::NEG_INFINITY),
            ];

            let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef];

            let (states, result) = run_update_batch(&arrays)?;

            let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE);

            dbg!(&state_vec);
            state_vec.sort_by(|a, b| match (a, b) {
                (lhs, rhs) => lhs.total_cmp(rhs),
            });

            let nan_idx = state_vec.len() - 1;
            assert_eq!(states.len(), 1);
            assert_eq!(
                &state_vec[..nan_idx],
                vec![
                    <$PRIM_TYPE>::NEG_INFINITY,
                    -4.5,
                    <$PRIM_TYPE as SubNormal>::SUBNORMAL,
                    1.0,
                    2.0,
                    3.0,
                    <$PRIM_TYPE>::INFINITY
                ]
            );
            assert!(state_vec[nan_idx].is_nan());
            assert_eq!(result, ScalarValue::Int64(Some(8)));

            Ok(())
        }};
    }

    macro_rules! test_count_distinct_update_batch_bigint {
        ($ARRAY_TYPE:ident, $DATA_TYPE:ident, $PRIM_TYPE:ty) => {{
            let values: Vec<Option<$PRIM_TYPE>> = vec![
                Some(i256::from(1)),
                Some(i256::from(1)),
                None,
                Some(i256::from(3)),
                Some(i256::from(2)),
                None,
                Some(i256::from(2)),
                Some(i256::from(3)),
                Some(i256::from(1)),
            ];

            let arrays = vec![Arc::new($ARRAY_TYPE::from(values)) as ArrayRef];

            let (states, result) = run_update_batch(&arrays)?;

            let mut state_vec = state_to_vec_primitive!(&states[0], $DATA_TYPE);
            state_vec.sort();

            assert_eq!(states.len(), 1);
            assert_eq!(state_vec, vec![i256::from(1), i256::from(2), i256::from(3)]);
            assert_eq!(result, ScalarValue::Int64(Some(3)));

            Ok(())
        }};
    }

    #[test]
    fn count_distinct_update_batch_i8() -> Result<()> {
        test_count_distinct_update_batch_numeric!(Int8Array, Int8Type, i8)
    }

    #[test]
    fn count_distinct_update_batch_i16() -> Result<()> {
        test_count_distinct_update_batch_numeric!(Int16Array, Int16Type, i16)
    }

    #[test]
    fn count_distinct_update_batch_i32() -> Result<()> {
        test_count_distinct_update_batch_numeric!(Int32Array, Int32Type, i32)
    }

    #[test]
    fn count_distinct_update_batch_i64() -> Result<()> {
        test_count_distinct_update_batch_numeric!(Int64Array, Int64Type, i64)
    }

    #[test]
    fn count_distinct_update_batch_u8() -> Result<()> {
        test_count_distinct_update_batch_numeric!(UInt8Array, UInt8Type, u8)
    }

    #[test]
    fn count_distinct_update_batch_u16() -> Result<()> {
        test_count_distinct_update_batch_numeric!(UInt16Array, UInt16Type, u16)
    }

    #[test]
    fn count_distinct_update_batch_u32() -> Result<()> {
        test_count_distinct_update_batch_numeric!(UInt32Array, UInt32Type, u32)
    }

    #[test]
    fn count_distinct_update_batch_u64() -> Result<()> {
        test_count_distinct_update_batch_numeric!(UInt64Array, UInt64Type, u64)
    }

    #[test]
    fn count_distinct_update_batch_f32() -> Result<()> {
        test_count_distinct_update_batch_floating_point!(Float32Array, Float32Type, f32)
    }

    #[test]
    fn count_distinct_update_batch_f64() -> Result<()> {
        test_count_distinct_update_batch_floating_point!(Float64Array, Float64Type, f64)
    }

    #[test]
    fn count_distinct_update_batch_i256() -> Result<()> {
        test_count_distinct_update_batch_bigint!(Decimal256Array, Decimal256Type, i256)
    }

    #[test]
    fn count_distinct_update_batch_boolean() -> Result<()> {
        let get_count = |data: BooleanArray| -> Result<(Vec<bool>, i64)> {
            let arrays = vec![Arc::new(data) as ArrayRef];
            let (states, result) = run_update_batch(&arrays)?;
            let mut state_vec = state_to_vec_bool(&states[0])?;
            state_vec.sort();

            let count = match result {
                ScalarValue::Int64(c) => c.ok_or_else(|| {
                    DataFusionError::Internal("Found None count".to_string())
                }),
                scalar => {
                    internal_err!("Found non int64 scalar value from count: {scalar}")
                }
            }?;
            Ok((state_vec, count))
        };

        let zero_count_values = BooleanArray::from(Vec::<bool>::new());

        let one_count_values = BooleanArray::from(vec![false, false]);
        let one_count_values_with_null =
            BooleanArray::from(vec![Some(true), Some(true), None, None]);

        let two_count_values = BooleanArray::from(vec![true, false, true, false, true]);
        let two_count_values_with_null = BooleanArray::from(vec![
            Some(true),
            Some(false),
            None,
            None,
            Some(true),
            Some(false),
        ]);

        assert_eq!(get_count(zero_count_values)?, (Vec::<bool>::new(), 0));
        assert_eq!(get_count(one_count_values)?, (vec![false], 1));
        assert_eq!(get_count(one_count_values_with_null)?, (vec![true], 1));
        assert_eq!(get_count(two_count_values)?, (vec![false, true], 2));
        assert_eq!(
            get_count(two_count_values_with_null)?,
            (vec![false, true], 2)
        );
        Ok(())
    }

    #[test]
    fn count_distinct_update_batch_all_nulls() -> Result<()> {
        let arrays = vec![Arc::new(Int32Array::from(
            vec![None, None, None, None] as Vec<Option<i32>>
        )) as ArrayRef];

        let (states, result) = run_update_batch(&arrays)?;
        let state_vec = state_to_vec_primitive!(&states[0], Int32Type);
        assert_eq!(states.len(), 1);
        assert!(state_vec.is_empty());
        assert_eq!(result, ScalarValue::Int64(Some(0)));

        Ok(())
    }

    #[test]
    fn count_distinct_update_batch_empty() -> Result<()> {
        let arrays = vec![Arc::new(Int32Array::from(vec![0_i32; 0])) as ArrayRef];

        let (states, result) = run_update_batch(&arrays)?;
        let state_vec = state_to_vec_primitive!(&states[0], Int32Type);
        assert_eq!(states.len(), 1);
        assert!(state_vec.is_empty());
        assert_eq!(result, ScalarValue::Int64(Some(0)));

        Ok(())
    }

    #[test]
    fn count_distinct_update() -> Result<()> {
        let (states, result) = run_update(
            &[DataType::Int32],
            &[
                vec![ScalarValue::Int32(Some(-1))],
                vec![ScalarValue::Int32(Some(5))],
                vec![ScalarValue::Int32(Some(-1))],
                vec![ScalarValue::Int32(Some(5))],
                vec![ScalarValue::Int32(Some(-1))],
                vec![ScalarValue::Int32(Some(-1))],
                vec![ScalarValue::Int32(Some(2))],
            ],
        )?;
        assert_eq!(states.len(), 1);
        assert_eq!(result, ScalarValue::Int64(Some(3)));

        let (states, result) = run_update(
            &[DataType::UInt64],
            &[
                vec![ScalarValue::UInt64(Some(1))],
                vec![ScalarValue::UInt64(Some(5))],
                vec![ScalarValue::UInt64(Some(1))],
                vec![ScalarValue::UInt64(Some(5))],
                vec![ScalarValue::UInt64(Some(1))],
                vec![ScalarValue::UInt64(Some(1))],
                vec![ScalarValue::UInt64(Some(2))],
            ],
        )?;
        assert_eq!(states.len(), 1);
        assert_eq!(result, ScalarValue::Int64(Some(3)));
        Ok(())
    }

    #[test]
    fn count_distinct_update_with_nulls() -> Result<()> {
        let (states, result) = run_update(
            &[DataType::Int32],
            &[
                // None of these updates contains a None, so these are accumulated.
                vec![ScalarValue::Int32(Some(-1))],
                vec![ScalarValue::Int32(Some(-1))],
                vec![ScalarValue::Int32(Some(-2))],
                // Each of these updates contains at least one None, so these
                // won't be accumulated.
                vec![ScalarValue::Int32(Some(-1))],
                vec![ScalarValue::Int32(None)],
                vec![ScalarValue::Int32(None)],
            ],
        )?;
        assert_eq!(states.len(), 1);
        assert_eq!(result, ScalarValue::Int64(Some(2)));

        let (states, result) = run_update(
            &[DataType::UInt64],
            &[
                // None of these updates contains a None, so these are accumulated.
                vec![ScalarValue::UInt64(Some(1))],
                vec![ScalarValue::UInt64(Some(1))],
                vec![ScalarValue::UInt64(Some(2))],
                // Each of these updates contains at least one None, so these
                // won't be accumulated.
                vec![ScalarValue::UInt64(Some(1))],
                vec![ScalarValue::UInt64(None)],
                vec![ScalarValue::UInt64(None)],
            ],
        )?;
        assert_eq!(states.len(), 1);
        assert_eq!(result, ScalarValue::Int64(Some(2)));
        Ok(())
    }
}

@alamb
Copy link
Contributor

alamb commented Jan 14, 2024

I thought the more the small string is the more performance gains, but it shows that the more the long string is the better

It seems to me this means that maybe the small string optimization is unnecessary at this time given it doesn't seem to make a significant different to performance 🤔

Maybe we could simplify the code ?

@alamb
Copy link
Contributor

alamb commented Jan 14, 2024

I think adding a like clickbench_extended would help this mission: #8860

Update: here is a PR to add this query #8861

I wonder if we can find any plausible query based on hits.parquet that shows a need for a short string optimization 🤔

@jayzhan211
Copy link
Contributor Author

jayzhan211 commented Jan 14, 2024

I thought the more the small string is the more performance gains, but it shows that the more the long string is the better

It seems to me this means that maybe the small string optimization is unnecessary at this time given it doesn't seem to make a significant different to performance 🤔

Maybe we could simplify the code ?

If the number of rows is large > 1e6, then the speed gains of short strings is larger than seconds (5s faster for n=1e6)

@jayzhan211
Copy link
Contributor Author

jayzhan211 commented Jan 15, 2024

I thought the more the small string is the more performance gains, but it shows that the more the long string is the better

It seems to me this means that maybe the small string optimization is unnecessary at this time given it doesn't seem to make a significant different to performance 🤔
Maybe we could simplify the code ?

If the number of rows is large > 1e6, then the speed gains of short strings is larger than seconds (5s faster for n=1e6)

hits.parquet data is not large enough

SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage")  FROM hits

where they are either len 1 or 2. This does not show difference.

Hashset

Query 0 iteration 0 took 1048.5 ms and returned 1 rows
Query 0 iteration 1 took 924.0 ms and returned 1 rows
Query 0 iteration 2 took 957.4 ms and returned 1 rows
Query 0 iteration 3 took 933.0 ms and returned 1 rows
Query 0 iteration 4 took 931.7 ms and returned 1 rows

SSO

Query 0 iteration 0 took 992.2 ms and returned 1 rows
Query 0 iteration 1 took 934.3 ms and returned 1 rows
Query 0 iteration 2 took 954.9 ms and returned 1 rows
Query 0 iteration 3 took 998.2 ms and returned 1 rows
Query 0 iteration 4 took 953.8 ms and returned 1 rows
SELECT COUNT(DISTINCT "HitColor"), COUNT(DISTINCT "BrowserCountry"), COUNT(DISTINCT "BrowserLanguage"), COUNT(DISTINCT "URL")  FROM hits

URL length is mostly > 8. Improve from 11s to 9s

HashSet

Query 0 iteration 0 took 11751.8 ms and returned 1 rows
Query 0 iteration 1 took 11154.3 ms and returned 1 rows
Query 0 iteration 2 took 10434.3 ms and returned 1 rows
Query 0 iteration 3 took 10988.1 ms and returned 1 rows
Query 0 iteration 4 took 12159.3 ms and returned 1 rows

SSO

Query 0 iteration 0 took 9415.5 ms and returned 1 rows
Query 0 iteration 1 took 9009.4 ms and returned 1 rows
Query 0 iteration 2 took 9832.9 ms and returned 1 rows
Query 0 iteration 3 took 10004.6 ms and returned 1 rows
Query 0 iteration 4 took 9829.1 ms and returned 1 rows

@alamb
Copy link
Contributor

alamb commented Jan 15, 2024

URL length is mostly > 8. Improve from 11s to 9s

Nice! I'll add this query to the "extended" benchmark I am working on in #8861

@alamb
Copy link
Contributor

alamb commented Jan 15, 2024

@jayzhan211 -- I plan to review this PR more carefully over the next day or two

#8861

I am hoping that we can then use the same structure for #7064

Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
Signed-off-by: jayzhan211 <[email protected]>
@alamb
Copy link
Contributor

alamb commented Jan 22, 2024

@jayzhan211 -- Thank you for 0e33b12
I did some refinement in a80b39c

@alamb
Copy link
Contributor

alamb commented Jan 22, 2024

For memory accounting test, I think we just need to run the clickbench extend queries and find out the best pre-allocated memory size for rawtable?

I have an idea for this test, let me push it up and see what you think

Update: 3e9289a

@alamb
Copy link
Contributor

alamb commented Jan 22, 2024

Ok, it think this PR is now blocked on the following two PRs:

  1. Change Accumulator::evaluate and Accumulator::state to take &mut self #8925
  2. Minor: Add new Extended ClickBench benchmark queries #8950

Once those are merged I think we can merge up from this branch, run some final benchmark numbers, and get it reviewed

Thanks again @jayzhan211

@jayzhan211
Copy link
Contributor Author

Thanks, @alamb. It seems I misunderstand both the goal of fuzz test and memory accounting test 😅

@alamb
Copy link
Contributor

alamb commented Jan 24, 2024

Thanks, @alamb. It seems I misunderstand both the goal of fuzz test and memory accounting test 😅

I think this was my bad for not explaining it well.

@alamb
Copy link
Contributor

alamb commented Jan 24, 2024

Ok, now I am just waiting on #8950 to merge and then I'll run the benchmarks and I think this PR will be ready for review

@alamb alamb changed the title Optimize COUNT( DISTINCT ...) for strings Optimize COUNT( DISTINCT ...) for strings (up to 9x faster) Jan 25, 2024
@alamb alamb marked this pull request as ready for review January 25, 2024 14:30
Copy link
Contributor

@alamb alamb left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this PR is now ready for review. The benchmarks are looking very nice

Thanks again for this great teamwork @jayzhan211

Since I wrote a bunch of this PR I think another committer should also approve prior to merging it

@@ -54,6 +54,7 @@ blake2 = { version = "^0.10.2", optional = true }
blake3 = { version = "1.0", optional = true }
chrono = { workspace = true }
datafusion-common = { workspace = true }
datafusion-execution = { workspace = true }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to use RawTableAlloc trait

@@ -152,6 +155,9 @@ impl AggregateExpr for DistinctCount {
Float32 => float_distinct_count_accumulator!(Float32Type),
Float64 => float_distinct_count_accumulator!(Float64Type),

Utf8 => Ok(Box::new(StringDistinctCountAccumulator::<i32>::new())),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key contribution in this PR is to add these specialized accumulators

/// Maximum size of a string that can be inlined in the hash table
const SHORT_STRING_LEN: usize = mem::size_of::<usize>();

/// Entry that is stored in a `SSOStringHashSet` that represents a string
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This explains the core change in this PR and how things work

Copy link
Contributor

@thinkharderdev thinkharderdev left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very nice!

return Ok(());
}

self.0.insert(values[0].clone());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: insert should be able to take a reference right?

Copy link
Contributor

@alamb alamb Jan 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right -- I did this in f5e268d

Thank you

@alamb
Copy link
Contributor

alamb commented Jan 28, 2024

Thank you for the review @thinkharderdev 🙏

@alamb alamb merged commit af0e8a9 into apache:main Jan 29, 2024
22 checks passed
@alamb
Copy link
Contributor

alamb commented Jan 29, 2024

🚀

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
core Core DataFusion crate physical-expr Physical Expressions sqllogictest SQL Logic Tests (.slt)
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants