diff --git a/docs/aggregations.md b/docs/aggregations.md new file mode 100644 index 000000000..aadf88791 --- /dev/null +++ b/docs/aggregations.md @@ -0,0 +1,405 @@ +# Aggregating Data + +## Aggregation Types +In Quix Streams, aggregation operations are divided into two groups: **Aggregators** and **Collectors**. + +### Aggregators +**Aggregators** incrementally combine the current value and the aggregated data and store the result to the state. +Use them when the aggregation operation can be performed in incremental way, like counting items. + + +### Collectors +**Collectors** accumulate individual values in the state before performing any aggregation. + +They can be used to batch items into a collection, or when the aggregation operation needs +the full dataset, like calculating a median. + +**Collectors** are optimized for storing individual values to the state and perform significantly better than **Aggregators** when you need to accumulate values into a list. + +!!! note + + Performance benefit comes at a price, **Collectors** only support [`.final()`](windowing.md#emitting-after-the-window-is-closed) mode. + Using [`.current()`](windowing.md#emitting-updates-for-each-message) is not supported. + + +## Using Aggregations + +!!! info + + Currently, aggregations can be performed only over [windowed](./windowing.md) data. + + +To calculate an aggregation, you need to define a window. + +To learn more about windows, see the [Windowing](./windowing.md) page. + +When you have a window, call `.agg()` and pass the configured aggregator or collector as a named parameter. + +**Example 1. Count items in the window** + +```python +from datetime import timedelta +from quixstreams import Application +from quixstreams.dataframe.windows import Count + +app = Application(...) +sdf = app.dataframe(...) + + +sdf = ( + + # Define a tumbling window of 10 minutes + sdf.tumbling_window(timedelta(minutes=10)) + + # Call .agg() and provide an Aggregator or Collector to it. + # Here we use a built-in aggregator "Count". + # The parameter name will be used as a part of the aggregated state and returned in the result. + .agg(count=Count()) + + # Specify how the windowed results are emitted. + # Here, emit results only for closed windows. + .final() +) + +# Output: +# { +# 'start': , +# 'end': , +# 'count': 9999 - total number of events in the window +# } +``` + +**Example 2. Accumulating items in the window** + +Use [`Collect()`](api-reference/quixstreams.md#collect) to gather all events within each window period into a list. +Collect takes an optional `column` parameter to limit the collection to one column of the input. + + +```python +from datetime import timedelta +from quixstreams import Application +from quixstreams.dataframe.windows import Collect + +app = Application(...) +sdf = app.dataframe(...) + +sdf = ( + # Define a tumbling window of 10 minutes + sdf.tumbling_window(timedelta(minutes=10)) + + # Collect events in the window into a list + .agg(events=Collect()) + + # Emit results only for closed windows + .final() +) +# Output: +# { +# 'start': , +# 'end': , +# 'events': [event1, event2, event3, ..., eventN] - list of all events in the window +# } +``` + + +### Aggregating over a single column + +**Aggregators** allow you to select a column using the optional `column` parameter. + +When `column` is passed, the Aggregator will perform aggregation only over this column. +It is assumed that the value is a dictionary. + +Otherwise, it will use the whole message. + +```python +from datetime import timedelta +from quixstreams import Application +from quixstreams.dataframe.windows import Min + +app = Application(...) +sdf = app.dataframe(...) + +# Input: +# {"temperature" : 9999} + +sdf = ( + # Define a tumbling window of 10 minutes + sdf.tumbling_window(timedelta(minutes=10)) + + # Calculate the Min aggregation over the "temperature" column + .agg(min_temperature=Min(column="temperature")) + + # Emit results only for closed windows + .final() +) + +# Output: +# { +# 'start': , +# 'end': , +# 'min_temperature': 9999 - minimum temperature +# }**** +``` + + +### Multiple Aggregations + +It is possible to calculate several different aggregations and collections over the same window. + +**Collectors** are optimized to store the values only once when shared with other collectors. + +You can define a wide range of aggregations, such as: + +- Aggregating over multiple message fields at once +- Calculating multiple aggregates for the same value + +**Example**: + +Assume you receive the temperature data from the sensor, and you need to calculate these aggregates for each 10-minute tumbling window: + +- min temperature +- max temperature +- total count of events +- average temperature + +```python +from datetime import timedelta +from quixstreams import Application +from quixstreams.dataframe.windows import Min, Max, Count, Mean + +app = Application(...) +sdf = app.dataframe(...) + +sdf = ( + + # Define a tumbling window of 10 minutes + sdf.tumbling_window(timedelta(minutes=10)) + + .agg( + min_temp=Min("temperature"), + max_temp=Max("temperature"), + avg_temp=Mean("temperature"), + total_events=Count(), + ) + + # Emit results only for closed windows + .final() +) + +# Output: +# { +# 'start': , +# 'end': , +# 'min_temp': 1, +# 'max_temp': 999, +# 'avg_temp': 34.32, +# 'total_events': 999, +# } +``` + + +## Built-in Aggregators and Collectors + +**Aggregators:** + +- [`Count()`](api-reference/quixstreams.md#count) - to count the number of values within a window. +- [`Min()`](api-reference/quixstreams.md#min) - to get a minimum value within a window. +- [`Max()`](api-reference/quixstreams.md#max) - to get a maximum value within a window. +- [`Mean()`](api-reference/quixstreams.md#mean) - to get a mean value within a window. +- [`Sum()`](api-reference/quixstreams.md#sum) - to sum values within a window. +- [`Reduce()`](api-reference/quixstreams.md#reduce) - to write a custom aggregation (deprecated, use [custom aggregator](#custom-aggregator) instead). + +**Collectors:** + +- [`Collect()`](api-reference/quixstreams.md#collect) - to collect all values within a window into a list. + + +## Custom Aggregators + + +To implement a custom aggregator, subclass the `Aggregator` class and implement 3 methods: + +- [`initialize`](api-reference/quixstreams.md#baseaggregatorinitialize): Called when initializing a new window. Starting value of the aggregation. +- [`agg`](api-reference/quixstreams.md#baseaggregatoragg): Called for every item added to the window. It should merge the new value with the aggregated state. +- [`result`](api-reference/quixstreams.md#baseaggregatorresult): Called to generate the result from the aggregated value + + +By default, the aggregation state key includes the aggregation class name. + +If your aggregations accepts parameters, like a column name, you can override the [`state_suffix`](api-reference/quixstreams.md#baseaggregatorstate_suffix) property to include those parameters in the state key. +Whenever the state key changes, the aggregation's state is reset. + + +**Example 1. Power sum** + +Calculate the sum of the power of incoming data over a 10-minute tumbing window,. + +```python +from datetime import timedelta +from quixstreams import Application +from quixstreams.dataframe.windows.aggregations import Aggregator + +app = Application(...) +sdf = app.dataframe(...) + +class PowerSum(Aggregator): + def initialize(self): + return 0 + + def agg(self, aggregated, new, timestamp): + if self.column is not None: + new = new[self.column] + return aggregated + (new * new) + + def result(self, aggregated): + return aggregated + +# Input: +# {"amount" : 2} +# {"amount" : 3} + +sdf = ( + # Define a tumbling window of 10 minutes + sdf.tumbling_window(timedelta(minutes=10)) + + # Aggregate the custom sum + .agg(sum=PowerSum()) + + # Emit results only for closed windows + .final() +) +# Output: +# { +# 'start': , +# 'end': , +# 'sum': 13 +# } +``` + + +**Example 2. Custom aggregation over multiple message fields** + + +```python +from datetime import timedelta +from quixstreams import Application +from quixstreams.dataframe.windows import Aggregator + +class TemperatureAggregator(Aggregator): + def initialize(self): + return { + "min_temp": 0, + "max_temp": 0, + "total_events": 0, + "sum_temp": 0, + } + + def agg(self, old, new, ts): + if self.column is not None: + new = new[self.column] + + old["min_temp"] = min(old["min_temp"], new) + old["max_temp"] = max(old["max_temp"], new) + old["total_events"] += 1 + old["sum_temp"] += new + return old + + def result(self, stored): + return { + "min_temp": stored["min_temp"], + "max_temp": stored["max_temp"], + "total_events": stored["total_events"] + "avg_temp": stored["sum_temp"] / stored["total_events"] + } + + +app = Application(...) +sdf = app.dataframe(...) + +sdf = ( + + # Define a tumbling window of 10 minutes + sdf.tumbling_window(timedelta(minutes=10)) + + .agg( + value=TemperatureAggregator(column="Temperature") + ) + + # Emit results only for closed windows + .final() +) + +# Output: +# { +# 'start': , +# 'end': , +# 'value': { +# 'min_temp': 1, +# 'max_temp': 999, +# 'avg_temp': 34.32, +# 'total_events': 999, +# } +# } +``` + + +## Custom Collectors + +To implement a custom **Collector**, subclass the [`Collector`](api-reference/quixstreams.md#collector) class and implement the [`result`](api-reference/quixstreams.md#basecollectorresult) method. + +It is called when the window is closed with an iterable of all the collected items in this window. + +By default, **Collectors** always store the full message. + +If you only need in a specific column, you can override the [`column`](api-reference/quixstreams.md#basecollectorcolumn) property to specify which column needs to be stored. + + +**Example:** + +Collect all events over a 10-minute tumbling window into a reversed order list. + +```python +from datetime import timedelta +from quixstreams import Application +from quixstreams.dataframe.windows.aggregations import Collector + +app = Application(...) +sdf = app.dataframe(...) + +class ReversedCollect(Collector): + def result(self, items): + # items is the list of all collected item during the window + return list(reversed(items)) + +sdf = ( + # Define a tumbling window of 10 minutes + sdf.tumbling_window(timedelta(minutes=10)) + + # Collect events in the window into a reversed list + .agg(events=ReversedCollect()) + + # Emit results only for closed windows + .final() +) +# Output: +# { +# 'start': , +# 'end': , +# 'events': [eventN, ..., event3, event2, event1] - reversed list of all events in the window +# } +``` + +## Reduce + +!!! warning + `Reduce` is deprecated. Use [multiple aggregations](aggregations.md#multiple-aggregations) and [custom Aggregators](aggregations.md#custom-aggregators) instead. They provide more control over parameters and better state management. + +[`Reduce()`](api-reference/quixstreams.md#reduce) allows you to perform complex aggregations using custom "reducer" and "initializer" functions: + +- The **"initializer"** function receives the **first** value for the given window, and it must return an initial state for this window. +This state will be later passed to the "reducer" function. +**It is called only once for each window.** + +- The **"reducer"** function receives an aggregated state and a current value, and it must combine them and return a new aggregated state. +This function should contain the actual aggregation logic. +It will be called for each message coming into the window, except the first one. diff --git a/docs/groupby.md b/docs/groupby.md index db8925469..dd1d23d7d 100644 --- a/docs/groupby.md +++ b/docs/groupby.md @@ -247,20 +247,19 @@ what `store_id` it came from) ***over the past hour***. In this case, we need to get a windowed sum based on a single column identifier: `item`. This can be done by simply passing the `item` column name to `.groupby()`, followed by -a [`tumbling_window()`](windowing.md#tumbling-windows) [`.sum()`](windowing.md#min-max-mean-and-sum) over the past `3600` seconds: +a [`tumbling_window()`](windowing.md#time-based-tumbling-windows) [`.sum()`](aggregations.md#min-max-mean-and-sum) over the past `3600` seconds: ```python sdf = StreamingDataFrame() sdf = sdf.group_by("item") -sdf = sdf.tumbling_window(duration_ms=3600).sum().final() -sdf = sdf.apply(lambda window_result: {"total_quantity": window_result["value"]}) +sdf = sdf.tumbling_window(duration_ms=3600).agg(total_quantity=agg.Sum()).final() ``` which generates data like: ```python -{"key": "A", "value": {"total_quantity": 9}} +{"key": "A", "total_quantity": 9} # ...etc... -{"key": "B", "value": {"total_quantity": 4}} +{"key": "B", "total_quantity": 4} # ...etc... ``` diff --git a/docs/tutorials/anomaly-detection/tutorial.md b/docs/tutorials/anomaly-detection/tutorial.md index 78e3ddb81..b489e3dce 100644 --- a/docs/tutorials/anomaly-detection/tutorial.md +++ b/docs/tutorials/anomaly-detection/tutorial.md @@ -180,31 +180,12 @@ which means we will be consuming data from a non-Kafka origin. Let's go over the SDF operations in this example in detail. - -### Prep Data for Windowing - -```python -sdf = sdf.apply(lambda data: data["Temperature_C"]) -``` - -To use the built-in windowing functions, our incoming event needs to be transformed: at this point the unaltered event dictionary will look something like (and this should be familiar!): - -`>>> {"Temperature_C": 65, "Timestamp": 1710856626905833677}` - -But it needs to be just the temperature: - -`>>> 65` - -So we'll perform a generic SDF transformation using [`SDF.apply(F)`](../../processing.md#streamingdataframeapply), -(`F` should take your current message value as an argument, and return your new message value): -our `F` is a simple `lambda`, in this case. - - - ### Windowing ```python -sdf = sdf.hopping_window(duration_ms=5000, step_ms=1000).mean().current() +import quixstreams.dataframe.windows.aggregations as agg + +sdf = sdf.hopping_window(duration_ms=5000, step_ms=1000).agg(value=agg.Mean(column="Temperature_C")).current() ``` Now we do a (5 second) windowing operation on our temperature value. A few very important notes here: diff --git a/docs/tutorials/anomaly-detection/tutorial_app.py b/docs/tutorials/anomaly-detection/tutorial_app.py index d6c6e68fb..97cf2280f 100644 --- a/docs/tutorials/anomaly-detection/tutorial_app.py +++ b/docs/tutorials/anomaly-detection/tutorial_app.py @@ -2,6 +2,7 @@ import random import time +import quixstreams.dataframe.windows.aggregations as agg from quixstreams import Application from quixstreams.sources import Source @@ -105,8 +106,11 @@ def main(): # If reading from a Kafka topic, pass topic= instead of a source sdf = app.dataframe(source=TemperatureGenerator()) - sdf = sdf.apply(lambda data: data["Temperature_C"]) - sdf = sdf.hopping_window(duration_ms=5000, step_ms=1000).mean().current() + sdf = ( + sdf.hopping_window(duration_ms=5000, step_ms=1000) + .agg(value=agg.Mean("Temperature_C")) + .current() + ) sdf = sdf.apply(lambda result: round(result["value"], 2)).filter( should_alert, metadata=True ) diff --git a/docs/windowing.md b/docs/windowing.md index 930fe4bd4..ea5a9a72f 100644 --- a/docs/windowing.md +++ b/docs/windowing.md @@ -96,8 +96,9 @@ Since version 2.6, all windowed aggregations always set timestamps equal to the **Example:** ```python -from quixstreams import Application from datetime import timedelta +from quixstreams import Application +from quixstreams.dataframe.windows import Min app = Application(...) @@ -107,14 +108,11 @@ sdf = app.dataframe(...) # value={"temperature" : 9999}, key="sensor_1", timestamp=10001 sdf = ( - # Extract the "temperature" column from the dictionary - sdf.apply(lambda value: value['temperature']) - # Define a tumbling window of 10 seconds .tumbling_window(timedelta(seconds=10)) # Calculate the minimum temperature - .min() + .agg(minimum_temperature=Min("temperature")) # Emit results for every incoming message .current() @@ -123,22 +121,13 @@ sdf = ( # value={ # 'start': 10000, # 'end': 20000, -# 'value': 9999 - minimum temperature +# 'minimum_temperature': 9999 - minimum temperature # }, # key="sensor_1", # timestamp=10000 - timestamp equals to the window start timestamp ``` - -### Message headers of the aggregation results - -Currently, windowed aggregations do not store the original headers of the messages. -The results of the windowed aggregations will have headers set to `None`. - -You may set messages headers by using the `StreamingDataFrame.set_headers()` API, as -described in [the "Updating Kafka Headers" section](./processing.md#updating-kafka-headers). - ## Time-based Tumbling Windows Tumbling windows slice time into non-overlapping intervals of a fixed size. @@ -177,9 +166,9 @@ Input: Expected output: ```json -{"avg_temperature": 30, "window_start_ms": 0, "window_end_ms": 3600000} -{"avg_temperature": 29.5, "window_start_ms": 0, "window_end_ms": 3600000} -{"avg_temperature": 29, "window_start_ms": 0, "window_end_ms": 3600000} +{"avg_temperature": 30, "start": 0, "end": 3600000} +{"avg_temperature": 29.5, "start": 0, "end": 3600000} +{"avg_temperature": 29, "start": 0, "end": 3600000} ``` Here is how to do it using tumbling windows: @@ -187,33 +176,22 @@ Here is how to do it using tumbling windows: ```python from datetime import timedelta from quixstreams import Application +from quixstreams.dataframe.windows import Mean app = Application(...) sdf = app.dataframe(...) sdf = ( - # Extract "temperature" value from the message - sdf.apply(lambda value: value["temperature"]) - # Define a tumbling window of 1 hour # You can also pass duration_ms as an integer of milliseconds .tumbling_window(duration_ms=timedelta(hours=1)) # Specify the "mean" aggregate function - .mean() + .agg(avg_temperature=Mean("temperature")) # Emit updates for each incoming message .current() - - # Unwrap the aggregated result to match the expected output format - .apply( - lambda result: { - "avg_temperature": result["value"], - "window_start_ms": result["start"], - "window_end_ms": result["end"], - } - ) ) ``` @@ -250,7 +228,7 @@ Input: Expected window output: ```json -{"data": [100, 50, 200], "window_start_ms": 121, "window_end_ms": 583} +{"data": [100, 50, 200], "start": 121, "end": 583} ``` Here is how to do it using tumbling windows: @@ -262,6 +240,7 @@ import urllib.request from datetime import timedelta from quixstreams import Application +from quixstreams.dataframe.windows import Collect def external_api(value): with urllib.request.urlopen("https://example.com", data=json.dumps(value["data"])) as rep: @@ -272,26 +251,14 @@ sdf = app.dataframe(...) sdf = ( - # Extract "experience" value from the message - sdf.apply(lambda value: value["data"]) - # Define a count-based tumbling window of 3 events .tumbling_count_window(count=3) # Specify the "collect" aggregate function - .collect() + .agg(data=Collect()) # Emit updates once the window is closed .final() - - # Unwrap the aggregated result to match the expected output format - .apply( - lambda result: { - "data": result["value"], - "window_start_ms": result["start"], - "window_end_ms": result["end"], - } - ) ) # Send a request to the external API @@ -343,45 +310,34 @@ Input: Expected output: ```json -{"avg_temperature": 30, "window_start_ms": 0, "window_end_ms": 3600000} +{"avg_temperature": 30, "start": 0, "end": 3600000} -{"avg_temperature": 29.5, "window_start_ms": 0, "window_end_ms": 3600000} -{"avg_temperature": 30, "window_start_ms": 60000, "window_end_ms": 4200000} +{"avg_temperature": 29.5, "start": 0, "end": 3600000} +{"avg_temperature": 30, "start": 60000, "end": 4200000} -{"avg_temperature": 29, "window_start_ms": 0, "window_end_ms": 3600000} -{"avg_temperature": 28.5, "window_start_ms": 60000, "window_end_ms": 4200000} +{"avg_temperature": 29, "start": 0, "end": 3600000} +{"avg_temperature": 28.5, "start": 60000, "end": 4200000} ``` ```python from datetime import timedelta from quixstreams import Application +from quixstreams.dataframe.windows import Mean app = Application(...) sdf = app.dataframe(...) sdf = ( - # Extract "temperature" value from the message - sdf.apply(lambda value: value["temperature"]) - # Define a hopping window of 1h with 10m step # You can also pass duration_ms and step_ms as integers of milliseconds .hopping_window(duration_ms=timedelta(hours=1), step_ms=timedelta(minutes=10)) # Specify the "mean" aggregate function - .mean() + .agg(avg_temperature=Mean("temperature")) # Emit updates for each incoming message .current() - - # Unwrap the aggregated result to match the expected output format - .apply( - lambda result: { - "avg_temperature": result["value"], - "window_start_ms": result["start"], - "window_end_ms": result["end"], - } - ) ) ``` @@ -448,43 +404,32 @@ Input: Expected output: ```json -{"avg_temperature": 30, "window_start_ms": 0, "window_end_ms": 3600000} -{"avg_temperature": 29.5, "window_start_ms": 1200000, "window_end_ms": 4800000} -{"avg_temperature": 29, "window_start_ms": 1200001, "window_end_ms": 4800001} -{"avg_temperature": 28.5, "window_start_ms": 3600000, "window_end_ms": 7200000} -{"avg_temperature": 27.5, "window_start_ms": 3600001, "window_end_ms": 7200001} # reading 30 is outside of the window +{"avg_temperature": 30, "start": 0, "end": 3600000} +{"avg_temperature": 29.5, "start": 1200000, "end": 4800000} +{"avg_temperature": 29, "start": 1200001, "end": 4800001} +{"avg_temperature": 28.5, "start": 3600000, "end": 7200000} +{"avg_temperature": 27.5, "start": 3600001, "end": 7200001} # reading 30 is outside of the window ``` ```python from datetime import timedelta from quixstreams import Application +from quixstreams.dataframe.windows import Mean app = Application(...) sdf = app.dataframe(...) sdf = ( - # Extract "temperature" value from the message - sdf.apply(lambda value: value["temperature"]) - # Define a sliding window of 1h # You can also pass duration_ms as integer of milliseconds .sliding_window(duration_ms=timedelta(hours=1)) # Specify the "mean" aggregate function - .mean() + .agg(avg_temperature=Mean("temperature")) # Emit updates for each incoming message .current() - - # Unwrap the aggregated result to match the expected output format - .apply( - lambda result: { - "avg_temperature": result["value"], - "window_start_ms": result["start"], - "window_end_ms": result["end"], - } - ) ) ``` @@ -524,10 +469,10 @@ Input: Expected window output: ```json -{"average": 120, "window_start_ms": 121, "window_end_ms": 583} -{"average": 100, "window_start_ms": 165, "window_end_ms": 723} -{"average": 120, "window_start_ms": 583, "window_end_ms": 1009} -{"average": 80, "window_start_ms": 723, "window_end_ms": 1242} +{"average": 120, "start": 121, "end": 583} +{"average": 100, "start": 165, "end": 723} +{"average": 120, "start": 583, "end": 1009} +{"average": 80, "start": 723, "end": 1242} ``` @@ -536,302 +481,25 @@ Here is how to do it using sliding windows: ```python from datetime import timedelta from quixstreams import Application +from quixstreams.dataframe.windows import Mean app = Application(...) sdf = app.dataframe(...) sdf = ( - # Extract "experience" value from the message - sdf.apply(lambda value: value["data"]) - # Define a count-based sliding window of 3 events .sliding_count_window(count=3) # Specify the "mean" aggregate function - .mean() + .agg(average=Mean("amount")) # Emit updates once the window is closed .final() - - # Unwrap the aggregated result to match the expected output format - .apply( - lambda result: { - "average": result["value"], - "window_start_ms": result["start"], - "window_end_ms": result["end"], - } - ) -) - -``` - -## Supported Aggregations - -Currently, windows support the following aggregation functions: - -- [`reduce()`](api-reference/quixstreams.md#fixedtimewindowdefinitionreduce) - to perform custom aggregations using "reducer" and "initializer" functions -- [`collect()`](api-reference/quixstreams.md#fixedtimewindowdefinitioncollect) - to collect all values within a window into a list -- [`min()`](api-reference/quixstreams.md#fixedtimewindowdefinitionmin) - to get a minimum value within a window -- [`max()`](api-reference/quixstreams.md#fixedtimewindowdefinitionmax) - to get a maximum value within a window -- [`mean()`](api-reference/quixstreams.md#fixedtimewindowdefinitionmean) - to get a mean value within a window -- [`sum()`](api-reference/quixstreams.md#fixedtimewindowdefinitionsum) - to sum values within a window -- [`count()`](api-reference/quixstreams.md#fixedtimewindowdefinitioncount) - to count the number of values within a window - -We will go over each ot them in more detail below. - -### Reduce() - -`.reduce()` allows you to perform complex aggregations using custom "reducer" and "initializer" functions: - -- The **"initializer"** function receives the **first** value for the given window, and it must return an initial state for this window. -This state will be later passed to the "reducer" function. -**It is called only once for each window.** - -- The **"reducer"** function receives an aggregated state and a current value, and it must combine them and return a new aggregated state. -This function should contain the actual aggregation logic. -It will be called for each message coming into the window, except the first one. - -With `reduce()`, you can define a wide range of aggregations, such as: - -- Aggregating over multiple message fields at once -- Using multiple message fields to create a single aggregate -- Calculating multiple aggregates for the same value - -**Example**: - -Assume you receive the temperature data from the sensor, and you need to calculate these aggregates for each 10-minute tumbling window: - -- min temperature -- max temperature -- total count of events -- average temperature - -Here is how you can do that with `reduce()`: - -```python -from datetime import timedelta -from quixstreams import Application - -app = Application(...) -sdf = app.dataframe(...) - - -def initializer(value: dict) -> dict: - """ - Initialize the state for aggregation when a new window starts. - - It will prime the aggregation when the first record arrives - in the window. - """ - return { - 'min_temp': value['temperature'], - 'max_temp': value['temperature'], - 'total_events': 1, - '_sum_temp': value['temperature'], - 'avg_temp': value['temperature'] - } - - -def reducer(aggregated: dict, value: dict) -> dict: - """ - Calculate "min", "max", "total" and "average" over temperature values. - - Reducer always receives two arguments: - - previously aggregated value (the "aggregated" argument) - - current value (the "value" argument) - It combines them into a new aggregated value and returns it. - This aggregated value will be also returned as a result of the window. - """ - total_events = aggregated['count'] + 1 - sum_temp = aggregated['_sum_temp'] + value - avg_temp = sum_temp / total_events - return { - 'min_temp': min(aggregated['min_temp'], value['temperature']), - 'max_temp': max(aggregated['max_temp'], value['temperature']), - 'total_events': total_events, - 'avg_temp': avg_temp, - '_sum_temp': sum_temp - } - - -sdf = ( - - # Define a tumbling window of 10 minutes - sdf.tumbling_window(timedelta(minutes=10)) - - # Create a "reduce" aggregation with "reducer" and "initializer" functions - .reduce(reducer=reducer, initializer=initializer) - - # Emit results only for closed windows - .final() ) -# Output: -# { -# 'start': , -# 'end': , -# 'value': {'min_temp': 1, 'max_temp': 999, 'total_events': 999, 'avg_temp': 34.32, '_sum_temp': 9999}, -# } - -``` - - -### Collect() -Use `.collect()` to gather all events in the window into a list. This operation is optimized for collecting values and performs significantly better than using `reduce()` to build a list. - -!!! note - Performance benefit comes at a price: `.collect()` only supports `.final()` mode. Using `.current()` is not supported. - -**Example:** - -Collect all events over a 10-minute tumbling window into a list. - -```python -from datetime import timedelta -from quixstreams import Application - -app = Application(...) -sdf = app.dataframe(...) - -sdf = ( - # Define a tumbling window of 10 minutes - sdf.tumbling_window(timedelta(minutes=10)) - - # Collect events in the window into a list - .collect() - - # Emit results only for closed windows - .final() -) -# Output: -# { -# 'start': , -# 'end': , -# 'value': [event1, event2, event3, ...] - list of all events in the window -# } ``` - -### Count() -Use `.count()` to calculate total number of events in the window. - -**Example:** - -Count all received events over a 10-minute tumbling window. - -```python -from datetime import timedelta -from quixstreams import Application - -app = Application(...) -sdf = app.dataframe(...) - - -sdf = ( - - # Define a tumbling window of 10 minutes - sdf.tumbling_window(timedelta(minutes=10)) - - # Count events in the window - .count() - - # Emit results only for closed windows - .final() -) -# Output: -# { -# 'start': , -# 'end': , -# 'value': 9999 - total number of events in the window -# } -``` - -### Min(), Max(), Mean() and Sum() - -Methods `.min()`, `.max()`, `.mean()`, and `.sum()` provide short API to calculate these aggregates over the streaming windows. - - -**These methods assume that incoming values are numbers.** - -When they are not, extract the numeric values first using `.apply()` function. - -**Example:** - -Imagine you receive the temperature data from the sensor, and you need to calculate only a minimum temperature for each 10-minute tumbling window. - -```python -from datetime import timedelta -from quixstreams import Application - -app = Application(...) -sdf = app.dataframe(...) - -# Input: -# {"temperature" : 9999} - -sdf = ( - # Extract the "temperature" column from the dictionary - sdf.apply(lambda value: value['temperature']) - - # Define a tumbling window of 10 minutes - .tumbling_window(timedelta(minutes=10)) - - # Calculate the minimum temperature - .min() - - # Emit results only for closed windows - .final() -) -# Output: -# { -# 'start': , -# 'end': , -# 'value': 9999 - minimum temperature -# } -``` - - - -## Transforming the result of a windowed aggregation -Windowed aggregations return aggregated results in the following format/schema: - -```python -{"start": , "end": , "value": } -``` - -Since it is rather generic, you may need to transform it into your own schema. -Here is how you can do that: - -```python -from datetime import timedelta -from quixstreams import Application - -app = Application(...) -sdf = app.dataframe(...) - -sdf = ( - # Define a tumbling window of 10 minutes - sdf.tumbling_window(timedelta(minutes=10)) - # Specify the "count" aggregation function - .count() - # Emit results only for closed windows - .final() -) - -# Input format: -# {"start": , "end": , "value": -sdf = sdf.apply( - lambda value: { - "count": value["value"], - "window": (value["start"], value["end"]), - } -) -# Output format: -# {"count": , "window": (, )} -``` - - ## Lateness and Out-of-Order Processing When working with event time, some events may be processed later than they're supposed to. Such events are called **"out-of-order"** because they violate the expected order of time in the data stream. @@ -941,13 +609,14 @@ To emit results for each processed message in the stream, use the following API: ```python from datetime import timedelta from quixstreams import Application +from quixstreams.dataframe.windows import Sum app = Application(...) sdf = app.dataframe(...) # Calculate a sum of values over a window of 10 seconds # and use .current() to emit results immediately -sdf = sdf.tumbling_window(timedelta(seconds=10)).sum().current() +sdf = sdf.tumbling_window(timedelta(seconds=10)).agg(value=Sum()).current() # Results: # -> Timestamp=100, value=1 -> emit {"start": 0, "end": 10000, "value": 1} @@ -969,13 +638,14 @@ Here is how to emit results only once for each window interval after it's closed ```python from datetime import timedelta from quixstreams import Application +from quixstreams.dataframe.windows import Sum app = Application(...) sdf = app.dataframe(...) # Calculate a sum of values over a window of 10 seconds # and use .final() to emit results only when the window is complete -sdf = sdf.tumbling_window(timedelta(seconds=10)).sum().final() +sdf = sdf.tumbling_window(timedelta(seconds=10)).agg(value=Sum()).final() # Results: # -> Timestamp=100, value=1 -> emit nothing (the window is not closed yet) @@ -1000,13 +670,14 @@ If some message keys appear irregularly in the stream, the latest windows can re ```python from datetime import timedelta from quixstreams import Application +from quixstreams.dataframe.windows import Sum app = Application(...) sdf = app.dataframe(...) # Calculate a sum of values over a window of 10 seconds # and use .final() to emit results only when the window is complete -sdf = sdf.tumbling_window(timedelta(seconds=10)).sum().final(closing_strategy="key") +sdf = sdf.tumbling_window(timedelta(seconds=10)).agg(value=Sum()).final(closing_strategy="key") # Details: # -> Timestamp=100, Key="A", value=1 -> emit nothing (the window is not closed yet) @@ -1030,13 +701,14 @@ If messages aren't ordered accross keys some message can be skipped if the windo ```python from datetime import timedelta from quixstreams import Application +from quixstreams.dataframe.windows import Sum app = Application(...) sdf = app.dataframe(...) # Calculate a sum of values over a window of 10 seconds # and use .final() to emit results only when the window is complete -sdf = sdf.tumbling_window(timedelta(seconds=10)).sum().final(closing_strategy="partition") +sdf = sdf.tumbling_window(timedelta(seconds=10)).agg(value=Sum()).final(closing_strategy="partition") # Details: # -> Timestamp=100, Key="A", value=1 -> emit nothing (the window is not closed yet) @@ -1055,6 +727,55 @@ sdf = sdf.tumbling_window(timedelta(seconds=10)).sum().final(closing_strategy="p # (key="C", value={"start": 0, "end": 10000, "value": 3}) ``` +## Transforming the result of a windowed aggregation +Windowed aggregations return aggregated results in the following format/schema: + +```python +{"start": , "end": , : } +``` + +Since it is rather generic, you may need to transform it into your own schema. +Here is how you can do that: + +```python +from datetime import timedelta +from quixstreams import Application +from quixstreams.dataframe.windows import Count + +app = Application(...) +sdf = app.dataframe(...) + +sdf = ( + # Define a tumbling window of 10 minutes + sdf.tumbling_window(timedelta(minutes=10)) + # Specify the "count" aggregation function + .agg(count=Count()) + # Emit results only for closed windows + .final() +) + +# Input format: +# {"start": , "end": , "count": +sdf = sdf.apply( + lambda value: { + "count": value["count"], + "window": (value["start"], value["end"]), + } +) +# Output format: +# {"count": , "window": (, )} +``` + + +### Message headers of the aggregation results + +Currently, windowed aggregations do not store the original headers of the messages. +The results of the windowed aggregations will have headers set to `None`. + +You may set messages headers by using the `StreamingDataFrame.set_headers()` API, as +described in [the "Updating Kafka Headers" section](./processing.md#updating-kafka-headers). + + ## Implementation Details Here are some general concepts about how windowed aggregations are implemented in Quix Streams: @@ -1075,9 +796,8 @@ The state store name is auto-generated by default using the following window att - Window type: `"tumbling"` or `"hopping"` - Window parameters: `duration_ms` and `step_ms` -- Aggregation function name: `"sum"`, `"count"`, `"reduce"`, etc. -E.g. a store name for `sum` aggregation over a hopping window of 30 seconds with a 5 second step will be `hopping_window_30000_5000_sum`. +E.g. a store name for a hopping window of 30 seconds with a 5 second step will be `hopping_window_30000_5000`. ### Updating Window Definitions @@ -1089,10 +809,12 @@ Quix Streams handles some of the situations, like: - Updating window type (e.g. from tumbling to hopping) - Updating window period or step -- Updating an aggregation function (except the `reduce()`) +- Adding/Removing/Updating an aggregation function (except `Reduce()`) + +Updating the window type and parameters will change the name of the underlying state store, and the new window definition will use a different one. -All of the above will change the name of the underlying state store, and the new window definition will use a different one. +Updating an aggregation parameter will change the aggregation state key and reset the modified aggregation state, other aggregations are not impacted. -But in some cases, these measures are not enough. For example, updating a code used in `reduce()` will not change the store name, but the data can still become inconsistent. +But in some cases, these measures are not enough. For example, updating a code used in `Reduce()` will not change the store name, but the data can still become inconsistent. In this case, you may need to update the `consumer_group` passed to the `Application` class. It will re-create all the state stores from scratch. diff --git a/mkdocs.yml b/mkdocs.yml index 883cb8640..2fb74a480 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -37,7 +37,8 @@ nav: - Process & Transform Data: processing.md - Inspecting Data & Debugging: debugging.md - GroupBy Operation: groupby.md - - Windows & Aggregations: windowing.md + - Windows: windowing.md + - Aggregations: aggregations.md - Configuration: configuration.md - StreamingDataFrame Branching: branching.md - Consuming Multiple Topics: consuming-multiple-topics.md diff --git a/quixstreams/dataframe/dataframe.py b/quixstreams/dataframe/dataframe.py index 7622b0fcd..fc66c434b 100644 --- a/quixstreams/dataframe/dataframe.py +++ b/quixstreams/dataframe/dataframe.py @@ -986,10 +986,12 @@ def tumbling_window( - The time windows always use the current event time. - Example Snippet: ```python + from quixstreams import Application + import quixstreams.dataframe.windows.aggregations as agg + app = Application() sdf = app.dataframe(...) @@ -1000,7 +1002,7 @@ def tumbling_window( ) # Specify the aggregation function - .sum() + .agg(value=agg.Sum()) # Specify how the results should be emitted downstream. # "current()" will emit results as they come for each updated window, @@ -1064,15 +1066,21 @@ def tumbling_count_window( - The end timestamp of the aggregation result is set to the latest timestamp. - Every window is grouped by the current Kafka message key. - Messages with `None` key will be ignored. + + Example Snippet: + ```python + from quixstreams import Application + import quixstreams.dataframe.windows.aggregations as agg + app = Application() sdf = app.dataframe(...) sdf = ( # Define a tumbling window of 10 messages sdf.tumbling_count_window(count=10) # Specify the aggregation function - .sum() + .agg(value=agg.Sum()) # Specify how the results should be emitted downstream. # "current()" will emit results as they come for each updated window, # possibly producing multiple messages per key-window pair @@ -1122,6 +1130,9 @@ def hopping_window( Example Snippet: ```python + from quixstreams import Application + import quixstreams.dataframe.windows.aggregations as agg + app = Application() sdf = app.dataframe(...) @@ -1134,7 +1145,7 @@ def hopping_window( ) # Specify the aggregation function - .sum() + .agg(value=agg.Sum()) # Specify how the results should be emitted downstream. # "current()" will emit results as they come for each updated window, @@ -1209,8 +1220,14 @@ def hopping_count_window( - The end timestamp of the aggregation result is set to the latest timestamp. - Every window is grouped by the current Kafka message key. - Messages with `None` key will be ignored. + + Example Snippet: + ```python + from quixstreams import Application + import quixstreams.dataframe.windows.aggregations as agg + app = Application() sdf = app.dataframe(...) sdf = ( @@ -1220,7 +1237,7 @@ def hopping_count_window( step=5, ) # Specify the aggregation function - .sum() + .agg(value=agg.Sum()) # Specify how the results should be emitted downstream. # "current()" will emit results as they come for each updated window, # possibly producing multiple messages per key-window pair @@ -1276,6 +1293,9 @@ def sliding_window( Example Snippet: ```python + from quixstreams import Application + import quixstreams.dataframe.windows.aggregations as agg + app = Application() sdf = app.dataframe(...) @@ -1287,7 +1307,7 @@ def sliding_window( ) # Specify the aggregation function - .sum() + .agg(value=agg.Sum()) # Specify how the results should be emitted downstream. # "current()" will emit results as they come for each updated window, @@ -1354,15 +1374,21 @@ def sliding_count_window( - Every window is grouped by the current Kafka message key. - Messages with `None` key will be ignored. - Every window contains a distinct aggregation. + + Example Snippet: + ```python + from quixstreams import Application + import quixstreams.dataframe.windows.aggregations as agg + app = Application() sdf = app.dataframe(...) sdf = ( # Define a sliding window of 10 messages sdf.sliding_count_window(count=10) # Specify the aggregation function - .sum() + .sum(value=agg.Sum()) # Specify how the results should be emitted downstream. # "current()" will emit results as they come for each updated window, # possibly producing multiple messages per key-window pair diff --git a/quixstreams/dataframe/windows/__init__.py b/quixstreams/dataframe/windows/__init__.py index ab5cb142d..8ac5f771a 100644 --- a/quixstreams/dataframe/windows/__init__.py +++ b/quixstreams/dataframe/windows/__init__.py @@ -1,3 +1,14 @@ +from .aggregations import ( + Aggregator, + Collect, + Collector, + Count, + Max, + Mean, + Min, + Reduce, + Sum, +) from .definitions import ( HoppingCountWindowDefinition, HoppingTimeWindowDefinition, @@ -8,10 +19,19 @@ ) __all__ = [ - "TumblingCountWindowDefinition", + "Collect", + "Count", + "Max", + "Mean", + "Min", + "Reduce", + "Sum", + "Aggregator", + "Collector", "HoppingCountWindowDefinition", - "SlidingCountWindowDefinition", "HoppingTimeWindowDefinition", + "SlidingCountWindowDefinition", "SlidingTimeWindowDefinition", + "TumblingCountWindowDefinition", "TumblingTimeWindowDefinition", ] diff --git a/quixstreams/dataframe/windows/aggregations.py b/quixstreams/dataframe/windows/aggregations.py index 2345ead37..260f619f3 100644 --- a/quixstreams/dataframe/windows/aggregations.py +++ b/quixstreams/dataframe/windows/aggregations.py @@ -3,17 +3,31 @@ Any, Callable, Generic, - Hashable, Iterable, Optional, TypeVar, Union, ) -from typing_extensions import TypeAlias +__all__ = [ + "Collect", + "Count", + "Max", + "Mean", + "Min", + "Reduce", + "Sum", + "Aggregator", + "BaseAggregator", + "Collector", + "BaseCollector", +] -class Aggregator(ABC): +S = TypeVar("S") + + +class BaseAggregator(ABC, Generic[S]): """ Base class for window aggregation. @@ -25,8 +39,22 @@ class Aggregator(ABC): To store all incoming items without reducing them use a `Collector`. """ + @property @abstractmethod - def initialize(self) -> Any: + def state_suffix(self) -> str: + """ + The state suffix is used to store the aggregation state in the window. + + The complete state key is built using the result column name and this suffix. + If these values change, the state key will also change, and the aggregation state will restart from zero. + + Aggregations should change the state suffix when their parameters change to avoid + conflicts with previous state values. + """ + ... + + @abstractmethod + def initialize(self) -> S: """ This method is triggered once to build the aggregation starting value. It should return the initial value for the aggregation. @@ -34,7 +62,7 @@ def initialize(self) -> Any: ... @abstractmethod - def agg(self, old: Any, new: Any) -> Any: + def agg(self, old: S, new: Any, timestamp: int) -> S: """ This method is trigged when a window is updated with a new value. It should return the updated aggregated value. @@ -42,7 +70,7 @@ def agg(self, old: Any, new: Any) -> Any: ... @abstractmethod - def result(self, value: Any) -> Any: + def result(self, value: S) -> Any: """ This method is triggered when a window is closed. It should return the final aggregation result. @@ -50,51 +78,86 @@ def result(self, value: Any) -> Any: ... -V = TypeVar("V", int, float) - +class Aggregator(BaseAggregator): + """ + Implementation of the `BaseAggregator` interface. -class ROOT: - pass + Provides default implementations for the `state_suffix` property. + """ + def __init__(self, column: Optional[str] = None) -> None: + self.column = column -Column: TypeAlias = Union[Hashable, type[ROOT]] + @property + def state_suffix(self) -> str: + if self.column is None: + return self.__class__.__name__ + return f"{self.__class__.__name__}/{self.column}" -class Sum(Aggregator): - def __init__(self, column: Column = ROOT) -> None: - self.column = column +class Count(Aggregator): + """ + Use `Count()` to aggregate the total number of events within each window period.. + """ def initialize(self) -> int: return 0 - def agg(self, old: V, new: Any) -> V: - new = new if self.column is ROOT else new.get(self.column) - return old + (new or 0) + def agg(self, old: int, new: Any, timestamp: int) -> int: + if self.column is not None: + new = new.get(self.column) - def result(self, value: V) -> V: + if new is None: + return old + + return old + 1 + + def result(self, value: int) -> int: return value -class Count(Aggregator): +V = TypeVar("V", int, float) + + +class Sum(Aggregator): + """ + Use `Sum()` to aggregate the sum of the events, or a column of the events, within each window period. + + :param column: The column to sum. Use `None` to sum the whole message. + Default - `None` + """ + def initialize(self) -> int: return 0 - def agg(self, old: int, new: Any) -> int: - return old + 1 + def agg(self, old: V, new: Any, timestamp: int) -> V: + if self.column is not None: + new = new.get(self.column) - def result(self, value: int) -> int: + if new is None: + return old + + return old + new + + def result(self, value: V) -> V: return value class Mean(Aggregator): - def __init__(self, column: Column = ROOT) -> None: - self.column = column + """ + Use `Mean()` to aggregate the mean of the events, or a column of the events, within each window period. + + :param column: The column to mean. Use `None` to mean the whole message. + Default - `None` + """ def initialize(self) -> tuple[float, int]: return 0.0, 0 - def agg(self, old: tuple[V, int], new: Any) -> tuple[V, int]: - new = new if self.column is ROOT else new.get(self.column) + def agg(self, old: tuple[V, int], new: Any, timestamp: int) -> tuple[V, int]: + if self.column is not None: + new = new.get(self.column) + if new is None: return old @@ -108,70 +171,86 @@ def result(self, value: tuple[Union[int, float], int]) -> Optional[float]: return sum_ / count_ -R = TypeVar("R", int, float) - +class Max(Aggregator): + """ + Use `Max()` to aggregate the max of the events, or a column of the events, within each window period. -class Reduce(Aggregator, Generic[R]): - def __init__( - self, - reducer: Callable[[R, Any], R], - initializer: Callable[[Any], R], - ) -> None: - self._initializer: Callable[[Any], R] = initializer - self._reducer: Callable[[R, Any], R] = reducer + :param column: The column to max. Use `None` to max the whole message. + Default - `None` + """ - def initialize(self) -> Any: + def initialize(self) -> None: return None - def agg(self, old: R, new: Any) -> Any: - return self._initializer(new) if old is None else self._reducer(old, new) + def agg(self, old: Optional[V], new: Any, timestamp: int) -> Optional[V]: + if self.column is not None: + new = new.get(self.column) - def result(self, value: R) -> R: + if new is None: + return old + if old is None: + return new + return max(old, new) + + def result(self, value: V) -> V: return value -class Max(Aggregator): - def __init__(self, column: Column = ROOT) -> None: - self.column = column +class Min(Aggregator): + """ + Use `Min()` to aggregate the min of the events, or a column of the events, within each window period. + + :param column: The column to min. Use `None` to min the whole message. + Default - `None` + """ def initialize(self) -> None: return None - def agg(self, old: Optional[V], new: Any) -> V: - new = new if self.column is ROOT else new.get(self.column) + def agg(self, old: Optional[V], new: Any, timestamp: int) -> Optional[V]: + if self.column is not None: + new = new.get(self.column) + + if new is None: + return old if old is None: return new - elif new is None: - return old - return max(old, new) + return min(old, new) def result(self, value: V) -> V: return value -class Min(Aggregator): - def __init__(self, column: Column = ROOT) -> None: - self.column = column +R = TypeVar("R") + + +class Reduce(Aggregator, Generic[R]): + """ + `Reduce()` allows you to perform complex aggregations using custom "reducer" and "initializer" functions. + """ + + def __init__( + self, + reducer: Callable[[R, Any], R], + initializer: Callable[[Any], R], + ) -> None: + self._initializer: Callable[[Any], R] = initializer + self._reducer: Callable[[R, Any], R] = reducer def initialize(self) -> None: return None - def agg(self, old: Optional[V], new: Any) -> V: - new = new if self.column is ROOT else new.get(self.column) - if old is None: - return new - elif new is None: - return old - return min(old, new) + def agg(self, old: Optional[R], new: Any, timestamp: int) -> R: + return self._initializer(new) if old is None else self._reducer(old, new) - def result(self, value: V) -> V: + def result(self, value: R) -> R: return value I = TypeVar("I") -class Collector(ABC, Generic[I]): +class BaseCollector(ABC, Generic[I]): """ Base class for window collections. @@ -184,11 +263,11 @@ class Collector(ABC, Generic[I]): @property @abstractmethod - def column(self) -> Column: + def column(self) -> Optional[str]: """ The column to collect. - Use `ROOT` to collect the whole message. + Use `None` to collect the whole message. """ ... @@ -201,13 +280,28 @@ def result(self, items: Iterable[I]) -> Any: ... -class Collect(Collector): - def __init__(self, column: Column = ROOT) -> None: +class Collector(BaseCollector): + """ + Implementation of the `BaseCollector` interface. + + Provides a default implementation for the `column` property. + """ + + def __init__(self, column: Optional[str] = None) -> None: self._column = column @property - def column(self) -> Column: + def column(self) -> Optional[str]: return self._column + +class Collect(Collector): + """ + Use `Collect()` to gather all events within each window period. into a list. + + :param column: The column to collect. Use `None` to collect the whole message. + Default - `None` + """ + def result(self, items: Iterable[Any]) -> list[Any]: return list(items) diff --git a/quixstreams/dataframe/windows/base.py b/quixstreams/dataframe/windows/base.py index 03b35e12b..47498e5b3 100644 --- a/quixstreams/dataframe/windows/base.py +++ b/quixstreams/dataframe/windows/base.py @@ -11,7 +11,6 @@ Iterable, Optional, Protocol, - TypedDict, cast, ) @@ -19,10 +18,11 @@ from quixstreams.context import message_context from quixstreams.core.stream import TransformExpandedCallback +from quixstreams.dataframe.exceptions import InvalidOperation from quixstreams.processing import ProcessingContext from quixstreams.state import WindowedPartitionTransaction -from .aggregations import Aggregator, Collector +from .aggregations import BaseAggregator, BaseCollector if TYPE_CHECKING: from quixstreams.dataframe.dataframe import StreamingDataFrame @@ -30,12 +30,7 @@ logger = logging.getLogger(__name__) -class WindowResult(TypedDict): - start: int - end: int - value: Any - - +WindowResult: TypeAlias = dict[str, Any] WindowKeyResult: TypeAlias = tuple[Any, WindowResult] Message: TypeAlias = tuple[WindowResult, Any, int, Any] @@ -52,8 +47,6 @@ def __init__( self, name: str, dataframe: "StreamingDataFrame", - aggregators: dict[str, Aggregator], - collectors: dict[str, Collector], ) -> None: if not name: raise ValueError("Window name must not be empty") @@ -61,17 +54,6 @@ def __init__( self._name = name self._dataframe = dataframe - self._aggregators = aggregators - self._aggregate = len(aggregators) > 0 - - self._collectors = collectors - self._collect = len(collectors) > 0 - - if not self._collect and not self._aggregate: - raise ValueError("At least one aggregation or collector must be defined") - elif len(collectors) + len(aggregators) > 1: - raise ValueError("Only one aggregation or collector can be defined") - @property def name(self) -> str: return self._name @@ -174,6 +156,11 @@ def current(self) -> "StreamingDataFrame": regardless of whether the window is closed or not. """ + if self.collect: + raise InvalidOperation( + "BaseCollectors are not supported by `current` windows" + ) + def window_callback( value: Any, key: Any, @@ -196,6 +183,201 @@ def window_callback( return self._apply_window(func=window_callback, name=self._name) + # Implemented by SingleAggregationWindowMixin and MultiAggregationWindowMixin + # Single aggregation and multi aggregation windows store aggregations and collections + # values in a different format. + @property + @abstractmethod + def collect(self) -> bool: ... + + @property + @abstractmethod + def aggregate(self) -> bool: ... + + @abstractmethod + def _initialize_value(self) -> Any: ... + + @abstractmethod + def _aggregate_value(self, state_value: Any, value: Any, timestamp) -> Any: ... + + @abstractmethod + def _collect_value(self, value: Any): ... + + @abstractmethod + def _results( + self, + aggregated: Any, + collected: list[Any], + start: int, + end: int, + ) -> WindowResult: ... + + +class SingleAggregationWindowMixin: + """ + DEPRECATED: Use MultiAggregationWindowMixin instead. + + Single aggregation window mixin for windows with a single aggregation or collection. + Store aggregated value directly in the window value. + """ + + def __init__( + self, + *, + aggregators: dict[str, BaseAggregator], + collectors: dict[str, BaseCollector], + **kwargs, + ) -> None: + if (len(collectors) + len(aggregators)) > 1: + raise ValueError("Only one aggregator or collector can be defined") + + if len(aggregators) > 0: + self._aggregator: Optional[BaseAggregator] = aggregators["value"] + self._collector: Optional[BaseCollector] = None + elif len(collectors) > 0: + self._collector = collectors["value"] + self._aggregator = None + else: + raise ValueError("At least one aggregator or collector must be defined") + + super().__init__(**kwargs) + + @property + def aggregate(self) -> bool: + return self._aggregator is not None + + @property + def collect(self) -> bool: + return self._collector is not None + + def _initialize_value(self) -> Any: + if self._aggregator: + return self._aggregator.initialize() + return None + + def _aggregate_value(self, state_value: Any, value: Any, timestamp: int) -> Any: + if self._aggregator: + return self._aggregator.agg(state_value, value, timestamp) + return None + + def _collect_value(self, value: Any): + # Single aggregation collect() always stores the full message + return value + + def _results( + self, + aggregated: Any, + collected: list[Any], + start: int, + end: int, + ) -> WindowResult: + result = {"start": start, "end": end} + if self._aggregator: + result["value"] = self._aggregator.result(aggregated) + elif self._collector: + result["value"] = self._collector.result(collected) + + return result + + +class MultiAggregationWindowMixin: + def __init__( + self, + *, + aggregators: dict[str, BaseAggregator], + collectors: dict[str, BaseCollector], + **kwargs, + ) -> None: + if not collectors and not aggregators: + raise ValueError("At least one aggregator or collector must be defined") + + self._aggregators: dict[str, tuple[str, BaseAggregator]] = { + f"{result_column}/{agg.state_suffix}": (result_column, agg) + for result_column, agg in aggregators.items() + } + + self._collect_all = False + self._collect_columns: set[str] = set() + self._collectors: dict[str, tuple[Optional[str], BaseCollector]] = {} + for result_column, col in collectors.items(): + input_column = col.column + if input_column is None: + self._collect_all = True + else: + self._collect_columns.add(input_column) + + self._collectors[result_column] = (input_column, col) + + super().__init__(**kwargs) + + @property + def aggregate(self) -> bool: + return bool(self._aggregators) + + @property + def collect(self) -> bool: + return bool(self._collectors) + + def _initialize_value(self) -> dict[str, Any]: + return {k: agg.initialize() for k, (_, agg) in self._aggregators.items()} + + def _aggregate_value( + self, state_values: dict[str, Any], value: Any, timestamp: int + ) -> dict[str, Any]: + return { + k: agg.agg(state_values[k], value, timestamp) + if k in state_values + else agg.agg(agg.initialize(), value, timestamp) + for k, (_, agg) in self._aggregators.items() + } + + def _collect_value(self, value) -> dict[str, Any]: + if self._collect_all: + return value + return {col: value[col] for col in self._collect_columns} + + def _results( + self, + aggregated: dict[str, Any], + collected: list[dict[str, Any]], + start: int, + end: int, + ) -> WindowResult: + result = {k: v for k, v in self._build_results(aggregated, collected)} + result["start"] = start + result["end"] = end + return result + + def _build_results( + self, + aggregated: dict[str, Any], + collected: list[dict[str, Any]], + ) -> Iterable[tuple[str, Any]]: + for key, (result_col, agg) in self._aggregators.items(): + if key in aggregated: + yield result_col, agg.result(aggregated[key]) + else: + yield result_col, agg.result(agg.initialize()) + + collected_columns = self._collected_by_columns(collected) + for result_col, (input_col, col) in self._collectors.items(): + if input_col is None: + yield result_col, col.result(collected) + else: + yield result_col, col.result(collected_columns[input_col]) + + def _collected_by_columns( + self, collected: list[dict[str, Any]] + ) -> dict[str, list[Any]]: + if not self._collect_columns: + return {} + + colums: dict[str, list[Any]] = {col: [] for col in self._collect_columns} + for c in collected: + for col in colums: + colums[col].append(c[col]) + return colums + def _noop() -> Any: """ diff --git a/quixstreams/dataframe/windows/count_based.py b/quixstreams/dataframe/windows/count_based.py index c2f1577eb..04384e202 100644 --- a/quixstreams/dataframe/windows/count_based.py +++ b/quixstreams/dataframe/windows/count_based.py @@ -1,13 +1,13 @@ import logging -from typing import TYPE_CHECKING, Any, Iterable, Optional, TypedDict +from typing import TYPE_CHECKING, Any, Iterable, Optional, TypedDict, Union, cast from quixstreams.state import WindowedPartitionTransaction -from .aggregations import Aggregator, Collector from .base import ( + MultiAggregationWindowMixin, + SingleAggregationWindowMixin, Window, WindowKeyResult, - WindowResult, ) if TYPE_CHECKING: @@ -16,12 +16,19 @@ logger = logging.getLogger(__name__) +_MISSING = object() + class CountWindowData(TypedDict): count: int start: int end: int - value: Any + + # Can be None for single aggregation windows not migrated + aggregations: Union[Any, dict[str, Any]] + collection_start_id: int + + # value: Any deprecated. Only used in single aggregation windows for both collection id tracking and aggregation class CountWindowsData(TypedDict): @@ -36,15 +43,11 @@ def __init__( count: int, name: str, dataframe: "StreamingDataFrame", - aggregators: dict[str, Aggregator], - collectors: dict[str, Collector], step: Optional[int] = None, ): super().__init__( name=name, dataframe=dataframe, - aggregators=aggregators, - collectors=collectors, ) self._max_count = count @@ -79,45 +82,51 @@ def process_window( """ state = transaction.as_state(prefix=key) data = state.get(key=self.STATE_KEY, default=CountWindowsData(windows=[])) + collect = self.collect + aggregate = self.aggregate - msg_id = None + # Start at -1 to indicate that we don't have a collection id yet. If we go from a no-collection window + # to collection window we add the count to the previous window collection id to get the new collection id. + # The count is always bigger or equal to 1 so we can safely use -1 as a marker. + collection_start_id = -1 if len(data["windows"]) == 0: - # for new tumbling window, reset the collection id to 0 - if self._collect: - window_value = msg_id = 0 - else: - window_value = self._aggregators["value"].initialize() - + collection_start_id = 0 data["windows"].append( CountWindowData( count=0, start=timestamp_ms, end=timestamp_ms, - value=window_value, + aggregations=self._initialize_value(), + collection_start_id=collection_start_id, ) ) elif self._step is not None and data["windows"][0]["count"] % self._step == 0: - if self._collect: - window_value = msg_id = ( - data["windows"][0]["value"] + data["windows"][0]["count"] + if collect: + collection_start_id = ( + self._get_collection_start_id(data["windows"][0]) + + data["windows"][0]["count"] ) - else: - window_value = self._aggregators["value"].initialize() data["windows"].append( CountWindowData( count=0, start=timestamp_ms, end=timestamp_ms, - value=window_value, + aggregations=self._initialize_value(), + collection_start_id=collection_start_id, ) ) - if self._collect: - if msg_id is None: - msg_id = data["windows"][0]["value"] + data["windows"][0]["count"] + if collect: + if collection_start_id is -1: + collection_start_id = ( + self._get_collection_start_id(data["windows"][0]) + + data["windows"][0]["count"] + ) - state.add_to_collection(id=msg_id, value=value) + state.add_to_collection( + id=collection_start_id, value=self._collect_value(value) + ) updated_windows, expired_windows, to_remove = [], [], [] for index, window in enumerate(data["windows"]): @@ -127,55 +136,76 @@ def process_window( elif timestamp_ms > window["end"]: window["end"] = timestamp_ms - if self._collect: - # window must close - if window["count"] >= self._max_count: - values = state.get_from_collection( - start=window["value"], - end=window["value"] + self._max_count, + if aggregate: + window["aggregations"] = self._aggregate_value( + self._get_aggregations(window), value, timestamp_ms + ) + updated_windows.append( + ( + key, + self._results( + window["aggregations"], [], window["start"], window["end"] + ), ) + ) - expired_windows.append( - ( - key, - WindowResult( - start=window["start"], - end=window["end"], - value=self._collectors["value"].result(values), - ), - ) - ) - to_remove.append(index) + if window["count"] >= self._max_count: + to_remove.append(index) + + if collect: + collection_start_id = self._get_collection_start_id(window) + collected = state.get_from_collection( + start=collection_start_id, + end=collection_start_id + self._max_count, + ) # for tumbling window we need to force deletion from 0 delete_start = 0 if self._step is None else None # for hopping windows we can only delete the value in the first step, the rest is # needed by follow up hopping windows step = self._max_count if self._step is None else self._step - delete_end = window["value"] + step + delete_end = collection_start_id + step state.delete_from_collection(end=delete_end, start=delete_start) - else: - window["value"] = self._aggregators["value"].agg(window["value"], value) - - result = ( - key, - WindowResult( - start=window["start"], - end=window["end"], - value=self._aggregators["value"].result(window["value"]), - ), + else: + collected = [] + + expired_windows.append( + ( + key, + self._results( + window["aggregations"], + collected, + window["start"], + window["end"], + ), + ) ) - updated_windows.append(result) - - if window["count"] >= self._max_count: - expired_windows.append(result) - to_remove.append(index) - for i in to_remove: del data["windows"][i] state.set(key=self.STATE_KEY, value=data) return updated_windows, expired_windows + + def _get_collection_start_id(self, window: CountWindowData) -> int: + start_id = window.get("collection_start_id", _MISSING) + if start_id is _MISSING: + start_id = cast(int, window["value"]) # type: ignore[typeddict-item] + window["collection_start_id"] = start_id + return start_id # type: ignore[return-value] + + def _get_aggregations(self, window: CountWindowData) -> Union[Any, dict[str, Any]]: + aggregations = window.get("aggregations", _MISSING) + if aggregations is _MISSING: + return window["value"] # type: ignore[typeddict-item] + return aggregations + + +class CountWindowSingleAggregation(SingleAggregationWindowMixin, CountWindow): + pass + + +class CountWindowMultiAggregation(MultiAggregationWindowMixin, CountWindow): + pass diff --git a/quixstreams/dataframe/windows/definitions.py b/quixstreams/dataframe/windows/definitions.py index 180811054..e3fe378e9 100644 --- a/quixstreams/dataframe/windows/definitions.py +++ b/quixstreams/dataframe/windows/definitions.py @@ -1,12 +1,11 @@ import abc from abc import abstractmethod -from typing import TYPE_CHECKING, Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional, Union from .aggregations import ( - ROOT, - Aggregator, + BaseAggregator, + BaseCollector, Collect, - Collector, Count, Max, Mean, @@ -18,13 +17,34 @@ Window, WindowOnLateCallback, ) -from .count_based import CountWindow -from .sliding import SlidingWindow -from .time_based import TimeWindow +from .count_based import ( + CountWindow, + CountWindowMultiAggregation, + CountWindowSingleAggregation, +) +from .sliding import ( + SlidingWindow, + SlidingWindowMultiAggregation, + SlidingWindowSingleAggregation, +) +from .time_based import ( + TimeWindow, + TimeWindowMultiAggregation, + TimeWindowSingleAggregation, +) if TYPE_CHECKING: from quixstreams.dataframe.dataframe import StreamingDataFrame +__all__ = [ + "TumblingCountWindowDefinition", + "HoppingCountWindowDefinition", + "SlidingCountWindowDefinition", + "HoppingTimeWindowDefinition", + "SlidingTimeWindowDefinition", + "TumblingTimeWindowDefinition", +] + class WindowDefinition(abc.ABC): def __init__( @@ -42,9 +62,9 @@ def __init__( @abstractmethod def _create_window( self, - func_name: str, - aggregators: Optional[dict[str, Aggregator]] = None, - collectors: Optional[dict[str, Collector]] = None, + func_name: Optional[str], + aggregators: Optional[dict[str, BaseAggregator]] = None, + collectors: Optional[dict[str, BaseCollector]] = None, ) -> Window: ... def sum(self) -> "Window": @@ -57,7 +77,7 @@ def sum(self) -> "Window": return self._create_window( func_name="sum", - aggregators={"value": Sum(column=ROOT)}, + aggregators={"value": Sum(column=None)}, ) def count(self) -> "Window": @@ -84,7 +104,7 @@ def mean(self) -> "Window": return self._create_window( func_name="mean", - aggregators={"value": Mean(column=ROOT)}, + aggregators={"value": Mean(column=None)}, ) def reduce( @@ -141,7 +161,7 @@ def max(self) -> "Window": return self._create_window( func_name="max", - aggregators={"value": Max(column=ROOT)}, + aggregators={"value": Max(column=None)}, ) def min(self) -> "Window": @@ -154,7 +174,7 @@ def min(self) -> "Window": return self._create_window( func_name="min", - aggregators={"value": Min(column=ROOT)}, + aggregators={"value": Min(column=None)}, ) def collect(self) -> "Window": @@ -168,7 +188,7 @@ def collect(self) -> "Window": Example Snippet: ```python # Collect all values in 1-second windows - window = df.tumbling_window(duration_ms=1000).collect() + window = sdf.tumbling_window(duration_ms=1000).collect() # Each window will contain a list of all values that occurred # within that second ``` @@ -179,7 +199,32 @@ def collect(self) -> "Window": return self._create_window( func_name="collect", - collectors={"value": Collect(column=ROOT)}, + collectors={"value": Collect(column=None)}, + ) + + def agg(self, **operations: Union[BaseAggregator, BaseCollector]) -> "Window": + if "start" in operations or "end" in operations: + raise ValueError( + "`start` and `end` are reserved keywords for the window boundaries" + ) + + aggregators: dict[str, BaseAggregator] = {} + collectors: dict[str, BaseCollector] = {} + + for column, op in operations.items(): + if isinstance(op, BaseAggregator): + aggregators[column] = op + elif isinstance(op, BaseCollector): + collectors[column] = op + else: + raise TypeError( + f"operation `{column}:{op}` must be either BaseAggregator or BaseCollector" + ) + + return self._create_window( + func_name=None, + aggregators=aggregators, + collectors=collectors, ) @@ -244,17 +289,27 @@ def __init__( on_late=on_late, ) - def _get_name(self, func_name: str) -> str: + def _get_name(self, func_name: Optional[str]) -> str: prefix = f"{self._name}_hopping_window" if self._name else "hopping_window" - return f"{prefix}_{self._duration_ms}_{self._step_ms}_{func_name}" + if func_name: + return f"{prefix}_{self._duration_ms}_{self._step_ms}_{func_name}" + else: + return f"{prefix}_{self._duration_ms}_{self._step_ms}" def _create_window( self, - func_name: str, - aggregators: Optional[dict[str, Aggregator]] = None, - collectors: Optional[dict[str, Collector]] = None, + func_name: Optional[str], + aggregators: Optional[dict[str, BaseAggregator]] = None, + collectors: Optional[dict[str, BaseCollector]] = None, ) -> TimeWindow: - return TimeWindow( + if func_name: + window_type: Union[ + type[TimeWindowSingleAggregation], type[TimeWindowMultiAggregation] + ] = TimeWindowSingleAggregation + else: + window_type = TimeWindowMultiAggregation + + return window_type( duration_ms=self._duration_ms, grace_ms=self._grace_ms, step_ms=self._step_ms, @@ -283,17 +338,27 @@ def __init__( on_late=on_late, ) - def _get_name(self, func_name: str) -> str: + def _get_name(self, func_name: Optional[str]) -> str: prefix = f"{self._name}_tumbling_window" if self._name else "tumbling_window" - return f"{prefix}_{self._duration_ms}_{func_name}" + if func_name: + return f"{prefix}_{self._duration_ms}_{func_name}" + else: + return f"{prefix}_{self._duration_ms}" def _create_window( self, - func_name: str, - aggregators: Optional[dict[str, Aggregator]] = None, - collectors: Optional[dict[str, Collector]] = None, + func_name: Optional[str], + aggregators: Optional[dict[str, BaseAggregator]] = None, + collectors: Optional[dict[str, BaseCollector]] = None, ) -> TimeWindow: - return TimeWindow( + if func_name: + window_type: Union[ + type[TimeWindowSingleAggregation], type[TimeWindowMultiAggregation] + ] = TimeWindowSingleAggregation + else: + window_type = TimeWindowMultiAggregation + + return window_type( duration_ms=self._duration_ms, grace_ms=self._grace_ms, name=self._get_name(func_name=func_name), @@ -321,17 +386,28 @@ def __init__( on_late=on_late, ) - def _get_name(self, func_name: str) -> str: + def _get_name(self, func_name: Optional[str]) -> str: prefix = f"{self._name}_sliding_window" if self._name else "sliding_window" - return f"{prefix}_{self._duration_ms}_{func_name}" + if func_name: + return f"{prefix}_{self._duration_ms}_{func_name}" + else: + return f"{prefix}_{self._duration_ms}" def _create_window( self, - func_name: str, - aggregators: Optional[dict[str, Aggregator]] = None, - collectors: Optional[dict[str, Collector]] = None, + func_name: Optional[str], + aggregators: Optional[dict[str, BaseAggregator]] = None, + collectors: Optional[dict[str, BaseCollector]] = None, ) -> SlidingWindow: - return SlidingWindow( + if func_name: + window_type: Union[ + type[SlidingWindowSingleAggregation], + type[SlidingWindowMultiAggregation], + ] = SlidingWindowSingleAggregation + else: + window_type = SlidingWindowMultiAggregation + + return window_type( duration_ms=self._duration_ms, grace_ms=self._grace_ms, name=self._get_name(func_name=func_name), @@ -357,11 +433,18 @@ def __init__( class TumblingCountWindowDefinition(CountWindowDefinition): def _create_window( self, - func_name: str, - aggregators: Optional[dict[str, Aggregator]] = None, - collectors: Optional[dict[str, Collector]] = None, - ) -> Window: - return CountWindow( + func_name: Optional[str], + aggregators: Optional[dict[str, BaseAggregator]] = None, + collectors: Optional[dict[str, BaseCollector]] = None, + ) -> CountWindow: + if func_name: + window_type: Union[ + type[CountWindowSingleAggregation], type[CountWindowMultiAggregation] + ] = CountWindowSingleAggregation + else: + window_type = CountWindowMultiAggregation + + return window_type( name=self._get_name(func_name=func_name), count=self._count, aggregators=aggregators or {}, @@ -369,13 +452,16 @@ def _create_window( dataframe=self._dataframe, ) - def _get_name(self, func_name: str) -> str: + def _get_name(self, func_name: Optional[str]) -> str: prefix = ( f"{self._name}_tumbling_count_window" if self._name else "tumbling_count_window" ) - return f"{prefix}_{func_name}" + if func_name: + return f"{prefix}_{func_name}" + else: + return prefix class HoppingCountWindowDefinition(CountWindowDefinition): @@ -395,11 +481,18 @@ def __init__( def _create_window( self, - func_name: str, - aggregators: Optional[dict[str, Aggregator]] = None, - collectors: Optional[dict[str, Collector]] = None, - ) -> Window: - return CountWindow( + func_name: Optional[str], + aggregators: Optional[dict[str, BaseAggregator]] = None, + collectors: Optional[dict[str, BaseCollector]] = None, + ) -> CountWindow: + if func_name: + window_type: Union[ + type[CountWindowSingleAggregation], type[CountWindowMultiAggregation] + ] = CountWindowSingleAggregation + else: + window_type = CountWindowMultiAggregation + + return window_type( name=self._get_name(func_name=func_name), count=self._count, aggregators=aggregators or {}, @@ -408,13 +501,15 @@ def _create_window( step=self._step, ) - def _get_name(self, func_name: str) -> str: + def _get_name(self, func_name: Optional[str]) -> str: prefix = ( f"{self._name}_hopping_count_window" if self._name else "hopping_count_window" ) - return f"{prefix}_{func_name}" + if func_name: + return f"{prefix}_{func_name}" + return prefix class SlidingCountWindowDefinition(HoppingCountWindowDefinition): @@ -423,10 +518,12 @@ def __init__( ): super().__init__(count=count, dataframe=dataframe, step=1, name=name) - def _get_name(self, func_name: str) -> str: + def _get_name(self, func_name: Optional[str]) -> str: prefix = ( f"{self._name}_sliding_count_window" if self._name else "sliding_count_window" ) - return f"{prefix}_{func_name}" + if func_name: + return f"{prefix}_{func_name}" + return prefix diff --git a/quixstreams/dataframe/windows/sliding.py b/quixstreams/dataframe/windows/sliding.py index c64c0a3f2..6e1753bcc 100644 --- a/quixstreams/dataframe/windows/sliding.py +++ b/quixstreams/dataframe/windows/sliding.py @@ -2,7 +2,11 @@ from quixstreams.state import WindowedPartitionTransaction, WindowedState -from .base import WindowKeyResult, WindowResult +from .base import ( + MultiAggregationWindowMixin, + SingleAggregationWindowMixin, + WindowKeyResult, +) from .time_based import ClosingStrategyValues, TimeWindow if TYPE_CHECKING: @@ -78,8 +82,8 @@ def process_window( duration = self._duration_ms grace = self._grace_ms - aggregate = self._aggregators["value"].agg if self._aggregate else None - collect = self._collect + aggregate = self.aggregate + collect = self.collect # Sliding windows are inclusive on both ends, so values with # timestamps equal to latest_timestamp - duration - grace @@ -150,7 +154,7 @@ def process_window( state=state, start=start, end=end, - value=aggregate(aggregation, value) if aggregate else None, + value=self._aggregate_value(aggregation, value, timestamp_ms), timestamp=timestamp_ms, max_timestamp=max_timestamp, ) @@ -176,9 +180,9 @@ def process_window( state=state, start=right_start, end=right_start + duration, - value=aggregate(self._aggregators["value"].initialize(), value) - if aggregate - else None, + value=self._aggregate_value( + self._initialize_value(), value, timestamp_ms + ), timestamp=timestamp_ms, max_timestamp=timestamp_ms, ) @@ -192,7 +196,9 @@ def process_window( state=state, start=start, end=end, - value=aggregate(aggregation, value) if aggregate else None, + value=self._aggregate_value( + aggregation, value, timestamp_ms + ), timestamp=timestamp_ms, max_timestamp=timestamp_ms, ) @@ -218,7 +224,9 @@ def process_window( state=state, start=right_start, end=right_start + duration, - value=aggregate(self._aggregators["value"].initialize(), value) + value=self._aggregate_value( + self._initialize_value(), value, timestamp_ms + ) if aggregate else None, timestamp=timestamp_ms, @@ -227,7 +235,7 @@ def process_window( # Create a left window with existing aggregation if it falls within the window if left_start > max_timestamp: - aggregation = self._aggregators["value"].initialize() + aggregation = self._initialize_value() updated_windows.append( self._update_window( @@ -235,7 +243,7 @@ def process_window( state=state, start=left_start, end=left_end, - value=aggregate(aggregation, value) if aggregate else None, + value=self._aggregate_value(aggregation, value, timestamp_ms), timestamp=timestamp_ms, max_timestamp=timestamp_ms, ) @@ -257,46 +265,40 @@ def process_window( state=state, start=left_start, end=left_end, - value=aggregate(self._aggregators["value"].initialize(), value) - if aggregate - else None, + value=self._aggregate_value( + self._initialize_value(), value, timestamp_ms + ), timestamp=timestamp_ms, max_timestamp=timestamp_ms, ) ) if collect: - state.add_to_collection(value=value, id=timestamp_ms) - - expired_windows = [ - ( - key, - WindowResult( - start=start, - end=end, - value=self._collectors["value"].result(aggregation) - if collect - else self._aggregators["value"].result(aggregation), - ), - ) - for (start, end), (max_timestamp, aggregation), _ in state.expire_windows( - max_start_time=max_expired_window_start, - delete=False, - collect=collect, - end_inclusive=True, - ) - if end == max_timestamp # Emit only left windows - ] + state.add_to_collection(value=self._collect_value(value), id=timestamp_ms) + + # build a complete list otherwise expired windows could be deleted + # in state.delete_windows() and never be fetched. + expired_windows = list( + self._expired_windows(state, max_expired_window_start, collect) + ) state.delete_windows( max_start_time=max_deleted_window_start, delete_values=collect, ) - if collect: - return [], expired_windows - else: - return reversed(updated_windows), expired_windows + return reversed(updated_windows), expired_windows + + def _expired_windows(self, state, max_expired_window_start, collect): + for window in state.expire_windows( + max_start_time=max_expired_window_start, + delete=False, + collect=collect, + end_inclusive=True, + ): + (start, end), (max_timestamp, aggregated), collected, key = window + if end == max_timestamp: + yield key, self._results(aggregated, collected, start, end) def _update_window( self, @@ -314,13 +316,12 @@ def _update_window( value=[max_timestamp, value], timestamp_ms=timestamp, ) - return ( - key, - WindowResult( - start=start, - end=end, - value=self._aggregators["value"].result(value) - if self._aggregate - else None, - ), - ) + return (key, self._results(value, [], start, end)) + + +class SlidingWindowSingleAggregation(SingleAggregationWindowMixin, SlidingWindow): + pass + + +class SlidingWindowMultiAggregation(MultiAggregationWindowMixin, SlidingWindow): + pass diff --git a/quixstreams/dataframe/windows/time_based.py b/quixstreams/dataframe/windows/time_based.py index 3fd034edb..80ee42f63 100644 --- a/quixstreams/dataframe/windows/time_based.py +++ b/quixstreams/dataframe/windows/time_based.py @@ -6,12 +6,12 @@ from quixstreams.context import message_context from quixstreams.state import WindowedPartitionTransaction, WindowedState -from .aggregations import Aggregator, Collector from .base import ( + MultiAggregationWindowMixin, + SingleAggregationWindowMixin, Window, WindowKeyResult, WindowOnLateCallback, - WindowResult, get_window_ranges, ) @@ -45,16 +45,12 @@ def __init__( grace_ms: int, name: str, dataframe: "StreamingDataFrame", - aggregators: dict[str, Aggregator], - collectors: dict[str, Collector], step_ms: Optional[int] = None, on_late: Optional[WindowOnLateCallback] = None, ): super().__init__( name=name, dataframe=dataframe, - aggregators=aggregators, - collectors=collectors, ) self._duration_ms = duration_ms @@ -138,8 +134,8 @@ def process_window( duration_ms = self._duration_ms grace_ms = self._grace_ms - collect = self._collect - aggregate = self._aggregate + collect = self.collect + aggregate = self.aggregate ranges = get_window_ranges( timestamp_ms=timestamp_ms, @@ -177,23 +173,22 @@ def process_window( if aggregate: current_value = state.get_window(start, end) if current_value is None: - current_value = self._aggregators["value"].initialize() + current_value = self._initialize_value() - aggregated = self._aggregators["value"].agg(current_value, value) + aggregated = self._aggregate_value(current_value, value, timestamp_ms) updated_windows.append( ( key, - WindowResult( - start=start, - end=end, - value=self._aggregators["value"].result(aggregated), - ), + self._results(aggregated, [], start, end), ) ) state.update_window(start, end, value=aggregated, timestamp_ms=timestamp_ms) if collect: - state.add_to_collection(value=value, id=timestamp_ms) + state.add_to_collection( + value=self._collect_value(value), + id=timestamp_ms, + ) if self._closing_strategy == ClosingStrategy.PARTITION: expired_windows = self.expire_by_partition( @@ -218,26 +213,14 @@ def expire_by_partition( for ( window_start, window_end, - ), aggregated, key in transaction.expire_all_windows( + ), aggregated, collected, key in transaction.expire_all_windows( max_end_time=max_expired_end, step_ms=self._step_ms if self._step_ms else self._duration_ms, collect=collect, delete=True, ): - if collect: - value = self._collectors["value"].result(aggregated) - else: - value = self._aggregators["value"].result(aggregated) - count += 1 - yield ( - key, - WindowResult( - start=window_start, - end=window_end, - value=value, - ), - ) + yield key, self._results(aggregated, collected, window_start, window_end) if count: logger.debug( @@ -254,23 +237,15 @@ def expire_by_key( start = time.monotonic() count = 0 - for (window_start, window_end), aggregated, _ in state.expire_windows( + for ( + window_start, + window_end, + ), aggregated, collected, _ in state.expire_windows( max_start_time=max_expired_start, collect=collect, ): - if collect: - value = self._collectors["value"].result(aggregated) - else: - value = self._aggregators["value"].result(aggregated) - - yield ( - key, - WindowResult( - start=window_start, - end=window_end, - value=value, - ), - ) + count += 1 + yield (key, self._results(aggregated, collected, window_start, window_end)) if count: logger.debug( @@ -313,3 +288,11 @@ def _on_expired_window( f"partition={ctx.topic}[{ctx.partition}] " f"offset={ctx.offset}" ) + + +class TimeWindowSingleAggregation(SingleAggregationWindowMixin, TimeWindow): + pass + + +class TimeWindowMultiAggregation(MultiAggregationWindowMixin, TimeWindow): + pass diff --git a/quixstreams/state/rocksdb/windowed/state.py b/quixstreams/state/rocksdb/windowed/state.py index a233cd13d..3e3021b20 100644 --- a/quixstreams/state/rocksdb/windowed/state.py +++ b/quixstreams/state/rocksdb/windowed/state.py @@ -1,7 +1,7 @@ -from typing import TYPE_CHECKING, Any, Optional +from typing import TYPE_CHECKING, Any, Iterable, Optional from quixstreams.state.base import TransactionState -from quixstreams.state.types import WindowDetail, WindowedState +from quixstreams.state.types import ExpiredWindowDetail, WindowDetail, WindowedState if TYPE_CHECKING: from .transaction import WindowedRocksDBPartitionTransaction @@ -125,7 +125,7 @@ def expire_windows( delete: bool = True, collect: bool = False, end_inclusive: bool = False, - ) -> list[WindowDetail]: + ) -> Iterable[ExpiredWindowDetail]: """ Get all expired windows from RocksDB up to the specified `max_start_time` timestamp. diff --git a/quixstreams/state/rocksdb/windowed/transaction.py b/quixstreams/state/rocksdb/windowed/transaction.py index e48be6cea..d2ba25e7d 100644 --- a/quixstreams/state/rocksdb/windowed/transaction.py +++ b/quixstreams/state/rocksdb/windowed/transaction.py @@ -16,7 +16,7 @@ LoadsFunc, serialize, ) -from quixstreams.state.types import WindowDetail +from quixstreams.state.types import ExpiredWindowDetail, WindowDetail from .metadata import ( GLOBAL_COUNTER_CF_NAME, @@ -234,7 +234,7 @@ def expire_windows( delete: bool = True, collect: bool = False, end_inclusive: bool = False, - ) -> list[WindowDetail]: + ) -> Iterable[ExpiredWindowDetail]: """ Get all expired windows with a set prefix from RocksDB up to the specified `max_start_time` timestamp. @@ -277,16 +277,16 @@ def expire_windows( # Use the latest expired timestamp to limit the iteration over # only those windows that have not been expired before - expired_windows = self.get_windows( + windows = self.get_windows( start_from_ms=start_from, start_to_ms=max_start_time, prefix=prefix, ) - if not expired_windows: - return [] + if not windows: + return # Save the start of the latest expired window to the expiration index - latest_window = expired_windows[-1] + latest_window = windows[-1] last_expired__gt = latest_window[0][0] self._set_timestamp( @@ -297,9 +297,8 @@ def expire_windows( # Collect values into windows if collect: - collected_expired_windows: list[WindowDetail] = [] - for (start, end), value, key in expired_windows: - collection = self.get_from_collection( + for (start, end), aggregated, key in windows: + collected = self.get_from_collection( start=start, # Sliding windows are inclusive on both ends # (including timestamps of messages equal to `end`). @@ -308,32 +307,26 @@ def expire_windows( end=end + 1 if end_inclusive else end, prefix=prefix, ) - if value is None: - value = collection - else: - # Sliding windows are timestamped: - # value is [max_timestamp, value] where max_timestamp - # is the timestamp of the latest message in the window - value[1] = collection - collected_expired_windows.append(((start, end), value, key)) - expired_windows = collected_expired_windows + yield ((start, end), aggregated, collected, key) + + else: + for window, aggregated, key in windows: + yield (window, aggregated, [], key) # Delete expired windows from the state if delete: - for (start, end), _, _ in expired_windows: + for (start, end), _, _ in windows: self.delete_window(start, end, prefix=prefix) if collect: self.delete_from_collection(end=start, prefix=prefix) - return expired_windows - def expire_all_windows( self, max_end_time: int, step_ms: int, delete: bool = True, collect: bool = False, - ) -> Iterable[WindowDetail]: + ) -> Iterable[ExpiredWindowDetail]: """ Get all expired windows for all prefix from RocksDB up to the specified `max_end_time` timestamp. @@ -346,6 +339,7 @@ def expire_all_windows( ) to_delete: set[tuple[bytes, int, int]] = set() + collected = [] if last_expired: windows = windows_to_expire(last_expired, max_end_time, step_ms) @@ -357,17 +351,17 @@ def expire_all_windows( if key[-8:] in suffixes: prefix, start, end = parse_window_key(key) to_delete.add((prefix, start, end)) + aggregated = self.get( + encode_integer_pair(start, end), prefix=prefix + ) if collect: - value: Any = self.get_from_collection( + collected = self.get_from_collection( start=start, end=end, prefix=prefix, ) - else: - value = self.get(encode_integer_pair(start, end), prefix=prefix) - assert value is not None # noqa: S101 + yield (start, end), aggregated, collected, prefix - yield (start, end), value, prefix else: # If we don't have a saved last_expired value it means one of two cases # 1. It's a new window, iterating over all the keys is fast. @@ -378,17 +372,17 @@ def expire_all_windows( prefix, start, end = parse_window_key(key) if end <= last_expired: to_delete.add((prefix, start, end)) + aggregated = self.get( + encode_integer_pair(start, end), prefix=prefix + ) if collect: - value = self.get_from_collection( + collected = self.get_from_collection( start=start, end=end, prefix=prefix, ) - else: - value = self.get(encode_integer_pair(start, end), prefix=prefix) - assert value is not None # noqa: S101 - yield (start, end), value, prefix + yield (start, end), aggregated, collected, prefix if delete: for prefix, start, end in to_delete: diff --git a/quixstreams/state/types.py b/quixstreams/state/types.py index c88509e53..c80c9e2ad 100644 --- a/quixstreams/state/types.py +++ b/quixstreams/state/types.py @@ -8,7 +8,12 @@ K = TypeVar("K", contravariant=True) V = TypeVar("V") -WindowDetail: TypeAlias = tuple[tuple[int, int], V, bytes] # (start, end), value, key +WindowDetail: TypeAlias = tuple[ + tuple[int, int], V, bytes +] # (start, end), aggregated, key +ExpiredWindowDetail: TypeAlias = tuple[ + tuple[int, int], V, list[V], bytes +] # (start, end), aggregated, collected, key class WindowedState(Protocol[K, V]): @@ -152,7 +157,7 @@ def expire_windows( delete: bool = True, collect: bool = False, end_inclusive: bool = False, - ) -> list[WindowDetail[V]]: + ) -> Iterable[ExpiredWindowDetail[V]]: """ Get all expired windows from RocksDB up to the specified `max_start_time` timestamp. @@ -350,7 +355,7 @@ def expire_windows( delete: bool = True, collect: bool = False, end_inclusive: bool = False, - ) -> list[WindowDetail[V]]: + ) -> Iterable[ExpiredWindowDetail[V]]: """ Get all expired windows with a set prefix from RocksDB up to the specified `max_start_time` timestamp. @@ -373,7 +378,7 @@ def expire_all_windows( step_ms: int, delete: bool = True, collect: bool = False, - ) -> Iterable[WindowDetail[V]]: + ) -> Iterable[ExpiredWindowDetail[V]]: """ Get all expired windows for all prefix from RocksDB up to the specified `max_start_time` timestamp. diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_aggregations.py b/tests/test_quixstreams/test_dataframe/test_windows/test_aggregations.py index 138718980..c3645d9c7 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_aggregations.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_aggregations.py @@ -1,6 +1,7 @@ import pytest from quixstreams.dataframe.windows.aggregations import ( + Collect, Count, Max, Mean, @@ -10,62 +11,93 @@ ) -@pytest.mark.parametrize( - ["aggregator", "values", "expected"], - [ - (Sum(), [1, 2], 3), - (Sum(), [0], 0), - (Sum(), [], 0), - (Sum(), [1, None, 2], 3), - (Sum(), [None], 0), - (Sum(column="foo"), [{"foo": 1}, {"foo": 2}], 3), - (Sum(column="foo"), [{"foo": 1}, {"foo": None}], 1), - (Sum(column="foo"), [{"foo": 1}, {"bar": 2}], 1), - (Sum(column="foo"), [{"foo": 1}, {}], 1), - (Count(), [1, "2", None], 3), - (Mean(), [1, 2], 1.5), - (Mean(), [0], 0), - (Mean(), [1, None, 2], 1.5), - (Mean(), [None], None), - (Mean(column="foo"), [{"foo": 1}, {"foo": 2}], 1.5), - (Mean(column="foo"), [{"foo": 1}, {"foo": None}], 1), - (Mean(column="foo"), [{"foo": 1}, {"bar": 2}], 1), - (Mean(column="foo"), [{"foo": 1}, {}], 1), - ( - Reduce( - reducer=lambda old, new: old + new, - initializer=lambda x: x, +class TestAggregators: + @pytest.mark.parametrize( + "aggregator, values, expected", + [ + (Sum(), [1, 2], 3), + (Sum(), [0], 0), + (Sum(), [], 0), + (Sum(), [1, None, 2], 3), + (Sum(), [None], 0), + (Sum(column="foo"), [{"foo": 1}, {"foo": 2}], 3), + (Sum(column="foo"), [{"foo": 1}, {"foo": None}], 1), + (Sum(column="foo"), [{"foo": 1}, {"bar": 2}], 1), + (Sum(column="foo"), [{"foo": 1}, {}], 1), + (Count(), [1, "2", None], 2), + (Count(), [1, "2", object()], 3), + (Count(column="foo"), [{"foo": 1}, {"foo": 2}, {"foo": 3}], 3), + (Count(column="foo"), [{"foo": 1}, {"foo": None}, {"foo": 3}], 2), + (Count(column="foo"), [{"bar": 1}, {"foo": 2}, {}], 1), + (Mean(), [1, 2], 1.5), + (Mean(), [0], 0), + (Mean(), [1, None, 2], 1.5), + (Mean(), [None], None), + (Mean(column="foo"), [{"foo": 1}, {"foo": 2}], 1.5), + (Mean(column="foo"), [{"foo": 1}, {"foo": None}], 1), + (Mean(column="foo"), [{"foo": 1}, {"bar": 2}], 1), + (Mean(column="foo"), [{"foo": 1}, {}], 1), + ( + Reduce( + reducer=lambda old, new: old + new, + initializer=lambda x: x, + ), + ["A", "B", "C"], + "ABC", ), - ["A", "B", "C"], - "ABC", - ), - (Max(), [3, 1, 2], 3), - (Max(), [3, None, 2], 3), - (Max(), [None, 3, 2], 3), - (Max(), [None], None), - (Max(column="foo"), [{"foo": 3}, {"foo": 1}], 3), - (Max(column="foo"), [{"foo": 3}, {"foo": None}], 3), - (Max(column="foo"), [{"foo": 3}, {"bar": 2}], 3), - (Max(column="foo"), [{"foo": 3}, {}], 3), - (Min(), [3, 1, 2], 1), - (Min(), [3, None, 2], 2), - (Min(), [None, 3, 2], 2), - (Min(), [None], None), - (Min(column="foo"), [{"foo": 3}, {"foo": 1}], 1), - (Min(column="foo"), [{"foo": 3}, {"foo": None}], 3), - (Min(column="foo"), [{"foo": 3}, {"bar": 2}], 3), - (Min(column="foo"), [{"foo": 3}, {}], 3), - ], -) -def test_aggregators(aggregator, values, expected): - old = aggregator.initialize() - for new in values: - old = aggregator.agg(old, new) + (Max(), [3, 1, 2], 3), + (Max(), [3, None, 2], 3), + (Max(), [None, 3, 2], 3), + (Max(), [None], None), + (Max(column="foo"), [{"foo": 3}, {"foo": 1}], 3), + (Max(column="foo"), [{"foo": 3}, {"foo": None}], 3), + (Max(column="foo"), [{"foo": 3}, {"bar": 2}], 3), + (Max(column="foo"), [{"foo": 3}, {}], 3), + (Min(), [3, 1, 2], 1), + (Min(), [3, None, 2], 2), + (Min(), [None, 3, 2], 2), + (Min(), [None], None), + (Min(column="foo"), [{"foo": 3}, {"foo": 1}], 1), + (Min(column="foo"), [{"foo": 3}, {"foo": None}], 3), + (Min(column="foo"), [{"foo": 3}, {"bar": 2}], 3), + (Min(column="foo"), [{"foo": 3}, {}], 3), + ], + ) + def test_aggregation(self, aggregator, values, expected): + old = aggregator.initialize() + for new in values: + old = aggregator.agg(old, new, 0) + + assert aggregator.result(old) == expected - assert aggregator.result(old) == expected + @pytest.mark.parametrize( + "aggregation, result", + [ + (Count(), "Count"), + (Sum(), "Sum"), + (Mean(), "Mean"), + (Max(), "Max"), + (Min(), "Min"), + (Count("value"), "Count/value"), + (Sum("value"), "Sum/value"), + (Mean("value"), "Mean/value"), + (Min("value"), "Min/value"), + (Max("value"), "Max/value"), + ], + ) + def test_state_suffix(self, aggregation, result): + assert aggregation.state_suffix == result -# @pytest.mark.parametrize("aggregator", [Sum(), Mean()]) -# def test_aggregators_exceptions(aggregator): -# with pytest.raises(TypeError): -# aggregator.agg(aggregator.initialize(), "1") +class TestCollectors: + @pytest.mark.parametrize( + "inputs, result", + [ + ([], []), + ([0, 1, 2, 3], [0, 1, 2, 3]), + (range(4), [0, 1, 2, 3]), + ], + ) + def test_collect(self, inputs, result): + col = Collect() + assert col.result(inputs) == result diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py b/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py index fffaf257c..c54630bdf 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_hopping.py @@ -1,5 +1,6 @@ import pytest +import quixstreams.dataframe.windows.aggregations as agg from quixstreams.dataframe.windows import ( HoppingCountWindowDefinition, HoppingTimeWindowDefinition, @@ -57,12 +58,208 @@ def test_hopping_window_definition_get_name( name = twd._get_name(func_name) assert name == expected_name + def test_multiaggregation( + self, + hopping_window_definition_factory, + state_manager, + ): + window = hopping_window_definition_factory(duration_ms=10, step_ms=5).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Mean(), + max=agg.Max(), + min=agg.Min(), + collect=agg.Collect(), + ) + window.final() + assert window.name == "hopping_window_10_5" + + store = state_manager.get_store(topic="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=2 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 0, + "end": 10, + "count": 1, + "sum": 1, + "mean": 1.0, + "max": 1, + "min": 1, + "collect": [], + }, + ), + ] + + updated, expired = process( + window, value=4, key=key, transaction=tx, timestamp_ms=6 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 0, + "end": 10, + "count": 2, + "sum": 5, + "mean": 2.5, + "max": 4, + "min": 1, + "collect": [], + }, + ), + ( + key, + { + "start": 5, + "end": 15, + "count": 1, + "sum": 4, + "mean": 4, + "max": 4, + "min": 4, + "collect": [], + }, + ), + ] + + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=12 + ) + assert expired == [ + ( + key, + { + "start": 0, + "end": 10, + "count": 2, + "sum": 5, + "mean": 2.5, + "max": 4, + "min": 1, + "collect": [1, 4], + }, + ), + ] + assert updated == [ + ( + key, + { + "start": 5, + "end": 15, + "count": 2, + "sum": 6, + "mean": 3.0, + "max": 4, + "min": 2, + "collect": [], + }, + ), + ( + key, + { + "start": 10, + "end": 20, + "count": 1, + "sum": 2, + "mean": 2, + "max": 2, + "min": 2, + "collect": [], + }, + ), + ] + + # Update window definition + # * delete an aggregation (min) + # * change aggregation but keep the name with new aggregation (mean -> max) + # * add new aggregations (sum2, collect2) + window = hopping_window_definition_factory(duration_ms=10, step_ms=5).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Max(), + max=agg.Max(), + collect=agg.Collect(), + sum2=agg.Sum(), + collect2=agg.Collect(), + ) + assert window.name == "hopping_window_10_5" # still the same window and store + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=16 + ) + assert ( + expired + == [ + ( + key, + { + "start": 5, + "end": 15, + "count": 2, + "sum": 6, + "sum2": 0, # sum2 only aggregates the values after the update. Use initial value. + "mean": None, # mean was replace by max. The aggregation restarts with the new values. Use initial value. + "max": 4, + "collect": [4, 2], + "collect2": [ + 4, + 2, + ], # Collect2 has all the values as they were fully collected before the update + }, + ) + ] + ) + assert ( + updated + == [ + ( + key, + { + "start": 10, + "end": 20, + "count": 2, + "sum": 3, + "sum2": 1, # sum2 only aggregates the values after the update + "mean": 1, # mean was replace by max. The aggregation restarts with the new values. + "max": 2, + "collect": [], + "collect2": [], + }, + ), + ( + key, + { + "start": 15, + "end": 25, + "count": 1, + "sum": 1, + "sum2": 1, # sum2 only aggregates the values after the update + "mean": 1, # mean was replace by max. The aggregation restarts with the new values. + "max": 1, + "collect": [], + "collect2": [], + }, + ), + ] + ) + @pytest.mark.parametrize("expiration", ("key", "partition")) def test_hoppingwindow_count( self, expiration, hopping_window_definition_factory, state_manager ): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.count() + assert window.name == "hopping_window_10_5_count" + window.final(closing_strategy=expiration) store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) @@ -88,6 +285,8 @@ def test_hoppingwindow_sum( ): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.sum() + assert window.name == "hopping_window_10_5_sum" + window.final(closing_strategy=expiration) store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) @@ -113,6 +312,8 @@ def test_hoppingwindow_mean( ): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.mean() + assert window.name == "hopping_window_10_5_mean" + window.final(closing_strategy=expiration) store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) @@ -141,6 +342,8 @@ def test_hoppingwindow_reduce( reducer=lambda agg, current: agg + [current], initializer=lambda value: [value], ) + assert window.name == "hopping_window_10_5_reduce" + window.final(closing_strategy=expiration) store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) @@ -165,6 +368,8 @@ def test_hoppingwindow_max( ): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.max() + assert window.name == "hopping_window_10_5_max" + window.final(closing_strategy=expiration) store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) @@ -189,6 +394,8 @@ def test_hoppingwindow_min( ): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.min() + assert window.name == "hopping_window_10_5_min" + window.final(closing_strategy=expiration) store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) @@ -213,6 +420,8 @@ def test_hoppingwindow_collect( ): window_def = hopping_window_definition_factory(duration_ms=10, step_ms=5) window = window_def.collect() + assert window.name == "hopping_window_10_5_collect" + window.final(closing_strategy=expiration) store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) @@ -471,10 +680,214 @@ def test_init_invalid(self, count, step, name, dataframe_factory): dataframe=dataframe_factory(), ) + def test_multiaggregation( + self, + count_hopping_window_definition_factory, + state_manager, + ): + window = count_hopping_window_definition_factory(count=3, step=2).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Mean(), + max=agg.Max(), + min=agg.Min(), + collect=agg.Collect(), + ) + window.final() + assert window.name == "hopping_count_window" + + store = state_manager.get_store(topic="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=2 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 2, + "end": 2, + "count": 1, + "sum": 1, + "mean": 1.0, + "max": 1, + "min": 1, + "collect": [], + }, + ), + ] + + updated, expired = process( + window, value=5, key=key, transaction=tx, timestamp_ms=6 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 2, + "end": 6, + "count": 2, + "sum": 6, + "mean": 3.0, + "max": 5, + "min": 1, + "collect": [], + }, + ), + ] + + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=12 + ) + assert expired == [ + ( + key, + { + "start": 2, + "end": 12, + "count": 3, + "sum": 9, + "mean": 3.0, + "max": 5, + "min": 1, + "collect": [1, 5, 3], + }, + ), + ] + assert updated == [ + ( + key, + { + "start": 2, + "end": 12, + "count": 3, + "sum": 9, + "mean": 3, + "max": 5, + "min": 1, + "collect": [], + }, + ), + ( + key, + { + "start": 12, + "end": 12, + "count": 1, + "sum": 3, + "mean": 3, + "max": 3, + "min": 3, + "collect": [], + }, + ), + ] + + # Update window definition + # * delete an aggregation (min) + # * change aggregation but keep the name with new aggregation (mean -> max) + # * add new aggregations (sum2, collect2) + window = count_hopping_window_definition_factory(count=3, step=2).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Max(), + max=agg.Max(), + collect=agg.Collect(), + sum2=agg.Sum(), + collect2=agg.Collect(), + ) + assert window.name == "hopping_count_window" # still the same window and store + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=16 + ) + assert not expired + assert ( + updated + == [ + ( + key, + { + "start": 12, + "end": 16, + "count": 2, + "sum": 4, + "sum2": 1, # sum2 only aggregates the values after the update + "mean": 1, # mean was replace by max. The aggregation restarts with the new values. + "max": 3, + "collect": [], + "collect2": [], + }, + ), + ] + ) + + updated, expired = process( + window, value=4, key=key, transaction=tx, timestamp_ms=22 + ) + assert ( + expired + == [ + ( + key, + { + "start": 12, + "end": 22, + "count": 3, + "sum": 8, + "sum2": 5, # sum2 only aggregates the values after the update + "mean": 4, # mean was replace by max. The aggregation restarts with the new values. + "max": 4, + "collect": [3, 1, 4], + "collect2": [3, 1, 4], + }, + ), + ] + ) + assert ( + updated + == [ + ( + key, + { + "start": 12, + "end": 22, + "count": 3, + "sum": 8, + "sum2": 5, # sum2 only aggregates the values after the update + "mean": 4, # mean was replace by max. The aggregation restarts with the new values. + "max": 4, + "collect": [], + "collect2": [], + }, + ), + ( + key, + { + "start": 22, + "end": 22, + "count": 1, + "sum": 4, + "sum2": 4, + "mean": 4, + "max": 4, + "collect": [], + "collect2": [], + }, + ), + ] + ) + def test_count(self, count_hopping_window_definition_factory, state_manager): window_def = count_hopping_window_definition_factory(count=4, step=2) window = window_def.count() - window.register_store() + assert window.name == "hopping_count_window_count" + + window.final() store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -529,7 +942,9 @@ def test_count(self, count_hopping_window_definition_factory, state_manager): def test_sum(self, count_hopping_window_definition_factory, state_manager): window_def = count_hopping_window_definition_factory(count=4, step=2) window = window_def.sum() - window.register_store() + assert window.name == "hopping_count_window_sum" + + window.final() store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -584,7 +999,9 @@ def test_sum(self, count_hopping_window_definition_factory, state_manager): def test_mean(self, count_hopping_window_definition_factory, state_manager): window_def = count_hopping_window_definition_factory(count=4, step=2) window = window_def.mean() - window.register_store() + assert window.name == "hopping_count_window_mean" + + window.final() store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -631,8 +1048,9 @@ def test_mean(self, count_hopping_window_definition_factory, state_manager): window, key="", value=6, transaction=tx, timestamp_ms=100 ) assert len(updated) == 2 - assert updated[0][1]["value"] == 4.5 # (3 + 4 + 5 + 6) / 4 - assert updated[1][1]["value"] == 5.5 # (5 + 6) / 2 + assert ( + updated[0][1]["value"] == 4.5 + ) # (3 # sum2 only aggregates the values after the update + 6) / 2 assert len(expired) == 1 assert expired[0][1]["value"] == 4.5 @@ -642,7 +1060,9 @@ def test_reduce(self, count_hopping_window_definition_factory, state_manager): reducer=lambda agg, current: agg + [current], initializer=lambda value: [value], ) - window.register_store() + assert window.name == "hopping_count_window_reduce" + + window.final() store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -697,7 +1117,9 @@ def test_reduce(self, count_hopping_window_definition_factory, state_manager): def test_max(self, count_hopping_window_definition_factory, state_manager): window_def = count_hopping_window_definition_factory(count=4, step=2) window = window_def.max() - window.register_store() + assert window.name == "hopping_count_window_max" + + window.final() store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -752,7 +1174,9 @@ def test_max(self, count_hopping_window_definition_factory, state_manager): def test_min(self, count_hopping_window_definition_factory, state_manager): window_def = count_hopping_window_definition_factory(count=4, step=2) window = window_def.min() - window.register_store() + assert window.name == "hopping_count_window_min" + + window.final() store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -807,7 +1231,9 @@ def test_min(self, count_hopping_window_definition_factory, state_manager): def test_collect(self, count_hopping_window_definition_factory, state_manager): window_def = count_hopping_window_definition_factory(count=4, step=2) window = window_def.collect() - window.register_store() + assert window.name == "hopping_count_window_collect" + + window.final() store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py b/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py index bbfb2837d..0ddf95ff7 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_sliding.py @@ -6,6 +6,7 @@ import pytest +import quixstreams.dataframe.windows.aggregations as agg from quixstreams.dataframe.windows import SlidingTimeWindowDefinition A, B, C, D = "A", "B", "C", "D" @@ -19,6 +20,13 @@ } +def process(window, value, key, transaction, timestamp_ms): + updated, expired = window.process_window( + value=value, key=key, transaction=transaction, timestamp_ms=timestamp_ms + ) + return list(updated), list(expired) + + @dataclass class Message: """ @@ -849,7 +857,8 @@ def test_sliding_window_reduce( key = b"key" for message in messages: with transaction_factory(window) as tx: - updated, expired = window.process_window( + updated, expired = process( + window=window, value=message.value, key=key, timestamp_ms=message.timestamp, @@ -974,7 +983,7 @@ def test_sliding_window_reduce( pytest.param(10, 5, COLLECTION_AGGREGATION, id="collection-aggregation"), ], ) -def test_sliding_windw_collect( +def test_sliding_window_collect( window_factory, transaction_factory, duration_ms, @@ -988,14 +997,15 @@ def test_sliding_windw_collect( ) for message in messages: with transaction_factory(window) as tx: - updated, expired = window.process_window( + updated, expired = process( + window=window, value=message.value, key=key, timestamp_ms=message.timestamp, transaction=tx, ) - assert list(updated) == [] # updates are not supported for collections + # assert list(updated) == [] # updates are not supported for collections assert [msg[1] for msg in expired] == message.expired with transaction_factory(window) as tx: @@ -1018,3 +1028,217 @@ def test_sliding_windw_collect( -1, 99, state._prefix ) assert all_values_in_state == message.expected_values_in_state + + +def test_sliding_window_multiaggregation( + sliding_window_definition_factory, transaction_factory +): + window = sliding_window_definition_factory(duration_ms=10, grace_ms=0).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Mean(), + max=agg.Max(), + min=agg.Min(), + collect=agg.Collect(), + ) + window.final() + assert window.name == "sliding_window_10" + + key = b"key" + with transaction_factory(window) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=2 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 0, + "end": 2, + "count": 1, + "sum": 1, + "mean": 1.0, + "max": 1, + "min": 1, + "collect": [], + }, + ), + ] + + updated, expired = process( + window, value=3, key=key, transaction=tx, timestamp_ms=3 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 0, + "end": 3, + "count": 2, + "sum": 4, + "mean": 2.0, + "max": 3, + "min": 1, + "collect": [], + }, + ), + ] + + updated, expired = process( + window, value=5, key=key, transaction=tx, timestamp_ms=11 + ) + assert expired == [ + ( + key, + { + "start": 0, + "end": 2, + "count": 1, + "sum": 1, + "mean": 1.0, + "max": 1, + "min": 1, + "collect": [ + 1, + ], + }, + ), + ( + key, + { + "start": 0, + "end": 3, + "count": 2, + "sum": 4, + "mean": 2.0, + "max": 3, + "min": 1, + "collect": [ + 1, + 3, + ], + }, + ), + ] + + assert updated == [ + ( + key, + { + "start": 1, + "end": 11, + "count": 3, + "sum": 9, + "mean": 3.0, + "max": 5, + "min": 1, + "collect": [], + }, + ), + ] + + # Update window definition + # * delete an aggregation (min) + # * change aggregation but keep the name with new aggregation (mean -> max) + # * add new aggregations (sum2, collect2) + window = sliding_window_definition_factory(duration_ms=10, grace_ms=0).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Max(), + max=agg.Max(), + collect=agg.Collect(), + sum2=agg.Sum(), + collect2=agg.Collect(), + ) + assert window.name == "sliding_window_10" # still the same window and store + + with transaction_factory(window) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=14 + ) + assert ( + expired + == [ + ( + key, + { + "start": 1, + "end": 11, + "count": 3, + "sum": 9, + "sum2": 0, # sum2 only aggregates the values after the update. Use initial value. + "mean": None, # mean was replace by max. The aggregation restarts with the new values. Use initial value. + "max": 5, + "collect": [1, 3, 5], + "collect2": [ + 1, + 3, + 5, + ], # Collect2 has all the values as they were fully collected before the update + }, + ), + ] + ) + assert updated == [ + ( + key, + { + "start": 4, + "end": 14, + "count": 2, + "sum": 6, + "sum2": 1, + "mean": 1, + "max": 5, + "collect": [], + "collect2": [], + }, + ), + ] + + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=18 + ) + assert ( + expired + == [ + ( + key, + { + "start": 4, + "end": 14, + "count": 2, + "sum": 6, + "sum2": 1, # sum2 only aggregates the values after the update. + "mean": 1, # mean was replace by max. The aggregation restarts with the new values. + "max": 5, + "collect": [5, 1], + "collect2": [ + 5, + 1, + ], # Collect2 has all the values as they were fully collected before the update + }, + ), + ] + ) + assert ( + updated + == [ + ( + key, + { + "start": 8, + "end": 18, + "count": 3, + "sum": 8, + "sum2": 3, # sum2 only aggregates the values after the update. + "mean": 2, # mean was replace by max. The aggregation restarts with the new values. + "max": 5, + "collect": [], + "collect2": [], + }, + ), + ] + ) diff --git a/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py b/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py index 5a9cd8f2b..ff724c598 100644 --- a/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py +++ b/tests/test_quixstreams/test_dataframe/test_windows/test_tumbling.py @@ -1,5 +1,6 @@ import pytest +import quixstreams.dataframe.windows.aggregations as agg from quixstreams.dataframe.windows import ( TumblingCountWindowDefinition, TumblingTimeWindowDefinition, @@ -53,12 +54,189 @@ def test_tumbling_window_definition_get_name( name = twd._get_name(func_name) assert name == expected_name + def test_multiaggregation( + self, + tumbling_window_definition_factory, + state_manager, + ): + window = tumbling_window_definition_factory(duration_ms=10).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Mean(), + max=agg.Max(), + min=agg.Min(), + collect=agg.Collect(), + ) + window.final(closing_strategy="key") + assert window.name == "tumbling_window_10" + + store = state_manager.get_store(topic="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=2 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 0, + "end": 10, + "count": 1, + "sum": 1, + "mean": 1.0, + "max": 1, + "min": 1, + "collect": [], + }, + ) + ] + + updated, expired = process( + window, value=4, key=key, transaction=tx, timestamp_ms=4 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 0, + "end": 10, + "count": 2, + "sum": 5, + "mean": 2.5, + "max": 4, + "min": 1, + "collect": [], + }, + ) + ] + + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=12 + ) + assert expired == [ + ( + key, + { + "start": 0, + "end": 10, + "count": 2, + "sum": 5, + "mean": 2.5, + "max": 4, + "min": 1, + "collect": [1, 4], + }, + ) + ] + assert updated == [ + ( + key, + { + "start": 10, + "end": 20, + "count": 1, + "sum": 2, + "mean": 2.0, + "max": 2, + "min": 2, + "collect": [], + }, + ) + ] + + # Update window definition + # * delete an aggregation (min) + # * change aggregation but keep the name with new aggregation (mean -> max) + # * add new aggregations (sum2, collect2) + window = tumbling_window_definition_factory(duration_ms=10).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Max(), + max=agg.Max(), + collect=agg.Collect(), + sum2=agg.Sum(), + collect2=agg.Collect(), + ) + assert window.name == "tumbling_window_10" # still the same window and store + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=13 + ) + assert not expired + assert ( + updated + == [ + ( + key, + { + "start": 10, + "end": 20, + "count": 2, + "sum": 3, + "sum2": 1, # sum2 only aggregates the values after the update + "mean": 1, # mean was replace by max. The aggregation restarts with the new values. + "max": 2, + "collect": [], + "collect2": [], + }, + ) + ] + ) + + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=22 + ) + assert ( + expired + == [ + ( + key, + { + "start": 10, + "end": 20, + "count": 2, + "sum": 3, + "sum2": 1, # sum2 only aggregates the values after the update + "mean": 1, # mean was replace by max. The aggregation restarts with the new values. + "max": 2, + "collect": [2, 1], + "collect2": [ + 2, + 1, + ], # Collect2 has all the values as they were fully collected before the update + }, + ) + ] + ) + assert updated == [ + ( + key, + { + "start": 20, + "end": 30, + "count": 1, + "sum": 2, + "sum2": 2, + "mean": 2, + "max": 2, + "collect": [], + "collect2": [], + }, + ) + ] + @pytest.mark.parametrize("expiration", ("key", "partition")) def test_tumblingwindow_count( self, expiration, tumbling_window_definition_factory, state_manager ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.count() + assert window.name == "tumbling_window_10_count" + window.final(closing_strategy=expiration) store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) @@ -78,6 +256,8 @@ def test_tumblingwindow_sum( ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.sum() + assert window.name == "tumbling_window_10_sum" + window.final(closing_strategy=expiration) store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) @@ -97,6 +277,8 @@ def test_tumblingwindow_mean( ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.mean() + assert window.name == "tumbling_window_10_mean" + window.final(closing_strategy=expiration) store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) @@ -119,6 +301,8 @@ def test_tumblingwindow_reduce( reducer=lambda agg, current: agg + [current], initializer=lambda value: [value], ) + assert window.name == "tumbling_window_10_reduce" + window.final(closing_strategy=expiration) store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) @@ -138,6 +322,8 @@ def test_tumblingwindow_max( ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.max() + assert window.name == "tumbling_window_10_max" + window.final(closing_strategy=expiration) store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) @@ -157,6 +343,8 @@ def test_tumblingwindow_min( ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.min() + assert window.name == "tumbling_window_10_min" + window.final(closing_strategy=expiration) store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) @@ -176,6 +364,8 @@ def test_tumblingwindow_collect( ): window_def = tumbling_window_definition_factory(duration_ms=10, grace_ms=5) window = window_def.collect() + assert window.name == "tumbling_window_10_collect" + window.final(closing_strategy=expiration) store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) @@ -398,10 +588,187 @@ def test_init_invalid(self, count, name, dataframe_factory): dataframe=dataframe_factory(), ) + def test_multiaggregation( + self, + count_tumbling_window_definition_factory, + state_manager, + ): + window = count_tumbling_window_definition_factory(count=2).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Mean(), + max=agg.Max(), + min=agg.Min(), + collect=agg.Collect(), + ) + window.final() + assert window.name == "tumbling_count_window" + + store = state_manager.get_store(topic="test", store_name=window.name) + store.assign_partition(0) + key = b"key" + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=2 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 2, + "end": 2, + "count": 1, + "sum": 1, + "mean": 1.0, + "max": 1, + "min": 1, + "collect": [], + }, + ) + ] + + updated, expired = process( + window, value=4, key=key, transaction=tx, timestamp_ms=4 + ) + assert expired == [ + ( + key, + { + "start": 2, + "end": 4, + "count": 2, + "sum": 5, + "mean": 2.5, + "max": 4, + "min": 1, + "collect": [1, 4], + }, + ) + ] + assert updated == [ + ( + key, + { + "start": 2, + "end": 4, + "count": 2, + "sum": 5, + "mean": 2.5, + "max": 4, + "min": 1, + "collect": [], + }, + ) + ] + + updated, expired = process( + window, value=2, key=key, transaction=tx, timestamp_ms=12 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 12, + "end": 12, + "count": 1, + "sum": 2, + "mean": 2.0, + "max": 2, + "min": 2, + "collect": [], + }, + ) + ] + + # Update window definition + # * delete an aggregation (min) + # * change aggregation but keep the name with new aggregation (mean -> max) + # * add new aggregations (sum2, collect2) + window = count_tumbling_window_definition_factory(count=2).agg( + count=agg.Count(), + sum=agg.Sum(), + mean=agg.Max(), + max=agg.Max(), + collect=agg.Collect(), + sum2=agg.Sum(), + collect2=agg.Collect(), + ) + assert window.name == "tumbling_count_window" # still the same window and store + with store.start_partition_transaction(0) as tx: + updated, expired = process( + window, value=1, key=key, transaction=tx, timestamp_ms=13 + ) + assert ( + expired + == [ + ( + key, + { + "start": 12, + "end": 13, + "count": 2, + "sum": 3, + "sum2": 1, # sum2 only aggregates the values after the update + "mean": 1, # mean was replace by max. The aggregation restarts with the new values. + "max": 2, + "collect": [2, 1], + "collect2": [ + 2, + 1, + ], # Collect2 has all the values as they were fully collected before the update + }, + ) + ] + ) + assert ( + updated + == [ + ( + key, + { + "start": 12, + "end": 13, + "count": 2, + "sum": 3, + "sum2": 1, # sum2 only aggregates the values after the update + "mean": 1, # mean was replace by max. The aggregation restarts with the new values. + "max": 2, + "collect": [], + "collect2": [], + }, + ) + ] + ) + + updated, expired = process( + window, value=5, key=key, transaction=tx, timestamp_ms=15 + ) + assert not expired + assert updated == [ + ( + key, + { + "start": 15, + "end": 15, + "count": 1, + "sum": 5, + "sum2": 5, + "mean": 5, + "max": 5, + "collect": [], + "collect2": [], + }, + ) + ] + def test_count(self, count_tumbling_window_definition_factory, state_manager): window_def = count_tumbling_window_definition_factory(count=10) window = window_def.count() - window.register_store() + assert window.name == "tumbling_count_window_count" + + window.final() store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -416,7 +783,9 @@ def test_count(self, count_tumbling_window_definition_factory, state_manager): def test_sum(self, count_tumbling_window_definition_factory, state_manager): window_def = count_tumbling_window_definition_factory(count=10) window = window_def.sum() - window.register_store() + assert window.name == "tumbling_count_window_sum" + + window.final() store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -431,7 +800,9 @@ def test_sum(self, count_tumbling_window_definition_factory, state_manager): def test_mean(self, count_tumbling_window_definition_factory, state_manager): window_def = count_tumbling_window_definition_factory(count=10) window = window_def.mean() - window.register_store() + assert window.name == "tumbling_count_window_mean" + + window.final() store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -449,7 +820,9 @@ def test_reduce(self, count_tumbling_window_definition_factory, state_manager): reducer=lambda agg, current: agg + [current], initializer=lambda value: [value], ) - window.register_store() + assert window.name == "tumbling_count_window_reduce" + + window.final() store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -464,7 +837,9 @@ def test_reduce(self, count_tumbling_window_definition_factory, state_manager): def test_max(self, count_tumbling_window_definition_factory, state_manager): window_def = count_tumbling_window_definition_factory(count=10) window = window_def.max() - window.register_store() + assert window.name == "tumbling_count_window_max" + + window.final() store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -479,7 +854,9 @@ def test_max(self, count_tumbling_window_definition_factory, state_manager): def test_min(self, count_tumbling_window_definition_factory, state_manager): window_def = count_tumbling_window_definition_factory(count=10) window = window_def.min() - window.register_store() + assert window.name == "tumbling_count_window_min" + + window.final() store = state_manager.get_store(topic="test", store_name=window.name) store.assign_partition(0) with store.start_partition_transaction(0) as tx: @@ -491,6 +868,29 @@ def test_min(self, count_tumbling_window_definition_factory, state_manager): assert updated[0][1]["value"] == 1 assert not expired + def test_collect(self, count_tumbling_window_definition_factory, state_manager): + window_def = count_tumbling_window_definition_factory(count=3) + window = window_def.collect() + assert window.name == "tumbling_count_window_collect" + + window.final() + store = state_manager.get_store(topic="test", store_name=window.name) + store.assign_partition(0) + with store.start_partition_transaction(0) as tx: + process(window, key="", value=1, transaction=tx, timestamp_ms=100) + process(window, key="", value=2, transaction=tx, timestamp_ms=100) + updated, expired = process( + window, key="", value=3, transaction=tx, timestamp_ms=101 + ) + + assert not updated + assert expired == [("", {"start": 100, "end": 101, "value": [1, 2, 3]})] + + with store.start_partition_transaction(0) as tx: + state = tx.as_state(prefix=b"") + remaining_items = state.get_from_collection(start=0, end=1000) + assert remaining_items == [] + def test_window_expired( self, count_tumbling_window_definition_factory, @@ -527,27 +927,6 @@ def test_window_expired( assert expired[0][1]["start"] == 100 assert expired[0][1]["end"] == 110 - def test_collect(self, count_tumbling_window_definition_factory, state_manager): - window_def = count_tumbling_window_definition_factory(count=3) - window = window_def.collect() - window.register_store() - store = state_manager.get_store(topic="test", store_name=window.name) - store.assign_partition(0) - with store.start_partition_transaction(0) as tx: - process(window, key="", value=1, transaction=tx, timestamp_ms=100) - process(window, key="", value=2, transaction=tx, timestamp_ms=100) - updated, expired = process( - window, key="", value=3, transaction=tx, timestamp_ms=101 - ) - - assert not updated - assert expired == [("", {"start": 100, "end": 101, "value": [1, 2, 3]})] - - with store.start_partition_transaction(0) as tx: - state = tx.as_state(prefix=b"") - remaining_items = state.get_from_collection(start=0, end=1000) - assert remaining_items == [] - def test_multiple_keys_sum( self, count_tumbling_window_definition_factory, state_manager ): diff --git a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py index faf3d2dd2..a44ce8ae9 100644 --- a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py +++ b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_state.py @@ -60,15 +60,19 @@ def test_expire_windows(transaction_state, delete): with transaction_state() as state: state.update_window(start_ms=20, end_ms=30, value=3, timestamp_ms=20) max_start_time = state.get_latest_timestamp() - duration_ms - expired = state.expire_windows(max_start_time=max_start_time, delete=delete) + expired = list( + state.expire_windows(max_start_time=max_start_time, delete=delete) + ) # "expire_windows" must update the expiration index so that the same # windows are not expired twice - assert not state.expire_windows(max_start_time=max_start_time, delete=delete) + assert not list( + state.expire_windows(max_start_time=max_start_time, delete=delete) + ) assert len(expired) == 2 assert expired == [ - ((0, 10), 1, b"__key__"), - ((10, 20), 2, b"__key__"), + ((0, 10), 1, [], b"__key__"), + ((10, 20), 2, [], b"__key__"), ] with transaction_state() as state: @@ -97,17 +101,19 @@ def test_expire_windows_with_collect(transaction_state, end_inclusive): with transaction_state() as state: state.update_window(start_ms=20, end_ms=30, value=None, timestamp_ms=20) max_start_time = state.get_latest_timestamp() - duration_ms - expired = state.expire_windows( - max_start_time=max_start_time, - collect=True, - end_inclusive=end_inclusive, + expired = list( + state.expire_windows( + max_start_time=max_start_time, + collect=True, + end_inclusive=end_inclusive, + ) ) window_1_value = ["a", "b"] if end_inclusive else ["a"] window_2_value = ["b", "c"] if end_inclusive else ["b"] assert expired == [ - ((0, 10), window_1_value, b"__key__"), - ((10, 20), [777, window_2_value], b"__key__"), + ((0, 10), None, window_1_value, b"__key__"), + ((10, 20), [777, None], window_2_value, b"__key__"), ] @@ -123,10 +129,10 @@ def test_same_keys_in_db_and_update_cache(transaction_state): state.update_window(start_ms=10, end_ms=20, value=2, timestamp_ms=10) max_start_time = state.get_latest_timestamp() - duration_ms - expired = state.expire_windows(max_start_time=max_start_time) + expired = list(state.expire_windows(max_start_time=max_start_time)) # Value from the cache takes precedence over the value in the db - assert expired == [((0, 10), 3, b"__key__")] + assert expired == [((0, 10), 3, [], b"__key__")] def test_get_latest_timestamp(windowed_rocksdb_store_factory): diff --git a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py index 32362618c..a58906367 100644 --- a/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py +++ b/tests/test_quixstreams/test_state/test_rocksdb/test_windowed/test_transaction.py @@ -63,19 +63,23 @@ def test_expire_windows_expired(self, windowed_rocksdb_store_factory, delete): start_ms=20, end_ms=30, value=3, timestamp_ms=20, prefix=prefix ) max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms - expired = tx.expire_windows( - max_start_time=max_start_time, prefix=prefix, delete=delete + expired = list( + tx.expire_windows( + max_start_time=max_start_time, prefix=prefix, delete=delete + ) ) # "expire_windows" must update the expiration index so that the same # windows are not expired twice - assert not tx.expire_windows( - max_start_time=max_start_time, prefix=prefix, delete=delete + assert not list( + tx.expire_windows( + max_start_time=max_start_time, prefix=prefix, delete=delete + ) ) assert len(expired) == 2 assert expired == [ - ((0, 10), 1, prefix), - ((10, 20), 2, prefix), + ((0, 10), 1, [], prefix), + ((10, 20), 2, [], prefix), ] with store.start_partition_transaction(0) as tx: @@ -113,18 +117,22 @@ def test_expire_windows_cached(self, windowed_rocksdb_store_factory, delete): start_ms=20, end_ms=30, value=3, timestamp_ms=20, prefix=prefix ) max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms - expired = tx.expire_windows( - max_start_time=max_start_time, prefix=prefix, delete=delete + expired = list( + tx.expire_windows( + max_start_time=max_start_time, prefix=prefix, delete=delete + ) ) # "expire_windows" must update the expiration index so that the same # windows are not expired twice - assert not tx.expire_windows( - max_start_time=max_start_time, prefix=prefix, delete=delete + assert not list( + tx.expire_windows( + max_start_time=max_start_time, prefix=prefix, delete=delete + ) ) assert len(expired) == 2 assert expired == [ - ((0, 10), 1, prefix), - ((10, 20), 2, prefix), + ((0, 10), 1, [], prefix), + ((10, 20), 2, [], prefix), ] assert ( tx.get_window(start_ms=0, end_ms=10, prefix=prefix) == None @@ -157,7 +165,9 @@ def test_expire_windows_empty(self, windowed_rocksdb_store_factory): start_ms=3, end_ms=13, value=1, timestamp_ms=3, prefix=prefix ) max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms - assert not tx.expire_windows(max_start_time=max_start_time, prefix=prefix) + assert not list( + tx.expire_windows(max_start_time=max_start_time, prefix=prefix) + ) def test_expire_windows_with_grace_expired(self, windowed_rocksdb_store_factory): store = windowed_rocksdb_store_factory() @@ -178,10 +188,12 @@ def test_expire_windows_with_grace_expired(self, windowed_rocksdb_store_factory) max_start_time = ( tx.get_latest_timestamp(prefix=prefix) - duration_ms - grace_ms ) - expired = tx.expire_windows(max_start_time=max_start_time, prefix=prefix) + expired = list( + tx.expire_windows(max_start_time=max_start_time, prefix=prefix) + ) assert len(expired) == 1 - assert expired == [((0, 10), 1, prefix)] + assert expired == [((0, 10), 1, [], prefix)] def test_expire_windows_with_grace_empty(self, windowed_rocksdb_store_factory): store = windowed_rocksdb_store_factory() @@ -202,7 +214,9 @@ def test_expire_windows_with_grace_empty(self, windowed_rocksdb_store_factory): max_start_time = ( tx.get_latest_timestamp(prefix=prefix) - duration_ms - grace_ms ) - expired = tx.expire_windows(max_start_time=max_start_time, prefix=prefix) + expired = list( + tx.expire_windows(max_start_time=max_start_time, prefix=prefix) + ) assert not expired @@ -281,7 +295,9 @@ def test_expire_windows_no_expired(self, windowed_rocksdb_store_factory): # "expire_windows" must update the expiration index so that the same # windows are not expired twice max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms - assert not tx.expire_windows(max_start_time=max_start_time, prefix=prefix) + assert not list( + tx.expire_windows(max_start_time=max_start_time, prefix=prefix) + ) def test_expire_windows_multiple_windows(self, windowed_rocksdb_store_factory): store = windowed_rocksdb_store_factory() @@ -307,12 +323,14 @@ def test_expire_windows_multiple_windows(self, windowed_rocksdb_store_factory): # "expire_windows" must update the expiration index so that the same # windows are not expired twice max_start_time = tx.get_latest_timestamp(prefix=prefix) - duration_ms - expired = tx.expire_windows(max_start_time=max_start_time, prefix=prefix) + expired = list( + tx.expire_windows(max_start_time=max_start_time, prefix=prefix) + ) assert len(expired) == 3 - assert expired[0] == ((0, 10), 1, prefix) - assert expired[1] == ((10, 20), 1, prefix) - assert expired[2] == ((20, 30), 1, prefix) + assert expired[0] == ((0, 10), 1, [], prefix) + assert expired[1] == ((10, 20), 1, [], prefix) + assert expired[2] == ((20, 30), 1, [], prefix) def test_get_latest_timestamp_update(self, windowed_rocksdb_store_factory): store = windowed_rocksdb_store_factory()