Skip to content

Commit

Permalink
ref(batching): add increment_by to BatchStep
Browse files Browse the repository at this point in the history
  • Loading branch information
MeredithAnya committed Nov 6, 2024
1 parent a8545c7 commit 3bb8f31
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 17 deletions.
6 changes: 3 additions & 3 deletions arroyo/processing/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ def __init__(
self.__processor_factory = processor_factory
self.__metrics_buffer = MetricsBuffer()

self.__processing_strategy: Optional[
ProcessingStrategy[TStrategyPayload]
] = None
self.__processing_strategy: Optional[ProcessingStrategy[TStrategyPayload]] = (
None
)

self.__message: Optional[BrokerValue[TStrategyPayload]] = None

Expand Down
20 changes: 11 additions & 9 deletions arroyo/processing/strategies/batching.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import MutableSequence, Optional, Union
from typing import MutableSequence, Optional, Union, Callable

from arroyo.processing.strategies.abstract import ProcessingStrategy
from arroyo.processing.strategies.reduce import Reduce
Expand Down Expand Up @@ -41,21 +41,23 @@ def __init__(
max_batch_size: int,
max_batch_time: float,
next_step: ProcessingStrategy[ValuesBatch[TStrategyPayload]],
increment_by: Optional[Callable[[BaseValue[TStrategyPayload]], int]] = None,
) -> None:
def accumulator(
result: ValuesBatch[TStrategyPayload], value: BaseValue[TStrategyPayload]
) -> ValuesBatch[TStrategyPayload]:
result.append(value)
return result

self.__reduce_step: Reduce[
TStrategyPayload, ValuesBatch[TStrategyPayload]
] = Reduce(
max_batch_size,
max_batch_time,
accumulator,
lambda: [],
next_step,
self.__reduce_step: Reduce[TStrategyPayload, ValuesBatch[TStrategyPayload]] = (
Reduce(
max_batch_size,
max_batch_time,
accumulator,
lambda: [],
next_step,
increment_by,
)
)

def submit(
Expand Down
11 changes: 10 additions & 1 deletion arroyo/processing/strategies/reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@ def __init__(
initial_value: Callable[[], TResult],
max_batch_size: int,
max_batch_time: float,
increment_by: Optional[Callable[[BaseValue[TPayload]], int]] = None,
):
self.accumulator = accumulator
self.initial_value = initial_value
self.max_batch_size = max_batch_size
self.max_batch_time = max_batch_time
self.increment_by = increment_by

self._buffer = initial_value()
self._buffer_size = 0
Expand All @@ -48,14 +50,19 @@ def is_ready(self) -> bool:

def append(self, message: BaseValue[TPayload]) -> None:
self._buffer = self.accumulator(self._buffer, message)
self._buffer_size += 1
if self.increment_by:
buffer_increment = self.increment_by(message)
else:
buffer_increment = 1
self._buffer_size += buffer_increment

def new(self) -> "ReduceBuffer[TPayload, TResult]":
return ReduceBuffer(
accumulator=self.accumulator,
initial_value=self.initial_value,
max_batch_size=self.max_batch_size,
max_batch_time=self.max_batch_time,
increment_by=self.increment_by,
)


Expand Down Expand Up @@ -83,13 +90,15 @@ def __init__(
accumulator: Accumulator[TResult, TPayload],
initial_value: Callable[[], TResult],
next_step: ProcessingStrategy[TResult],
increment_by: Optional[Callable[[BaseValue[TPayload]], int]] = None,
) -> None:
self.__buffer_step = Buffer(
buffer=ReduceBuffer(
max_batch_size=max_batch_size,
max_batch_time=max_batch_time,
accumulator=accumulator,
initial_value=initial_value,
increment_by=increment_by,
),
next_step=next_step,
)
Expand Down
65 changes: 61 additions & 4 deletions tests/processing/strategies/test_batching.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import time
from datetime import datetime
from typing import Any, Mapping, Sequence, cast
from typing import Any, Mapping, Sequence, cast, Optional, Callable
from unittest.mock import Mock, call, patch

import pytest

from arroyo.processing.strategies.abstract import MessageRejected
from arroyo.processing.strategies.batching import BatchStep, UnbatchStep, ValuesBatch
from arroyo.processing.strategies.run_task import RunTask
from arroyo.types import BrokerValue, Message, Partition, Topic, Value
from arroyo.types import BrokerValue, Message, Partition, Topic, Value, BaseValue

NOW = datetime(2022, 1, 1, 0, 0, 1)

Expand Down Expand Up @@ -114,6 +114,10 @@ def message(partition: int, offset: int, payload: str) -> Message[str]:
return Message(broker_value(partition=partition, offset=offset, payload=payload))


def increment_by(message: BaseValue[str]) -> int:
return len(message.payload)


test_batch = [
pytest.param(
datetime(2022, 1, 1, 0, 0, 1),
Expand All @@ -122,6 +126,7 @@ def message(partition: int, offset: int, payload: str) -> Message[str]:
message(0, 2, "Message 2"),
],
[],
None,
id="Half full batch",
),
pytest.param(
Expand All @@ -146,6 +151,7 @@ def message(partition: int, offset: int, payload: str) -> Message[str]:
)
)
],
None,
id="One full batch",
),
pytest.param(
Expand Down Expand Up @@ -186,23 +192,74 @@ def message(partition: int, offset: int, payload: str) -> Message[str]:
)
),
],
None,
id="Two full batches",
),
pytest.param(
datetime(2022, 1, 1, 0, 0, 1),
[
message(0, 1, "1"),
message(0, 2, "11"),
message(0, 3, "222"),
message(1, 1, "33"),
message(1, 2, "333"),
],
[
call(
Message(
Value(
payload=[broker_value(0, 1, "1"), broker_value(0, 2, "11")],
committable={Partition(Topic("test"), 0): 3},
timestamp=NOW,
),
)
),
call(
Message(
Value(
payload=[
broker_value(0, 3, "222"),
],
committable={Partition(Topic("test"), 0): 4},
timestamp=NOW,
),
)
),
call(
Message(
Value(
payload=[
broker_value(1, 1, "33"),
broker_value(1, 2, "333"),
],
committable={Partition(Topic("test"), 1): 3},
timestamp=NOW,
),
)
),
],
increment_by,
id="Three batches using increment by",
),
]


@pytest.mark.parametrize("start_time, messages_in, expected_batches", test_batch)
@pytest.mark.parametrize(
"start_time, messages_in, expected_batches, increment_by", test_batch
)
@patch("time.time")
def test_batch_step(
mock_time: Any,
start_time: datetime,
messages_in: Sequence[Message[str]],
expected_batches: Sequence[ValuesBatch[str]],
increment_by: Optional[Callable[[BaseValue[str]], int]],
) -> None:
start = time.mktime(start_time.timetuple())
mock_time.return_value = start
next_step = Mock()
batch_step = BatchStep[str](3, 10.0, next_step)
print("incrementby", increment_by)
batch_step = BatchStep[str](3, 10.0, next_step, increment_by)
for message in messages_in:
batch_step.submit(message)
batch_step.poll()
Expand Down

0 comments on commit 3bb8f31

Please sign in to comment.