diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs index fb18fee60fc1b..5218ebacc4fb2 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs @@ -44,7 +44,10 @@ use datafusion_expr_common::accumulator::Accumulator; use datafusion_expr_common::groups_accumulator::{EmitTo, GroupsAccumulator}; pub const MAX_PREALLOC_BLOCK_SIZE: usize = 8192; -const GROUP_INDEX_DATA_MASK: u64 = 0x7fffffffffffffff; +const FLAT_GROUP_INDEX_ID_MASK: u64 = 0; +const FLAT_GROUP_INDEX_OFFSET_MASK: u64 = u64::MAX; +const BLOCKED_GROUP_INDEX_ID_MASK: u64 = 0xffffffff00000000; +const BLOCKED_GROUP_INDEX_OFFSET_MASK: u64 = 0x00000000ffffffff; /// An adapter that implements [`GroupsAccumulator`] for any [`Accumulator`] /// @@ -453,36 +456,14 @@ impl EmitToExt for EmitTo { pub struct BlockedGroupIndex { pub block_id: u32, pub block_offset: u64, - pub flag: u8, } impl BlockedGroupIndex { #[inline] - pub fn new_from_parts(flag: u8, block_id: u32, block_offset: u64) -> Self { + pub fn new_from_parts(block_id: u32, block_offset: u64) -> Self { Self { block_id, block_offset, - flag, - } - } - - #[inline] - pub fn new(raw_index: usize) -> Self { - let raw_index = raw_index as u64; - let flag = raw_index >> 63; - let data = raw_index & GROUP_INDEX_DATA_MASK; - let (highs, lows) = ((data >> 32) as u32, data as u32); - - let block_id = (highs & (u32::MAX.wrapping_add(1 - flag as u32))) as u32; - let block_offset = { - let offset_high = highs as u64 & (u64::MAX.wrapping_add(flag)); - (offset_high << 32) | (lows as u64) - }; - - Self { - block_id, - block_offset, - flag: flag as u8, } } @@ -498,8 +479,39 @@ impl BlockedGroupIndex { #[inline] pub fn as_packed_index(&self) -> usize { - (((self.flag as u64) << 63) | ((self.block_id as u64) << 32) | self.block_offset) - as usize + (((self.block_id as u64) << 32) | self.block_offset) as usize + } +} + +pub struct BlockedGroupIndexBuilder { + block_id_mask: u64, + block_offset_mask: u64, +} + +impl BlockedGroupIndexBuilder { + pub fn new(is_blocked: bool) -> Self { + if is_blocked { + Self { + block_id_mask: BLOCKED_GROUP_INDEX_ID_MASK, + block_offset_mask: BLOCKED_GROUP_INDEX_OFFSET_MASK, + } + } else { + Self { + block_id_mask: FLAT_GROUP_INDEX_ID_MASK, + block_offset_mask: FLAT_GROUP_INDEX_OFFSET_MASK, + } + } + } + + pub fn build(&self, packed_index: usize) -> BlockedGroupIndex { + let packed_index = packed_index as u64; + let block_id = ((packed_index & self.block_id_mask) >> 32) as u32; + let block_offset = packed_index & self.block_offset_mask; + + BlockedGroupIndex { + block_id, + block_offset, + } } } diff --git a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs index b9e6bf64adf85..770a104548802 100644 --- a/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs +++ b/datafusion/functions-aggregate-common/src/aggregate/groups_accumulator/accumulate.rs @@ -26,7 +26,7 @@ use arrow::datatypes::ArrowPrimitiveType; use datafusion_expr_common::groups_accumulator::EmitTo; use crate::aggregate::groups_accumulator::{ - BlockedGroupIndex, Blocks, MAX_PREALLOC_BLOCK_SIZE, + BlockedGroupIndex, BlockedGroupIndexBuilder, Blocks, MAX_PREALLOC_BLOCK_SIZE, }; /// Track the accumulator null state per row: if any values for that @@ -303,11 +303,14 @@ impl BlockedNullState { false, ); let seen_values_blocks = &mut self.seen_values_blocks; + let group_index_builder = + BlockedGroupIndexBuilder::new(self.block_size.is_some()); do_blocked_accumulate( group_indices, values, opt_filter, + &group_index_builder, value_fn, |group_index| { seen_values_blocks[group_index.block_id()] @@ -587,6 +590,7 @@ fn do_blocked_accumulate( group_indices: &[usize], values: &PrimitiveArray, opt_filter: Option<&BooleanArray>, + group_index_builder: &BlockedGroupIndexBuilder, mut value_fn: F1, mut set_valid_fn: F2, ) where @@ -600,7 +604,7 @@ fn do_blocked_accumulate( (false, None) => { let iter = group_indices.iter().zip(data.iter()); for (&group_index, &new_value) in iter { - let blocked_index = BlockedGroupIndex::new(group_index); + let blocked_index = group_index_builder.build(group_index); set_valid_fn(&blocked_index); value_fn(&blocked_index, new_value); } @@ -628,7 +632,8 @@ fn do_blocked_accumulate( // valid bit was set, real value let is_valid = (mask & index_mask) != 0; if is_valid { - let blocked_index = BlockedGroupIndex::new(group_index); + let blocked_index = + group_index_builder.build(group_index); set_valid_fn(&blocked_index); value_fn(&blocked_index, new_value); } @@ -646,7 +651,7 @@ fn do_blocked_accumulate( .for_each(|(i, (&group_index, &new_value))| { let is_valid = remainder_bits & (1 << i) != 0; if is_valid { - let blocked_index = BlockedGroupIndex::new(group_index); + let blocked_index = group_index_builder.build(group_index); set_valid_fn(&blocked_index); value_fn(&blocked_index, new_value); } @@ -664,7 +669,7 @@ fn do_blocked_accumulate( .zip(filter.iter()) .for_each(|((&group_index, &new_value), filter_value)| { if let Some(true) = filter_value { - let blocked_index = BlockedGroupIndex::new(group_index); + let blocked_index = group_index_builder.build(group_index); set_valid_fn(&blocked_index); value_fn(&blocked_index, new_value); } @@ -683,7 +688,7 @@ fn do_blocked_accumulate( .for_each(|((filter_value, &group_index), new_value)| { if let Some(true) = filter_value { if let Some(new_value) = new_value { - let blocked_index = BlockedGroupIndex::new(group_index); + let blocked_index = group_index_builder.build(group_index); set_valid_fn(&blocked_index); value_fn(&blocked_index, new_value); } @@ -914,7 +919,6 @@ mod test { let block_id = *idx / self.block_size; let block_offset = *idx % self.block_size; BlockedGroupIndex::new_from_parts( - 1, block_id as u32, block_offset as u64, ) diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index 38c050c41609c..e6d44b71ec6bd 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -18,7 +18,7 @@ use ahash::RandomState; use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::{ - ensure_enough_room_for_values, BlockedGroupIndex, Blocks, EmitToExt, VecBlocks, + ensure_enough_room_for_values, BlockedGroupIndexBuilder, Blocks, EmitToExt, VecBlocks, }; use std::collections::HashSet; use std::ops::BitAnd; @@ -399,12 +399,14 @@ impl GroupsAccumulator for CountGroupsAccumulator { 0, ); + let group_index_builder = + BlockedGroupIndexBuilder::new(self.block_size.is_some()); accumulate_indices( group_indices, values.logical_nulls().as_ref(), opt_filter, |group_index| { - let blocked_index = BlockedGroupIndex::new(group_index); + let blocked_index = group_index_builder.build(group_index); let count = &mut self.counts[blocked_index.block_id()] [blocked_index.block_offset()]; *count += 1; @@ -436,12 +438,14 @@ impl GroupsAccumulator for CountGroupsAccumulator { 0, ); + let group_index_builder = + BlockedGroupIndexBuilder::new(self.block_size.is_some()); do_count_merge_batch( values, group_indices, opt_filter, |group_index, partial_count| { - let blocked_index = BlockedGroupIndex::new(group_index); + let blocked_index = group_index_builder.build(group_index); let count = &mut self.counts[blocked_index.block_id()] [blocked_index.block_offset()]; *count += partial_count; diff --git a/datafusion/physical-plan/src/aggregates/group_values/row.rs b/datafusion/physical-plan/src/aggregates/group_values/row.rs index 313ba315686e1..2ff433ac7a87e 100644 --- a/datafusion/physical-plan/src/aggregates/group_values/row.rs +++ b/datafusion/physical-plan/src/aggregates/group_values/row.rs @@ -29,7 +29,7 @@ use datafusion_common::{DataFusionError, Result}; use datafusion_execution::memory_pool::proxy::{RawTableAllocExt, VecAllocExt}; use datafusion_expr::EmitTo; use datafusion_functions_aggregate_common::aggregate::groups_accumulator::{ - BlockedGroupIndex, Blocks, + BlockedGroupIndex, BlockedGroupIndexBuilder, Blocks, }; use hashbrown::raw::RawTable; @@ -135,7 +135,8 @@ impl GroupValues for GroupValuesRows { batch_hashes.resize(n_rows, 0); create_hashes(cols, &self.random_state, batch_hashes)?; - let flag = self.block_size.is_some() as u8; + let group_index_builder = + BlockedGroupIndexBuilder::new(self.block_size.is_some()); for (row, &target_hash) in batch_hashes.iter().enumerate() { let entry = self.map.get_mut(target_hash, |(exist_hash, group_idx)| { // Somewhat surprisingly, this closure can be called even if the @@ -149,7 +150,7 @@ impl GroupValues for GroupValuesRows { // verify that the group that we are inserting with hash is // actually the same key value as the group in // existing_idx (aka group_values @ row) - let blocked_index = BlockedGroupIndex::new(*group_idx); + let blocked_index = group_index_builder.build(*group_idx); group_rows.row(row) == group_values[blocked_index.block_id()] .row(blocked_index.block_offset()) @@ -179,7 +180,6 @@ impl GroupValues for GroupValuesRows { cur_blk.push(group_rows.row(row)); let blocked_index = BlockedGroupIndex::new_from_parts( - flag, blk_id as u32, blk_offset as u64, ); @@ -284,16 +284,16 @@ impl GroupValues for GroupValuesRows { let output = self.row_converter.convert_rows(cur_blk.iter())?; unsafe { - let flag = self.block_size.is_some() as u8; + let group_index_builder = + BlockedGroupIndexBuilder::new(self.block_size.is_some()); for bucket in self.map.iter() { // Decrement group index by n let group_idx = bucket.as_ref().1; - let old_blk_idx = BlockedGroupIndex::new(group_idx); + let old_blk_idx = group_index_builder.build(group_idx); match old_blk_idx.block_id().checked_sub(1) { // Group index was >= n, shift value down Some(new_blk_id) => { let new_group_idx = BlockedGroupIndex::new_from_parts( - flag, new_blk_id as u32, old_blk_idx.block_offset, );