Skip to content

Commit

Permalink
add fuzz tests
Browse files Browse the repository at this point in the history
  • Loading branch information
adriangb committed Dec 13, 2024
1 parent bf652e5 commit 969e83c
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 0 deletions.
2 changes: 2 additions & 0 deletions datafusion/core/tests/fuzz_cases/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ mod sort_fuzz;
mod aggregation_fuzzer;
mod equivalence;

mod pruning;

mod limit_fuzz;
mod sort_preserving_repartition_fuzz;
mod window_fuzz;
230 changes: 230 additions & 0 deletions datafusion/core/tests/fuzz_cases/pruning.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
use std::sync::Arc;

use arrow_array::{Array, RecordBatch, StringArray};
use arrow_schema::{DataType, Field, Schema};
use bytes::{BufMut, Bytes, BytesMut};
use datafusion::{
datasource::{
listing::PartitionedFile,
physical_plan::{parquet::ParquetExecBuilder, FileScanConfig},
},
prelude::*,
};
use datafusion_common::DFSchema;
use datafusion_execution::object_store::ObjectStoreUrl;
use datafusion_physical_expr::PhysicalExpr;
use datafusion_physical_plan::{collect, filter::FilterExec, ExecutionPlan};
use itertools::Itertools;
use object_store::{memory::InMemory, path::Path, ObjectStore, PutPayload};
use parquet::{
arrow::ArrowWriter,
file::properties::{EnabledStatistics, WriterProperties},
};
use rand::seq::SliceRandom;
use url::Url;

#[tokio::test]
async fn test_fuzz_utf8() {
// Fuzz testing for UTF8 predicate pruning
// The basic idea is that query results should always be the same with or without stats/pruning
// If we get this right we at least guarantee that there are no incorrect results
// There may still be suboptimal pruning or stats but that's something we can try to catch
// with more targeted tests.

// Since we know where the edge cases might be we don't do random black box fuzzing.
// Instead we fuzz on specific pre-defined axis:
//
// - Which characters are in each value. We want to make sure to include characters that when
// incremented, truncated or otherwise manipulated might cause issues.
// - The values in each row group. This impacts which min/max stats are generated for each rg.
// We'll generate combinations of the characters with lengths ranging from 1 to 4.
// - Truncation of statistics to 1, 2 or 3 characters as well as no truncation.

let mut rng = rand::thread_rng();

let characters = [
"z",
"0",
"~",
"ß",
"℣",
"%", // this one is useful for like/not like tests since it will result in randomly inserted wildcards
"_", // this one is useful for like/not like tests since it will result in randomly inserted wildcards
"\u{7F}",
"\u{7FF}",
"\u{FF}",
"\u{10FFFF}",
"\u{D7FF}",
"\u{FDCF}",
// null character
"\u{0}",
];

let value_lengths = [1, 2, 3];

// generate all combinations of characters with lengths ranging from 1 to 4
let mut values = vec![];
for length in &value_lengths {
values.extend(
characters
.iter()
.cloned()
.combinations(*length)
// now get all permutations of each combination
.flat_map(|c| c.into_iter().permutations(*length))
// and join them into strings
.map(|c| c.join("")),
);
}

println!("Generated {} values", values.len());

// randomly pick 100 values
values.shuffle(&mut rng);
values.truncate(100);

let mut row_groups = vec![];
// generate all combinations of values for row groups (1 or 2 values per rg, more is unessecarry since we only get min/max stats out)
for rg_length in [1, 2] {
row_groups.extend(values.iter().cloned().combinations(rg_length));
}

println!("Generated {} row groups", row_groups.len());

// Randomly pick 100 row groups (combinations of said values)
row_groups.shuffle(&mut rng);
row_groups.truncate(100);

let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, false)]));
let df_schema = DFSchema::try_from(schema.clone()).unwrap();

let store = InMemory::new();
let mut files = vec![];
for (idx, truncation_length) in [Some(1), Some(2), None].iter().enumerate() {
// parquet files only support 32767 row groups per file, so chunk up into multiple files so we don't error if running on a large number of row groups
for (rg_idx, row_groups) in row_groups.chunks(32766).enumerate() {
let buf = write_parquet_file(
*truncation_length,
schema.clone(),
row_groups.to_vec(),
)
.await;
let filename = format!("test_fuzz_utf8_{idx}_{rg_idx}.parquet");
files.push((filename.clone(), buf.len()));
let payload = PutPayload::from(buf);
let path = Path::from(filename);
store.put(&path, payload).await.unwrap();
}
}

println!("Generated {} parquet files", files.len());

let ctx = SessionContext::new();

ctx.register_object_store(&Url::parse("memory://").unwrap(), Arc::new(store));

let mut predicates = vec![];
for value in values {
predicates.push(col("a").eq(lit(value.clone())));
predicates.push(col("a").not_eq(lit(value.clone())));
predicates.push(col("a").lt(lit(value.clone())));
predicates.push(col("a").lt_eq(lit(value.clone())));
predicates.push(col("a").gt(lit(value.clone())));
predicates.push(col("a").gt_eq(lit(value.clone())));
predicates.push(col("a").like(lit(value.clone())));
predicates.push(col("a").not_like(lit(value.clone())));
predicates.push(col("a").like(lit(format!("%{}", value.clone()))));
predicates.push(col("a").like(lit(format!("{}%", value.clone()))));
predicates.push(col("a").not_like(lit(format!("%{}", value.clone()))));
predicates.push(col("a").not_like(lit(format!("{}%", value.clone()))));
}

for predicate in predicates {
println!("Testing predicate {:?}", predicate);
let phys_expr_predicate = ctx
.create_physical_expr(predicate.clone(), &df_schema)
.unwrap();
let expected = execute_with_predicate(
&files,
phys_expr_predicate.clone(),
false,
schema.clone(),
&ctx,
)
.await;
let with_pruning = execute_with_predicate(
&files,
phys_expr_predicate,
true,
schema.clone(),
&ctx,
)
.await;
assert_eq!(expected, with_pruning);
}
}

async fn execute_with_predicate(
files: &[(String, usize)],
predicate: Arc<dyn PhysicalExpr>,
prune_stats: bool,
schema: Arc<Schema>,
ctx: &SessionContext,
) -> Vec<String> {
let scan =
FileScanConfig::new(ObjectStoreUrl::parse("memory://").unwrap(), schema.clone())
.with_file_group(
files
.iter()
.map(|(path, size)| PartitionedFile::new(path.clone(), *size as u64))
.collect(),
);
let mut builder = ParquetExecBuilder::new(scan);
if prune_stats {
builder = builder.with_predicate(predicate.clone())
}
let exec = Arc::new(builder.build()) as Arc<dyn ExecutionPlan>;
let exec =
Arc::new(FilterExec::try_new(predicate, exec).unwrap()) as Arc<dyn ExecutionPlan>;

let batches = collect(exec, ctx.task_ctx()).await.unwrap();
let mut values = vec![];
for batch in batches {
let column = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
for i in 0..column.len() {
values.push(column.value(i).to_string());
}
}
values
}

async fn write_parquet_file(
truncation_length: Option<usize>,
schema: Arc<Schema>,
row_groups: Vec<Vec<String>>,
) -> Bytes {
let mut buf = BytesMut::new().writer();
let mut props = WriterProperties::builder();
if let Some(truncation_length) = truncation_length {
props = props.set_max_statistics_size(truncation_length);
}
props = props.set_statistics_enabled(EnabledStatistics::Chunk); // row group level
let props = props.build();
{
let mut writer =
ArrowWriter::try_new(&mut buf, schema.clone(), Some(props)).unwrap();
for rg_values in row_groups.iter() {
let arr = StringArray::from_iter_values(rg_values.iter());
let batch =
RecordBatch::try_new(schema.clone(), vec![Arc::new(arr)]).unwrap();
writer.write(&batch).unwrap();
writer.flush().unwrap(); // finishes the current row group and starts a new one
}
writer.finish().unwrap();
}
buf.into_inner().freeze()
}

0 comments on commit 969e83c

Please sign in to comment.