Skip to content

Commit

Permalink
Merge pull request #39 from kpn/fix/typing-issues
Browse files Browse the repository at this point in the history
fix: typing
  • Loading branch information
woile authored Aug 11, 2022
2 parents 357f11f + 25a6dc8 commit f291bf7
Show file tree
Hide file tree
Showing 18 changed files with 115 additions and 82 deletions.
2 changes: 1 addition & 1 deletion examples/json_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import aiokafka

from kstreams import Stream, consts, create_engine
from kstreams.custom_types import Headers
from kstreams.types import Headers


class JsonSerializer:
Expand Down
5 changes: 0 additions & 5 deletions kstreams/custom_types.py

This file was deleted.

24 changes: 15 additions & 9 deletions kstreams/engine.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import asyncio
import inspect
import logging
from typing import Any, Coroutine, Dict, List, Optional, Type, Union
from typing import Any, Callable, Dict, List, Optional, Type, Union

from aiokafka.structs import RecordMetadata

from .backends.kafka import Kafka
from .clients import ConsumerType, ProducerType
from .custom_types import DecoratedCallable, Headers
from .exceptions import DuplicateStreamException
from .exceptions import DuplicateStreamException, EngineNotStartedException
from .prometheus.monitor import PrometheusMonitor
from .prometheus.tasks import metrics_task
from .serializers import ValueDeserializer, ValueSerializer
from .singlenton import Singleton
from .streams import Stream
from .types import Headers
from .utils import encode_headers

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -51,6 +53,9 @@ async def send(
value_serializer: Optional[ValueSerializer] = None,
value_serializer_kwargs: Optional[Dict] = None,
):
if self._producer is None:
raise EngineNotStartedException()

value_serializer = value_serializer or self.value_serializer

# serialize only when value and value_serializer are present
Expand All @@ -59,18 +64,19 @@ async def send(
value, headers=headers, value_serializer_kwargs=value_serializer_kwargs
)

encoded_headers = None
if headers is not None:
headers = encode_headers(headers)
encoded_headers = encode_headers(headers)

fut = await self._producer.send(
topic,
value=value,
key=key,
partition=partition,
timestamp_ms=timestamp_ms,
headers=headers,
headers=encoded_headers,
)
metadata = await fut
metadata: RecordMetadata = await fut
self.monitor.add_topic_partition_offset(
topic, metadata.partition, metadata.offset
)
Expand Down Expand Up @@ -151,7 +157,7 @@ def _create_stream(
self,
topics: Union[List[str], str],
*,
func: Coroutine[Stream, Any, Any],
func: Callable[[Stream], None],
name: Optional[str] = None,
value_deserializer: Optional[ValueDeserializer] = None,
**kwargs,
Expand Down Expand Up @@ -182,8 +188,8 @@ def stream(
name: Optional[str] = None,
value_deserializer: Optional[ValueDeserializer] = None,
**kwargs,
) -> DecoratedCallable:
def decorator(func: Coroutine[Stream, Any, Any]) -> Stream:
) -> Callable[[Callable[[Stream], None]], Stream]:
def decorator(func: Callable[[Stream], None]) -> Stream:
stream = self._create_stream(
topics,
func=func,
Expand Down
5 changes: 5 additions & 0 deletions kstreams/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,8 @@ def __str__(self) -> str:
)

return msg


class EngineNotStartedException(Exception):
def __str__(self) -> str:
return "Engine has not been started. Try with `await engine.start()`"
11 changes: 7 additions & 4 deletions kstreams/prometheus/tasks.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import asyncio
from typing import Any, DefaultDict, List
from typing import Any, DefaultDict, List, Type

from kstreams.clients import Consumer
from kstreams.clients import ConsumerType
from kstreams.streams import Stream

from .monitor import PrometheusMonitorType
Expand All @@ -10,11 +10,14 @@
async def metrics_task(streams: List[Stream], monitor: PrometheusMonitorType):
while True:
for stream in streams:
await generate_consumer_metrics(stream.consumer, monitor=monitor)
if stream.consumer is not None:
await generate_consumer_metrics(stream.consumer, monitor=monitor)
await asyncio.sleep(3)


async def generate_consumer_metrics(consumer: Consumer, monitor: PrometheusMonitorType):
async def generate_consumer_metrics(
consumer: Type[ConsumerType], monitor: PrometheusMonitorType
):
"""
Generate Consumer Metrics for Prometheus
Expand Down
2 changes: 1 addition & 1 deletion kstreams/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import aiokafka

from .custom_types import Headers
from .types import Headers


class ValueDeserializer(Protocol):
Expand Down
22 changes: 14 additions & 8 deletions kstreams/streams.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import inspect
import logging
import uuid
from typing import Any, AsyncGenerator, Coroutine, Dict, List, Optional, Type, Union
from typing import Any, AsyncGenerator, Callable, Dict, List, Optional, Type, Union

from aiokafka import errors, structs

Expand All @@ -19,7 +19,7 @@ def __init__(
topics: Union[List[str], str],
*,
backend: Kafka,
func: Union[Coroutine[Any, Any, Type[ConsumerType]], AsyncGenerator],
func: Callable[["Stream"], None],
consumer_class: Type[ConsumerType] = Consumer,
name: Optional[str] = None,
config: Optional[Dict] = None,
Expand All @@ -29,6 +29,7 @@ def __init__(
self.func = func
self.backend = backend
self.consumer_class = consumer_class
self.consumer: Optional[Type[ConsumerType]] = None
self.config = config or {}
self._consumer_task: Optional[asyncio.Task] = None
self.name = name or str(uuid.uuid4())
Expand All @@ -41,21 +42,22 @@ def __init__(
# so we always create a list and then we expand it with *topics
self.topics = [topics] if isinstance(topics, str) else topics

def _create_consumer(self) -> ConsumerType:
def _create_consumer(self) -> Type[ConsumerType]:
config = {**self.backend.dict(), **self.config}
return self.consumer_class(*self.topics, **config)

async def stop(self) -> None:
if not self.running:
return None

await self.consumer.stop()
self.running = False
if self.consumer is not None:
await self.consumer.stop()
self.running = False

if self._consumer_task is not None:
self._consumer_task.cancel()

async def start(self) -> None:
async def start(self) -> Optional[AsyncGenerator]:
async def func_wrapper(func):
try:
# await for the end user coroutine
Expand All @@ -78,6 +80,7 @@ async def func_wrapper(func):
# It is not an async_generator so we need to
# create an asyncio.Task with func
self._consumer_task = asyncio.create_task(func_wrapper(func))
return None

async def __aenter__(self) -> AsyncGenerator:
"""
Expand All @@ -97,7 +100,8 @@ async def stream(consumer):
"""
logger.info("Starting async_gen Stream....")
async_gen = await self.start()
return async_gen
# For now ignoring the typing issue. The start method might be splited
return async_gen # type: ignore

async def __aexit__(self, exc_type, exc, tb) -> None:
logger.info("Stopping async_gen Stream....")
Expand All @@ -114,7 +118,9 @@ async def __anext__(self) -> structs.ConsumerRecord:
try:
# value is a ConsumerRecord:
# namedtuple["topic", "partition", "offset", "key", "value"]
consumer_record: structs.ConsumerRecord = await self.consumer.getone()
consumer_record: structs.ConsumerRecord = (
await self.consumer.getone() # type: ignore
)

# deserialize only when value and value_deserializer are present
if (
Expand Down
19 changes: 1 addition & 18 deletions kstreams/test_utils/structs.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,4 @@
from dataclasses import dataclass
from typing import Generic, NamedTuple, Optional, TypeVar

from kstreams.custom_types import KafkaHeaders

KT = TypeVar("KT")
VT = TypeVar("VT")
from typing import NamedTuple


class TopicPartition(NamedTuple):
Expand All @@ -17,14 +11,3 @@ class RecordMetadata(NamedTuple):
partition: int
topic: str
timestamp: int


@dataclass
class ConsumerRecord(Generic[KT, VT]):
topic: str
partition: int
offset: int
timestamp: int
key: Optional[KT]
value: Optional[VT]
headers: KafkaHeaders
40 changes: 25 additions & 15 deletions kstreams/test_utils/test_clients.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from datetime import datetime
from typing import Any, Coroutine, Dict, List, Optional, Tuple, Union
from typing import Any, Coroutine, Dict, List, Optional, Tuple

from aiokafka.structs import ConsumerRecord

from kstreams.clients import Consumer, Producer
from kstreams.custom_types import Headers
from kstreams.serializers import ValueSerializer
from kstreams.types import Headers

from .structs import ConsumerRecord, RecordMetadata, TopicPartition
from .structs import RecordMetadata, TopicPartition
from .topics import TopicManager


Expand All @@ -21,12 +23,12 @@ async def send(
value: Any = None,
key: Any = None,
partition: int = 1,
timestamp_ms: Optional[int] = None,
timestamp_ms: Optional[float] = None,
headers: Optional[Headers] = None,
value_serializer: Optional[ValueSerializer] = None,
value_serializer_kwargs: Optional[Dict] = None,
) -> Coroutine:
topic = TopicManager.get_or_create_topic(topic_name)
topic = TopicManager.get_or_create(topic_name)
timestamp_ms = timestamp_ms or datetime.now().timestamp()
total_messages = topic.total_messages + 1

Expand All @@ -38,6 +40,10 @@ async def send(
partition=partition,
timestamp=timestamp_ms,
offset=total_messages,
timestamp_type=None,
checksum=None,
serialized_key_size=None,
serialized_value_size=None,
)

await topic.put(consumer_record)
Expand All @@ -56,12 +62,12 @@ async def fut():
class TestConsumer(Base, Consumer):
def __init__(self, *topics: str, group_id: Optional[str] = None, **kwargs) -> None:
# copy the aiokafka behavior
self.topics: Tuple[str] = topics
self._group_id: str = group_id
self.topics: Tuple[str, ...] = topics
self._group_id: Optional[str] = group_id
self._assigments: List[TopicPartition] = []

for topic_name in topics:
TopicManager.create_topic(topic_name, consumer=self)
TopicManager.create(topic_name, consumer=self)
self._assigments.append(TopicPartition(topic=topic_name, partition=1))

# Called to make sure that has all the kafka attributes like _coordinator
Expand All @@ -72,8 +78,11 @@ def assignment(self) -> List[TopicPartition]:
return self._assigments

def last_stable_offset(self, topic_partition: TopicPartition) -> int:
topic = TopicManager.get_topic(topic_partition.topic)
return topic.total_messages
topic = TopicManager.get(topic_partition.topic)

if topic is not None:
return topic.total_messages
return -1

async def position(self, topic_partition: TopicPartition) -> int:
return self.last_stable_offset(topic_partition)
Expand All @@ -83,13 +92,14 @@ def highwater(self, topic_partition: TopicPartition) -> int:

async def getone(
self,
) -> Union[bytes, Dict]: # The return type must be fixed later on
) -> Optional[ConsumerRecord]: # The return type must be fixed later on
topic = None
for topic_partition in self._assigments:
topic = TopicManager.get_topic(topic_partition.topic)
if topic is None:
raise AttributeError("There should be a topic")
topic = TopicManager.get(topic_partition.topic)

if not topic.consumed:
break

return await topic.get()
if topic is not None:
return await topic.get()
return None
2 changes: 1 addition & 1 deletion kstreams/test_utils/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
from typing import Any, Dict, List, Optional, Type

from kstreams.create import create_engine
from kstreams.custom_types import Headers
from kstreams.serializers import ValueSerializer
from kstreams.streams import Stream
from kstreams.types import Headers

from .structs import RecordMetadata
from .test_clients import TestConsumer, TestProducer
Expand Down
19 changes: 12 additions & 7 deletions kstreams/test_utils/topics.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,30 @@ class TopicManager:
topics: ClassVar[Dict[str, Topic]] = {}

@classmethod
def get_topic(cls, name: str) -> Optional[Topic]:
return cls.topics.get(name)
def get(cls, name: str) -> Topic:
topic = cls.topics.get(name)

if topic is not None:
return topic
raise ValueError(f"Topic {name} not found")

@classmethod
def create_topic(
def create(
cls, name: str, consumer: Optional["test_clients.Consumer"] = None
) -> Topic:
topic = Topic(name=name, queue=asyncio.Queue(), consumer=consumer)
cls.topics[name] = topic
return topic

@classmethod
def get_or_create_topic(cls, name: str) -> Topic:
def get_or_create(cls, name: str) -> Topic:
"""
Add a new queue if does not exist and return it
"""
topic = cls.get_topic(name)
if topic is None:
topic = cls.create_topic(name)
try:
topic = cls.get(name)
except ValueError:
topic = cls.create(name)
return topic

@classmethod
Expand Down
4 changes: 4 additions & 0 deletions kstreams/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from typing import Dict, Sequence, Tuple

Headers = Dict[str, str]
EncodedHeaders = Sequence[Tuple[str, bytes]]
Loading

0 comments on commit f291bf7

Please sign in to comment.