From 4caeee43daa06dade874df014d519e381adb0d85 Mon Sep 17 00:00:00 2001 From: Maximilian Moser Date: Wed, 22 Feb 2023 17:48:23 +0100 Subject: [PATCH] aggregations: add support for "terms" bucket aggregation --- invenio_stats/aggregations.py | 43 +++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/invenio_stats/aggregations.py b/invenio_stats/aggregations.py index bb0511c..133abae 100644 --- a/invenio_stats/aggregations.py +++ b/invenio_stats/aggregations.py @@ -156,6 +156,10 @@ def list_bookmarks(self, start_date=None, end_date=None, limit=None): "sum", } +ALLOWED_BUCKETS = { + "terms", +} + class StatAggregator(object): """Generic aggregation class. @@ -195,6 +199,7 @@ def __init__( client=None, field=None, metric_fields=None, + bucket_fields=None, copy_fields=None, query_modifiers=None, interval="day", @@ -210,6 +215,9 @@ def __init__( metric aggregation will be computed. The format of the dictionary is "destination field" -> tuple("metric type", "source field", "metric_options"). + :param bucket_fields: dictionary of fields on which a bucket aggregation will + be computed. The format of the dictionary is + "dest field" -> tuple("bucket agg type", "src field", "bucket agg params"). :param copy_fields: list of fields which are copied from the raw events into the aggregation. :param query_modifiers: list of functions modifying the raw events @@ -225,6 +233,7 @@ def __init__( self.index = prefix_index(f"stats-{event}") self.field = field self.metric_fields = metric_fields or {} + self.bucket_fields = bucket_fields or {} self.interval = interval self.doc_id_suffix = SUPPORTED_INTERVALS[interval] self.index_interval = index_interval @@ -245,6 +254,15 @@ def __init__( ) ) + if any(v not in ALLOWED_BUCKETS for k, (v, _, _) in self.bucket_fields.items()): + raise ( + ValueError( + "Bucket aggregation type should be one of [{}]".format( + ", ".join(ALLOWED_BUCKETS) + ) + ) + ) + if list(SUPPORTED_INTERVALS.keys()).index(interval) > list( SUPPORTED_INTERVALS.keys() ).index(index_interval): @@ -293,6 +311,22 @@ def _split_date_range(self, lower_limit, upper_limit): res[dt_key] = upper_limit return res + def _handle_bucket_agg(aggregation_buckets): + """Transform the bucket aggregation result into something leaner. + + In the case of a "terms" bucket aggregation, this function will turn the list + of bucket documents into a simple object of the shape + ``{"key1": count1, "key2": count2, ...}``. + """ + result = {} + for bucket in aggregation_buckets: + # NOTE that this is primarily intended for 'terms' buckets and needs to be + # checked in case we want to support further bucket aggregations + keyword = bucket.get("key_as_string", str(bucket["key"])) + result[keyword] = bucket["doc_count"] + + return result + def agg_iter(self, dt): """Aggregate and return dictionary to be indexed in ES.""" rounded_dt = format_range_dt(dt, self.interval) @@ -333,6 +367,9 @@ def agg_iter(self, dt): for dst, (metric, src, opts) in self.metric_fields.items(): terms.metric(dst, metric, field=src, **opts) + for dst, (bucket, src, opts) in self.bucket_fields.items(): + terms.bucket(dst, bucket, field=src, **opts) + results = self.agg_query.execute( # NOTE: Without this, the aggregation changes above, do not # invalidate the search's response cache, and thus you would @@ -355,6 +392,12 @@ def agg_iter(self, dt): for f in self.metric_fields: aggregation_data[f] = aggregation[f]["value"] + if self.bucket_fields: + for f in self.bucket_fields: + aggregation_data[f] = self._handle_bucket_agg( + aggregation[f]["buckets"] + ) + for destination, source in self.copy_fields.items(): if isinstance(source, str): aggregation_data[destination] = doc[source]