From 3353c061830dbd1ef912736fee8426ee0405a3f9 Mon Sep 17 00:00:00 2001 From: kamille Date: Wed, 9 Oct 2024 22:07:06 +0800 Subject: [PATCH] Add Aggregation fuzzer framework (#12667) * impl primitive arrays generator. * sort out the test record batch generating codes. * draft for `DataSetsGenerator`. * tmp * improve the data generator, and start to impl the session context generator. * impl context generator. * tmp * define the `AggregationFuzzer`. * add ut for data generator. * improve comments for `SessionContextGenerator`. * define `GeneratedSessionContextBuilder` to reduce repeated codes. * extract the check equality logic for reusing. * add ut for `SessionContextGenerator`. * tmp * finish the main logic of `AggregationFuzzer`. * try to rewrite some test using the fuzzer. * fix header. * expose table name through `AggregationFuzzerBuilder`. * throw err to aggr fuzzer, and expect them then. * switch to Arc to slightly improve performance. * throw more errors to fuzzer. * print task informantion before panic. * improve comments. * support printing generated session context params in error reporting. * add todo. * add some new fuzz case based on `AggregationFuzzer`. * fix lint. * print more information in error report. * fix clippy. * improve comment of `SessionContextGenerator`. * just use fixed `data_gen_rounds` and `ctx_gen_rounds` currently, because we will hardly set them. * improve comments for rounds constants. * small improvements. * select sql from some candidates ranther than fixed one. * make `data_gen_rounds` able to set again, and add more tests. * add no group cases. * add fuzz test for basic string aggr. * make `data_gen_rounds` smaller. * add comments. * fix typo. * fix comment. --- .../core/tests/fuzz_cases/aggregate_fuzz.rs | 302 ++++++++++++ .../aggregation_fuzzer/context_generator.rs | 343 +++++++++++++ .../aggregation_fuzzer/data_generator.rs | 459 ++++++++++++++++++ .../fuzz_cases/aggregation_fuzzer/fuzzer.rs | 281 +++++++++++ .../fuzz_cases/aggregation_fuzzer/mod.rs | 69 +++ datafusion/core/tests/fuzz_cases/mod.rs | 1 + test-utils/Cargo.toml | 1 + test-utils/src/array_gen/mod.rs | 22 + test-utils/src/array_gen/primitive.rs | 80 +++ test-utils/src/array_gen/string.rs | 78 +++ test-utils/src/lib.rs | 1 + test-utils/src/string_gen.rs | 72 +-- 12 files changed, 1646 insertions(+), 63 deletions(-) create mode 100644 datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs create mode 100644 datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs create mode 100644 datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs create mode 100644 datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs create mode 100644 test-utils/src/array_gen/mod.rs create mode 100644 test-utils/src/array_gen/primitive.rs create mode 100644 test-utils/src/array_gen/string.rs diff --git a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs index 62e9be63983c..5cc5157c3af9 100644 --- a/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs +++ b/datafusion/core/tests/fuzz_cases/aggregate_fuzz.rs @@ -44,6 +44,307 @@ use rand::rngs::StdRng; use rand::{Rng, SeedableRng}; use tokio::task::JoinSet; +use crate::fuzz_cases::aggregation_fuzzer::{ + AggregationFuzzerBuilder, ColumnDescr, DatasetGeneratorConfig, +}; + +// ======================================================================== +// The new aggregation fuzz tests based on [`AggregationFuzzer`] +// ======================================================================== + +// TODO: write more test case to cover more `group by`s and `aggregation function`s +// TODO: maybe we can use macro to simply the case creating + +/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` + `no group by` +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_prim_aggr_no_group() { + let builder = AggregationFuzzerBuilder::default(); + + // Define data generator config + let columns = vec![ColumnDescr::new("a", DataType::Int32)]; + + let data_gen_config = DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set: Vec::new(), + }; + + // Build fuzzer + let fuzzer = builder + .data_gen_config(data_gen_config) + .data_gen_rounds(16) + .add_sql("SELECT sum(a) FROM fuzz_table") + .add_sql("SELECT sum(distinct a) FROM fuzz_table") + .add_sql("SELECT max(a) FROM fuzz_table") + .add_sql("SELECT min(a) FROM fuzz_table") + .add_sql("SELECT count(a) FROM fuzz_table") + .add_sql("SELECT count(distinct a) FROM fuzz_table") + .add_sql("SELECT avg(a) FROM fuzz_table") + .table_name("fuzz_table") + .build(); + + fuzzer.run().await; +} + +/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` + `group by single int64` +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_prim_aggr_group_by_single_int64() { + let builder = AggregationFuzzerBuilder::default(); + + // Define data generator config + let columns = vec![ + ColumnDescr::new("a", DataType::Int32), + ColumnDescr::new("b", DataType::Int64), + ColumnDescr::new("c", DataType::Int64), + ]; + let sort_keys_set = vec![ + vec!["b".to_string()], + vec!["c".to_string(), "b".to_string()], + ]; + let data_gen_config = DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set, + }; + + // Build fuzzer + let fuzzer = builder + .data_gen_config(data_gen_config) + .data_gen_rounds(16) + .add_sql("SELECT b, sum(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, sum(distinct a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, avg(a) FROM fuzz_table GROUP BY b") + .table_name("fuzz_table") + .build(); + + fuzzer.run().await; +} + +/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` + `group by single string` +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_prim_aggr_group_by_single_string() { + let builder = AggregationFuzzerBuilder::default(); + + // Define data generator config + let columns = vec![ + ColumnDescr::new("a", DataType::Int32), + ColumnDescr::new("b", DataType::Utf8), + ColumnDescr::new("c", DataType::Int64), + ]; + let sort_keys_set = vec![ + vec!["b".to_string()], + vec!["c".to_string(), "b".to_string()], + ]; + let data_gen_config = DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set, + }; + + // Build fuzzer + let fuzzer = builder + .data_gen_config(data_gen_config) + .data_gen_rounds(16) + .add_sql("SELECT b, sum(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, sum(distinct a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, avg(a) FROM fuzz_table GROUP BY b") + .table_name("fuzz_table") + .build(); + + fuzzer.run().await; +} + +/// Fuzz test for `basic prim aggr(sum/sum distinct/max/min/count/avg)` + `group by string + int64` +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_prim_aggr_group_by_mixed_string_int64() { + let builder = AggregationFuzzerBuilder::default(); + + // Define data generator config + let columns = vec![ + ColumnDescr::new("a", DataType::Int32), + ColumnDescr::new("b", DataType::Utf8), + ColumnDescr::new("c", DataType::Int64), + ColumnDescr::new("d", DataType::Int32), + ]; + let sort_keys_set = vec![ + vec!["b".to_string(), "c".to_string()], + vec!["d".to_string(), "b".to_string(), "c".to_string()], + ]; + let data_gen_config = DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set, + }; + + // Build fuzzer + let fuzzer = builder + .data_gen_config(data_gen_config) + .data_gen_rounds(16) + .add_sql("SELECT b, c, sum(a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, sum(distinct a) FROM fuzz_table GROUP BY b,c") + .add_sql("SELECT b, c, max(a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, min(a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, count(a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, count(distinct a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, avg(a) FROM fuzz_table GROUP BY b, c") + .table_name("fuzz_table") + .build(); + + fuzzer.run().await; +} + +/// Fuzz test for `basic string aggr(count/count distinct/min/max)` + `no group by` +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_string_aggr_no_group() { + let builder = AggregationFuzzerBuilder::default(); + + // Define data generator config + let columns = vec![ColumnDescr::new("a", DataType::Utf8)]; + + let data_gen_config = DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set: Vec::new(), + }; + + // Build fuzzer + let fuzzer = builder + .data_gen_config(data_gen_config) + .data_gen_rounds(8) + .add_sql("SELECT max(a) FROM fuzz_table") + .add_sql("SELECT min(a) FROM fuzz_table") + .add_sql("SELECT count(a) FROM fuzz_table") + .add_sql("SELECT count(distinct a) FROM fuzz_table") + .table_name("fuzz_table") + .build(); + + fuzzer.run().await; +} + +/// Fuzz test for `basic string aggr(count/count distinct/min/max)` + `group by single int64` +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_string_aggr_group_by_single_int64() { + let builder = AggregationFuzzerBuilder::default(); + + // Define data generator config + let columns = vec![ + ColumnDescr::new("a", DataType::Utf8), + ColumnDescr::new("b", DataType::Int64), + ColumnDescr::new("c", DataType::Int64), + ]; + let sort_keys_set = vec![ + vec!["b".to_string()], + vec!["c".to_string(), "b".to_string()], + ]; + let data_gen_config = DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set, + }; + + // Build fuzzer + let fuzzer = builder + .data_gen_config(data_gen_config) + .data_gen_rounds(8) + // FIXME: Encounter error in min/max + // ArrowError(InvalidArgumentError("number of columns(1) must match number of fields(2) in schema")) + // .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b") + // .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b") + .table_name("fuzz_table") + .build(); + + fuzzer.run().await; +} + +/// Fuzz test for `basic string aggr(count/count distinct/min/max)` + `group by single string` +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_string_aggr_group_by_single_string() { + let builder = AggregationFuzzerBuilder::default(); + + // Define data generator config + let columns = vec![ + ColumnDescr::new("a", DataType::Utf8), + ColumnDescr::new("b", DataType::Utf8), + ColumnDescr::new("c", DataType::Int64), + ]; + let sort_keys_set = vec![ + vec!["b".to_string()], + vec!["c".to_string(), "b".to_string()], + ]; + let data_gen_config = DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set, + }; + + // Build fuzzer + let fuzzer = builder + .data_gen_config(data_gen_config) + .data_gen_rounds(16) + // FIXME: Encounter error in min/max + // ArrowError(InvalidArgumentError("number of columns(1) must match number of fields(2) in schema")) + // .add_sql("SELECT b, max(a) FROM fuzz_table GROUP BY b") + // .add_sql("SELECT b, min(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, count(a) FROM fuzz_table GROUP BY b") + .add_sql("SELECT b, count(distinct a) FROM fuzz_table GROUP BY b") + .table_name("fuzz_table") + .build(); + + fuzzer.run().await; +} + +/// Fuzz test for `basic string aggr(count/count distinct/min/max)` + `group by string + int64` +#[tokio::test(flavor = "multi_thread")] +async fn test_basic_string_aggr_group_by_mixed_string_int64() { + let builder = AggregationFuzzerBuilder::default(); + + // Define data generator config + let columns = vec![ + ColumnDescr::new("a", DataType::Utf8), + ColumnDescr::new("b", DataType::Utf8), + ColumnDescr::new("c", DataType::Int64), + ColumnDescr::new("d", DataType::Int32), + ]; + let sort_keys_set = vec![ + vec!["b".to_string(), "c".to_string()], + vec!["d".to_string(), "b".to_string(), "c".to_string()], + ]; + let data_gen_config = DatasetGeneratorConfig { + columns, + rows_num_range: (512, 1024), + sort_keys_set, + }; + + // Build fuzzer + let fuzzer = builder + .data_gen_config(data_gen_config) + .data_gen_rounds(16) + // FIXME: Encounter error in min/max + // ArrowError(InvalidArgumentError("number of columns(1) must match number of fields(2) in schema")) + // .add_sql("SELECT b, c, max(a) FROM fuzz_table GROUP BY b, c") + // .add_sql("SELECT b, c, min(a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, count(a) FROM fuzz_table GROUP BY b, c") + .add_sql("SELECT b, c, count(distinct a) FROM fuzz_table GROUP BY b, c") + .table_name("fuzz_table") + .build(); + + fuzzer.run().await; +} + +// ======================================================================== +// The old aggregation fuzz tests +// ======================================================================== +/// Tracks if this stream is generating input or output /// Tests that streaming aggregate and batch (non streaming) aggregate produce /// same results #[tokio::test(flavor = "multi_thread")] @@ -311,6 +612,7 @@ async fn group_by_string_test( let actual = extract_result_counts(results); assert_eq!(expected, actual); } + async fn verify_ordered_aggregate(frame: &DataFrame, expected_sort: bool) { struct Visitor { expected_sort: bool, diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs new file mode 100644 index 000000000000..af454bee7ce8 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/context_generator.rs @@ -0,0 +1,343 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::{cmp, sync::Arc}; + +use datafusion::{ + datasource::MemTable, + prelude::{SessionConfig, SessionContext}, +}; +use datafusion_catalog::TableProvider; +use datafusion_common::error::Result; +use datafusion_common::ScalarValue; +use datafusion_expr::col; +use rand::{thread_rng, Rng}; + +use crate::fuzz_cases::aggregation_fuzzer::data_generator::Dataset; + +/// SessionContext generator +/// +/// During testing, `generate_baseline` will be called firstly to generate a standard [`SessionContext`], +/// and we will run `sql` on it to get the `expected result`. Then `generate` will be called some times to +/// generate some random [`SessionContext`]s, and we will run the same `sql` on them to get `actual results`. +/// Finally, we compare the `actual results` with `expected result`, the test only success while all they are +/// same with the expected. +/// +/// Following parameters of [`SessionContext`] used in query running will be generated randomly: +/// - `batch_size` +/// - `target_partitions` +/// - `skip_partial parameters` +/// - hint `sorted` or not +/// - `spilling` or not (TODO, I think a special `MemoryPool` may be needed +/// to support this) +/// +pub struct SessionContextGenerator { + /// Current testing dataset + dataset: Arc, + + /// Table name of the test table + table_name: String, + + /// Used in generate the random `batch_size` + /// + /// The generated `batch_size` is between (0, total_rows_num] + max_batch_size: usize, + + /// Candidate `SkipPartialParams` which will be picked randomly + candidate_skip_partial_params: Vec, + + /// The upper bound of the randomly generated target partitions, + /// and the lower bound will be 1 + max_target_partitions: usize, +} + +impl SessionContextGenerator { + pub fn new(dataset_ref: Arc, table_name: &str) -> Self { + let candidate_skip_partial_params = vec![ + SkipPartialParams::ensure_trigger(), + SkipPartialParams::ensure_not_trigger(), + ]; + + let max_batch_size = cmp::max(1, dataset_ref.total_rows_num); + let max_target_partitions = num_cpus::get(); + + Self { + dataset: dataset_ref, + table_name: table_name.to_string(), + max_batch_size, + candidate_skip_partial_params, + max_target_partitions, + } + } +} + +impl SessionContextGenerator { + /// Generate the `SessionContext` for the baseline run + pub fn generate_baseline(&self) -> Result { + let schema = self.dataset.batches[0].schema(); + let batches = self.dataset.batches.clone(); + let provider = MemTable::try_new(schema, vec![batches])?; + + // The baseline context should try best to disable all optimizations, + // and pursuing the rightness. + let batch_size = self.max_batch_size; + let target_partitions = 1; + let skip_partial_params = SkipPartialParams::ensure_not_trigger(); + + let builder = GeneratedSessionContextBuilder { + batch_size, + target_partitions, + skip_partial_params, + sort_hint: false, + table_name: self.table_name.clone(), + table_provider: Arc::new(provider), + }; + + builder.build() + } + + /// Randomly generate session context + pub fn generate(&self) -> Result { + let mut rng = thread_rng(); + let schema = self.dataset.batches[0].schema(); + let batches = self.dataset.batches.clone(); + let provider = MemTable::try_new(schema, vec![batches])?; + + // We will randomly generate following options: + // - `batch_size`, from range: [1, `total_rows_num`] + // - `target_partitions`, from range: [1, cpu_num] + // - `skip_partial`, trigger or not trigger currently for simplicity + // - `sorted`, if found a sorted dataset, will or will not push down this information + // - `spilling`(TODO) + let batch_size = rng.gen_range(1..=self.max_batch_size); + + let target_partitions = rng.gen_range(1..=self.max_target_partitions); + + let skip_partial_params_idx = + rng.gen_range(0..self.candidate_skip_partial_params.len()); + let skip_partial_params = + self.candidate_skip_partial_params[skip_partial_params_idx]; + + let (provider, sort_hint) = + if rng.gen_bool(0.5) && !self.dataset.sort_keys.is_empty() { + // Sort keys exist and random to push down + let sort_exprs = self + .dataset + .sort_keys + .iter() + .map(|key| col(key).sort(true, true)) + .collect::>(); + (provider.with_sort_order(vec![sort_exprs]), true) + } else { + (provider, false) + }; + + let builder = GeneratedSessionContextBuilder { + batch_size, + target_partitions, + sort_hint, + skip_partial_params, + table_name: self.table_name.clone(), + table_provider: Arc::new(provider), + }; + + builder.build() + } +} + +/// The generated [`SessionContext`] with its params +/// +/// Storing the generated `params` is necessary for +/// reporting the broken test case. +pub struct SessionContextWithParams { + pub ctx: SessionContext, + pub params: SessionContextParams, +} + +/// Collect the generated params, and build the [`SessionContext`] +struct GeneratedSessionContextBuilder { + batch_size: usize, + target_partitions: usize, + sort_hint: bool, + skip_partial_params: SkipPartialParams, + table_name: String, + table_provider: Arc, +} + +impl GeneratedSessionContextBuilder { + fn build(self) -> Result { + // Build session context + let mut session_config = SessionConfig::default(); + session_config = session_config.set( + "datafusion.execution.batch_size", + &ScalarValue::UInt64(Some(self.batch_size as u64)), + ); + session_config = session_config.set( + "datafusion.execution.target_partitions", + &ScalarValue::UInt64(Some(self.target_partitions as u64)), + ); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_rows_threshold", + &ScalarValue::UInt64(Some(self.skip_partial_params.rows_threshold as u64)), + ); + session_config = session_config.set( + "datafusion.execution.skip_partial_aggregation_probe_ratio_threshold", + &ScalarValue::Float64(Some(self.skip_partial_params.ratio_threshold)), + ); + + let ctx = SessionContext::new_with_config(session_config); + ctx.register_table(self.table_name, self.table_provider)?; + + let params = SessionContextParams { + batch_size: self.batch_size, + target_partitions: self.target_partitions, + sort_hint: self.sort_hint, + skip_partial_params: self.skip_partial_params, + }; + + Ok(SessionContextWithParams { ctx, params }) + } +} + +/// The generated params for [`SessionContext`] +#[derive(Debug)] +#[allow(dead_code)] +pub struct SessionContextParams { + batch_size: usize, + target_partitions: usize, + sort_hint: bool, + skip_partial_params: SkipPartialParams, +} + +/// Partial skipping parameters +#[derive(Debug, Clone, Copy)] +pub struct SkipPartialParams { + /// Related to `skip_partial_aggregation_probe_ratio_threshold` in `ExecutionOptions` + pub ratio_threshold: f64, + + /// Related to `skip_partial_aggregation_probe_rows_threshold` in `ExecutionOptions` + pub rows_threshold: usize, +} + +impl SkipPartialParams { + /// Generate `SkipPartialParams` ensuring to trigger partial skipping + pub fn ensure_trigger() -> Self { + Self { + ratio_threshold: 0.0, + rows_threshold: 0, + } + } + + /// Generate `SkipPartialParams` ensuring not to trigger partial skipping + pub fn ensure_not_trigger() -> Self { + Self { + ratio_threshold: 1.0, + rows_threshold: usize::MAX, + } + } +} + +#[cfg(test)] +mod test { + use arrow_array::{RecordBatch, StringArray, UInt32Array}; + use arrow_schema::{DataType, Field, Schema}; + + use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches; + + use super::*; + + #[tokio::test] + async fn test_generated_context() { + // 1. Define a test dataset firstly + let a_col: StringArray = [ + Some("rust"), + Some("java"), + Some("cpp"), + Some("go"), + Some("go1"), + Some("python"), + Some("python1"), + Some("python2"), + ] + .into_iter() + .collect(); + // Sort by "b" + let b_col: UInt32Array = [ + Some(1), + Some(2), + Some(4), + Some(8), + Some(8), + Some(16), + Some(16), + Some(16), + ] + .into_iter() + .collect(); + let schema = Schema::new(vec![ + Field::new("a", DataType::Utf8, true), + Field::new("b", DataType::UInt32, true), + ]); + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![Arc::new(a_col), Arc::new(b_col)], + ) + .unwrap(); + + // One row a group to create batches + let mut batches = Vec::with_capacity(batch.num_rows()); + for start in 0..batch.num_rows() { + let sub_batch = batch.slice(start, 1); + batches.push(sub_batch); + } + + let dataset = Dataset::new(batches, vec!["b".to_string()]); + + // 2. Generate baseline context, and some randomly session contexts. + // Run the same query on them, and all randoms' results should equal to baseline's + let ctx_generator = SessionContextGenerator::new(Arc::new(dataset), "fuzz_table"); + + let query = "select b, count(a) from fuzz_table group by b"; + let baseline_wrapped_ctx = ctx_generator.generate_baseline().unwrap(); + let mut random_wrapped_ctxs = Vec::with_capacity(8); + for _ in 0..8 { + let ctx = ctx_generator.generate().unwrap(); + random_wrapped_ctxs.push(ctx); + } + + let base_result = baseline_wrapped_ctx + .ctx + .sql(query) + .await + .unwrap() + .collect() + .await + .unwrap(); + + for wrapped_ctx in random_wrapped_ctxs { + let random_result = wrapped_ctx + .ctx + .sql(query) + .await + .unwrap() + .collect() + .await + .unwrap(); + check_equality_of_batches(&base_result, &random_result).unwrap(); + } + } +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs new file mode 100644 index 000000000000..9d45779295e7 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -0,0 +1,459 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow_array::{ArrayRef, RecordBatch}; +use arrow_schema::{DataType, Field, Schema}; +use datafusion_common::{arrow_datafusion_err, DataFusionError, Result}; +use datafusion_physical_expr::{expressions::col, PhysicalSortExpr}; +use datafusion_physical_plan::sorts::sort::sort_batch; +use rand::{ + rngs::{StdRng, ThreadRng}, + thread_rng, Rng, SeedableRng, +}; +use test_utils::{ + array_gen::{PrimitiveArrayGenerator, StringArrayGenerator}, + stagger_batch, +}; + +/// Config for Data sets generator +/// +/// # Parameters +/// - `columns`, you just need to define `column name`s and `column data type`s +/// fot the test datasets, and then they will be randomly generated from generator +/// when you can `generate` function +/// +/// - `rows_num_range`, the rows num of the datasets will be randomly generated +/// among this range +/// +/// - `sort_keys`, if `sort_keys` are defined, when you can `generate`, the generator +/// will generate one `base dataset` firstly. Then the `base dataset` will be sorted +/// based on each `sort_key` respectively. And finally `len(sort_keys) + 1` datasets +/// will be returned +/// +#[derive(Debug, Clone)] +pub struct DatasetGeneratorConfig { + // Descriptions of columns in datasets, it's `required` + pub columns: Vec, + + // Rows num range of the generated datasets, it's `required` + pub rows_num_range: (usize, usize), + + // Sort keys used to generate the sorted data set, it's optional + pub sort_keys_set: Vec>, +} + +/// Dataset generator +/// +/// It will generate one random [`Dataset`]s when `generate` function is called. +/// +/// The generation logic in `generate`: +/// +/// - Randomly generate a base record from `batch_generator` firstly. +/// And `columns`, `rows_num_range` in `config`(detail can see `DataSetsGeneratorConfig`), +/// will be used in generation. +/// +/// - Sort the batch according to `sort_keys` in `config` to generator another +/// `len(sort_keys)` sorted batches. +/// +/// - Split each batch to multiple batches which each sub-batch in has the randomly `rows num`, +/// and this multiple batches will be used to create the `Dataset`. +/// +pub struct DatasetGenerator { + batch_generator: RecordBatchGenerator, + sort_keys_set: Vec>, +} + +impl DatasetGenerator { + pub fn new(config: DatasetGeneratorConfig) -> Self { + let batch_generator = RecordBatchGenerator::new( + config.rows_num_range.0, + config.rows_num_range.1, + config.columns, + ); + + Self { + batch_generator, + sort_keys_set: config.sort_keys_set, + } + } + + pub fn generate(&self) -> Result> { + let mut datasets = Vec::with_capacity(self.sort_keys_set.len() + 1); + + // Generate the base batch + let base_batch = self.batch_generator.generate()?; + let batches = stagger_batch(base_batch.clone()); + let dataset = Dataset::new(batches, Vec::new()); + datasets.push(dataset); + + // Generate the related sorted batches + let schema = base_batch.schema_ref(); + for sort_keys in self.sort_keys_set.clone() { + let sort_exprs = sort_keys + .iter() + .map(|key| { + let col_expr = col(key, schema)?; + Ok(PhysicalSortExpr::new_default(col_expr)) + }) + .collect::>>()?; + let sorted_batch = sort_batch(&base_batch, &sort_exprs, None)?; + + let batches = stagger_batch(sorted_batch); + let dataset = Dataset::new(batches, sort_keys); + datasets.push(dataset); + } + + Ok(datasets) + } +} + +/// Single test data set +#[derive(Debug)] +pub struct Dataset { + pub batches: Vec, + pub total_rows_num: usize, + pub sort_keys: Vec, +} + +impl Dataset { + pub fn new(batches: Vec, sort_keys: Vec) -> Self { + let total_rows_num = batches.iter().map(|batch| batch.num_rows()).sum::(); + + Self { + batches, + total_rows_num, + sort_keys, + } + } +} + +#[derive(Debug, Clone)] +pub struct ColumnDescr { + // Column name + name: String, + + // Data type of this column + column_type: DataType, +} + +impl ColumnDescr { + #[inline] + pub fn new(name: &str, column_type: DataType) -> Self { + Self { + name: name.to_string(), + column_type, + } + } +} + +/// Record batch generator +struct RecordBatchGenerator { + min_rows_nun: usize, + + max_rows_num: usize, + + columns: Vec, + + candidate_null_pcts: Vec, +} + +macro_rules! generate_string_array { + ($SELF:ident, $NUM_ROWS:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $OFFSET_TYPE:ty) => {{ + let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); + let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; + let max_len = $BATCH_GEN_RNG.gen_range(1..50); + let num_distinct_strings = if $NUM_ROWS > 1 { + $BATCH_GEN_RNG.gen_range(1..$NUM_ROWS) + } else { + $NUM_ROWS + }; + + let mut generator = StringArrayGenerator { + max_len, + num_strings: $NUM_ROWS, + num_distinct_strings, + null_pct, + rng: $ARRAY_GEN_RNG, + }; + + generator.gen_data::<$OFFSET_TYPE>() + }}; +} + +macro_rules! generate_primitive_array { + ($SELF:ident, $NUM_ROWS:ident, $BATCH_GEN_RNG:ident, $ARRAY_GEN_RNG:ident, $DATA_TYPE:ident) => { + paste::paste! {{ + let null_pct_idx = $BATCH_GEN_RNG.gen_range(0..$SELF.candidate_null_pcts.len()); + let null_pct = $SELF.candidate_null_pcts[null_pct_idx]; + let num_distinct_primitives = if $NUM_ROWS > 1 { + $BATCH_GEN_RNG.gen_range(1..$NUM_ROWS) + } else { + $NUM_ROWS + }; + + let mut generator = PrimitiveArrayGenerator { + num_primitives: $NUM_ROWS, + num_distinct_primitives, + null_pct, + rng: $ARRAY_GEN_RNG, + }; + + generator.[< gen_data_ $DATA_TYPE >]() + }}} +} + +impl RecordBatchGenerator { + fn new(min_rows_nun: usize, max_rows_num: usize, columns: Vec) -> Self { + let candidate_null_pcts = vec![0.0, 0.01, 0.1, 0.5]; + + Self { + min_rows_nun, + max_rows_num, + columns, + candidate_null_pcts, + } + } + + fn generate(&self) -> Result { + let mut rng = thread_rng(); + let num_rows = rng.gen_range(self.min_rows_nun..=self.max_rows_num); + let array_gen_rng = StdRng::from_seed(rng.gen()); + + // Build arrays + let mut arrays = Vec::with_capacity(self.columns.len()); + for col in self.columns.iter() { + let array = self.generate_array_of_type( + col.column_type.clone(), + num_rows, + &mut rng, + array_gen_rng.clone(), + ); + arrays.push(array); + } + + // Build schema + let fields = self + .columns + .iter() + .map(|col| Field::new(col.name.clone(), col.column_type.clone(), true)) + .collect::>(); + let schema = Arc::new(Schema::new(fields)); + + RecordBatch::try_new(schema, arrays).map_err(|e| arrow_datafusion_err!(e)) + } + + fn generate_array_of_type( + &self, + data_type: DataType, + num_rows: usize, + batch_gen_rng: &mut ThreadRng, + array_gen_rng: StdRng, + ) -> ArrayRef { + match data_type { + DataType::Int8 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + i8 + ) + } + DataType::Int16 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + i16 + ) + } + DataType::Int32 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + i32 + ) + } + DataType::Int64 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + i64 + ) + } + DataType::UInt8 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + u8 + ) + } + DataType::UInt16 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + u16 + ) + } + DataType::UInt32 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + u32 + ) + } + DataType::UInt64 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + u64 + ) + } + DataType::Float32 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + f32 + ) + } + DataType::Float64 => { + generate_primitive_array!( + self, + num_rows, + batch_gen_rng, + array_gen_rng, + f64 + ) + } + DataType::Utf8 => { + generate_string_array!(self, num_rows, batch_gen_rng, array_gen_rng, i32) + } + DataType::LargeUtf8 => { + generate_string_array!(self, num_rows, batch_gen_rng, array_gen_rng, i64) + } + _ => unreachable!(), + } + } +} + +#[cfg(test)] +mod test { + use arrow_array::UInt32Array; + + use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches; + + use super::*; + + #[test] + fn test_generated_datasets() { + // The test datasets generation config + // We expect that after calling `generate` + // - Generate 2 datasets + // - They have 2 column "a" and "b", + // "a"'s type is `Utf8`, and "b"'s type is `UInt32` + // - One of them is unsorted, another is sorted by column "b" + // - Their rows num should be same and between [16, 32] + let config = DatasetGeneratorConfig { + columns: vec![ + ColumnDescr { + name: "a".to_string(), + column_type: DataType::Utf8, + }, + ColumnDescr { + name: "b".to_string(), + column_type: DataType::UInt32, + }, + ], + rows_num_range: (16, 32), + sort_keys_set: vec![vec!["b".to_string()]], + }; + + let gen = DatasetGenerator::new(config); + let datasets = gen.generate().unwrap(); + + // Should Generate 2 datasets + assert_eq!(datasets.len(), 2); + + // Should have 2 column "a" and "b", + // "a"'s type is `Utf8`, and "b"'s type is `UInt32` + let check_fields = |batch: &RecordBatch| { + assert_eq!(batch.num_columns(), 2); + let fields = batch.schema().fields().clone(); + assert_eq!(fields[0].name(), "a"); + assert_eq!(*fields[0].data_type(), DataType::Utf8); + assert_eq!(fields[1].name(), "b"); + assert_eq!(*fields[1].data_type(), DataType::UInt32); + }; + + let batch = &datasets[0].batches[0]; + check_fields(batch); + let batch = &datasets[1].batches[0]; + check_fields(batch); + + // One batches should be sort by "b" + let sorted_batches = &datasets[1].batches; + let b_vals = sorted_batches.iter().flat_map(|batch| { + let uint_array = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + uint_array.iter() + }); + let mut prev_b_val = u32::MIN; + for b_val in b_vals { + let b_val = b_val.unwrap_or(u32::MIN); + assert!(b_val >= prev_b_val); + prev_b_val = b_val; + } + + // Two batches should be same after sorting + check_equality_of_batches(&datasets[0].batches, &datasets[1].batches).unwrap(); + + // Rows num should between [16, 32] + let rows_num0 = datasets[0] + .batches + .iter() + .map(|batch| batch.num_rows()) + .sum::(); + let rows_num1 = datasets[1] + .batches + .iter() + .map(|batch| batch.num_rows()) + .sum::(); + assert_eq!(rows_num0, rows_num1); + assert!(rows_num0 >= 16); + assert!(rows_num0 <= 32); + } +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs new file mode 100644 index 000000000000..abb34048284d --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/fuzzer.rs @@ -0,0 +1,281 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::sync::Arc; + +use arrow::util::pretty::pretty_format_batches; +use arrow_array::RecordBatch; +use rand::{thread_rng, Rng}; +use tokio::task::JoinSet; + +use crate::fuzz_cases::aggregation_fuzzer::{ + check_equality_of_batches, + context_generator::{SessionContextGenerator, SessionContextWithParams}, + data_generator::{Dataset, DatasetGenerator, DatasetGeneratorConfig}, + run_sql, +}; + +/// Rounds to call `generate` of [`SessionContextGenerator`] +/// in [`AggregationFuzzer`], `ctx_gen_rounds` random [`SessionContext`] +/// will generated for each dataset for testing. +const CTX_GEN_ROUNDS: usize = 16; + +/// Aggregation fuzzer's builder +pub struct AggregationFuzzerBuilder { + /// See `candidate_sqls` in [`AggregationFuzzer`], no default, and required to set + candidate_sqls: Vec>, + + /// See `table_name` in [`AggregationFuzzer`], no default, and required to set + table_name: Option>, + + /// Used to generate `dataset_generator` in [`AggregationFuzzer`], + /// no default, and required to set + data_gen_config: Option, + + /// See `data_gen_rounds` in [`AggregationFuzzer`], default 16 + data_gen_rounds: usize, +} + +impl AggregationFuzzerBuilder { + fn new() -> Self { + Self { + candidate_sqls: Vec::new(), + table_name: None, + data_gen_config: None, + data_gen_rounds: 16, + } + } + + pub fn add_sql(mut self, sql: &str) -> Self { + self.candidate_sqls.push(Arc::from(sql)); + self + } + + pub fn table_name(mut self, table_name: &str) -> Self { + self.table_name = Some(Arc::from(table_name)); + self + } + + pub fn data_gen_config(mut self, data_gen_config: DatasetGeneratorConfig) -> Self { + self.data_gen_config = Some(data_gen_config); + self + } + + pub fn data_gen_rounds(mut self, data_gen_rounds: usize) -> Self { + self.data_gen_rounds = data_gen_rounds; + self + } + + pub fn build(self) -> AggregationFuzzer { + assert!(!self.candidate_sqls.is_empty()); + let candidate_sqls = self.candidate_sqls; + let table_name = self.table_name.expect("table_name is required"); + let data_gen_config = self.data_gen_config.expect("data_gen_config is required"); + let data_gen_rounds = self.data_gen_rounds; + + let dataset_generator = DatasetGenerator::new(data_gen_config); + + AggregationFuzzer { + candidate_sqls, + table_name, + dataset_generator, + data_gen_rounds, + } + } +} + +impl Default for AggregationFuzzerBuilder { + fn default() -> Self { + Self::new() + } +} + +/// AggregationFuzzer randomly generating multiple [`AggregationFuzzTestTask`], +/// and running them to check the correctness of the optimizations +/// (e.g. sorted, partial skipping, spilling...) +pub struct AggregationFuzzer { + /// Candidate test queries represented by sqls + candidate_sqls: Vec>, + + /// The queried table name + table_name: Arc, + + /// Dataset generator used to randomly generate datasets + dataset_generator: DatasetGenerator, + + /// Rounds to call `generate` of [`DatasetGenerator`], + /// len(sort_keys_set) + 1` datasets will be generated for testing. + /// + /// It is suggested to set value 2x or more bigger than num of + /// `candidate_sqls` for better test coverage. + data_gen_rounds: usize, +} + +/// Query group including the tested dataset and its sql query +struct QueryGroup { + dataset: Dataset, + sql: Arc, +} + +impl AggregationFuzzer { + pub async fn run(&self) { + let mut join_set = JoinSet::new(); + let mut rng = thread_rng(); + + // Loop to generate datasets and its query + for _ in 0..self.data_gen_rounds { + // Generate datasets first + let datasets = self + .dataset_generator + .generate() + .expect("should success to generate dataset"); + + // Then for each of them, we random select a test sql for it + let query_groups = datasets + .into_iter() + .map(|dataset| { + let sql_idx = rng.gen_range(0..self.candidate_sqls.len()); + let sql = self.candidate_sqls[sql_idx].clone(); + + QueryGroup { dataset, sql } + }) + .collect::>(); + + let tasks = self.generate_fuzz_tasks(query_groups).await; + for task in tasks { + join_set.spawn(async move { + task.run().await; + }); + } + } + + while let Some(join_handle) = join_set.join_next().await { + // propagate errors + join_handle.unwrap(); + } + } + + async fn generate_fuzz_tasks( + &self, + query_groups: Vec, + ) -> Vec { + let mut tasks = Vec::with_capacity(query_groups.len() * CTX_GEN_ROUNDS); + for QueryGroup { dataset, sql } in query_groups { + let dataset_ref = Arc::new(dataset); + let ctx_generator = + SessionContextGenerator::new(dataset_ref.clone(), &self.table_name); + + // Generate the baseline context, and get the baseline result firstly + let baseline_ctx_with_params = ctx_generator + .generate_baseline() + .expect("should success to generate baseline session context"); + let baseline_result = run_sql(&sql, &baseline_ctx_with_params.ctx) + .await + .expect("should success to run baseline sql"); + let baseline_result = Arc::new(baseline_result); + // Generate test tasks + for _ in 0..CTX_GEN_ROUNDS { + let ctx_with_params = ctx_generator + .generate() + .expect("should success to generate session context"); + let task = AggregationFuzzTestTask { + dataset_ref: dataset_ref.clone(), + expected_result: baseline_result.clone(), + sql: sql.clone(), + ctx_with_params, + }; + + tasks.push(task); + } + } + tasks + } +} + +/// One test task generated by [`AggregationFuzzer`] +/// +/// It includes: +/// - `expected_result`, the expected result generated by baseline [`SessionContext`] +/// (disable all possible optimizations for ensuring correctness). +/// +/// - `ctx`, a randomly generated [`SessionContext`], `sql` will be run +/// on it after, and check if the result is equal to expected. +/// +/// - `sql`, the selected test sql +/// +/// - `dataset_ref`, the input dataset, store it for error reported when found +/// the inconsistency between the one for `ctx` and `expected results`. +/// +struct AggregationFuzzTestTask { + /// Generated session context in current test case + ctx_with_params: SessionContextWithParams, + + /// Expected result in current test case + /// It is generate from `query` + `baseline session context` + expected_result: Arc>, + + /// The test query + /// Use sql to represent it currently. + sql: Arc, + + /// The test dataset for error reporting + dataset_ref: Arc, +} + +impl AggregationFuzzTestTask { + async fn run(&self) { + let task_result = run_sql(&self.sql, &self.ctx_with_params.ctx) + .await + .expect("should success to run sql"); + self.check_result(&task_result, &self.expected_result); + } + + // TODO: maybe we should persist the `expected_result` and `task_result`, + // because the readability is not so good if we just print it. + fn check_result(&self, task_result: &[RecordBatch], expected_result: &[RecordBatch]) { + let result = check_equality_of_batches(task_result, expected_result); + if let Err(e) = result { + // If we found inconsistent result, we print the test details for reproducing at first + println!( + "##### AggregationFuzzer error report ##### + ### Sql:\n{}\n\ + ### Schema:\n{}\n\ + ### Session context params:\n{:?}\n\ + ### Inconsistent row:\n\ + - row_idx:{}\n\ + - task_row:{}\n\ + - expected_row:{}\n\ + ### Task total result:\n{}\n\ + ### Expected total result:\n{}\n\ + ### Input:\n{}\n\ + ", + self.sql, + self.dataset_ref.batches[0].schema_ref(), + self.ctx_with_params.params, + e.row_idx, + e.lhs_row, + e.rhs_row, + pretty_format_batches(task_result).unwrap(), + pretty_format_batches(expected_result).unwrap(), + pretty_format_batches(&self.dataset_ref.batches).unwrap(), + ); + + // Then we just panic + panic!(); + } + } +} diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs new file mode 100644 index 000000000000..d93a5b7b9360 --- /dev/null +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/mod.rs @@ -0,0 +1,69 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::util::pretty::pretty_format_batches; +use arrow_array::RecordBatch; +use datafusion::prelude::SessionContext; +use datafusion_common::error::Result; + +mod context_generator; +mod data_generator; +mod fuzzer; + +pub use data_generator::{ColumnDescr, DatasetGeneratorConfig}; +pub use fuzzer::*; + +#[derive(Debug)] +pub(crate) struct InconsistentResult { + pub row_idx: usize, + pub lhs_row: String, + pub rhs_row: String, +} + +pub(crate) fn check_equality_of_batches( + lhs: &[RecordBatch], + rhs: &[RecordBatch], +) -> std::result::Result<(), InconsistentResult> { + let lhs_formatted_batches = pretty_format_batches(lhs).unwrap().to_string(); + let mut lhs_formatted_batches_sorted: Vec<&str> = + lhs_formatted_batches.trim().lines().collect(); + lhs_formatted_batches_sorted.sort_unstable(); + let rhs_formatted_batches = pretty_format_batches(rhs).unwrap().to_string(); + let mut rhs_formatted_batches_sorted: Vec<&str> = + rhs_formatted_batches.trim().lines().collect(); + rhs_formatted_batches_sorted.sort_unstable(); + + for (row_idx, (lhs_row, rhs_row)) in lhs_formatted_batches_sorted + .iter() + .zip(&rhs_formatted_batches_sorted) + .enumerate() + { + if lhs_row != rhs_row { + return Err(InconsistentResult { + row_idx, + lhs_row: lhs_row.to_string(), + rhs_row: rhs_row.to_string(), + }); + } + } + + Ok(()) +} + +pub(crate) async fn run_sql(sql: &str, ctx: &SessionContext) -> Result> { + ctx.sql(sql).await?.collect().await +} diff --git a/datafusion/core/tests/fuzz_cases/mod.rs b/datafusion/core/tests/fuzz_cases/mod.rs index 69241571b4af..5bc36b963c44 100644 --- a/datafusion/core/tests/fuzz_cases/mod.rs +++ b/datafusion/core/tests/fuzz_cases/mod.rs @@ -21,6 +21,7 @@ mod join_fuzz; mod merge_fuzz; mod sort_fuzz; +mod aggregation_fuzzer; mod limit_fuzz; mod sort_preserving_repartition_fuzz; mod window_fuzz; diff --git a/test-utils/Cargo.toml b/test-utils/Cargo.toml index 325a2cc2fcc4..414fa5569cfe 100644 --- a/test-utils/Cargo.toml +++ b/test-utils/Cargo.toml @@ -29,4 +29,5 @@ workspace = true arrow = { workspace = true } datafusion-common = { workspace = true, default-features = true } env_logger = { workspace = true } +paste = "1.0.15" rand = { workspace = true } diff --git a/test-utils/src/array_gen/mod.rs b/test-utils/src/array_gen/mod.rs new file mode 100644 index 000000000000..4a799ae737d7 --- /dev/null +++ b/test-utils/src/array_gen/mod.rs @@ -0,0 +1,22 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod primitive; +mod string; + +pub use primitive::PrimitiveArrayGenerator; +pub use string::StringArrayGenerator; diff --git a/test-utils/src/array_gen/primitive.rs b/test-utils/src/array_gen/primitive.rs new file mode 100644 index 000000000000..f70ebf6686d0 --- /dev/null +++ b/test-utils/src/array_gen/primitive.rs @@ -0,0 +1,80 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, PrimitiveArray, UInt32Array}; +use arrow::datatypes::{ + Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, + UInt32Type, UInt64Type, UInt8Type, +}; +use rand::rngs::StdRng; +use rand::Rng; + +/// Randomly generate primitive array +pub struct PrimitiveArrayGenerator { + /// the total number of strings in the output + pub num_primitives: usize, + /// The number of distinct strings in the columns + pub num_distinct_primitives: usize, + /// The percentage of nulls in the columns + pub null_pct: f64, + /// Random number generator + pub rng: StdRng, +} + +macro_rules! impl_gen_data { + ($NATIVE_TYPE:ty, $ARROW_TYPE:ident) => { + paste::paste! { + pub fn [< gen_data_ $NATIVE_TYPE >](&mut self) -> ArrayRef { + // table of strings from which to draw + let distinct_primitives: PrimitiveArray<$ARROW_TYPE> = (0..self.num_distinct_primitives) + .map(|_| Some(self.rng.gen::<$NATIVE_TYPE>())) + .collect(); + + // pick num_strings randomly from the distinct string table + let indicies: UInt32Array = (0..self.num_primitives) + .map(|_| { + if self.rng.gen::() < self.null_pct { + None + } else if self.num_distinct_primitives > 1 { + let range = 1..(self.num_distinct_primitives as u32); + Some(self.rng.gen_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let options = None; + arrow::compute::take(&distinct_primitives, &indicies, options).unwrap() + } + } + }; +} + +// TODO: support generating more primitive arrays +impl PrimitiveArrayGenerator { + impl_gen_data!(i8, Int8Type); + impl_gen_data!(i16, Int16Type); + impl_gen_data!(i32, Int32Type); + impl_gen_data!(i64, Int64Type); + impl_gen_data!(u8, UInt8Type); + impl_gen_data!(u16, UInt16Type); + impl_gen_data!(u32, UInt32Type); + impl_gen_data!(u64, UInt64Type); + impl_gen_data!(f32, Float32Type); + impl_gen_data!(f64, Float64Type); +} diff --git a/test-utils/src/array_gen/string.rs b/test-utils/src/array_gen/string.rs new file mode 100644 index 000000000000..fbfa2bb941e0 --- /dev/null +++ b/test-utils/src/array_gen/string.rs @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, UInt32Array}; +use rand::rngs::StdRng; +use rand::Rng; + +/// Randomly generate string arrays +pub struct StringArrayGenerator { + //// The maximum length of the strings + pub max_len: usize, + /// the total number of strings in the output + pub num_strings: usize, + /// The number of distinct strings in the columns + pub num_distinct_strings: usize, + /// The percentage of nulls in the columns + pub null_pct: f64, + /// Random number generator + pub rng: StdRng, +} + +impl StringArrayGenerator { + /// Creates a StringArray or LargeStringArray with random strings according + /// to the parameters of the BatchGenerator + pub fn gen_data(&mut self) -> ArrayRef { + // table of strings from which to draw + let distinct_strings: GenericStringArray = (0..self.num_distinct_strings) + .map(|_| Some(random_string(&mut self.rng, self.max_len))) + .collect(); + + // pick num_strings randomly from the distinct string table + let indicies: UInt32Array = (0..self.num_strings) + .map(|_| { + if self.rng.gen::() < self.null_pct { + None + } else if self.num_distinct_strings > 1 { + let range = 1..(self.num_distinct_strings as u32); + Some(self.rng.gen_range(range)) + } else { + Some(0) + } + }) + .collect(); + + let options = None; + arrow::compute::take(&distinct_strings, &indicies, options).unwrap() + } +} + +/// Return a string of random characters of length 1..=max_len +fn random_string(rng: &mut StdRng, max_len: usize) -> String { + // pick characters at random (not just ascii) + match max_len { + 0 => "".to_string(), + 1 => String::from(rng.gen::()), + _ => { + let len = rng.gen_range(1..=max_len); + rng.sample_iter::(rand::distributions::Standard) + .take(len) + .map(char::from) + .collect::() + } + } +} diff --git a/test-utils/src/lib.rs b/test-utils/src/lib.rs index 3ddba2fec800..9db8920833ae 100644 --- a/test-utils/src/lib.rs +++ b/test-utils/src/lib.rs @@ -22,6 +22,7 @@ use datafusion_common::cast::as_int32_array; use rand::prelude::StdRng; use rand::{Rng, SeedableRng}; +pub mod array_gen; mod data_gen; mod string_gen; pub mod tpcds; diff --git a/test-utils/src/string_gen.rs b/test-utils/src/string_gen.rs index 530fc1535387..725eb22b85af 100644 --- a/test-utils/src/string_gen.rs +++ b/test-utils/src/string_gen.rs @@ -1,3 +1,4 @@ +use crate::array_gen::StringArrayGenerator; // Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information @@ -14,27 +15,14 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. -// -// use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, RecordBatch, UInt32Array}; + use crate::stagger_batch; -use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, UInt32Array}; use arrow::record_batch::RecordBatch; use rand::rngs::StdRng; use rand::{thread_rng, Rng, SeedableRng}; /// Randomly generate strings -pub struct StringBatchGenerator { - //// The maximum length of the strings - pub max_len: usize, - /// the total number of strings in the output - pub num_strings: usize, - /// The number of distinct strings in the columns - pub num_distinct_strings: usize, - /// The percentage of nulls in the columns - pub null_pct: f64, - /// Random number generator - pub rng: StdRng, -} +pub struct StringBatchGenerator(StringArrayGenerator); impl StringBatchGenerator { /// Make batches of random strings with a random length columns "a" and "b". @@ -44,8 +32,8 @@ impl StringBatchGenerator { pub fn make_input_batches(&mut self) -> Vec { // use a random number generator to pick a random sized output let batch = RecordBatch::try_from_iter(vec![ - ("a", self.gen_data::()), - ("b", self.gen_data::()), + ("a", self.0.gen_data::()), + ("b", self.0.gen_data::()), ]) .unwrap(); stagger_batch(batch) @@ -57,9 +45,9 @@ impl StringBatchGenerator { /// if large is true, the array is a LargeStringArray pub fn make_sorted_input_batches(&mut self, large: bool) -> Vec { let array = if large { - self.gen_data::() + self.0.gen_data::() } else { - self.gen_data::() + self.0.gen_data::() }; let array = arrow::compute::sort(&array, None).unwrap(); @@ -68,32 +56,6 @@ impl StringBatchGenerator { stagger_batch(batch) } - /// Creates a StringArray or LargeStringArray with random strings according - /// to the parameters of the BatchGenerator - fn gen_data(&mut self) -> ArrayRef { - // table of strings from which to draw - let distinct_strings: GenericStringArray = (0..self.num_distinct_strings) - .map(|_| Some(random_string(&mut self.rng, self.max_len))) - .collect(); - - // pick num_strings randomly from the distinct string table - let indicies: UInt32Array = (0..self.num_strings) - .map(|_| { - if self.rng.gen::() < self.null_pct { - None - } else if self.num_distinct_strings > 1 { - let range = 1..(self.num_distinct_strings as u32); - Some(self.rng.gen_range(range)) - } else { - Some(0) - } - }) - .collect(); - - let options = None; - arrow::compute::take(&distinct_strings, &indicies, options).unwrap() - } - /// Return an set of `BatchGenerator`s that cover a range of interesting /// cases pub fn interesting_cases() -> Vec { @@ -109,31 +71,15 @@ impl StringBatchGenerator { } else { num_strings }; - cases.push(StringBatchGenerator { + cases.push(StringBatchGenerator(StringArrayGenerator { max_len, num_strings, num_distinct_strings, null_pct, rng: StdRng::from_seed(rng.gen()), - }) + })) } } cases } } - -/// Return a string of random characters of length 1..=max_len -fn random_string(rng: &mut StdRng, max_len: usize) -> String { - // pick characters at random (not just ascii) - match max_len { - 0 => "".to_string(), - 1 => String::from(rng.gen::()), - _ => { - let len = rng.gen_range(1..=max_len); - rng.sample_iter::(rand::distributions::Standard) - .take(len) - .map(char::from) - .collect::() - } - } -}