diff --git a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs index 8a1abb7d965f..105e652a7056 100644 --- a/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs +++ b/datafusion/core/src/datasource/physical_plan/parquet/row_groups.rs @@ -15,10 +15,12 @@ // specific language governing permissions and limitations // under the License. +use arrow::compute::{can_cast_types, cast_with_options, CastOptions}; use arrow::{array::ArrayRef, datatypes::Schema}; -use arrow_array::BooleanArray; -use arrow_schema::FieldRef; -use datafusion_common::{Column, ScalarValue}; +use arrow_array::{Array, BooleanArray}; +use arrow_schema::{DataType, FieldRef}; +use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; +use datafusion_common::{Column, DataFusionError, ScalarValue}; use parquet::file::metadata::ColumnChunkMetaData; use parquet::schema::types::SchemaDescriptor; use parquet::{ @@ -276,15 +278,74 @@ impl<'a> PruningStatistics for RowGroupPruningStatistics<'a> { scalar.to_array().ok() } + /// The basic idea is to check whether all of the `values` are not within the min-max boundary. + /// If any one value is within the min-max boundary, then this row group will not be skipped. + /// Otherwise, this row group will be able to be skipped. fn contained( &self, - _column: &Column, - _values: &HashSet, + column: &Column, + values: &HashSet, ) -> Option { - None + let min_values = self.min_values(column)?; + let max_values = self.max_values(column)?; + // The boundary should be with length of 1 + if min_values.len() != max_values.len() || min_values.len() != 1 { + return None; + } + let min_value = ScalarValue::try_from_array(min_values.as_ref(), 0).ok()?; + let max_value = ScalarValue::try_from_array(max_values.as_ref(), 0).ok()?; + + // The boundary should be with the same data type + if min_value.data_type() != max_value.data_type() { + return None; + } + let target_data_type = min_value.data_type(); + + let (c, _) = self.column(&column.name)?; + let has_null = c.statistics()?.null_count() > 0; + let mut known_not_present = true; + for value in values { + // If it's null, check whether the null exists from the statistics + if has_null && value.is_null() { + known_not_present = false; + break; + } + // The filter values should be cast to the boundary's data type + if !can_cast_types(&value.data_type(), &target_data_type) { + return None; + } + let value = + cast_scalar_value(value, &target_data_type, &DEFAULT_CAST_OPTIONS) + .ok()?; + + // If the filter value is within the boundary, will not be able to filter out this row group + if value >= min_value && value <= max_value { + known_not_present = false; + break; + } + } + + let contains = if known_not_present { Some(false) } else { None }; + + Some(BooleanArray::from(vec![contains])) } } +const DEFAULT_CAST_OPTIONS: CastOptions<'static> = CastOptions { + safe: false, + format_options: DEFAULT_FORMAT_OPTIONS, +}; + +/// Cast scalar value to the given data type using an arrow kernel. +fn cast_scalar_value( + value: &ScalarValue, + data_type: &DataType, + cast_options: &CastOptions, +) -> Result { + let cast_array = cast_with_options(&value.to_array()?, data_type, cast_options)?; + ScalarValue::try_from_array(&cast_array, 0) +} + #[cfg(test)] mod tests { use super::*;