Skip to content

Commit

Permalink
try another unify group index computation.
Browse files Browse the repository at this point in the history
  • Loading branch information
Rachelint committed Aug 30, 2024
1 parent c9f0b06 commit c6bae6e
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,10 @@ use datafusion_expr_common::accumulator::Accumulator;
use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator};

pub const MAX_PREALLOC_BLOCK_SIZE: usize = 8192;
const GROUP_INDEX_DATA_MASK: u64 = 0x7fffffffffffffff;
const FLAT_GROUP_INDEX_ID_MASK: u64 = 0;
const FLAT_GROUP_INDEX_OFFSET_MASK: u64 = u64::MAX;
const BLOCKED_GROUP_INDEX_ID_MASK: u64 = 0xffffffff00000000;
const BLOCKED_GROUP_INDEX_OFFSET_MASK: u64 = 0x00000000ffffffff;

/// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`]
///
Expand Down Expand Up @@ -453,36 +456,14 @@ impl EmitToExt for EmitTo {
pub struct BlockedGroupIndex {
pub block_id: u32,
pub block_offset: u64,
pub flag: u8,
}

impl BlockedGroupIndex {
#[inline]
pub fn new_from_parts(flag: u8, block_id: u32, block_offset: u64) -> Self {
pub fn new_from_parts(block_id: u32, block_offset: u64) -> Self {
Self {
block_id,
block_offset,
flag,
}
}

#[inline]
pub fn new(raw_index: usize) -> Self {
let raw_index = raw_index as u64;
let flag = raw_index >> 63;
let data = raw_index & GROUP_INDEX_DATA_MASK;
let (highs, lows) = ((data >> 32) as u32, data as u32);

let block_id = (highs & (u32::MAX.wrapping_add(1 - flag as u32))) as u32;
let block_offset = {
let offset_high = highs as u64 & (u64::MAX.wrapping_add(flag));
(offset_high << 32) | (lows as u64)
};

Self {
block_id,
block_offset,
flag: flag as u8,
}
}

Expand All @@ -498,8 +479,39 @@ impl BlockedGroupIndex {

#[inline]
pub fn as_packed_index(&self) -> usize {
(((self.flag as u64) << 63) | ((self.block_id as u64) << 32) | self.block_offset)
as usize
(((self.block_id as u64) << 32) | self.block_offset) as usize
}
}

pub struct BlockedGroupIndexBuilder {
block_id_mask: u64,
block_offset_mask: u64,
}

impl BlockedGroupIndexBuilder {
pub fn new(is_blocked: bool) -> Self {
if is_blocked {
Self {
block_id_mask: BLOCKED_GROUP_INDEX_ID_MASK,
block_offset_mask: BLOCKED_GROUP_INDEX_OFFSET_MASK,
}
} else {
Self {
block_id_mask: FLAT_GROUP_INDEX_ID_MASK,
block_offset_mask: FLAT_GROUP_INDEX_OFFSET_MASK,
}
}
}

pub fn build(&self, packed_index: usize) -> BlockedGroupIndex {
let packed_index = packed_index as u64;
let block_id = ((packed_index & self.block_id_mask) >> 32) as u32;
let block_offset = packed_index & self.block_offset_mask;

BlockedGroupIndex {
block_id,
block_offset,
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use arrow::datatypes::ArrowPrimitiveType;
use datafusion_expr_common::groups_accumulator::EmitTo;

use crate::aggregate::groups_accumulator::{
BlockedGroupIndex, Blocks, MAX_PREALLOC_BLOCK_SIZE,
BlockedGroupIndex, BlockedGroupIndexBuilder, Blocks, MAX_PREALLOC_BLOCK_SIZE,
};

/// Track the accumulator null state per row: if any values for that
Expand Down Expand Up @@ -303,11 +303,14 @@ impl BlockedNullState {
false,
);
let seen_values_blocks = &mut self.seen_values_blocks;
let group_index_builder =
BlockedGroupIndexBuilder::new(self.block_size.is_some());

do_blocked_accumulate(
group_indices,
values,
opt_filter,
&group_index_builder,
value_fn,
|group_index| {
seen_values_blocks[group_index.block_id()]
Expand Down Expand Up @@ -587,6 +590,7 @@ fn do_blocked_accumulate<T, F1, F2>(
group_indices: &[usize],
values: &PrimitiveArray<T>,
opt_filter: Option<&BooleanArray>,
group_index_builder: &BlockedGroupIndexBuilder,
mut value_fn: F1,
mut set_valid_fn: F2,
) where
Expand All @@ -600,7 +604,7 @@ fn do_blocked_accumulate<T, F1, F2>(
(false, None) => {
let iter = group_indices.iter().zip(data.iter());
for (&group_index, &new_value) in iter {
let blocked_index = BlockedGroupIndex::new(group_index);
let blocked_index = group_index_builder.build(group_index);
set_valid_fn(&blocked_index);
value_fn(&blocked_index, new_value);
}
Expand Down Expand Up @@ -628,7 +632,8 @@ fn do_blocked_accumulate<T, F1, F2>(
// valid bit was set, real value
let is_valid = (mask & index_mask) != 0;
if is_valid {
let blocked_index = BlockedGroupIndex::new(group_index);
let blocked_index =
group_index_builder.build(group_index);
set_valid_fn(&blocked_index);
value_fn(&blocked_index, new_value);
}
Expand All @@ -646,7 +651,7 @@ fn do_blocked_accumulate<T, F1, F2>(
.for_each(|(i, (&group_index, &new_value))| {
let is_valid = remainder_bits & (1 << i) != 0;
if is_valid {
let blocked_index = BlockedGroupIndex::new(group_index);
let blocked_index = group_index_builder.build(group_index);
set_valid_fn(&blocked_index);
value_fn(&blocked_index, new_value);
}
Expand All @@ -664,7 +669,7 @@ fn do_blocked_accumulate<T, F1, F2>(
.zip(filter.iter())
.for_each(|((&group_index, &new_value), filter_value)| {
if let Some(true) = filter_value {
let blocked_index = BlockedGroupIndex::new(group_index);
let blocked_index = group_index_builder.build(group_index);
set_valid_fn(&blocked_index);
value_fn(&blocked_index, new_value);
}
Expand All @@ -683,7 +688,7 @@ fn do_blocked_accumulate<T, F1, F2>(
.for_each(|((filter_value, &group_index), new_value)| {
if let Some(true) = filter_value {
if let Some(new_value) = new_value {
let blocked_index = BlockedGroupIndex::new(group_index);
let blocked_index = group_index_builder.build(group_index);
set_valid_fn(&blocked_index);
value_fn(&blocked_index, new_value);
}
Expand Down Expand Up @@ -914,7 +919,6 @@ mod test {
let block_id = *idx / self.block_size;
let block_offset = *idx % self.block_size;
BlockedGroupIndex::new_from_parts(
1,
block_id as u32,
block_offset as u64,
)
Expand Down
10 changes: 7 additions & 3 deletions datafusion/functions-aggregate/src/count.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
use ahash::RandomState;
use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator;
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::{
ensure_enough_room_for_values, BlockedGroupIndex, Blocks, EmitToExt, VecBlocks,
ensure_enough_room_for_values, BlockedGroupIndexBuilder, Blocks, EmitToExt, VecBlocks,
};
use std::collections::HashSet;
use std::ops::BitAnd;
Expand Down Expand Up @@ -399,12 +399,14 @@ impl GroupsAccumulator for CountGroupsAccumulator {
0,
);

let group_index_builder =
BlockedGroupIndexBuilder::new(self.block_size.is_some());
accumulate_indices(
group_indices,
values.logical_nulls().as_ref(),
opt_filter,
|group_index| {
let blocked_index = BlockedGroupIndex::new(group_index);
let blocked_index = group_index_builder.build(group_index);
let count = &mut self.counts[blocked_index.block_id()]
[blocked_index.block_offset()];
*count += 1;
Expand Down Expand Up @@ -436,12 +438,14 @@ impl GroupsAccumulator for CountGroupsAccumulator {
0,
);

let group_index_builder =
BlockedGroupIndexBuilder::new(self.block_size.is_some());
do_count_merge_batch(
values,
group_indices,
opt_filter,
|group_index, partial_count| {
let blocked_index = BlockedGroupIndex::new(group_index);
let blocked_index = group_index_builder.build(group_index);
let count = &mut self.counts[blocked_index.block_id()]
[blocked_index.block_offset()];
*count += partial_count;
Expand Down
14 changes: 7 additions & 7 deletions datafusion/physical-plan/src/aggregates/group_values/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use datafusion_common::{DataFusionError, Result};
use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt};
use datafusion_expr::EmitTo;
use datafusion_functions_aggregate_common::aggregate::groups_accumulator::{
BlockedGroupIndex, Blocks,
BlockedGroupIndex, BlockedGroupIndexBuilder, Blocks,
};
use hashbrown::raw::RawTable;

Expand Down Expand Up @@ -135,7 +135,8 @@ impl GroupValues for GroupValuesRows {
batch_hashes.resize(n_rows, 0);
create_hashes(cols, &self.random_state, batch_hashes)?;

let flag = self.block_size.is_some() as u8;
let group_index_builder =
BlockedGroupIndexBuilder::new(self.block_size.is_some());
for (row, &target_hash) in batch_hashes.iter().enumerate() {
let entry = self.map.get_mut(target_hash, |(exist_hash, group_idx)| {
// Somewhat surprisingly, this closure can be called even if the
Expand All @@ -149,7 +150,7 @@ impl GroupValues for GroupValuesRows {
// verify that the group that we are inserting with hash is
// actually the same key value as the group in
// existing_idx (aka group_values @ row)
let blocked_index = BlockedGroupIndex::new(*group_idx);
let blocked_index = group_index_builder.build(*group_idx);
group_rows.row(row)
== group_values[blocked_index.block_id()]
.row(blocked_index.block_offset())
Expand Down Expand Up @@ -179,7 +180,6 @@ impl GroupValues for GroupValuesRows {
cur_blk.push(group_rows.row(row));

let blocked_index = BlockedGroupIndex::new_from_parts(
flag,
blk_id as u32,
blk_offset as u64,
);
Expand Down Expand Up @@ -284,16 +284,16 @@ impl GroupValues for GroupValuesRows {
let output = self.row_converter.convert_rows(cur_blk.iter())?;

unsafe {
let flag = self.block_size.is_some() as u8;
let group_index_builder =
BlockedGroupIndexBuilder::new(self.block_size.is_some());
for bucket in self.map.iter() {
// Decrement group index by n
let group_idx = bucket.as_ref().1;
let old_blk_idx = BlockedGroupIndex::new(group_idx);
let old_blk_idx = group_index_builder.build(group_idx);
match old_blk_idx.block_id().checked_sub(1) {
// Group index was >= n, shift value down
Some(new_blk_id) => {
let new_group_idx = BlockedGroupIndex::new_from_parts(
flag,
new_blk_id as u32,
old_blk_idx.block_offset,
);
Expand Down

0 comments on commit c6bae6e

Please sign in to comment.