diff --git a/invenio_stats/aggregations.py b/invenio_stats/aggregations.py index 993d44d..0e18503 100644 --- a/invenio_stats/aggregations.py +++ b/invenio_stats/aggregations.py @@ -11,7 +11,6 @@ import math from collections import OrderedDict -from copy import deepcopy from datetime import datetime from functools import wraps @@ -156,6 +155,10 @@ def list_bookmarks(self, start_date=None, end_date=None, limit=None): "sum", } +ALLOWED_BUCKET_AGGS = { + "terms", +} + class StatAggregator(object): """Generic aggregation class. @@ -195,6 +198,7 @@ def __init__( client=None, field=None, metric_fields=None, + bucket_fields=None, copy_fields=None, query_modifiers=None, interval="day", @@ -210,6 +214,9 @@ def __init__( metric aggregation will be computed. The format of the dictionary is "destination field" -> tuple("metric type", "source field", "metric_options"). + :param bucket_fields: dictionary of fields on which a bucket aggregation will + be computed. The format of the dictionary is + "dest field" -> tuple("bucket agg type", "src field", "bucket agg params"). :param copy_fields: list of fields which are copied from the raw events into the aggregation. :param query_modifiers: list of functions modifying the raw events @@ -225,6 +232,7 @@ def __init__( self.index = prefix_index(f"stats-{event}") self.field = field self.metric_fields = metric_fields or {} + self.bucket_fields = bucket_fields or {} self.interval = interval self.doc_id_suffix = SUPPORTED_INTERVALS[interval] self.index_interval = index_interval @@ -245,6 +253,17 @@ def __init__( ) ) + if any( + v not in ALLOWED_BUCKET_AGGS for k, (v, _, _) in self.bucket_fields.items() + ): + raise ( + ValueError( + "Bucket aggregation type should be one of [{}]".format( + ", ".join(ALLOWED_BUCKET_AGGS) + ) + ) + ) + if list(SUPPORTED_INTERVALS.keys()).index(interval) > list( SUPPORTED_INTERVALS.keys() ).index(index_interval): @@ -293,15 +312,34 @@ def _split_date_range(self, lower_limit, upper_limit): res[dt_key] = upper_limit return res + def _handle_bucket_agg(aggregation_buckets): + """Transform the bucket aggregation result into something leaner. + + In the case of a "terms" bucket aggregation, this function will turn the list + of bucket documents into a simple object of the shape + ``{"key1": count1, "key2": count2, ...}``. + """ + result = {} + for bucket in aggregation_buckets: + # NOTE that this is primarily intended for 'terms' buckets and needs to be + # checked in case we want to support further bucket aggregations + keyword = bucket.get("key_as_string", str(bucket["key"])) + result[keyword] = bucket["doc_count"] + + return result + def agg_iter(self, dt): """Aggregate and return dictionary to be indexed in the search engine.""" rounded_dt = format_range_dt(dt, self.interval) - agg_query = dsl.Search(using=self.client, index=self.event_index).filter( - "range", - # Filter for the specific interval (hour, day, month) - timestamp={"gte": rounded_dt, "lte": rounded_dt}, + self.agg_query = ( + dsl.Search(using=self.client, index=self.event_index) + .filter( + "range", + # Filter for the specific interval (hour, day, month) + timestamp={"gte": rounded_dt, "lte": rounded_dt}, + ) + .extra(size=0) ) - self.agg_query = agg_query # apply query modifiers for modifier in self.query_modifiers: @@ -330,6 +368,9 @@ def agg_iter(self, dt): for dst, (metric, src, opts) in self.metric_fields.items(): terms.metric(dst, metric, field=src, **opts) + for dst, (bucket, src, opts) in self.bucket_fields.items(): + terms.bucket(dst, bucket, field=src, **opts) + results = self.agg_query.execute( # NOTE: Without this, the aggregation changes above, do not # invalidate the search's response cache, and thus you would @@ -337,7 +378,8 @@ def agg_iter(self, dt): ignore_cache=True, ) for aggregation in results.aggregations["terms"].buckets: - doc = aggregation.top_hit.hits.hits[0]["_source"] + doc = aggregation.top_hit.hits.hits[0]["_source"].to_dict() + aggregation = aggregation.to_dict() interval_date = datetime.strptime( doc["timestamp"], "%Y-%m-%dT%H:%M:%S" ).replace(**dict.fromkeys(INTERVAL_ROUNDING[self.interval], 0)) @@ -351,6 +393,12 @@ def agg_iter(self, dt): for f in self.metric_fields: aggregation_data[f] = aggregation[f]["value"] + if self.bucket_fields: + for f in self.bucket_fields: + aggregation_data[f] = self._handle_bucket_agg( + aggregation[f]["buckets"] + ) + for destination, source in self.copy_fields.items(): if isinstance(source, str): aggregation_data[destination] = doc[source] diff --git a/invenio_stats/utils.py b/invenio_stats/utils.py index ca64f6a..3ded52b 100644 --- a/invenio_stats/utils.py +++ b/invenio_stats/utils.py @@ -13,12 +13,11 @@ from base64 import b64encode from math import ceil -from flask import current_app, request, session +from flask import request, session from flask_login import current_user from geolite2 import geolite2 from invenio_cache import current_cache from invenio_search.engine import dsl -from werkzeug.utils import import_string def get_anonymization_salt(ts):