diff --git a/.benchmarks/Darwin-CPython-3.11-64bit/0002_4beaca18138343fa989b1283ae577de131abb733_20241126_151413_uncommited-changes.json b/.benchmarks/Darwin-CPython-3.11-64bit/0002_4beaca18138343fa989b1283ae577de131abb733_20241126_151413_uncommited-changes.json new file mode 100644 index 00000000..33dd94e1 --- /dev/null +++ b/.benchmarks/Darwin-CPython-3.11-64bit/0002_4beaca18138343fa989b1283ae577de131abb733_20241126_151413_uncommited-changes.json @@ -0,0 +1,148 @@ +{ + "machine_info": { + "node": "Woile-MacBook-Pro.local", + "processor": "arm", + "machine": "arm64", + "python_compiler": "Clang 16.0.6 ", + "python_implementation": "CPython", + "python_implementation_version": "3.11.10", + "python_version": "3.11.10", + "python_build": [ + "main", + "Sep 7 2024 01:03:31" + ], + "release": "24.1.0", + "system": "Darwin", + "cpu": { + "python_version": "3.11.10.final.0 (64 bit)", + "cpuinfo_version": [ + 9, + 0, + 0 + ], + "cpuinfo_version_string": "9.0.0", + "arch": "ARM_8", + "bits": 64, + "count": 12, + "arch_string_raw": "arm64", + "brand_raw": "Apple M3 Pro" + } + }, + "commit_info": { + "id": "4beaca18138343fa989b1283ae577de131abb733", + "time": "2024-11-26T16:04:54+01:00", + "author_time": "2022-09-17T09:45:33+02:00", + "dirty": true, + "project": "kstreams", + "branch": "feat/dependency-injection" + }, + "benchmarks": [ + { + "group": null, + "name": "test_startup_and_processing_single_consumer_record", + "fullname": "tests/test_benchmarks.py::test_startup_and_processing_single_consumer_record", + "params": null, + "param": null, + "extra_info": {}, + "options": { + "disable_gc": false, + "timer": "perf_counter", + "min_rounds": 5, + "max_time": 1.0, + "min_time": 5e-06, + "warmup": false + }, + "stats": { + "min": 0.00010604201816022396, + "max": 0.010822750016814098, + "mean": 0.00016831592185808216, + "stddev": 0.0003149759900475096, + "rounds": 1544, + "median": 0.00013091700384393334, + "iqr": 1.879199407994747e-05, + "q1": 0.00012045800394844264, + "q3": 0.0001392499980283901, + "iqr_outliers": 114, + "stddev_outliers": 63, + "outliers": "63;114", + "ld15iqr": 0.00010604201816022396, + "hd15iqr": 0.000167582998983562, + "ops": 5941.2085853836425, + "total": 0.25987978334887885, + "iterations": 1 + } + }, + { + "group": null, + "name": "test_startup_and_inject_all", + "fullname": "tests/test_benchmarks.py::test_startup_and_inject_all", + "params": null, + "param": null, + "extra_info": {}, + "options": { + "disable_gc": false, + "timer": "perf_counter", + "min_rounds": 5, + "max_time": 1.0, + "min_time": 5e-06, + "warmup": false + }, + "stats": { + "min": 0.0001525410043541342, + "max": 0.03395479201572016, + "mean": 0.00024245004325373947, + "stddev": 0.0008005369642820076, + "rounds": 4560, + "median": 0.00021754149929620326, + "iqr": 5.6313510867767036e-05, + "q1": 0.00018539548909757286, + "q3": 0.0002417089999653399, + "iqr_outliers": 65, + "stddev_outliers": 5, + "outliers": "5;65", + "ld15iqr": 0.0001525410043541342, + "hd15iqr": 0.00034166700788773596, + "ops": 4124.561029479529, + "total": 1.105572197237052, + "iterations": 1 + } + }, + { + "group": null, + "name": "test_consume_many", + "fullname": "tests/test_benchmarks.py::test_consume_many", + "params": null, + "param": null, + "extra_info": {}, + "options": { + "disable_gc": false, + "timer": "perf_counter", + "min_rounds": 5, + "max_time": 1.0, + "min_time": 5e-06, + "warmup": false + }, + "stats": { + "min": 0.0034218749788124114, + "max": 0.004076749988598749, + "mean": 0.0034961712951928296, + "stddev": 7.397271227296705e-05, + "rounds": 268, + "median": 0.00347295799292624, + "iqr": 7.147900760173798e-05, + "q1": 0.003452624994679354, + "q3": 0.003524104002281092, + "iqr_outliers": 14, + "stddev_outliers": 24, + "outliers": "24;14", + "ld15iqr": 0.0034218749788124114, + "hd15iqr": 0.0036370840098243207, + "ops": 286.02717532032295, + "total": 0.9369739071116783, + "iterations": 1 + } + } + ], + "datetime": "2024-11-26T15:14:17.201596+00:00", + "version": "5.1.0" +} \ No newline at end of file diff --git a/.github/workflows/bench-release.yml b/.github/workflows/bench-release.yml index 1d15de8d..4bea8502 100644 --- a/.github/workflows/bench-release.yml +++ b/.github/workflows/bench-release.yml @@ -1,4 +1,4 @@ -name: Bump version +name: Benchmark latest release on: push: @@ -46,5 +46,5 @@ jobs: git config --global user.email "action@github.com" git config --global user.name "GitHub Action" git add .benchmarks/ - git commit -m "bench: bench: add benchmark current release" + git commit -m "bench: current release" git push origin master diff --git a/.github/workflows/pr-tests.yaml b/.github/workflows/pr-tests.yaml index e4920cdb..c02795eb 100644 --- a/.github/workflows/pr-tests.yaml +++ b/.github/workflows/pr-tests.yaml @@ -17,7 +17,7 @@ on: required: true jobs: - build_test_bench: + test: runs-on: ubuntu-latest strategy: matrix: @@ -56,11 +56,6 @@ jobs: git config --global user.email "action@github.com" git config --global user.name "GitHub Action" ./scripts/test - - - name: Benchmark regression test - run: | - ./scripts/bench-compare - - name: Upload coverage to Codecov uses: codecov/codecov-action@v5.0.2 with: @@ -68,3 +63,35 @@ jobs: name: kstreams fail_ci_if_error: true token: ${{secrets.CODECOV_TOKEN}} + bench: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + with: + fetch-depth: 0 + - name: Setup python + uses: actions/setup-python@v5 + with: + python-version: '3.13' + architecture: x64 + - name: Set Cache + uses: actions/cache@v4 + id: cache # name for referring later + with: + path: .venv/ + # The cache key depends on poetry.lock + key: ${{ runner.os }}-cache-${{ hashFiles('poetry.lock') }} + restore-keys: | + ${{ runner.os }}-cache- + ${{ runner.os }}- + - name: Install Dependencies + # if: steps.cache.outputs.cache-hit != 'true' + run: | + python -m pip install -U pip poetry + poetry --version + poetry config --local virtualenvs.in-project true + poetry install + - name: Benchmark regression test + run: | + ./scripts/bench-compare + diff --git a/README.md b/README.md index cfbb6dc1..4f80f042 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,7 @@ if __name__ == "__main__": - [ ] Store (kafka streams pattern) - [ ] Stream Join - [ ] Windowing +- [ ] PEP 593 ## Development diff --git a/kstreams/__init__.py b/kstreams/__init__.py index 54e82471..c8b045eb 100644 --- a/kstreams/__init__.py +++ b/kstreams/__init__.py @@ -1,5 +1,7 @@ from aiokafka.structs import RecordMetadata, TopicPartition +from ._di.parameters import FromHeader, Header +from .backends.kafka import Kafka from .clients import Consumer, Producer from .create import StreamEngine, create_engine from .prometheus.monitor import PrometheusMonitor, PrometheusMonitorType @@ -31,4 +33,8 @@ "TestStreamClient", "TopicPartition", "TopicPartitionOffset", + "Kafka", + "StreamDependencyManager", + "FromHeader", + "Header", ] diff --git a/kstreams/_di/binders/api.py b/kstreams/_di/binders/api.py new file mode 100644 index 00000000..02b959e8 --- /dev/null +++ b/kstreams/_di/binders/api.py @@ -0,0 +1,68 @@ +import inspect +from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar, Union + +from di.api.dependencies import CacheKey +from di.dependent import Dependent, Marker + +from kstreams.types import ConsumerRecord + + +class ExtractorTrait(Protocol): + """Implement to extract data from incoming `ConsumerRecord`. + + Consumers will always work with a consumer Record. + Implementing this would let you extract information from the `ConsumerRecord`. + """ + + def __hash__(self) -> int: + """Required by di in order to cache the deps""" + ... + + def __eq__(self, __o: object) -> bool: + """Required by di in order to cache the deps""" + ... + + async def extract( + self, consumer_record: ConsumerRecord + ) -> Union[Awaitable[Any], AsyncIterator[Any]]: + """This is where the magic should happen. + + For example, you could "extract" here a json from the `ConsumerRecord.value` + """ + ... + + +T = TypeVar("T", covariant=True) + + +class MarkerTrait(Protocol[T]): + def register_parameter(self, param: inspect.Parameter) -> T: ... + + +class Binder(Dependent[Any]): + def __init__( + self, + *, + extractor: ExtractorTrait, + ) -> None: + super().__init__(call=extractor.extract, scope="consumer_record") + self.extractor = extractor + + @property + def cache_key(self) -> CacheKey: + return self.extractor + + +class BinderMarker(Marker): + """Bind together the different dependencies. + + NETX: Add asyncapi marker here, like `MarkerTrait[AsyncApiTrait]`. + Recommendation to wait until 3.0: + - [#618](https://github.com/asyncapi/spec/issues/618) + """ + + def __init__(self, *, extractor_marker: MarkerTrait[ExtractorTrait]) -> None: + self.extractor_marker = extractor_marker + + def register_parameter(self, param: inspect.Parameter) -> Binder: + return Binder(extractor=self.extractor_marker.register_parameter(param)) diff --git a/kstreams/_di/binders/header.py b/kstreams/_di/binders/header.py new file mode 100644 index 00000000..c0f46de6 --- /dev/null +++ b/kstreams/_di/binders/header.py @@ -0,0 +1,44 @@ +import inspect +from typing import Any, NamedTuple, Optional + +from kstreams.exceptions import HeaderNotFound +from kstreams.types import ConsumerRecord + + +class HeaderExtractor(NamedTuple): + name: str + + def __hash__(self) -> int: + return hash((self.__class__, self.name)) + + def __eq__(self, __o: object) -> bool: + return isinstance(__o, HeaderExtractor) and __o.name == self.name + + async def extract(self, consumer_record: ConsumerRecord) -> Any: + headers = dict(consumer_record.headers) + try: + header = headers[self.name] + except KeyError as e: + message = ( + f"No header `{self.name}` found.\n" + "Check if your broker is sending the header.\n" + "Try adding a default value to your parameter like `None`.\n" + "Or set `convert_underscores = False`." + ) + raise HeaderNotFound(message) from e + else: + return header + + +class HeaderMarker(NamedTuple): + alias: Optional[str] + convert_underscores: bool + + def register_parameter(self, param: inspect.Parameter) -> HeaderExtractor: + if self.alias is not None: + name = self.alias + elif self.convert_underscores: + name = param.name.replace("_", "-") + else: + name = param.name + return HeaderExtractor(name=name) diff --git a/kstreams/_di/dependencies/core.py b/kstreams/_di/dependencies/core.py new file mode 100644 index 00000000..310638d4 --- /dev/null +++ b/kstreams/_di/dependencies/core.py @@ -0,0 +1,114 @@ +from typing import Any, Callable, Optional + +from di import Container, bind_by_type +from di.dependent import Dependent +from di.executors import AsyncExecutor + +from kstreams._di.dependencies.hooks import bind_by_generic +from kstreams.streams import Stream +from kstreams.types import ConsumerRecord, Send + +LayerFn = Callable[..., Any] + + +class StreamDependencyManager: + """Core of dependency injection on kstreams. + + This is an internal class of kstreams that manages the dependency injection, + as a user you should not use this class directly. + + Attributes: + container: dependency store. + stream: the stream wrapping the user function. Optional to improve testability. + When instanciating this class you must provide a stream, otherwise users + won't be able to use the `stream` parameter in their functions. + send: send object. Optional to improve testability, same as stream. + + Usage: + + stream and send are ommited for simplicity + + ```python + def user_func(cr: ConsumerRecord): + ... + + sdm = StreamDependencyManager() + sdm.solve(user_func) + sdm.execute(consumer_record) + ``` + """ + + container: Container + + def __init__( + self, + container: Optional[Container] = None, + stream: Optional[Stream] = None, + send: Optional[Send] = None, + ): + self.container = container or Container() + self.async_executor = AsyncExecutor() + self.stream = stream + self.send = send + + def solve_user_fn(self, fn: LayerFn) -> None: + """Build the dependency graph for the given function. + + Objects must be injected before this function is called. + + Attributes: + fn: user defined function, using allowed kstreams params + """ + self._register_consumer_record() + + if isinstance(self.stream, Stream): + self._register_stream(self.stream) + + if self.send is not None: + self._register_send(self.send) + + self.solved_user_fn = self.container.solve( + Dependent(fn, scope="consumer_record"), + scopes=["consumer_record", "stream", "application"], + ) + + async def execute(self, consumer_record: ConsumerRecord) -> Any: + """Execute the dependencies graph with external values. + + Attributes: + consumer_record: A kafka record containing `values`, `headers`, etc. + """ + async with self.container.enter_scope("consumer_record") as state: + return await self.container.execute_async( + self.solved_user_fn, + values={ConsumerRecord: consumer_record}, + executor=self.async_executor, + state=state, + ) + + def _register_stream(self, stream: Stream): + """Register the stream with the container.""" + hook = bind_by_type( + Dependent(lambda: stream, scope="consumer_record", wire=False), Stream + ) + self.container.bind(hook) + + def _register_consumer_record(self): + """Register consumer record with the container. + + We bind_by_generic because we want to bind the `ConsumerRecord` type which + is generic. + + The value must be injected at runtime. + """ + hook = bind_by_generic( + Dependent(ConsumerRecord, scope="consumer_record", wire=False), + ConsumerRecord, + ) + self.container.bind(hook) + + def _register_send(self, send: Send): + hook = bind_by_type( + Dependent(lambda: send, scope="consumer_record", wire=False), Send + ) + self.container.bind(hook) diff --git a/kstreams/_di/dependencies/hooks.py b/kstreams/_di/dependencies/hooks.py new file mode 100644 index 00000000..11e0799e --- /dev/null +++ b/kstreams/_di/dependencies/hooks.py @@ -0,0 +1,31 @@ +import inspect +from typing import Any, get_origin + +from di._container import BindHook +from di._utils.inspect import get_type +from di.api.dependencies import DependentBase + + +def bind_by_generic( + provider: DependentBase[Any], + dependency: type, +) -> BindHook: + """Hook to substitute the matched dependency based on its generic.""" + + def hook( + param: inspect.Parameter | None, dependent: DependentBase[Any] + ) -> DependentBase[Any] | None: + if dependent.call == dependency: + return provider + if param is None: + return None + + type_annotation_option = get_type(param) + if type_annotation_option is None: + return None + type_annotation = type_annotation_option.value + if get_origin(type_annotation) is dependency: + return provider + return None + + return hook diff --git a/kstreams/_di/parameters.py b/kstreams/_di/parameters.py new file mode 100644 index 00000000..b16a1b1e --- /dev/null +++ b/kstreams/_di/parameters.py @@ -0,0 +1,37 @@ +from typing import Optional, TypeVar + +from kstreams._di.binders.api import BinderMarker +from kstreams._di.binders.header import HeaderMarker +from kstreams.typing import Annotated + + +def Header( + *, alias: Optional[str] = None, convert_underscores: bool = True +) -> BinderMarker: + """Construct another type from the headers of a kafka record. + + Args: + alias: Use a different header name + convert_underscores: If True, convert underscores to dashes. + + Usage: + + ```python + from kstream import Header, Annotated + + def user_fn(event_type: Annotated[str, Header(alias="EventType")]): + ... + ``` + """ + header_marker = HeaderMarker(alias=alias, convert_underscores=convert_underscores) + binder = BinderMarker(extractor_marker=header_marker) + return binder + + +T = TypeVar("T") + +FromHeader = Annotated[T, Header()] +FromHeader.__doc__ = """General purpose convenient header type. + +Use `Annotated` to provide custom params. +""" diff --git a/kstreams/engine.py b/kstreams/engine.py index b69095f6..5b8b9057 100644 --- a/kstreams/engine.py +++ b/kstreams/engine.py @@ -5,13 +5,13 @@ from aiokafka.structs import RecordMetadata +from kstreams.middleware.di_middleware import DependencyInjectionHandler from kstreams.structs import TopicPartitionOffset from .backends.kafka import Kafka from .clients import Consumer, Producer from .exceptions import DuplicateStreamException, EngineNotStartedException from .middleware import Middleware -from .middleware.udf_middleware import UdfHandler from .prometheus.monitor import PrometheusMonitor from .rebalance_listener import MetricsRebalanceListener, RebalanceListener from .serializers import Deserializer, Serializer @@ -389,7 +389,13 @@ def add_stream( stream.rebalance_listener.stream = stream stream.rebalance_listener.engine = self - stream.udf_handler = UdfHandler( + # stream.udf_handler = UdfHandler( + # next_call=stream.func, + # send=self.send, + # stream=stream, + # ) + + stream.udf_handler = DependencyInjectionHandler( next_call=stream.func, send=self.send, stream=stream, @@ -397,7 +403,7 @@ def add_stream( # NOTE: When `no typing` support is deprecated this check can # be removed - if stream.udf_handler.type != UDFType.NO_TYPING: + if stream.udf_handler.get_type() != UDFType.NO_TYPING: stream.func = self._build_stream_middleware_stack(stream=stream) def _build_stream_middleware_stack(self, *, stream: Stream) -> NextMiddlewareCall: diff --git a/kstreams/exceptions.py b/kstreams/exceptions.py index 249f2db7..dd5a65d9 100644 --- a/kstreams/exceptions.py +++ b/kstreams/exceptions.py @@ -25,3 +25,6 @@ def __str__(self) -> str: class BackendNotSet(StreamException): ... + + +class HeaderNotFound(StreamException): ... diff --git a/kstreams/middleware/di_middleware.py b/kstreams/middleware/di_middleware.py new file mode 100644 index 00000000..2e53f5ad --- /dev/null +++ b/kstreams/middleware/di_middleware.py @@ -0,0 +1,30 @@ +import inspect +import typing + +from kstreams import types +from kstreams._di.dependencies.core import StreamDependencyManager +from kstreams.streams_utils import UDFType, setup_type + +from .middleware import BaseMiddleware + + +class DependencyInjectionHandler(BaseMiddleware): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + self.dependecy_manager = StreamDependencyManager( + stream=self.stream, send=self.send + ) + + # To be deprecated once streams with type hints are deprecated + signature = inspect.signature(self.next_call) + self.params = list(signature.parameters.values()) + self.type: UDFType = setup_type(self.params) + if self.type == UDFType.WITH_TYPING: + self.dependecy_manager.solve_user_fn(fn=self.next_call) + + def get_type(self) -> UDFType: + return self.type + + async def __call__(self, cr: types.ConsumerRecord) -> typing.Any: + return await self.dependecy_manager.execute(cr) diff --git a/kstreams/middleware/middleware.py b/kstreams/middleware/middleware.py index f5b164bd..0173c15d 100644 --- a/kstreams/middleware/middleware.py +++ b/kstreams/middleware/middleware.py @@ -4,7 +4,7 @@ import typing from kstreams import types -from kstreams.streams_utils import StreamErrorPolicy +from kstreams.streams_utils import StreamErrorPolicy, UDFType if typing.TYPE_CHECKING: from kstreams import Stream, StreamEngine # pragma: no cover @@ -14,6 +14,10 @@ class MiddlewareProtocol(typing.Protocol): + next_call: types.NextMiddlewareCall + send: types.Send + stream: "Stream" + def __init__( self, *, @@ -44,7 +48,11 @@ def __repr__(self) -> str: return f"{middleware_name}({extra_options})" -class BaseMiddleware: +class BaseMiddleware(MiddlewareProtocol): + next_call: types.NextMiddlewareCall + send: types.Send + stream: "Stream" + def __init__( self, *, @@ -145,3 +153,20 @@ async def cleanup_policy(self, exc: Exception) -> None: await self.engine.stop() await self.stream.is_processing.acquire() signal.raise_signal(signal.SIGTERM) + + # acquire the asyncio.Lock `is_processing` again to resume the processing + # and avoid `RuntimeError: Lock is not acquired.` + await self.stream.is_processing.acquire() + + +class BaseDependcyMiddleware(MiddlewareProtocol, typing.Protocol): + """Base class for Dependency Injection Middleware. + + Both old and new DI middlewares make use of the type. + + The `type` is used to identify the way to call the user defined function. + + On top of that, this middleware helps avoid circular dependencies. + """ + + def get_type(self) -> UDFType: ... diff --git a/kstreams/middleware/udf_middleware.py b/kstreams/middleware/udf_middleware.py index 2bc1f295..f12b1f9d 100644 --- a/kstreams/middleware/udf_middleware.py +++ b/kstreams/middleware/udf_middleware.py @@ -21,6 +21,9 @@ def __init__(self, *args, **kwargs) -> None: self.params = list(signature.parameters.values()) self.type: UDFType = setup_type(self.params) + def get_type(self) -> UDFType: + return self.type + def bind_udf_params(self, cr: types.ConsumerRecord) -> typing.List: # NOTE: When `no typing` support is deprecated then this can # be more eficient as the CR will be always there. diff --git a/kstreams/streams.py b/kstreams/streams.py index bf17682e..e6112a21 100644 --- a/kstreams/streams.py +++ b/kstreams/streams.py @@ -10,12 +10,12 @@ from kstreams import TopicPartition from kstreams.exceptions import BackendNotSet -from kstreams.middleware.middleware import ExceptionMiddleware +from kstreams.middleware.middleware import BaseDependcyMiddleware, ExceptionMiddleware from kstreams.structs import TopicPartitionOffset from .backends.kafka import Kafka from .clients import Consumer -from .middleware import Middleware, udf_middleware +from .middleware import Middleware from .rebalance_listener import RebalanceListener from .serializers import Deserializer from .streams_utils import StreamErrorPolicy, UDFType @@ -172,11 +172,14 @@ def __init__( self.seeked_initial_offsets = False self.rebalance_listener = rebalance_listener self.middlewares = middlewares or [] - self.udf_handler: typing.Optional[udf_middleware.UdfHandler] = None + self.udf_handler: typing.Optional[BaseDependcyMiddleware] = None self.topics = [topics] if isinstance(topics, str) else topics self.subscribe_by_pattern = subscribe_by_pattern self.error_policy = error_policy + def __name__(self) -> str: + return self.name + def _create_consumer(self) -> Consumer: if self.backend is None: raise BackendNotSet("A backend has not been set for this stream") @@ -342,7 +345,7 @@ async def start(self) -> None: self.running = True if self.udf_handler is not None: - if self.udf_handler.type == UDFType.NO_TYPING: + if self.udf_handler.get_type() == UDFType.NO_TYPING: # deprecated use case msg = ( "Streams with `async for in` loop approach are deprecated.\n" @@ -356,6 +359,10 @@ async def start(self) -> None: await func else: # Typing cases + + # If it's an async generator, then DON'T await the function + # because we want to start ONLY and let the user retrieve the + # values while iterating the stream if not inspect.isasyncgenfunction(self.udf_handler.next_call): # Is not an async_generator, then create `await` the func await self.func_wrapper_with_typing() @@ -436,7 +443,7 @@ async def __anext__(self) -> ConsumerRecord: if ( self.udf_handler is not None - and self.udf_handler.type == UDFType.NO_TYPING + and self.udf_handler.get_type() == UDFType.NO_TYPING ): return cr return await self.func(cr) diff --git a/kstreams/types.py b/kstreams/types.py index 3562f3b6..90107722 100644 --- a/kstreams/types.py +++ b/kstreams/types.py @@ -8,8 +8,7 @@ Headers = typing.Dict[str, str] EncodedHeaders = typing.Sequence[typing.Tuple[str, bytes]] -StreamFunc = typing.Callable - +StreamFunc = typing.Callable[..., typing.Any] EngineHooks = typing.Sequence[typing.Callable[[], typing.Any]] diff --git a/kstreams/typing.py b/kstreams/typing.py new file mode 100644 index 00000000..1707d56b --- /dev/null +++ b/kstreams/typing.py @@ -0,0 +1,8 @@ +"""Remove this file when python3.8 support is dropped.""" + +import sys + +if sys.version_info < (3, 9): + from typing_extensions import Annotated as Annotated # noqa: F401 +else: + from typing import Annotated as Annotated # noqa: F401 diff --git a/poetry.lock b/poetry.lock index d5718fe3..58943738 100644 --- a/poetry.lock +++ b/poetry.lock @@ -416,6 +416,23 @@ files = [ {file = "decli-0.6.2.tar.gz", hash = "sha256:36f71eb55fd0093895efb4f416ec32b7f6e00147dda448e3365cf73ceab42d6f"}, ] +[[package]] +name = "di" +version = "0.79.2" +description = "Dependency injection toolkit" +optional = false +python-versions = ">=3.8,<4" +files = [ + {file = "di-0.79.2-py3-none-any.whl", hash = "sha256:4b2ac7c46d4d9e941ca47d37c2029ba739c1f8a0e19e5288731224870f00d6e6"}, + {file = "di-0.79.2.tar.gz", hash = "sha256:0c65b9ccb984252dadbdcdb39743eeddef0c1f167f791c59fcd70e97bb0d3af8"}, +] + +[package.dependencies] +graphlib2 = ">=0.4.1,<0.5.0" + +[package.extras] +anyio = ["anyio (>=3.5.0)"] + [[package]] name = "exceptiongroup" version = "1.2.2" @@ -430,6 +447,21 @@ files = [ [package.extras] test = ["pytest (>=6)"] +[[package]] +name = "faker" +version = "33.0.0" +description = "Faker is a Python package that generates fake data for you." +optional = false +python-versions = ">=3.8" +files = [ + {file = "Faker-33.0.0-py3-none-any.whl", hash = "sha256:68e5580cb6b4226710886e595eabc13127149d6e71e9d1db65506a7fbe2c7fce"}, + {file = "faker-33.0.0.tar.gz", hash = "sha256:9b01019c1ddaf2253ca2308c0472116e993f4ad8fc9905f82fa965e0c6f932e9"}, +] + +[package.dependencies] +python-dateutil = ">=2.4" +typing-extensions = "*" + [[package]] name = "fastapi" version = "0.115.5" @@ -478,6 +510,40 @@ python-dateutil = ">=2.8.1" [package.extras] dev = ["flake8", "markdown", "twine", "wheel"] +[[package]] +name = "graphlib2" +version = "0.4.7" +description = "Rust port of the Python stdlib graphlib modules" +optional = false +python-versions = ">=3.7" +files = [ + {file = "graphlib2-0.4.7-cp37-abi3-macosx_10_7_x86_64.whl", hash = "sha256:483710733215783cdc76452ccde1247af8f697685c9c1dfd9bb9ff4f52d990ee"}, + {file = "graphlib2-0.4.7-cp37-abi3-macosx_10_9_x86_64.macosx_11_0_arm64.macosx_10_9_universal2.whl", hash = "sha256:3619c7d3c5aca95e6cbbfc283aa6bf42ffa5b59d7f39c8d0ad615bce65dc406f"}, + {file = "graphlib2-0.4.7-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b19f1b91d0f22ca3d1cfb2965478db98cf5916a5c6cea5fdc7caf4bf1bfbc33"}, + {file = "graphlib2-0.4.7-cp37-abi3-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:624020f6808ee21ffbb2e455f8dd4196bbb37032a35aa3327f0f5b65fb6a35d1"}, + {file = "graphlib2-0.4.7-cp37-abi3-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:6efc6a197a619a97f1b105aea14b202101241c1db9014bd100ad19cf29288cbf"}, + {file = "graphlib2-0.4.7-cp37-abi3-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d7cc38b68775cb2cdfc487bbaca2f7991da0d76d42a68f412c2ca61461e6e026"}, + {file = "graphlib2-0.4.7-cp37-abi3-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b06bed98d42f4e10adfe2a8332efdca06b5bac6e7c86dd1d22a4dea4de9b275a"}, + {file = "graphlib2-0.4.7-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c9ec3a5645bdf020d8bd9196b2665e26090d60e523fd498df29628f2c5fbecc"}, + {file = "graphlib2-0.4.7-cp37-abi3-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:824df87f767471febfd785a05a2cc77c0c973e0112a548df827763ca0aa8c126"}, + {file = "graphlib2-0.4.7-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2de5e32ca5c0b06d442d2be4b378cc0bc335c5fcbc14a7d531a621eb8294d019"}, + {file = "graphlib2-0.4.7-cp37-abi3-musllinux_1_2_armv7l.whl", hash = "sha256:13a23fcf07c7bef8a5ad0e04ab826d3a2a2bcb493197005300c68b4ea7b8f581"}, + {file = "graphlib2-0.4.7-cp37-abi3-musllinux_1_2_i686.whl", hash = "sha256:15a8a6daa28c1fb5c518d387879f3bbe313264fbbc2fab5635b718bc71a24913"}, + {file = "graphlib2-0.4.7-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:0cb6c4449834077972c3cea4602f86513b4b75fcf2d40b12e4fe4bf1aa5c8da2"}, + {file = "graphlib2-0.4.7-cp37-abi3-win32.whl", hash = "sha256:31b40cea537845d80b69403ae306d7c6a68716b76f5171f68daed1804aadefec"}, + {file = "graphlib2-0.4.7-cp37-abi3-win_amd64.whl", hash = "sha256:d40935a9da81a046ebcaa0216ad593ef504ae8a5425a59bdbd254c0462adedc8"}, + {file = "graphlib2-0.4.7-pp37-pypy37_pp73-macosx_10_7_x86_64.whl", hash = "sha256:9cef08a50632e75a9e11355e68fa1f8c9371d0734642855f8b5c4ead1b058e6f"}, + {file = "graphlib2-0.4.7-pp37-pypy37_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:aeecb604d70317c20ca6bc3556f7f5c40146ad1f0ded837e978b2fe6edf3e567"}, + {file = "graphlib2-0.4.7-pp37-pypy37_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb4ae9df7ed895c6557619049c9f73e1c2e6d1fbed568010fd5d4af94e2f0692"}, + {file = "graphlib2-0.4.7-pp38-pypy38_pp73-macosx_10_7_x86_64.whl", hash = "sha256:3ee3a99fc39df948fef340b01254709cc603263f8b176f72ed26f1eea44070a4"}, + {file = "graphlib2-0.4.7-pp38-pypy38_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5873480df8991273bd1585122df232acd0f946c401c254bd9f0d661c72589dcf"}, + {file = "graphlib2-0.4.7-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:297c817229501255cd3a744c62c8f91e5139ee79bc550488f5bc765ffa33f7c5"}, + {file = "graphlib2-0.4.7-pp39-pypy39_pp73-macosx_10_7_x86_64.whl", hash = "sha256:853ef22df8e9f695706e0b8556cda9342d4d617f7d7bd02803e824bcc0c30b20"}, + {file = "graphlib2-0.4.7-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee62ff1042fde980adf668e30393eca79aee8f1fa1274ab3b98d69091c70c5e8"}, + {file = "graphlib2-0.4.7-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b16e21e70938132d4160c2591fed59f79b5f8b702e4860c8933111b5fedb55c2"}, + {file = "graphlib2-0.4.7.tar.gz", hash = "sha256:a951c18cb4c2c2834eec898b4c75d3f930d6f08beb37496f0e0ce56eb3f571f5"}, +] + [[package]] name = "griffe" version = "1.5.1" @@ -1023,13 +1089,13 @@ files = [ [[package]] name = "pydantic" -version = "2.10.1" +version = "2.10.2" description = "Data validation using Python type hints" optional = false python-versions = ">=3.8" files = [ - {file = "pydantic-2.10.1-py3-none-any.whl", hash = "sha256:a8d20db84de64cf4a7d59e899c2caf0fe9d660c7cfc482528e7020d7dd189a7e"}, - {file = "pydantic-2.10.1.tar.gz", hash = "sha256:a4daca2dc0aa429555e0656d6bf94873a7dc5f54ee42b1f5873d666fb3f35560"}, + {file = "pydantic-2.10.2-py3-none-any.whl", hash = "sha256:cfb96e45951117c3024e6b67b25cdc33a3cb7b2fa62e239f7af1378358a1d99e"}, + {file = "pydantic-2.10.2.tar.gz", hash = "sha256:2bc2d7f17232e0841cbba4641e65ba1eb6fafb3a08de3a091ff3ce14a197c4fa"}, ] [package.dependencies] @@ -1153,6 +1219,38 @@ files = [ [package.dependencies] typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" +[[package]] +name = "pygal" +version = "3.0.5" +description = "A Python svg graph plotting library" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pygal-3.0.5-py3-none-any.whl", hash = "sha256:a3268a5667b470c8fbbb0eca7e987561a7321caeba589d40e4c1bc16dbe71393"}, + {file = "pygal-3.0.5.tar.gz", hash = "sha256:c0a0f34e5bc1c01975c2bfb8342ad521e293ad42e525699dd00c4d7a52c14b71"}, +] + +[package.dependencies] +importlib-metadata = "*" + +[package.extras] +docs = ["pygal-sphinx-directives", "sphinx", "sphinx-rtd-theme"] +lxml = ["lxml"] +moulinrouge = ["flask", "pygal-maps-ch", "pygal-maps-fr", "pygal-maps-world"] +png = ["cairosvg"] +test = ["cairosvg", "coveralls", "lxml", "pyquery", "pytest", "pytest-cov", "ruff (>=0.5.6)"] + +[[package]] +name = "pygaljs" +version = "1.0.2" +description = "Python package providing assets from https://github.com/Kozea/pygal.js" +optional = false +python-versions = "*" +files = [ + {file = "pygaljs-1.0.2-py2.py3-none-any.whl", hash = "sha256:d75e18cb21cc2cda40c45c3ee690771e5e3d4652bf57206f20137cf475c0dbe8"}, + {file = "pygaljs-1.0.2.tar.gz", hash = "sha256:0b71ee32495dcba5fbb4a0476ddbba07658ad65f5675e4ad409baf154dec5111"}, +] + [[package]] name = "pygments" version = "2.18.0" @@ -1238,7 +1336,10 @@ files = [ [package.dependencies] py-cpuinfo = "*" +pygal = {version = "*", optional = true, markers = "extra == \"histogram\""} +pygaljs = {version = "*", optional = true, markers = "extra == \"histogram\""} pytest = ">=8.1" +setuptools = {version = "*", optional = true, markers = "extra == \"histogram\""} [package.extras] aspect = ["aspectlib"] @@ -1532,6 +1633,26 @@ files = [ {file = "ruff-0.7.4.tar.gz", hash = "sha256:cd12e35031f5af6b9b93715d8c4f40360070b2041f81273d0527683d5708fce2"}, ] +[[package]] +name = "setuptools" +version = "75.6.0" +description = "Easily download, build, install, upgrade, and uninstall Python packages" +optional = false +python-versions = ">=3.9" +files = [ + {file = "setuptools-75.6.0-py3-none-any.whl", hash = "sha256:ce74b49e8f7110f9bf04883b730f4765b774ef3ef28f722cce7c273d253aaf7d"}, + {file = "setuptools-75.6.0.tar.gz", hash = "sha256:8199222558df7c86216af4f84c30e9b34a61d8ba19366cc914424cdbd28252f6"}, +] + +[package.extras] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.7.0)"] +core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +cover = ["pytest-cov"] +doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +enabler = ["pytest-enabler (>=2.2)"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.2.0)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (>=1.12,<1.14)", "pytest-mypy"] + [[package]] name = "six" version = "1.16.0" @@ -1762,4 +1883,4 @@ type = ["pytest-mypy"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "6ebc30facbae1ff72fa99b1e741ad7bd1e9c366334caec913ac7a7ed0872a794" +content-hash = "d7442a776d70ab70e1149784d70ef9f1348165421dedcb28c53e53a9673138e7" diff --git a/pyproject.toml b/pyproject.toml index fc8519e8..c32e456c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,11 +29,12 @@ prometheus-client = "<1.0" future = "^1.0.0" PyYAML = ">=5.4,<7.0.0" pydantic = ">=2.0.0,<3.0.0" +di = "^0.79.2" [tool.poetry.group.dev.dependencies] pytest = "^8.3.3" pytest-asyncio = "^0.24.0" -pytest-benchmark = "^5.1.0" +pytest-benchmark = { version = "^5.1.0", extras = ["histogram"] } pytest-cov = "^6" pytest-httpserver = "^1.1.0" mypy = "^1.11.2" @@ -48,6 +49,7 @@ mkdocs-material = "^9.5.39" starlette-prometheus = "^0.10.0" codecov = "^2.1.12" mkdocstrings = { version = "^0.27", extras = ["python"] } +Faker = "^33" [build-system] requires = ["poetry-core>=1.0.0"] diff --git a/scripts/bench-compare b/scripts/bench-compare index 99e44829..0f160625 100755 --- a/scripts/bench-compare +++ b/scripts/bench-compare @@ -16,4 +16,4 @@ if [ -d '.venv' ] ; then fi # Commented out until after merge, so there will be date to compare with. -# ${PREFIX}pytest tests/test_benchmarks.py --benchmark-compare --benchmark-compare-fail=min:5% +${PREFIX}pytest tests/test_benchmarks.py --benchmark-compare --benchmark-compare-fail=min:5% --benchmark-compare-fail=mean:5% diff --git a/scripts/test b/scripts/test index a20ebb12..5cff7269 100755 --- a/scripts/test +++ b/scripts/test @@ -5,7 +5,7 @@ if [ -d '.venv' ] ; then export PREFIX=".venv/bin/" fi -${PREFIX}pytest -x --cov-report term-missing --cov-report=xml:coverage.xml --cov=kstreams ${1-"./tests"} $2 +${PREFIX}pytest --cov-report term-missing --cov-report=xml:coverage.xml --cov=kstreams ${1-"./tests"} $2 ${PREFIX}ruff check kstreams tests ${PREFIX}ruff format --check kstreams tests examples -${PREFIX}mypy kstreams/ +${PREFIX}mypy kstreams/ tests diff --git a/tests/_di/test_dependency_manager.py b/tests/_di/test_dependency_manager.py new file mode 100644 index 00000000..13bf6cc1 --- /dev/null +++ b/tests/_di/test_dependency_manager.py @@ -0,0 +1,108 @@ +from typing import Any, AsyncGenerator, Generator + +import pytest + +from kstreams._di.dependencies.core import StreamDependencyManager +from kstreams.streams import Stream +from kstreams.types import ConsumerRecord + + +class AppWrapper: + """This is a fake class used to check if the ConsumerRecord is injected""" + + def __init__(self) -> None: + self.foo = "bar" + + async def consume(self, cr: ConsumerRecord) -> str: + return self.foo + + +@pytest.fixture +def di_cr(rand_consumer_record) -> Generator[ConsumerRecord, Any, None]: + """Dependency injected ConsumerRecord""" + yield rand_consumer_record() + + +async def test_cr_is_injected(di_cr: ConsumerRecord): + async def user_fn(cr: ConsumerRecord) -> str: + cr.value = "hello" + return cr.value + + stream_manager = StreamDependencyManager() + stream_manager._register_consumer_record() + stream_manager.solve_user_fn(user_fn) + content = await stream_manager.execute(di_cr) + assert content == "hello" + + +async def test_cr_is_injected_in_class(di_cr: ConsumerRecord): + app = AppWrapper() + stream_manager = StreamDependencyManager() + stream_manager._register_consumer_record() + stream_manager.solve_user_fn(app.consume) + content = await stream_manager.execute(di_cr) + assert content == app.foo + + +async def test_cr_generics_is_injected(di_cr: ConsumerRecord): + async def user_fn(cr: ConsumerRecord[Any, Any]) -> str: + cr.value = "hello" + return cr.value + + stream_manager = StreamDependencyManager() + stream_manager._register_consumer_record() + stream_manager.solve_user_fn(user_fn) + content = await stream_manager.execute(di_cr) + assert content == "hello" + +async def test_cr_generics_str_is_injected(di_cr: ConsumerRecord): + async def user_fn(cr: ConsumerRecord[str, str]) -> str: + cr.value = "hello" + return cr.value + + stream_manager = StreamDependencyManager() + stream_manager._register_consumer_record() + stream_manager.solve_user_fn(user_fn) + content = await stream_manager.execute(di_cr) + assert content == "hello" + + +async def test_cr_with_generator(di_cr: ConsumerRecord): + async def user_fn(cr: ConsumerRecord) -> AsyncGenerator[str, None]: + cr.value = "hello" + yield cr.value + + stream_manager = StreamDependencyManager() + stream_manager._register_consumer_record() + stream_manager.solve_user_fn(user_fn) + content = await stream_manager.execute(di_cr) + + assert content == "hello" + + +async def test_stream(di_cr: ConsumerRecord): + async def user_fn(stream: Stream) -> str: + return stream.name + + stream = Stream("my-topic", func=user_fn, name="stream_name") + stream_manager = StreamDependencyManager() + stream_manager._register_stream(stream) + stream_manager._register_consumer_record() + stream_manager.solve_user_fn(user_fn) + content = await stream_manager.execute(di_cr) + assert content == "stream_name" + + +async def test_stream_and_consumer_record(di_cr: ConsumerRecord): + async def user_fn(stream: Stream, record: ConsumerRecord) -> tuple[str, str]: + return (stream.name, record.topic) + + stream = Stream("my-topic", func=user_fn, name="stream_name") + stream_manager = StreamDependencyManager() + stream_manager._register_stream(stream) + stream_manager._register_consumer_record() + stream_manager.solve_user_fn(user_fn) + (stream_name, topic_name) = await stream_manager.execute(di_cr) + + assert stream_name == "stream_name" + assert topic_name == di_cr.topic diff --git a/tests/_di/test_hooks.py b/tests/_di/test_hooks.py new file mode 100644 index 00000000..30e83ad8 --- /dev/null +++ b/tests/_di/test_hooks.py @@ -0,0 +1,75 @@ +import typing + +import pytest +from di import Container +from di.dependent import Dependent +from di.executors import SyncExecutor + +from kstreams._di.dependencies.hooks import bind_by_generic + +KT = typing.TypeVar("KT") +VT = typing.TypeVar("VT") + + +class Record(typing.Generic[KT, VT]): + def __init__(self, key: KT, value: VT): + self.key = key + self.value = value + + +def func_hinted(record: Record[str, int]) -> Record[str, int]: + return record + + +def func_base(record: Record) -> Record: + return record + + +@pytest.mark.parametrize( + "func", + [ + func_hinted, + func_base, + ], +) +def test_bind_generic_ok(func: typing.Callable): + dep = Dependent(func) + container = Container() + container.bind( + bind_by_generic( + Dependent(lambda: Record("foo", 1), wire=False), + Record, + ) + ) + solved = container.solve(dep, scopes=[None]) + with container.enter_scope(None) as state: + instance = solved.execute_sync(executor=SyncExecutor(), state=state) + assert isinstance(instance, Record) + + +def func_str(record: str) -> str: + return record + + +@pytest.mark.parametrize( + "func", + [ + func_str, + ], +) +def test_bind_generic_unrelated(func: typing.Callable): + dep = Dependent(func) + container = Container() + container.bind( + bind_by_generic( + Dependent(lambda: Record("foo", 1), wire=False), + Record, + ) + ) + solved = container.solve(dep, scopes=[None]) + with container.enter_scope(None) as state: + instance = solved.execute_sync(executor=SyncExecutor(), state=state) + print(type(instance)) + print(instance) + assert not isinstance(instance, Record) + assert isinstance(instance, str) diff --git a/tests/_di/test_param_headers.py b/tests/_di/test_param_headers.py new file mode 100644 index 00000000..cef94960 --- /dev/null +++ b/tests/_di/test_param_headers.py @@ -0,0 +1,76 @@ +from typing import Callable + +import pytest + +from kstreams import FromHeader, Header +from kstreams._di.dependencies.core import StreamDependencyManager +from kstreams.exceptions import HeaderNotFound +from kstreams.types import ConsumerRecord +from kstreams.typing import Annotated + +RandConsumerRecordFixture = Callable[..., ConsumerRecord] + + +async def test_from_headers_ok(rand_consumer_record: RandConsumerRecordFixture): + cr = rand_consumer_record(headers=(("event-type", "hello"),)) + + async def user_fn(event_type: FromHeader[str]) -> str: + return event_type + + stream_manager = StreamDependencyManager() + stream_manager.solve_user_fn(user_fn) + header_content = await stream_manager.execute(cr) + assert header_content == "hello" + + +async def test_from_header_not_found(rand_consumer_record: RandConsumerRecordFixture): + cr = rand_consumer_record(headers=(("event-type", "hello"),)) + + def user_fn(a_header: FromHeader[str]) -> str: + return a_header + + stream_manager = StreamDependencyManager() + stream_manager.solve_user_fn(user_fn) + with pytest.raises(HeaderNotFound): + await stream_manager.execute(cr) + + +@pytest.mark.xfail(reason="not implemenetd yet") +async def test_from_headers_numbers(rand_consumer_record: RandConsumerRecordFixture): + cr = rand_consumer_record(headers=(("event-type", "1"),)) + + async def user_fn(event_type: FromHeader[int]) -> int: + return event_type + + stream_manager = StreamDependencyManager() + stream_manager.solve_user_fn(user_fn) + header_content = await stream_manager.execute(cr) + assert header_content == 1 + + +async def test_headers_alias(rand_consumer_record: RandConsumerRecordFixture): + cr = rand_consumer_record(headers=(("EventType", "hello"),)) + + async def user_fn(event_type: Annotated[int, Header(alias="EventType")]) -> int: + return event_type + + stream_manager = StreamDependencyManager() + stream_manager.solve_user_fn(user_fn) + header_content = await stream_manager.execute(cr) + assert header_content == "hello" + + +async def test_headers_convert_underscores( + rand_consumer_record: RandConsumerRecordFixture, +): + cr = rand_consumer_record(headers=(("event_type", "hello"),)) + + async def user_fn( + event_type: Annotated[int, Header(convert_underscores=False)], + ) -> int: + return event_type + + stream_manager = StreamDependencyManager() + stream_manager.solve_user_fn(user_fn) + header_content = await stream_manager.execute(cr) + assert header_content == "hello" diff --git a/tests/conftest.py b/tests/conftest.py index e1bbc43e..3ccd6494 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,22 @@ import asyncio +import logging from collections import namedtuple from dataclasses import field -from typing import Any, Dict, List, NamedTuple, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Sequence, Tuple import pytest import pytest_asyncio +from faker import Faker from pytest_httpserver import HTTPServer from kstreams import clients, create_engine +from kstreams.types import ConsumerRecord from kstreams.utils import create_ssl_context_from_mem +# Silence faker DEBUG logs +logger = logging.getLogger("faker") +logger.setLevel(logging.INFO) + class RecordMetadata(NamedTuple): offset: int = 1 @@ -243,3 +250,55 @@ def _(): benchmark(func, *args, **kwargs) return _wrapper + + +@pytest.fixture +def fake(): + return Faker() + + +@pytest.fixture() +def rand_consumer_record(fake: Faker) -> Callable[..., ConsumerRecord]: + """A random consumer record generator. + + You can inject this fixture in your test, + and then you can override the default values. + + Example: + + ```python + def test_my_consumer(rand_consumer_record): + rand_cr = rand_consumer_record() + custom_attrs_cr = rand_consumer_record(topic="my-topic", value="my-value") + # ... + ``` + """ + + def generate( + topic: Optional[str] = None, + headers: Optional[Sequence[Tuple[str, bytes]]] = None, + partition: Optional[int] = None, + offset: Optional[int] = None, + timestamp: Optional[int] = None, + timestamp_type: Optional[int] = None, + key: Optional[Any] = None, + value: Optional[Any] = None, + checksum: Optional[int] = None, + serialized_key_size: Optional[int] = None, + serialized_value_size: Optional[int] = None, + ) -> ConsumerRecord: + return ConsumerRecord( + topic=topic or fake.slug(), + headers=headers or tuple(), + partition=partition or fake.pyint(max_value=10), + offset=offset or fake.pyint(max_value=99999999), + timestamp=timestamp or int(fake.unix_time()), + timestamp_type=timestamp_type or 1, + key=key or fake.pystr(), + value=value or fake.pystr().encode(), + checksum=checksum, + serialized_key_size=serialized_key_size or fake.pyint(max_value=10), + serialized_value_size=serialized_value_size or fake.pyint(max_value=10), + ) + + return generate diff --git a/tests/test_monitor.py b/tests/test_monitor.py index 0b94361b..bb013d98 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -20,62 +20,68 @@ async def my_coroutine(_): stream_engine.add_stream(stream=stream) await stream.start() + assert stream.consumer is not None await stream_engine.monitor.generate_consumer_metrics(stream.consumer) consumer = stream.consumer for topic_partition in consumer.assignment(): # super ugly notation but for now is the only way to get the metrics met_committed = ( - stream_engine.monitor.MET_COMMITTED.labels( - topic=topic_partition.topic, - partition=topic_partition.partition, - consumer_group=consumer._group_id, - ) - .collect()[0] + list( + stream_engine.monitor.MET_COMMITTED.labels( + topic=topic_partition.topic, + partition=topic_partition.partition, + consumer_group=consumer._group_id, + ).collect() + )[0] .samples[0] .value ) met_position = ( - stream_engine.monitor.MET_POSITION.labels( - topic=topic_partition.topic, - partition=topic_partition.partition, - consumer_group=consumer._group_id, - ) - .collect()[0] + list( + stream_engine.monitor.MET_POSITION.labels( + topic=topic_partition.topic, + partition=topic_partition.partition, + consumer_group=consumer._group_id, + ).collect() + )[0] .samples[0] .value ) met_highwater = ( - stream_engine.monitor.MET_HIGHWATER.labels( - topic=topic_partition.topic, - partition=topic_partition.partition, - consumer_group=consumer._group_id, - ) - .collect()[0] + list( + stream_engine.monitor.MET_HIGHWATER.labels( + topic=topic_partition.topic, + partition=topic_partition.partition, + consumer_group=consumer._group_id, + ).collect() + )[0] .samples[0] .value ) met_lag = ( - stream_engine.monitor.MET_LAG.labels( - topic=topic_partition.topic, - partition=topic_partition.partition, - consumer_group=consumer._group_id, - ) - .collect()[0] + list( + stream_engine.monitor.MET_LAG.labels( + topic=topic_partition.topic, + partition=topic_partition.partition, + consumer_group=consumer._group_id, + ).collect() + )[0] .samples[0] .value ) met_position_lag = ( - stream_engine.monitor.MET_POSITION_LAG.labels( - topic=topic_partition.topic, - partition=topic_partition.partition, - consumer_group=consumer._group_id, - ) - .collect()[0] + list( + stream_engine.monitor.MET_POSITION_LAG.labels( + topic=topic_partition.topic, + partition=topic_partition.partition, + consumer_group=consumer._group_id, + ).collect() + )[0] .samples[0] .value ) @@ -135,56 +141,61 @@ async def my_coroutine(_): for topic_partition in consumer.assignment(): # super ugly notation but for now is the only way to get the metrics met_committed = ( - stream_engine.monitor.MET_COMMITTED.labels( - topic=topic_partition.topic, - partition=topic_partition.partition, - consumer_group=consumer._group_id, - ) - .collect()[0] + list( + stream_engine.monitor.MET_COMMITTED.labels( + topic=topic_partition.topic, + partition=topic_partition.partition, + consumer_group=consumer._group_id, + ).collect() + )[0] .samples[0] .value ) met_position = ( - stream_engine.monitor.MET_POSITION.labels( - topic=topic_partition.topic, - partition=topic_partition.partition, - consumer_group=consumer._group_id, - ) - .collect()[0] + list( + stream_engine.monitor.MET_POSITION.labels( + topic=topic_partition.topic, + partition=topic_partition.partition, + consumer_group=consumer._group_id, + ).collect() + )[0] .samples[0] .value ) met_highwater = ( - stream_engine.monitor.MET_HIGHWATER.labels( - topic=topic_partition.topic, - partition=topic_partition.partition, - consumer_group=consumer._group_id, - ) - .collect()[0] + list( + stream_engine.monitor.MET_HIGHWATER.labels( + topic=topic_partition.topic, + partition=topic_partition.partition, + consumer_group=consumer._group_id, + ).collect() + )[0] .samples[0] .value ) met_lag = ( - stream_engine.monitor.MET_LAG.labels( - topic=topic_partition.topic, - partition=topic_partition.partition, - consumer_group=consumer._group_id, - ) - .collect()[0] + list( + stream_engine.monitor.MET_LAG.labels( + topic=topic_partition.topic, + partition=topic_partition.partition, + consumer_group=consumer._group_id, + ).collect() + )[0] .samples[0] .value ) met_position_lag = ( - stream_engine.monitor.MET_POSITION_LAG.labels( - topic=topic_partition.topic, - partition=topic_partition.partition, - consumer_group=consumer._group_id, - ) - .collect()[0] + list( + stream_engine.monitor.MET_POSITION_LAG.labels( + topic=topic_partition.topic, + partition=topic_partition.partition, + consumer_group=consumer._group_id, + ).collect() + )[0] .samples[0] .value ) @@ -200,9 +211,9 @@ async def my_coroutine(_): met_position_lag == consumer.highwater(topic_partition) - consumer_position ) - assert len(stream_engine.monitor.MET_POSITION_LAG.collect()[0].samples) == 2 + assert len(list(stream_engine.monitor.MET_POSITION_LAG.collect())[0].samples) == 2 await stream_engine.remove_stream(stream) - assert len(stream_engine.monitor.MET_POSITION_LAG.collect()[0].samples) == 0 + assert len(list(stream_engine.monitor.MET_POSITION_LAG.collect())[0].samples) == 0 @pytest.mark.asyncio @@ -223,6 +234,6 @@ async def my_coroutine(_): stream_engine.add_stream(stream=stream) await stream.start() - assert len(stream_engine.monitor.MET_POSITION_LAG.collect()[0].samples) == 0 + assert len(list(stream_engine.monitor.MET_POSITION_LAG.collect())[0].samples) == 0 await stream_engine.remove_stream(stream) assert "Metrics for consumer with group-id: my-group not found" in caplog.text diff --git a/tests/test_streams.py b/tests/test_streams.py index 9f05a5ae..29cdd224 100644 --- a/tests/test_streams.py +++ b/tests/test_streams.py @@ -339,7 +339,6 @@ async def streaming_fn(_): Consumer.stop.assert_awaited() -@pytest.mark.asyncio async def test_stream_decorates_properly(stream_engine: StreamEngine): topic = "local--hello-kpn"