From 3bb8f312ca13c67adaf91a04fdeb3fc320e8962e Mon Sep 17 00:00:00 2001 From: Meredith Heller Date: Wed, 6 Nov 2024 14:53:05 -0800 Subject: [PATCH] ref(batching): add increment_by to BatchStep --- arroyo/processing/processor.py | 6 +- arroyo/processing/strategies/batching.py | 20 +++--- arroyo/processing/strategies/reduce.py | 11 +++- tests/processing/strategies/test_batching.py | 65 ++++++++++++++++++-- 4 files changed, 85 insertions(+), 17 deletions(-) diff --git a/arroyo/processing/processor.py b/arroyo/processing/processor.py index 631d3396..621f8344 100644 --- a/arroyo/processing/processor.py +++ b/arroyo/processing/processor.py @@ -143,9 +143,9 @@ def __init__( self.__processor_factory = processor_factory self.__metrics_buffer = MetricsBuffer() - self.__processing_strategy: Optional[ - ProcessingStrategy[TStrategyPayload] - ] = None + self.__processing_strategy: Optional[ProcessingStrategy[TStrategyPayload]] = ( + None + ) self.__message: Optional[BrokerValue[TStrategyPayload]] = None diff --git a/arroyo/processing/strategies/batching.py b/arroyo/processing/strategies/batching.py index 185cab2b..27dfad6d 100644 --- a/arroyo/processing/strategies/batching.py +++ b/arroyo/processing/strategies/batching.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import MutableSequence, Optional, Union +from typing import MutableSequence, Optional, Union, Callable from arroyo.processing.strategies.abstract import ProcessingStrategy from arroyo.processing.strategies.reduce import Reduce @@ -41,6 +41,7 @@ def __init__( max_batch_size: int, max_batch_time: float, next_step: ProcessingStrategy[ValuesBatch[TStrategyPayload]], + increment_by: Optional[Callable[[BaseValue[TStrategyPayload]], int]] = None, ) -> None: def accumulator( result: ValuesBatch[TStrategyPayload], value: BaseValue[TStrategyPayload] @@ -48,14 +49,15 @@ def accumulator( result.append(value) return result - self.__reduce_step: Reduce[ - TStrategyPayload, ValuesBatch[TStrategyPayload] - ] = Reduce( - max_batch_size, - max_batch_time, - accumulator, - lambda: [], - next_step, + self.__reduce_step: Reduce[TStrategyPayload, ValuesBatch[TStrategyPayload]] = ( + Reduce( + max_batch_size, + max_batch_time, + accumulator, + lambda: [], + next_step, + increment_by, + ) ) def submit( diff --git a/arroyo/processing/strategies/reduce.py b/arroyo/processing/strategies/reduce.py index a339f4ac..70ff1792 100644 --- a/arroyo/processing/strategies/reduce.py +++ b/arroyo/processing/strategies/reduce.py @@ -21,11 +21,13 @@ def __init__( initial_value: Callable[[], TResult], max_batch_size: int, max_batch_time: float, + increment_by: Optional[Callable[[BaseValue[TPayload]], int]] = None, ): self.accumulator = accumulator self.initial_value = initial_value self.max_batch_size = max_batch_size self.max_batch_time = max_batch_time + self.increment_by = increment_by self._buffer = initial_value() self._buffer_size = 0 @@ -48,7 +50,11 @@ def is_ready(self) -> bool: def append(self, message: BaseValue[TPayload]) -> None: self._buffer = self.accumulator(self._buffer, message) - self._buffer_size += 1 + if self.increment_by: + buffer_increment = self.increment_by(message) + else: + buffer_increment = 1 + self._buffer_size += buffer_increment def new(self) -> "ReduceBuffer[TPayload, TResult]": return ReduceBuffer( @@ -56,6 +62,7 @@ def new(self) -> "ReduceBuffer[TPayload, TResult]": initial_value=self.initial_value, max_batch_size=self.max_batch_size, max_batch_time=self.max_batch_time, + increment_by=self.increment_by, ) @@ -83,6 +90,7 @@ def __init__( accumulator: Accumulator[TResult, TPayload], initial_value: Callable[[], TResult], next_step: ProcessingStrategy[TResult], + increment_by: Optional[Callable[[BaseValue[TPayload]], int]] = None, ) -> None: self.__buffer_step = Buffer( buffer=ReduceBuffer( @@ -90,6 +98,7 @@ def __init__( max_batch_time=max_batch_time, accumulator=accumulator, initial_value=initial_value, + increment_by=increment_by, ), next_step=next_step, ) diff --git a/tests/processing/strategies/test_batching.py b/tests/processing/strategies/test_batching.py index cdff34a0..779ffde9 100644 --- a/tests/processing/strategies/test_batching.py +++ b/tests/processing/strategies/test_batching.py @@ -1,6 +1,6 @@ import time from datetime import datetime -from typing import Any, Mapping, Sequence, cast +from typing import Any, Mapping, Sequence, cast, Optional, Callable from unittest.mock import Mock, call, patch import pytest @@ -8,7 +8,7 @@ from arroyo.processing.strategies.abstract import MessageRejected from arroyo.processing.strategies.batching import BatchStep, UnbatchStep, ValuesBatch from arroyo.processing.strategies.run_task import RunTask -from arroyo.types import BrokerValue, Message, Partition, Topic, Value +from arroyo.types import BrokerValue, Message, Partition, Topic, Value, BaseValue NOW = datetime(2022, 1, 1, 0, 0, 1) @@ -114,6 +114,10 @@ def message(partition: int, offset: int, payload: str) -> Message[str]: return Message(broker_value(partition=partition, offset=offset, payload=payload)) +def increment_by(message: BaseValue[str]) -> int: + return len(message.payload) + + test_batch = [ pytest.param( datetime(2022, 1, 1, 0, 0, 1), @@ -122,6 +126,7 @@ def message(partition: int, offset: int, payload: str) -> Message[str]: message(0, 2, "Message 2"), ], [], + None, id="Half full batch", ), pytest.param( @@ -146,6 +151,7 @@ def message(partition: int, offset: int, payload: str) -> Message[str]: ) ) ], + None, id="One full batch", ), pytest.param( @@ -186,23 +192,74 @@ def message(partition: int, offset: int, payload: str) -> Message[str]: ) ), ], + None, id="Two full batches", ), + pytest.param( + datetime(2022, 1, 1, 0, 0, 1), + [ + message(0, 1, "1"), + message(0, 2, "11"), + message(0, 3, "222"), + message(1, 1, "33"), + message(1, 2, "333"), + ], + [ + call( + Message( + Value( + payload=[broker_value(0, 1, "1"), broker_value(0, 2, "11")], + committable={Partition(Topic("test"), 0): 3}, + timestamp=NOW, + ), + ) + ), + call( + Message( + Value( + payload=[ + broker_value(0, 3, "222"), + ], + committable={Partition(Topic("test"), 0): 4}, + timestamp=NOW, + ), + ) + ), + call( + Message( + Value( + payload=[ + broker_value(1, 1, "33"), + broker_value(1, 2, "333"), + ], + committable={Partition(Topic("test"), 1): 3}, + timestamp=NOW, + ), + ) + ), + ], + increment_by, + id="Three batches using increment by", + ), ] -@pytest.mark.parametrize("start_time, messages_in, expected_batches", test_batch) +@pytest.mark.parametrize( + "start_time, messages_in, expected_batches, increment_by", test_batch +) @patch("time.time") def test_batch_step( mock_time: Any, start_time: datetime, messages_in: Sequence[Message[str]], expected_batches: Sequence[ValuesBatch[str]], + increment_by: Optional[Callable[[BaseValue[str]], int]], ) -> None: start = time.mktime(start_time.timetuple()) mock_time.return_value = start next_step = Mock() - batch_step = BatchStep[str](3, 10.0, next_step) + print("incrementby", increment_by) + batch_step = BatchStep[str](3, 10.0, next_step, increment_by) for message in messages_in: batch_step.submit(message) batch_step.poll()