Skip to content

Commit

Permalink
Cache the argument types in first/last functions (#7096)
Browse files Browse the repository at this point in the history
They cannot change after the first call, but currently we're looking
them up each row, which adds up to a major part of these functions' run
time.
  • Loading branch information
akuzm authored Aug 15, 2024
1 parent c2269db commit e0017d8
Show file tree
Hide file tree
Showing 3 changed files with 1,136 additions and 96 deletions.
168 changes: 96 additions & 72 deletions src/agg_bookend.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,22 @@ TS_FUNCTION_INFO_V1(ts_bookend_deserializefunc);
/* A PolyDatum represents a polymorphic datum */
typedef struct PolyDatum
{
Oid type_oid;
bool is_null;
Datum datum;
} PolyDatum;

typedef struct TypeInfoCache
{
Oid typoid;
int16 typlen;
bool typbyval;
} TypeInfoCache;

/* PolyDatumIOState is internal state used by polydatum_serialize and polydatum_deserialize */
typedef struct PolyDatumIOState
{
Oid type_oid;
TypeInfoCache type;

FmgrInfo proc;
Oid typeioparam;
} PolyDatumIOState;
Expand All @@ -54,7 +61,6 @@ polydatum_from_arg(int argno, FunctionCallInfo fcinfo)
{
PolyDatum value;

value.type_oid = get_fn_expr_argtype(fcinfo->flinfo, argno);
value.is_null = PG_ARGISNULL(argno);
if (!value.is_null)
value.datum = PG_GETARG_DATUM(argno);
Expand Down Expand Up @@ -92,7 +98,8 @@ polydatum_serialize(PolyDatum *pd, StringInfo buf, PolyDatumIOState *state, Func
{
bytea *outputbytes;

polydatum_serialize_type(buf, pd->type_oid);
Assert(OidIsValid(state->type.typoid));
polydatum_serialize_type(buf, state->type.typoid);

if (pd->is_null)
{
Expand All @@ -101,15 +108,6 @@ polydatum_serialize(PolyDatum *pd, StringInfo buf, PolyDatumIOState *state, Func
return;
}

if (state->type_oid != pd->type_oid)
{
Oid func;
bool is_varlena;

getTypeBinaryOutputInfo(pd->type_oid, &func, &is_varlena);
fmgr_info_cxt(func, &state->proc, fcinfo->flinfo->fn_mcxt);
state->type_oid = pd->type_oid;
}
outputbytes = SendFunctionCall(&state->proc, pd->datum);
pq_sendint32(buf, VARSIZE(outputbytes) - VARHDRSZ);
pq_sendbytes(buf, VARDATA(outputbytes), VARSIZE(outputbytes) - VARHDRSZ);
Expand Down Expand Up @@ -149,7 +147,7 @@ polydatum_deserialize(MemoryContext mem_ctx, PolyDatum *result, StringInfo buf,

MemoryContext old_context = MemoryContextSwitchTo(mem_ctx);

result->type_oid = polydatum_deserialize_type(buf);
Oid deserialized_type = polydatum_deserialize_type(buf);

/* Following is copied/adapted from record_recv in core postgres */

Expand Down Expand Up @@ -190,13 +188,15 @@ polydatum_deserialize(MemoryContext mem_ctx, PolyDatum *result, StringInfo buf,
}

/* Now call the column's receiveproc */
if (state->type_oid != result->type_oid)
if (state->type.typoid != deserialized_type)
{
Oid func;
Assert(!OidIsValid(state->type.typoid));

getTypeBinaryInputInfo(result->type_oid, &func, &state->typeioparam);
Oid func;
getTypeBinaryInputInfo(deserialized_type, &func, &state->typeioparam);
fmgr_info_cxt(func, &state->proc, fcinfo->flinfo->fn_mcxt);
state->type_oid = result->type_oid;
state->type.typoid = deserialized_type;
get_typlenbyval(state->type.typoid, &state->type.typlen, &state->type.typbyval);
}

result->datum = ReceiveFunctionCall(&state->proc, bufptr, state->typeioparam, -1);
Expand All @@ -217,9 +217,17 @@ polydatum_deserialize(MemoryContext mem_ctx, PolyDatum *result, StringInfo buf,
return result;
}

typedef struct TransCache
{
TypeInfoCache value_type_cache;
TypeInfoCache cmp_type_cache;
FmgrInfo cmp_proc;
} TransCache;

/* Internal state for bookend aggregates */
typedef struct InternalCmpAggStore
{
TransCache aggstate_type_cache;
PolyDatum value;
PolyDatum cmp; /* the comparison element. e.g. time */
} InternalCmpAggStore;
Expand All @@ -242,22 +250,12 @@ typedef struct InternalCmpAggStoreIOState
PolyDatumIOState cmp; /* the comparison element. e.g. time */
} InternalCmpAggStoreIOState;

typedef struct TypeInfoCache
{
Oid type_oid;
int16 typelen;
bool typebyval;
} TypeInfoCache;

inline static void
typeinfocache_polydatumcopy(TypeInfoCache *tic, PolyDatum input, PolyDatum *output)
{
if (tic->type_oid != input.type_oid)
{
tic->type_oid = input.type_oid;
get_typlenbyval(tic->type_oid, &tic->typelen, &tic->typebyval);
}
if (!tic->typebyval && !output->is_null)
Assert(OidIsValid(tic->typoid));

if (!tic->typbyval && !output->is_null)
{
pfree(DatumGetPointer(output->datum));
}
Expand All @@ -266,7 +264,7 @@ typeinfocache_polydatumcopy(TypeInfoCache *tic, PolyDatum input, PolyDatum *outp

if (!input.is_null)
{
output->datum = datumCopy(input.datum, tic->typebyval, tic->typelen);
output->datum = datumCopy(input.datum, tic->typbyval, tic->typlen);
output->is_null = false;
}
else
Expand Down Expand Up @@ -302,48 +300,45 @@ cmpproc_cmp(FmgrInfo *cmp_proc, FunctionCallInfo fcinfo, PolyDatum left, PolyDat
return DatumGetBool(FunctionCall2Coll(cmp_proc, fcinfo->fncollation, left.datum, right.datum));
}

typedef struct TransCache
{
TypeInfoCache value_type_cache;
TypeInfoCache cmp_type_cache;
FmgrInfo cmp_proc;
} TransCache;

static TransCache *
transcache_get(FunctionCallInfo fcinfo)
{
TransCache *my_extra = (TransCache *) fcinfo->flinfo->fn_extra;

if (my_extra == NULL)
{
fcinfo->flinfo->fn_extra =
MemoryContextAllocZero(fcinfo->flinfo->fn_mcxt, sizeof(TransCache));
my_extra = (TransCache *) fcinfo->flinfo->fn_extra;
}
return my_extra;
}

/*
* bookend_sfunc - internal function called by ts_last_sfunc and ts_first_sfunc;
*/
static inline Datum
bookend_sfunc(MemoryContext aggcontext, InternalCmpAggStore *state, PolyDatum value, PolyDatum cmp,
char *opname, FunctionCallInfo fcinfo)
bookend_sfunc(MemoryContext aggcontext, InternalCmpAggStore *state, char *opname,
FunctionCallInfo fcinfo)
{
PolyDatum value = polydatum_from_arg(1, fcinfo);
PolyDatum cmp = polydatum_from_arg(2, fcinfo);

MemoryContext old_context;
TransCache *cache = transcache_get(fcinfo);

old_context = MemoryContextSwitchTo(aggcontext);

if (state == NULL)
{
state = init_store(aggcontext);
cmpproc_init(fcinfo, &cache->cmp_proc, cmp.type_oid, opname);
TransCache *cache = &state->aggstate_type_cache;

TypeInfoCache *v = &cache->value_type_cache;
v->typoid = get_fn_expr_argtype(fcinfo->flinfo, 1);
get_typlenbyval(v->typoid, &v->typlen, &v->typbyval);

TypeInfoCache *c = &cache->cmp_type_cache;
c->typoid = get_fn_expr_argtype(fcinfo->flinfo, 2);
get_typlenbyval(c->typoid, &c->typlen, &c->typbyval);

typeinfocache_polydatumcopy(&cache->value_type_cache, value, &state->value);
typeinfocache_polydatumcopy(&cache->cmp_type_cache, cmp, &state->cmp);
}
else if (!cmp.is_null)
{
TransCache *cache = &state->aggstate_type_cache;

if (cache->cmp_proc.fn_addr == NULL)
{
cmpproc_init(fcinfo, &cache->cmp_proc, cache->cmp_type_cache.typoid, opname);
}

/* only do comparison if cmp is not NULL */
if (state->cmp.is_null || cmpproc_cmp(&cache->cmp_proc, fcinfo, cmp, state->cmp))
{
Expand All @@ -364,13 +359,10 @@ bookend_combinefunc(MemoryContext aggcontext, InternalCmpAggStore *state1,
InternalCmpAggStore *state2, char *opname, FunctionCallInfo fcinfo)
{
MemoryContext old_context;
TransCache *cache;

if (state2 == NULL)
PG_RETURN_POINTER(state1);

cache = transcache_get(fcinfo);

/*
* manually copy all fields from state2 to state1, as per other combine
* func like int8_avg_combine
Expand All @@ -380,8 +372,21 @@ bookend_combinefunc(MemoryContext aggcontext, InternalCmpAggStore *state1,
old_context = MemoryContextSwitchTo(aggcontext);

state1 = init_store(aggcontext);
typeinfocache_polydatumcopy(&cache->value_type_cache, state2->value, &state1->value);
typeinfocache_polydatumcopy(&cache->cmp_type_cache, state2->cmp, &state1->cmp);
Assert(OidIsValid(state2->aggstate_type_cache.value_type_cache.typoid));
Assert(OidIsValid(state2->aggstate_type_cache.cmp_type_cache.typoid));
TransCache *cache1 = &state1->aggstate_type_cache;
TransCache *cache2 = &state2->aggstate_type_cache;
/*
* Initialize the type information from the right-hand state. Note that
* we will have to re-lookup the comparison procedure on demand, because
* the comparison procedure from the right-hand state might have been
* allocated in a different memory context.
*/
cache1->value_type_cache = cache2->value_type_cache;
cache1->cmp_type_cache = cache2->cmp_type_cache;

typeinfocache_polydatumcopy(&cache1->value_type_cache, state2->value, &state1->value);
typeinfocache_polydatumcopy(&cache1->cmp_type_cache, state2->cmp, &state1->cmp);

MemoryContextSwitchTo(old_context);
PG_RETURN_POINTER(state1);
Expand All @@ -399,12 +404,16 @@ bookend_combinefunc(MemoryContext aggcontext, InternalCmpAggStore *state1,
PG_RETURN_POINTER(state1);
}

cmpproc_init(fcinfo, &cache->cmp_proc, state1->cmp.type_oid, opname);
if (cmpproc_cmp(&cache->cmp_proc, fcinfo, state2->cmp, state1->cmp))
TransCache *cache1 = &state1->aggstate_type_cache;
if (cache1->cmp_proc.fn_addr == NULL)
{
cmpproc_init(fcinfo, &cache1->cmp_proc, cache1->cmp_type_cache.typoid, opname);
}
if (cmpproc_cmp(&cache1->cmp_proc, fcinfo, state2->cmp, state1->cmp))
{
old_context = MemoryContextSwitchTo(aggcontext);
typeinfocache_polydatumcopy(&cache->value_type_cache, state2->value, &state1->value);
typeinfocache_polydatumcopy(&cache->cmp_type_cache, state2->cmp, &state1->cmp);
typeinfocache_polydatumcopy(&cache1->value_type_cache, state2->value, &state1->value);
typeinfocache_polydatumcopy(&cache1->cmp_type_cache, state2->cmp, &state1->cmp);
MemoryContextSwitchTo(old_context);
}

Expand All @@ -417,8 +426,6 @@ ts_first_sfunc(PG_FUNCTION_ARGS)
{
InternalCmpAggStore *store =
PG_ARGISNULL(0) ? NULL : (InternalCmpAggStore *) PG_GETARG_POINTER(0);
PolyDatum value = polydatum_from_arg(1, fcinfo);
PolyDatum cmp = polydatum_from_arg(2, fcinfo);
MemoryContext aggcontext;

if (!AggCheckCallContext(fcinfo, &aggcontext))
Expand All @@ -427,7 +434,7 @@ ts_first_sfunc(PG_FUNCTION_ARGS)
elog(ERROR, "first_sfun called in non-aggregate context");
}

return bookend_sfunc(aggcontext, store, value, cmp, "<", fcinfo);
return bookend_sfunc(aggcontext, store, "<", fcinfo);
}

/* last(internal internal_state, anyelement value, "any" comparison_element) */
Expand All @@ -436,8 +443,6 @@ ts_last_sfunc(PG_FUNCTION_ARGS)
{
InternalCmpAggStore *store =
PG_ARGISNULL(0) ? NULL : (InternalCmpAggStore *) PG_GETARG_POINTER(0);
PolyDatum value = polydatum_from_arg(1, fcinfo);
PolyDatum cmp = polydatum_from_arg(2, fcinfo);
MemoryContext aggcontext;

if (!AggCheckCallContext(fcinfo, &aggcontext))
Expand All @@ -446,7 +451,7 @@ ts_last_sfunc(PG_FUNCTION_ARGS)
elog(ERROR, "last_sfun called in non-aggregate context");
}

return bookend_sfunc(aggcontext, store, value, cmp, ">", fcinfo);
return bookend_sfunc(aggcontext, store, ">", fcinfo);
}

/* first_combinerfunc(internal, internal) => internal */
Expand Down Expand Up @@ -502,6 +507,21 @@ ts_bookend_serializefunc(PG_FUNCTION_ARGS)
fcinfo->flinfo->fn_extra =
MemoryContextAllocZero(fcinfo->flinfo->fn_mcxt, sizeof(InternalCmpAggStoreIOState));
my_extra = (InternalCmpAggStoreIOState *) fcinfo->flinfo->fn_extra;

Oid func;
bool is_varlena;

my_extra->value.type = state->aggstate_type_cache.value_type_cache;
Assert(OidIsValid(my_extra->value.type.typoid));

getTypeBinaryOutputInfo(my_extra->value.type.typoid, &func, &is_varlena);
fmgr_info_cxt(func, &my_extra->value.proc, fcinfo->flinfo->fn_mcxt);

my_extra->cmp.type = state->aggstate_type_cache.cmp_type_cache;
Assert(OidIsValid(my_extra->cmp.type.typoid));

getTypeBinaryOutputInfo(my_extra->cmp.type.typoid, &func, &is_varlena);
fmgr_info_cxt(func, &my_extra->cmp.proc, fcinfo->flinfo->fn_mcxt);
}
pq_begintypsend(&buf);
polydatum_serialize(&state->value, &buf, &my_extra->value, fcinfo);
Expand Down Expand Up @@ -542,6 +562,10 @@ ts_bookend_deserializefunc(PG_FUNCTION_ARGS)
result = MemoryContextAllocZero(aggcontext, sizeof(InternalCmpAggStore));
polydatum_deserialize(aggcontext, &result->value, &buf, &my_extra->value, fcinfo);
polydatum_deserialize(aggcontext, &result->cmp, &buf, &my_extra->cmp, fcinfo);

result->aggstate_type_cache.value_type_cache = my_extra->value.type;
result->aggstate_type_cache.cmp_type_cache = my_extra->cmp.type;

PG_RETURN_POINTER(result);
}

Expand Down
Loading

0 comments on commit e0017d8

Please sign in to comment.