Skip to content

Commit

Permalink
Implement GroupsAccumulator for corr(x,y) aggregate function (#13581)
Browse files Browse the repository at this point in the history
* Implement GroupsAccumulator for corr(x,y)

* feedbacks

* fix CI MSRV

* review

* avoid collect in accumulation

* add back cast
  • Loading branch information
2010YOUY01 authored Dec 12, 2024
1 parent 44f4be2 commit 36a1361
Show file tree
Hide file tree
Showing 2 changed files with 499 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,75 @@ pub fn accumulate<T, F>(
}
}

/// Accumulates with multiple accumulate(value) columns. (e.g. `corr(c1, c2)`)
///
/// This method assumes that for any input record index, if any of the value column
/// is null, or it's filtered out by `opt_filter`, then the record would be ignored.
/// (won't be accumulated by `value_fn`)
///
/// # Arguments
///
/// * `group_indices` - To which groups do the rows in `value_columns` belong
/// * `value_columns` - The input arrays to accumulate
/// * `opt_filter` - Optional filter array. If present, only rows where filter is `Some(true)` are included
/// * `value_fn` - Callback function for each valid row, with parameters:
/// * `group_idx`: The group index for the current row
/// * `batch_idx`: The index of the current row in the input arrays
/// * `columns`: Reference to all input arrays for accessing values
pub fn accumulate_multiple<T, F>(
group_indices: &[usize],
value_columns: &[&PrimitiveArray<T>],
opt_filter: Option<&BooleanArray>,
mut value_fn: F,
) where
T: ArrowPrimitiveType + Send,
F: FnMut(usize, usize, &[&PrimitiveArray<T>]) + Send,
{
// Calculate `valid_indices` to accumulate, non-valid indices are ignored.
// `valid_indices` is a bit mask corresponding to the `group_indices`. An index
// is considered valid if:
// 1. All columns are non-null at this index.
// 2. Not filtered out by `opt_filter`

// Take AND from all null buffers of `value_columns`.
let combined_nulls = value_columns
.iter()
.map(|arr| arr.logical_nulls())
.fold(None, |acc, nulls| {
NullBuffer::union(acc.as_ref(), nulls.as_ref())
});

// Take AND from previous combined nulls and `opt_filter`.
let valid_indices = match (combined_nulls, opt_filter) {
(None, None) => None,
(None, Some(filter)) => Some(filter.clone()),
(Some(nulls), None) => Some(BooleanArray::new(nulls.inner().clone(), None)),
(Some(nulls), Some(filter)) => {
let combined = nulls.inner() & filter.values();
Some(BooleanArray::new(combined, None))
}
};

for col in value_columns.iter() {
debug_assert_eq!(col.len(), group_indices.len());
}

match valid_indices {
None => {
for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
value_fn(group_idx, batch_idx, value_columns);
}
}
Some(valid_indices) => {
for (batch_idx, &group_idx) in group_indices.iter().enumerate() {
if valid_indices.value(batch_idx) {
value_fn(group_idx, batch_idx, value_columns);
}
}
}
}
}

/// This function is called to update the accumulator state per row
/// when the value is not needed (e.g. COUNT)
///
Expand Down Expand Up @@ -528,7 +597,7 @@ fn initialize_builder(
mod test {
use super::*;

use arrow::array::UInt32Array;
use arrow::array::{Int32Array, UInt32Array};
use rand::{rngs::ThreadRng, Rng};
use std::collections::HashSet;

Expand Down Expand Up @@ -940,4 +1009,107 @@ mod test {
.collect()
}
}

#[test]
fn test_accumulate_multiple_no_nulls_no_filter() {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![1, 2, 3, 4]);
let values2 = Int32Array::from(vec![10, 20, 30, 40]);
let value_columns = [values1, values2];

let mut accumulated = vec![];
accumulate_multiple(
&group_indices,
&value_columns.iter().collect::<Vec<_>>(),
None,
|group_idx, batch_idx, columns| {
let values = columns.iter().map(|col| col.value(batch_idx)).collect();
accumulated.push((group_idx, values));
},
);

let expected = vec![
(0, vec![1, 10]),
(1, vec![2, 20]),
(0, vec![3, 30]),
(1, vec![4, 40]),
];
assert_eq!(accumulated, expected);
}

#[test]
fn test_accumulate_multiple_with_nulls() {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
let value_columns = [values1, values2];

let mut accumulated = vec![];
accumulate_multiple(
&group_indices,
&value_columns.iter().collect::<Vec<_>>(),
None,
|group_idx, batch_idx, columns| {
let values = columns.iter().map(|col| col.value(batch_idx)).collect();
accumulated.push((group_idx, values));
},
);

// Only rows where both columns are non-null should be accumulated
let expected = vec![(0, vec![1, 10]), (1, vec![4, 40])];
assert_eq!(accumulated, expected);
}

#[test]
fn test_accumulate_multiple_with_filter() {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![1, 2, 3, 4]);
let values2 = Int32Array::from(vec![10, 20, 30, 40]);
let value_columns = [values1, values2];

let filter = BooleanArray::from(vec![true, false, true, false]);

let mut accumulated = vec![];
accumulate_multiple(
&group_indices,
&value_columns.iter().collect::<Vec<_>>(),
Some(&filter),
|group_idx, batch_idx, columns| {
let values = columns.iter().map(|col| col.value(batch_idx)).collect();
accumulated.push((group_idx, values));
},
);

// Only rows where filter is true should be accumulated
let expected = vec![(0, vec![1, 10]), (0, vec![3, 30])];
assert_eq!(accumulated, expected);
}

#[test]
fn test_accumulate_multiple_with_nulls_and_filter() {
let group_indices = vec![0, 1, 0, 1];
let values1 = Int32Array::from(vec![Some(1), None, Some(3), Some(4)]);
let values2 = Int32Array::from(vec![Some(10), Some(20), None, Some(40)]);
let value_columns = [values1, values2];

let filter = BooleanArray::from(vec![true, true, true, false]);

let mut accumulated = vec![];
accumulate_multiple(
&group_indices,
&value_columns.iter().collect::<Vec<_>>(),
Some(&filter),
|group_idx, batch_idx, columns| {
let values = columns.iter().map(|col| col.value(batch_idx)).collect();
accumulated.push((group_idx, values));
},
);

// Only rows where both:
// 1. Filter is true
// 2. Both columns are non-null
// should be accumulated
let expected = [(0, vec![1, 10])];
assert_eq!(accumulated, expected);
}
}
Loading

0 comments on commit 36a1361

Please sign in to comment.