Skip to content

Commit

Permalink
feat: add dependency injection framework
Browse files Browse the repository at this point in the history
  • Loading branch information
woile committed Nov 27, 2024
1 parent 079e2a8 commit 3083d43
Show file tree
Hide file tree
Showing 28 changed files with 981 additions and 93 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/bench-release.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Bump version
name: Benchmark latest release

on:
push:
Expand Down Expand Up @@ -46,5 +46,5 @@ jobs:
git config --global user.email "[email protected]"
git config --global user.name "GitHub Action"
git add .benchmarks/
git commit -m "bench: bench: add benchmark current release"
git commit -m "bench: current release"
git push origin master
40 changes: 34 additions & 6 deletions .github/workflows/pr-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ on:
required: true

jobs:
build_test_bench:
test:
runs-on: ubuntu-latest
strategy:
matrix:
Expand Down Expand Up @@ -56,15 +56,43 @@ jobs:
git config --global user.email "[email protected]"
git config --global user.name "GitHub Action"
./scripts/test
- name: Benchmark regression test
run: |
./scripts/bench-compare
- name: Upload coverage to Codecov
uses: codecov/[email protected]
with:
file: ./coverage.xml
name: kstreams
fail_ci_if_error: true
token: ${{secrets.CODECOV_TOKEN}}
bench:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Setup python
uses: actions/setup-python@v5
with:
python-version: '3.13'
architecture: x64
- name: Set Cache
uses: actions/cache@v4
id: cache # name for referring later
with:
path: .venv/
# The cache key depends on poetry.lock
key: ${{ runner.os }}-cache-${{ hashFiles('poetry.lock') }}
restore-keys: |
${{ runner.os }}-cache-
${{ runner.os }}-
- name: Install Dependencies
# if: steps.cache.outputs.cache-hit != 'true'
run: |
python -m pip install -U pip poetry
poetry --version
poetry config --local virtualenvs.in-project true
poetry install
- name: Benchmark regression test
run: |
./scripts/bench-current
./scripts/bench-compare
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ if __name__ == "__main__":
- [ ] Store (kafka streams pattern)
- [ ] Stream Join
- [ ] Windowing
- [ ] PEP 593

## Development

Expand Down
6 changes: 6 additions & 0 deletions kstreams/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from aiokafka.structs import RecordMetadata, TopicPartition

from ._di.parameters import FromHeader, Header
from .backends.kafka import Kafka
from .clients import Consumer, Producer
from .create import StreamEngine, create_engine
from .prometheus.monitor import PrometheusMonitor, PrometheusMonitorType
Expand Down Expand Up @@ -31,4 +33,8 @@
"TestStreamClient",
"TopicPartition",
"TopicPartitionOffset",
"Kafka",
"StreamDependencyManager",
"FromHeader",
"Header",
]
68 changes: 68 additions & 0 deletions kstreams/_di/binders/api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import inspect
from typing import Any, AsyncIterator, Awaitable, Protocol, TypeVar, Union

from di.api.dependencies import CacheKey
from di.dependent import Dependent, Marker

from kstreams.types import ConsumerRecord


class ExtractorTrait(Protocol):
"""Implement to extract data from incoming `ConsumerRecord`.
Consumers will always work with a consumer Record.
Implementing this would let you extract information from the `ConsumerRecord`.
"""

def __hash__(self) -> int:
"""Required by di in order to cache the deps"""
...

def __eq__(self, __o: object) -> bool:
"""Required by di in order to cache the deps"""
...

async def extract(
self, consumer_record: ConsumerRecord
) -> Union[Awaitable[Any], AsyncIterator[Any]]:
"""This is where the magic should happen.
For example, you could "extract" here a json from the `ConsumerRecord.value`
"""
...


T = TypeVar("T", covariant=True)


class MarkerTrait(Protocol[T]):
def register_parameter(self, param: inspect.Parameter) -> T: ...


class Binder(Dependent[Any]):
def __init__(
self,
*,
extractor: ExtractorTrait,
) -> None:
super().__init__(call=extractor.extract, scope="consumer_record")
self.extractor = extractor

@property
def cache_key(self) -> CacheKey:
return self.extractor


class BinderMarker(Marker):
"""Bind together the different dependencies.
NETX: Add asyncapi marker here, like `MarkerTrait[AsyncApiTrait]`.
Recommendation to wait until 3.0:
- [#618](https://github.com/asyncapi/spec/issues/618)
"""

def __init__(self, *, extractor_marker: MarkerTrait[ExtractorTrait]) -> None:
self.extractor_marker = extractor_marker

def register_parameter(self, param: inspect.Parameter) -> Binder:
return Binder(extractor=self.extractor_marker.register_parameter(param))
44 changes: 44 additions & 0 deletions kstreams/_di/binders/header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import inspect
from typing import Any, NamedTuple, Optional

from kstreams.exceptions import HeaderNotFound
from kstreams.types import ConsumerRecord


class HeaderExtractor(NamedTuple):
name: str

def __hash__(self) -> int:
return hash((self.__class__, self.name))

def __eq__(self, __o: object) -> bool:
return isinstance(__o, HeaderExtractor) and __o.name == self.name

async def extract(self, consumer_record: ConsumerRecord) -> Any:
headers = dict(consumer_record.headers)
try:
header = headers[self.name]
except KeyError as e:
message = (
f"No header `{self.name}` found.\n"
"Check if your broker is sending the header.\n"
"Try adding a default value to your parameter like `None`.\n"
"Or set `convert_underscores = False`."
)
raise HeaderNotFound(message) from e
else:
return header


class HeaderMarker(NamedTuple):
alias: Optional[str]
convert_underscores: bool

def register_parameter(self, param: inspect.Parameter) -> HeaderExtractor:
if self.alias is not None:
name = self.alias
elif self.convert_underscores:
name = param.name.replace("_", "-")
else:
name = param.name
return HeaderExtractor(name=name)
113 changes: 113 additions & 0 deletions kstreams/_di/dependencies/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from typing import Any, Callable, Optional

from di import Container, bind_by_type
from di.dependent import Dependent
from di.executors import AsyncExecutor

from kstreams._di.dependencies.hooks import bind_by_generic
from kstreams.streams import Stream
from kstreams.types import ConsumerRecord, Send

LayerFn = Callable[..., Any]


class StreamDependencyManager:
"""Core of dependency injection on kstreams.
This is an internal class of kstreams that manages the dependency injection,
as a user you should not use this class directly.
Attributes:
container: dependency store.
stream: the stream wrapping the user function. Optional to improve testability.
When instanciating this class you must provide a stream, otherwise users
won't be able to use the `stream` parameter in their functions.
send: send object. Optional to improve testability, same as stream.
Usage:
stream and send are ommited for simplicity
```python
def user_func(cr: ConsumerRecord):
...
sdm = StreamDependencyManager()
sdm.solve(user_func)
sdm.execute(consumer_record)
```
"""

container: Container

def __init__(
self,
container: Optional[Container] = None,
stream: Optional[Stream] = None,
send: Optional[Send] = None,
):
self.container = container or Container()
self.async_executor = AsyncExecutor()
self.stream = stream
self.send = send

def solve_user_fn(self, fn: LayerFn) -> None:
"""Build the dependency graph for the given function.
Objects must be injected before this function is called.
Attributes:
fn: user defined function, using allowed kstreams params
"""
self._register_consumer_record()

if isinstance(self.stream, Stream):
self._register_stream(self.stream)

if self.send is not None:
self._register_send(self.send)

self.solved_user_fn = self.container.solve(
Dependent(fn, scope="consumer_record"),
scopes=["consumer_record", "stream", "application"],
)

async def execute(self, consumer_record: ConsumerRecord) -> Any:
"""Execute the dependencies graph with external values.
Attributes:
consumer_record: A kafka record containing `values`, `headers`, etc.
"""
async with self.container.enter_scope("consumer_record") as state:
return await self.solved_user_fn.execute_async(
executor=self.async_executor,
state=state,
values={ConsumerRecord: consumer_record},
)

def _register_stream(self, stream: Stream):
"""Register the stream with the container."""
hook = bind_by_type(
Dependent(lambda: stream, scope="consumer_record", wire=False), Stream
)
self.container.bind(hook)

def _register_consumer_record(self):
"""Register consumer record with the container.
We bind_by_generic because we want to bind the `ConsumerRecord` type which
is generic.
The value must be injected at runtime.
"""
hook = bind_by_generic(
Dependent(ConsumerRecord, scope="consumer_record", wire=False),
ConsumerRecord,
)
self.container.bind(hook)

def _register_send(self, send: Send):
hook = bind_by_type(
Dependent(lambda: send, scope="consumer_record", wire=False), Send
)
self.container.bind(hook)
31 changes: 31 additions & 0 deletions kstreams/_di/dependencies/hooks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import inspect
from typing import Any, get_origin

from di._container import BindHook
from di._utils.inspect import get_type
from di.api.dependencies import DependentBase


def bind_by_generic(
provider: DependentBase[Any],
dependency: type,
) -> BindHook:
"""Hook to substitute the matched dependency based on its generic."""

def hook(
param: inspect.Parameter | None, dependent: DependentBase[Any]
) -> DependentBase[Any] | None:
if dependent.call == dependency:
return provider
if param is None:
return None

type_annotation_option = get_type(param)
if type_annotation_option is None:
return None
type_annotation = type_annotation_option.value
if get_origin(type_annotation) is dependency:
return provider
return None

return hook
37 changes: 37 additions & 0 deletions kstreams/_di/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from typing import Optional, TypeVar

from kstreams._di.binders.api import BinderMarker
from kstreams._di.binders.header import HeaderMarker
from kstreams.typing import Annotated


def Header(
*, alias: Optional[str] = None, convert_underscores: bool = True
) -> BinderMarker:
"""Construct another type from the headers of a kafka record.
Args:
alias: Use a different header name
convert_underscores: If True, convert underscores to dashes.
Usage:
```python
from kstream import Header, Annotated
def user_fn(event_type: Annotated[str, Header(alias="EventType")]):
...
```
"""
header_marker = HeaderMarker(alias=alias, convert_underscores=convert_underscores)
binder = BinderMarker(extractor_marker=header_marker)
return binder


T = TypeVar("T")

FromHeader = Annotated[T, Header()]
FromHeader.__doc__ = """General purpose convenient header type.
Use `Annotated` to provide custom params.
"""
Loading

0 comments on commit 3083d43

Please sign in to comment.