diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs index 516310dc81ae..810f74e8515b 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_group_filter.rs @@ -228,6 +228,94 @@ struct BloomFilterStatistics { column_sbbf: HashMap, } +impl BloomFilterStatistics { + /// Helper function for checking if [`Sbbf`] filter contains [`ScalarValue`]. + /// + /// In case the type of scalar is not supported, returns `true`, assuming that the + /// value may be present. + fn check_scalar(sbbf: &Sbbf, value: &ScalarValue, parquet_type: &Type) -> bool { + match value { + ScalarValue::Utf8(Some(v)) + | ScalarValue::Utf8View(Some(v)) + | ScalarValue::LargeUtf8(Some(v)) => sbbf.check(&v.as_str()), + ScalarValue::Binary(Some(v)) + | ScalarValue::BinaryView(Some(v)) + | ScalarValue::LargeBinary(Some(v)) => sbbf.check(v), + ScalarValue::FixedSizeBinary(_size, Some(v)) => sbbf.check(v), + ScalarValue::Boolean(Some(v)) => sbbf.check(v), + ScalarValue::Float64(Some(v)) => sbbf.check(v), + ScalarValue::Float32(Some(v)) => sbbf.check(v), + ScalarValue::Int64(Some(v)) => sbbf.check(v), + ScalarValue::Int32(Some(v)) => sbbf.check(v), + ScalarValue::UInt64(Some(v)) => sbbf.check(v), + ScalarValue::UInt32(Some(v)) => sbbf.check(v), + ScalarValue::Decimal128(Some(v), p, s) => match parquet_type { + Type::INT32 => { + //https://github.com/apache/parquet-format/blob/eb4b31c1d64a01088d02a2f9aefc6c17c54cc6fc/Encodings.md?plain=1#L35-L42 + // All physical type are little-endian + if *p > 9 { + //DECIMAL can be used to annotate the following types: + // + // int32: for 1 <= precision <= 9 + // int64: for 1 <= precision <= 18 + return true; + } + let b = (*v as i32).to_le_bytes(); + // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 + let decimal = Decimal::Int32 { + value: b, + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + Type::INT64 => { + if *p > 18 { + return true; + } + let b = (*v as i64).to_le_bytes(); + let decimal = Decimal::Int64 { + value: b, + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + Type::FIXED_LEN_BYTE_ARRAY => { + // keep with from_bytes_to_i128 + let b = v.to_be_bytes().to_vec(); + // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 + let decimal = Decimal::Bytes { + value: b.into(), + precision: *p as i32, + scale: *s as i32, + }; + sbbf.check(&decimal) + } + _ => true, + }, + // One more parrern matching since not all data types are supported + // inside of a Dictionary + ScalarValue::Dictionary(_, inner) => match inner.as_ref() { + ScalarValue::Int32(_) + | ScalarValue::Int64(_) + | ScalarValue::UInt32(_) + | ScalarValue::UInt64(_) + | ScalarValue::Float32(_) + | ScalarValue::Float64(_) + | ScalarValue::Utf8(_) + | ScalarValue::LargeUtf8(_) + | ScalarValue::Binary(_) + | ScalarValue::LargeBinary(_) => { + BloomFilterStatistics::check_scalar(sbbf, inner, parquet_type) + } + _ => true, + }, + _ => true, + } + } +} + impl PruningStatistics for BloomFilterStatistics { fn min_values(&self, _column: &Column) -> Option { None @@ -268,70 +356,7 @@ impl PruningStatistics for BloomFilterStatistics { let known_not_present = values .iter() - .map(|value| { - match value { - ScalarValue::Utf8(Some(v)) | ScalarValue::Utf8View(Some(v)) => { - sbbf.check(&v.as_str()) - } - ScalarValue::Binary(Some(v)) | ScalarValue::BinaryView(Some(v)) => { - sbbf.check(v) - } - ScalarValue::FixedSizeBinary(_size, Some(v)) => sbbf.check(v), - ScalarValue::Boolean(Some(v)) => sbbf.check(v), - ScalarValue::Float64(Some(v)) => sbbf.check(v), - ScalarValue::Float32(Some(v)) => sbbf.check(v), - ScalarValue::Int64(Some(v)) => sbbf.check(v), - ScalarValue::Int32(Some(v)) => sbbf.check(v), - ScalarValue::UInt64(Some(v)) => sbbf.check(v), - ScalarValue::UInt32(Some(v)) => sbbf.check(v), - ScalarValue::Decimal128(Some(v), p, s) => match parquet_type { - Type::INT32 => { - //https://github.com/apache/parquet-format/blob/eb4b31c1d64a01088d02a2f9aefc6c17c54cc6fc/Encodings.md?plain=1#L35-L42 - // All physical type are little-endian - if *p > 9 { - //DECIMAL can be used to annotate the following types: - // - // int32: for 1 <= precision <= 9 - // int64: for 1 <= precision <= 18 - return true; - } - let b = (*v as i32).to_le_bytes(); - // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 - let decimal = Decimal::Int32 { - value: b, - precision: *p as i32, - scale: *s as i32, - }; - sbbf.check(&decimal) - } - Type::INT64 => { - if *p > 18 { - return true; - } - let b = (*v as i64).to_le_bytes(); - let decimal = Decimal::Int64 { - value: b, - precision: *p as i32, - scale: *s as i32, - }; - sbbf.check(&decimal) - } - Type::FIXED_LEN_BYTE_ARRAY => { - // keep with from_bytes_to_i128 - let b = v.to_be_bytes().to_vec(); - // Use Decimal constructor after https://github.com/apache/arrow-rs/issues/5325 - let decimal = Decimal::Bytes { - value: b.into(), - precision: *p as i32, - scale: *s as i32, - }; - sbbf.check(&decimal) - } - _ => true, - }, - _ => true, - } - }) + .map(|value| BloomFilterStatistics::check_scalar(sbbf, value, parquet_type)) // The row group doesn't contain any of the values if // all the checks are false .all(|v| !v); diff --git a/datafusion/core/tests/parquet/mod.rs b/datafusion/core/tests/parquet/mod.rs index 46be2433116a..f45eacce18df 100644 --- a/datafusion/core/tests/parquet/mod.rs +++ b/datafusion/core/tests/parquet/mod.rs @@ -17,14 +17,14 @@ //! Parquet integration tests use crate::parquet::utils::MetricsFinder; -use arrow::array::Decimal128Array; use arrow::{ array::{ make_array, Array, ArrayRef, BinaryArray, Date32Array, Date64Array, - FixedSizeBinaryArray, Float64Array, Int16Array, Int32Array, Int64Array, - Int8Array, LargeBinaryArray, LargeStringArray, StringArray, - TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray, - TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array, + Decimal128Array, DictionaryArray, FixedSizeBinaryArray, Float64Array, Int16Array, + Int32Array, Int64Array, Int8Array, LargeBinaryArray, LargeStringArray, + StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, + TimestampNanosecondArray, TimestampSecondArray, UInt16Array, UInt32Array, + UInt64Array, UInt8Array, }, datatypes::{DataType, Field, Schema}, record_batch::RecordBatch, @@ -64,7 +64,7 @@ fn init() { // ---------------------- /// What data to use -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone)] enum Scenario { Timestamps, Dates, @@ -84,6 +84,7 @@ enum Scenario { WithNullValues, WithNullValuesPageLevel, UTF8, + Dictionary, } enum Unit { @@ -740,6 +741,54 @@ fn make_utf8_batch(value: Vec>) -> RecordBatch { .unwrap() } +fn make_dictionary_batch(strings: Vec<&str>, integers: Vec) -> RecordBatch { + let keys = Int32Array::from_iter(0..strings.len() as i32); + let small_keys = Int16Array::from_iter(0..strings.len() as i16); + + let utf8_values = StringArray::from(strings.clone()); + let utf8_dict = DictionaryArray::new(keys.clone(), Arc::new(utf8_values)); + + let large_utf8 = LargeStringArray::from(strings.clone()); + let large_utf8_dict = DictionaryArray::new(keys.clone(), Arc::new(large_utf8)); + + let binary = + BinaryArray::from_iter_values(strings.iter().cloned().map(|v| v.as_bytes())); + let binary_dict = DictionaryArray::new(keys.clone(), Arc::new(binary)); + + let large_binary = + LargeBinaryArray::from_iter_values(strings.iter().cloned().map(|v| v.as_bytes())); + let large_binary_dict = DictionaryArray::new(keys.clone(), Arc::new(large_binary)); + + let int32 = Int32Array::from_iter_values(integers.clone()); + let int32_dict = DictionaryArray::new(small_keys.clone(), Arc::new(int32)); + + let int64 = Int64Array::from_iter_values(integers.iter().cloned().map(|v| v as i64)); + let int64_dict = DictionaryArray::new(keys.clone(), Arc::new(int64)); + + let uint32 = + UInt32Array::from_iter_values(integers.iter().cloned().map(|v| v as u32)); + let uint32_dict = DictionaryArray::new(small_keys.clone(), Arc::new(uint32)); + + let decimal = Decimal128Array::from_iter_values( + integers.iter().cloned().map(|v| (v * 100) as i128), + ) + .with_precision_and_scale(6, 2) + .unwrap(); + let decimal_dict = DictionaryArray::new(keys.clone(), Arc::new(decimal)); + + RecordBatch::try_from_iter(vec![ + ("utf8", Arc::new(utf8_dict) as _), + ("large_utf8", Arc::new(large_utf8_dict) as _), + ("binary", Arc::new(binary_dict) as _), + ("large_binary", Arc::new(large_binary_dict) as _), + ("int32", Arc::new(int32_dict) as _), + ("int64", Arc::new(int64_dict) as _), + ("uint32", Arc::new(uint32_dict) as _), + ("decimal", Arc::new(decimal_dict) as _), + ]) + .unwrap() +} + fn create_data_batch(scenario: Scenario) -> Vec { match scenario { Scenario::Timestamps => { @@ -961,6 +1010,13 @@ fn create_data_batch(scenario: Scenario) -> Vec { ]), ] } + + Scenario::Dictionary => { + vec![ + make_dictionary_batch(vec!["a", "b", "c", "d", "e"], vec![0, 1, 2, 5, 6]), + make_dictionary_batch(vec!["f", "g", "h", "i", "j"], vec![0, 1, 3, 8, 9]), + ] + } } } diff --git a/datafusion/core/tests/parquet/row_group_pruning.rs b/datafusion/core/tests/parquet/row_group_pruning.rs index 536ac5414a9a..d8ce2970bdf7 100644 --- a/datafusion/core/tests/parquet/row_group_pruning.rs +++ b/datafusion/core/tests/parquet/row_group_pruning.rs @@ -1323,3 +1323,215 @@ async fn test_row_group_with_null_values() { .test_row_group_prune() .await; } + +#[tokio::test] +async fn test_bloom_filter_utf8_dict() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE utf8 = 'h'") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE utf8 = 'ab'") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE large_utf8 = 'b'") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE large_utf8 = 'cd'") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn test_bloom_filter_integer_dict() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE int32 = arrow_cast(8, 'Int32')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE int32 = arrow_cast(7, 'Int32')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE int64 = 8") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE int64 = 7") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn test_bloom_filter_unsigned_integer_dict() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE uint32 = arrow_cast(8, 'UInt32')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE uint32 = arrow_cast(7, 'UInt32')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; +} + +#[tokio::test] +async fn test_bloom_filter_binary_dict() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE binary = arrow_cast('b', 'Binary')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE binary = arrow_cast('banana', 'Binary')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE large_binary = arrow_cast('d', 'LargeBinary')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query( + "SELECT * FROM t WHERE large_binary = arrow_cast('dre', 'LargeBinary')", + ) + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; +} + +// Makes sense to enable (or at least try to) after +// https://github.com/apache/datafusion/issues/13821 +#[ignore] +#[tokio::test] +async fn test_bloom_filter_decimal_dict() { + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE decimal = arrow_cast(8, 'Decimal128(6, 2)')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(1) + .with_pruned_by_bloom_filter(Some(0)) + .with_matched_by_bloom_filter(Some(1)) + .test_row_group_prune() + .await; + + RowGroupPruningTest::new() + .with_scenario(Scenario::Dictionary) + .with_query("SELECT * FROM t WHERE decimal = arrow_cast(7, 'Decimal128(6, 2)')") + .with_expected_errors(Some(0)) + .with_matched_by_stats(Some(1)) + .with_pruned_by_stats(Some(1)) + .with_expected_rows(0) + .with_pruned_by_bloom_filter(Some(1)) + .with_matched_by_bloom_filter(Some(0)) + .test_row_group_prune() + .await; +}