Skip to content

Commit

Permalink
fixing the unified bitmap
Browse files Browse the repository at this point in the history
  • Loading branch information
akuzm committed Oct 22, 2024
1 parent 2340280 commit a2806ea
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 53 deletions.
29 changes: 29 additions & 0 deletions tsl/src/compression/arrow_c_data_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,35 @@ arrow_set_row_validity(uint64 *bitmap, size_t row_number, bool value)
Assert(arrow_row_is_valid(bitmap, row_number) == value);
}

static inline uint64 *
arrow_combine_validity(size_t num_words, uint64 *restrict storage, const uint64 *filter1,
const uint64 *filter2, const uint64 *filter3)
{
if (filter1 == NULL && filter2 == NULL && filter3 == NULL)
{
return NULL;
}

for (size_t i = 0; i < num_words; i++)
{
uint64 word = ~0;
if (filter1 != NULL)
{
word &= filter1[i];
}
if (filter2 != NULL)
{
word &= filter2[i];
}
if (filter3 != NULL)
{
word &= filter3[i];
}
storage[i] = word;
}

return storage;
}
/* Increase the `source_value` to be an even multiple of `pad_to`. */
static inline uint64
pad_to_multiple(uint64 pad_to, uint64 source_value)
Expand Down
22 changes: 10 additions & 12 deletions tsl/src/nodes/vector_agg/function/functions.c
Original file line number Diff line number Diff line change
Expand Up @@ -92,28 +92,25 @@ count_any_scalar(void *agg_state, Datum constvalue, bool constisnull, int n,
}

static void
count_any_many_vector(void *agg_state, const ArrowArray *vector, const uint64 *filter,
MemoryContext agg_extra_mctx)
count_any_vector(void *agg_state, const ArrowArray *vector, const uint64 *filter,
MemoryContext agg_extra_mctx)
{
CountState *state = (CountState *) agg_state;
const int n = vector->length;
const uint64 *restrict validity = (uint64 *) vector->buffers[0];
/* First, process the full words. */
for (int i = 0; i < n / 64; i++)
{
const uint64 validity_word = validity ? validity[i] : ~0ULL;
const uint64 filter_word = filter ? filter[i] : ~0ULL;
const uint64 resulting_word = validity_word & filter_word;

#ifdef HAVE__BUILTIN_POPCOUNT
state->count += __builtin_popcountll(resulting_word);
state->count += __builtin_popcountll(filter_word);
#else
/*
* Unfortunately, we have to have this fallback for Windows.
*/
for (uint16 i = 0; i < 64; i++)
{
const bool this_bit = (resulting_word >> i) & 1;
const bool this_bit = (filter_word >> i) & 1;
state->count += this_bit;
}
#endif
Expand All @@ -125,13 +122,14 @@ count_any_many_vector(void *agg_state, const ArrowArray *vector, const uint64 *f
*/
for (int i = 64 * (n / 64); i < n; i++)
{
state->count += arrow_row_is_valid(validity, i) * arrow_row_is_valid(filter, i);
state->count += arrow_row_is_valid(filter, i);
}
}

static void
count_any_many(void *restrict agg_states, const uint32 *offsets, const uint64 *filter,
int start_row, int end_row, const ArrowArray *vector, MemoryContext agg_extra_mctx)
count_any_many_vector(void *restrict agg_states, const uint32 *offsets, const uint64 *filter,
int start_row, int end_row, const ArrowArray *vector,
MemoryContext agg_extra_mctx)
{
for (int row = start_row; row < end_row; row++)
{
Expand All @@ -148,8 +146,8 @@ VectorAggFunctions count_any_agg = {
.agg_init = count_init,
.agg_emit = count_emit,
.agg_scalar = count_any_scalar,
.agg_vector = count_any_many_vector,
.agg_many_vector = count_any_many,
.agg_vector = count_any_vector,
.agg_many_vector = count_any_many_vector,
};

/*
Expand Down
45 changes: 28 additions & 17 deletions tsl/src/nodes/vector_agg/grouping_policy_batch.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ typedef struct
* the grouping policy is reset.
*/
MemoryContext agg_extra_mctx;

uint64 *tmp_filter;
uint64 num_tmp_filter_words;
} GroupingPolicyBatch;

static const GroupingPolicy grouping_policy_batch_functions;
Expand Down Expand Up @@ -91,27 +94,13 @@ gp_batch_reset(GroupingPolicy *obj)
}

static void
compute_single_aggregate(DecompressBatchState *batch_state, VectorAggDef *agg_def, void *agg_state,
MemoryContext agg_extra_mctx)
compute_single_aggregate(GroupingPolicyBatch *policy, DecompressBatchState *batch_state,
VectorAggDef *agg_def, void *agg_state, MemoryContext agg_extra_mctx)
{
ArrowArray *arg_arrow = NULL;
Datum arg_datum = 0;
bool arg_isnull = true;

uint64 *filter = batch_state->vector_qual_result;
if (agg_def->filter_result != NULL)
{
filter = agg_def->filter_result;
if (batch_state->vector_qual_result != NULL)
{
const size_t num_words = (batch_state->total_batch_rows + 63) / 64;
for (size_t i = 0; i < num_words; i++)
{
filter[i] &= batch_state->vector_qual_result[i];
}
}
}

/*
* We have functions with one argument, and one function with no arguments
* (count(*)). Collect the arguments.
Expand All @@ -134,6 +123,16 @@ compute_single_aggregate(DecompressBatchState *batch_state, VectorAggDef *agg_de
}
}

/*
* Compute the unified validity bitmap.
*/
const size_t num_words = (batch_state->total_batch_rows + 63) / 64;
const uint64 *filter = arrow_combine_validity(num_words,
policy->tmp_filter,
batch_state->vector_qual_result,
agg_def->filter_result,
arg_arrow != NULL ? arg_arrow->buffers[0] : NULL);

/*
* Now call the function.
*/
Expand Down Expand Up @@ -167,12 +166,24 @@ static void
gp_batch_add_batch(GroupingPolicy *gp, DecompressBatchState *batch_state)
{
GroupingPolicyBatch *policy = (GroupingPolicyBatch *) gp;

/*
* Allocate the temporary filter array for computing the combined results of
* batch filter, aggregate filter and column validity.
*/
const size_t num_words = (batch_state->total_batch_rows + 63) / 64;
if (num_words > policy->num_tmp_filter_words)
{
policy->tmp_filter = palloc(sizeof(*policy->tmp_filter) * (num_words * 2 + 1));
policy->num_tmp_filter_words = (num_words * 2 + 1);
}

const int naggs = list_length(policy->agg_defs);
for (int i = 0; i < naggs; i++)
{
VectorAggDef *agg_def = (VectorAggDef *) list_nth(policy->agg_defs, i);
void *agg_state = (void *) list_nth(policy->agg_states, i);
compute_single_aggregate(batch_state, agg_def, agg_state, policy->agg_extra_mctx);
compute_single_aggregate(policy, batch_state, agg_def, agg_state, policy->agg_extra_mctx);
}

const int ngrp = list_length(policy->output_grouping_columns);
Expand Down
34 changes: 10 additions & 24 deletions tsl/src/nodes/vector_agg/grouping_policy_hash.c
Original file line number Diff line number Diff line change
Expand Up @@ -146,30 +146,16 @@ compute_single_aggregate(GroupingPolicyHash *policy, const DecompressBatchState
}
}

uint64 *restrict filter = NULL;
if (batch_state->vector_qual_result != NULL || agg_def->filter_result != NULL ||
(arg_arrow != NULL && arg_arrow->buffers[0] != NULL))
{
filter = policy->tmp_filter;
const size_t num_words = (batch_state->total_batch_rows + 63) / 64;
for (size_t i = 0; i < num_words; i++)
{
uint64 word = -1;
if (batch_state->vector_qual_result != NULL)
{
word &= batch_state->vector_qual_result[i];
}
if (agg_def->filter_result != NULL)
{
word &= agg_def->filter_result[i];
}
if (arg_arrow != NULL && arg_arrow->buffers[0] != NULL)
{
word &= ((uint64 *) arg_arrow->buffers[0])[i];
}
filter[i] = word;
}
}
/*
* Compute the unified validity bitmap.
*/
const size_t num_words = (batch_state->total_batch_rows + 63) / 64;
uint64 *restrict filter =
arrow_combine_validity(num_words,
policy->tmp_filter,
batch_state->vector_qual_result,
agg_def->filter_result,
arg_arrow != NULL ? arg_arrow->buffers[0] : NULL);

/*
* Now call the function.
Expand Down

0 comments on commit a2806ea

Please sign in to comment.