Skip to content

Commit

Permalink
add unique group by test for median, approx_median, `approx_disti…
Browse files Browse the repository at this point in the history
…nct`.
  • Loading branch information
Rachelint committed Sep 4, 2024
1 parent 8e80e78 commit be5316f
Show file tree
Hide file tree
Showing 7 changed files with 98 additions and 48 deletions.
2 changes: 1 addition & 1 deletion datafusion/expr-common/src/groups_accumulator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ pub trait GroupsAccumulator: Send {
///
/// [`Accumulator::state`]: crate::accumulator::Accumulator::state
fn convert_to_state(
&mut self,
&self,
_values: &[ArrayRef],
_opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,16 @@ pub mod accumulate;
pub mod bool_op;
pub mod nulls;
pub mod prim_op;
use std::mem;

use arrow::{
array::{ArrayRef, ArrowPrimitiveType, AsArray, BooleanArray, PrimitiveArray, TimestampSecondArray},
array::{
ArrayRef, AsArray, BooleanArray, PrimitiveArray,
},
compute,
datatypes::{DataType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type},
datatypes::{
DataType, Float16Type, Float32Type, Float64Type, Int16Type, Int32Type, Int64Type,
Int8Type, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
},
};
use datafusion_common::{
arrow_datafusion_err, utils::get_arrayref_at_indices, DataFusionError, Result,
Expand Down Expand Up @@ -384,31 +388,29 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
}

fn convert_to_state(
&mut self,
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
let num_rows = values[0].len();

// Make the buffer large enough.
self.ensure_convert_buffer_large_enough(num_rows)?;

// Each row has its respective group.
let mut results = vec![Vec::with_capacity(num_rows); values.len()];
// Each row has its respective group
let mut results = vec![];
for row_idx in 0..num_rows {
// Take the empty to update, and replace with the new empty.
let new_accumulator = (self.factory)()?;
let mut convert_state =
mem::replace(&mut self.convert_state_buffer[row_idx], new_accumulator);
// Take the empty to update, and replace with the new empty
let mut converted_accumulator = (self.factory)()?;

// Convert row to state by applying it to the empty accumulator.
// Convert row to state by applying it to the empty accumulator.\
let values_to_accumulate =
slice_and_maybe_filter(&values, opt_filter, &[row_idx, row_idx + 1])?;
convert_state.update_batch(&values_to_accumulate)?;
let row_states = convert_state.state()?;
converted_accumulator.update_batch(&values_to_accumulate)?;
let states = converted_accumulator.state()?;

for (col_idx, col_state) in row_states.into_iter().enumerate() {
results[col_idx].push(col_state);
// Resize results to have enough columns according to the converted states
results.resize_with(states.len(), || Vec::with_capacity(num_rows));
// Add the states to results
for (idx, state_val) in states.into_iter().enumerate() {
results[idx].push(state_val);
}
}

Expand Down Expand Up @@ -469,50 +471,53 @@ pub(crate) fn slice_and_maybe_filter(
.collect();

if let Some(f) = filter_opt {
let filter_array = f.slice(offset, length);
let filter = f.slice(offset, length);

sliced_arrays
.iter()
.map(|array| {
compute::filter(array, &filter_array)
.map_err(|e| arrow_datafusion_err!(e))
})
.map(|array| filter_array(array, &filter))
.collect()
} else {
Ok(sliced_arrays)
}
}


macro_rules! filter_primitive_array {
($Values:ident, $PType:ident, $Filter:expr) => {{
let array = $Values.as_primitive::<$PType>().clone();
let nulls = filtered_null_mask($Filter, &array);
Ok(std::sync::Arc::new(set_nulls(array, nulls)))
}}
}};
}

fn filter_array<P: ArrowPrimitiveType>(
// TODO:
fn filter_array(
values: &ArrayRef,
filter: &BooleanArray,
) -> Result<ArrayRef> {
match values.data_type() {
DataType::Boolean => {
let values_null_buffer_filtered = filtered_null_mask(Some(filter), &values);
let (values_buf, _) = values.into_parts();
Ok(Arc::new(BooleanArray::new(values_buf, values_null_buffer_filtered)))
},
DataType::Int8 => filter_primitive_array!(values, Int8Type, Some(filter)),
DataType::Int16 => filter_primitive_array!(values, Int16Type, Some(filter)),
DataType::Int32 => filter_primitive_array!(values, Int32Type, Some(filter)),
DataType::Int64 => filter_primitive_array!(values, Int64Type, Some(filter)),
DataType::UInt8 => filter_primitive_array!(values, UInt8Type, Some(filter)),
DataType::UInt16 => filter_primitive_array!(values, UInt16Type, Some(filter)),
DataType::UInt32 => filter_primitive_array!(values, UInt32Type, Some(filter)),
DataType::UInt64 => filter_primitive_array!(values, UInt64Type, Some(filter)),
DataType::Float16 => filter_primitive_array!(values, Float16Type, Some(filter)),
DataType::Float32 => filter_primitive_array!(values, Float32Type, Some(filter)),
DataType::Float64 => filter_primitive_array!(values, Float64Type, Some(filter)),
other_type => compute::filter(values, predicate),
DataType::Boolean => {
let array = values.as_boolean().clone();
let array_null_buffer_filtered = filtered_null_mask(Some(filter), &array);
let (array_values_buf, _) = array.into_parts();
Ok(std::sync::Arc::new(BooleanArray::new(
array_values_buf,
array_null_buffer_filtered,
)))
}
DataType::Int8 => filter_primitive_array!(values, Int8Type, Some(filter)),
DataType::Int16 => filter_primitive_array!(values, Int16Type, Some(filter)),
DataType::Int32 => filter_primitive_array!(values, Int32Type, Some(filter)),
DataType::Int64 => filter_primitive_array!(values, Int64Type, Some(filter)),
DataType::UInt8 => filter_primitive_array!(values, UInt8Type, Some(filter)),
DataType::UInt16 => filter_primitive_array!(values, UInt16Type, Some(filter)),
DataType::UInt32 => filter_primitive_array!(values, UInt32Type, Some(filter)),
DataType::UInt64 => filter_primitive_array!(values, UInt64Type, Some(filter)),
DataType::Float16 => filter_primitive_array!(values, Float16Type, Some(filter)),
DataType::Float32 => filter_primitive_array!(values, Float32Type, Some(filter)),
DataType::Float64 => filter_primitive_array!(values, Float64Type, Some(filter)),
_ => {
compute::filter(values, filter).map_err(|e| arrow_datafusion_err!(e))
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ where
}

fn convert_to_state(
&mut self,
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ where
/// - null otherwise
///
fn convert_to_state(
&mut self,
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -453,7 +453,7 @@ impl GroupsAccumulator for CountGroupsAccumulator {
/// * `1` (for non-null, non filtered values)
/// * `0` (for null values)
fn convert_to_state(
&mut self,
&self,
values: &[ArrayRef],
opt_filter: Option<&BooleanArray>,
) -> Result<Vec<ArrayRef>> {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1065,7 +1065,7 @@ impl GroupedHashAggregateStream {

let iter = self
.accumulators
.iter_mut()
.iter()
.zip(input_values.iter())
.zip(filter_values.iter());

Expand Down
45 changes: 45 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate_skip_partial.slt
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,51 @@ GROUP BY 1, 2 ORDER BY 1 LIMIT 5;
-2117946883 d -2117946883 NULL NULL NULL
-2098805236 c -2098805236 NULL NULL NULL

query ITIIII
SELECT c5, c1,
MEDIAN(c5),
MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END),
MEDIAN(c5) FILTER (WHERE c1 = 'b'),
MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b')
FROM aggregate_test_100
GROUP BY 1, 2 ORDER BY 1 LIMIT 5;
----
-2141999138 c -2141999138 NULL NULL NULL
-2141451704 a -2141451704 -2141451704 NULL NULL
-2138770630 b -2138770630 NULL -2138770630 NULL
-2117946883 d -2117946883 NULL NULL NULL
-2098805236 c -2098805236 NULL NULL NULL

query ITIIII
SELECT c5, c1,
APPROX_MEDIAN(c5),
APPROX_MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END),
APPROX_MEDIAN(c5) FILTER (WHERE c1 = 'b'),
APPROX_MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b')
FROM aggregate_test_100
GROUP BY 1, 2 ORDER BY 1 LIMIT 5;
----
-2141999138 c -2141999138 NULL NULL NULL
-2141451704 a -2141451704 -2141451704 NULL NULL
-2138770630 b -2138770630 NULL -2138770630 NULL
-2117946883 d -2117946883 NULL NULL NULL
-2098805236 c -2098805236 NULL NULL NULL

query ITIIII
SELECT c5, c1,
APPROX_DISTINCT(c5),
APPROX_DISTINCT(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END),
APPROX_DISTINCT(c5) FILTER (WHERE c1 = 'b'),
APPROX_DISTINCT(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b')
FROM aggregate_test_100
GROUP BY 1, 2 ORDER BY 1 LIMIT 5;
----
-2141999138 c 1 0 0 0
-2141451704 a 1 1 0 0
-2138770630 b 1 0 1 0
-2117946883 d 1 0 0 0
-2098805236 c 1 0 0 0

# FIXME: add bool_and(v3) column when issue fixed
# ISSUE https://github.com/apache/datafusion/issues/11846
query TBBB rowsort
Expand Down

0 comments on commit be5316f

Please sign in to comment.