Skip to content

Commit

Permalink
aggregations: add support for "terms" bucket aggregation
Browse files Browse the repository at this point in the history
  • Loading branch information
max-moser committed Feb 23, 2023
1 parent 1b1c57f commit 4caeee4
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions invenio_stats/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ def list_bookmarks(self, start_date=None, end_date=None, limit=None):
"sum",
}

ALLOWED_BUCKETS = {
"terms",
}


class StatAggregator(object):
"""Generic aggregation class.
Expand Down Expand Up @@ -195,6 +199,7 @@ def __init__(
client=None,
field=None,
metric_fields=None,
bucket_fields=None,
copy_fields=None,
query_modifiers=None,
interval="day",
Expand All @@ -210,6 +215,9 @@ def __init__(
metric aggregation will be computed. The format of the dictionary
is "destination field" ->
tuple("metric type", "source field", "metric_options").
:param bucket_fields: dictionary of fields on which a bucket aggregation will
be computed. The format of the dictionary is
"dest field" -> tuple("bucket agg type", "src field", "bucket agg params").
:param copy_fields: list of fields which are copied from the raw events
into the aggregation.
:param query_modifiers: list of functions modifying the raw events
Expand All @@ -225,6 +233,7 @@ def __init__(
self.index = prefix_index(f"stats-{event}")
self.field = field
self.metric_fields = metric_fields or {}
self.bucket_fields = bucket_fields or {}
self.interval = interval
self.doc_id_suffix = SUPPORTED_INTERVALS[interval]
self.index_interval = index_interval
Expand All @@ -245,6 +254,15 @@ def __init__(
)
)

if any(v not in ALLOWED_BUCKETS for k, (v, _, _) in self.bucket_fields.items()):
raise (
ValueError(
"Bucket aggregation type should be one of [{}]".format(
", ".join(ALLOWED_BUCKETS)
)
)
)

if list(SUPPORTED_INTERVALS.keys()).index(interval) > list(
SUPPORTED_INTERVALS.keys()
).index(index_interval):
Expand Down Expand Up @@ -293,6 +311,22 @@ def _split_date_range(self, lower_limit, upper_limit):
res[dt_key] = upper_limit
return res

def _handle_bucket_agg(aggregation_buckets):
"""Transform the bucket aggregation result into something leaner.
In the case of a "terms" bucket aggregation, this function will turn the list
of bucket documents into a simple object of the shape
``{"key1": count1, "key2": count2, ...}``.
"""
result = {}
for bucket in aggregation_buckets:
# NOTE that this is primarily intended for 'terms' buckets and needs to be
# checked in case we want to support further bucket aggregations
keyword = bucket.get("key_as_string", str(bucket["key"]))
result[keyword] = bucket["doc_count"]

return result

def agg_iter(self, dt):
"""Aggregate and return dictionary to be indexed in ES."""
rounded_dt = format_range_dt(dt, self.interval)
Expand Down Expand Up @@ -333,6 +367,9 @@ def agg_iter(self, dt):
for dst, (metric, src, opts) in self.metric_fields.items():
terms.metric(dst, metric, field=src, **opts)

for dst, (bucket, src, opts) in self.bucket_fields.items():
terms.bucket(dst, bucket, field=src, **opts)

results = self.agg_query.execute(
# NOTE: Without this, the aggregation changes above, do not
# invalidate the search's response cache, and thus you would
Expand All @@ -355,6 +392,12 @@ def agg_iter(self, dt):
for f in self.metric_fields:
aggregation_data[f] = aggregation[f]["value"]

if self.bucket_fields:
for f in self.bucket_fields:
aggregation_data[f] = self._handle_bucket_agg(
aggregation[f]["buckets"]
)

for destination, source in self.copy_fields.items():
if isinstance(source, str):
aggregation_data[destination] = doc[source]
Expand Down

0 comments on commit 4caeee4

Please sign in to comment.