diff --git a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs index 42160b7593a4e..dbbd869c7b38e 100644 --- a/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs +++ b/datafusion/core/tests/fuzz_cases/aggregation_fuzzer/data_generator.rs @@ -99,11 +99,7 @@ impl DatasetGenerator { let base_batch = self.batch_generator.generate(); let total_rows_num = base_batch.num_rows(); let batches = stagger_batch(base_batch.clone()); - let dataset = Dataset { - batches, - total_rows_num, - sort_keys: Vec::new(), - }; + let dataset = Dataset::new(batches, Vec::new()); datasets.push(dataset); // Generate the related sorted batches @@ -123,11 +119,7 @@ impl DatasetGenerator { .expect("sort batch should not fail"); let batches = stagger_batch(sorted_batch); - let dataset = Dataset { - batches, - total_rows_num, - sort_keys: Vec::new(), - }; + let dataset = Dataset::new(batches, sort_keys); datasets.push(dataset); } @@ -143,6 +135,18 @@ pub struct Dataset { pub sort_keys: Vec, } +impl Dataset { + pub fn new(batches: Vec, sort_keys: Vec) -> Self { + let total_rows_num = batches.iter().map(|batch| batch.num_rows()).sum::(); + + Self { + batches, + total_rows_num, + sort_keys, + } + } +} + #[derive(Debug, Clone)] pub struct ColumnDescr { // Column name @@ -362,6 +366,8 @@ mod test { use arrow::util::pretty::pretty_format_batches; use arrow_array::UInt32Array; + use crate::fuzz_cases::aggregation_fuzzer::check_equality_of_batches; + use super::*; #[test] @@ -428,20 +434,8 @@ mod test { } // Two batches should be same after sorting - let formatted_batches0 = pretty_format_batches(&datasets[0].batches) - .unwrap() - .to_string(); - let mut formatted_batches0_sorted: Vec<&str> = - formatted_batches0.trim().lines().collect(); - formatted_batches0_sorted.sort_unstable(); - let formatted_batches1 = pretty_format_batches(&datasets[1].batches) - .unwrap() - .to_string(); - let mut formatted_batches1_sorted: Vec<&str> = - formatted_batches1.trim().lines().collect(); - formatted_batches1_sorted.sort_unstable(); - assert_eq!(formatted_batches0_sorted, formatted_batches1_sorted); - + check_equality_of_batches(&datasets[0].batches, &datasets[1].batches); + // Rows num should between [16, 32] let rows_num0 = datasets[0] .batches