diff --git a/.github/workflows/deploy-latest.yml b/.github/workflows/deploy-latest.yml
deleted file mode 100644
index 19af55b..0000000
--- a/.github/workflows/deploy-latest.yml
+++ /dev/null
@@ -1,31 +0,0 @@
-
-name: "Deploy latest"
-
-on:
- push:
- branches:
- - main
-
-jobs:
- deploy_latest:
-
- runs-on: ubuntu-latest
- steps:
- - uses: actions/checkout@v3
- - name: Set up Python
- uses: actions/setup-python@v4
- with:
- python-version: '3.x'
- - name: Install dependencies
- run: |
- python -m pip install --upgrade pip
- pip install build
- - name: Build package
- run: |
- echo "999999-dev-$(date +%s)" > heizer/VERSION
- python -m build
- - name: Publish package
- uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29
- with:
- user: __token__
- password: ${{ secrets.PYPI_API_TOKEN }}
diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml
index 323a9f2..482e4df 100644
--- a/.github/workflows/main.yml
+++ b/.github/workflows/main.yml
@@ -83,7 +83,7 @@ jobs:
run: |
python -m pytest -s tests --cov=heizer --cov-report=xml --junit-xml=report.xml
- test_build_doc:
+ build_doc_and_publish_on_main:
runs-on: ubuntu-latest
services:
zookeeper:
@@ -116,13 +116,16 @@ jobs:
sudo apt-get install -y zip gzip tar
python -m pip install --upgrade pip
pip install -e .[all,doc]
- sphinx-multiversion ./docs/source ./doc/
- zip -r doc.zip doc/
+ echo ${{ github.ref_name }} > heizer/VERSION
+ python docs/create_versions_file.py
+ sphinx-build ./docs/source "./public/${{ github.ref_name }}"
+ mv ./docs/versions.json ./public/versions.json
+ zip -r public.zip ./public/
- name: Upload Artifact
uses: actions/upload-artifact@v3
with:
- name: doc_zip
- path: doc.zip
+ name: public_zip
+ path: public.zip
retention-days: 5
- name: deploy
if: github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/'))
@@ -131,7 +134,7 @@ jobs:
host: ${{ secrets.SERVER_IP }}
username: "root"
password: ${{ secrets.SERVER_PASS }}
- source: "doc/"
+ source: "public/"
target: "/var/www/html/docs/"
strip_components: 1
overwrite: true
diff --git a/.gitignore b/.gitignore
index 07d18ad..207576e 100644
--- a/.gitignore
+++ b/.gitignore
@@ -156,3 +156,5 @@ cython_debug/
# Editors
.idea
.vscode
+.fleet
+public
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..9fde319
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,40 @@
+# Main/Dev
+
+
+### Features
+
+### Improvements
+
+### Fixes
+
+---
+
+# Releases
+
+## 0.2.0
+
+### Breaking Changes
+
+- rewrite heizer producer to be a class instead of decorator. There is no clear benefit to create
+ producer as decorator
+- updated class names
+
+### New Features:
+
+- support retry in consumer
+- add helpers to list, list topics
+- add consumer signal
+- write consumer status log file for health checking
+- support async produce in producer
+
+### Improvements
+
+- add more logs
+
+---
+
+## 0.1.5
+
+### Improvements
+
+- improve consumer
diff --git a/docs/create_versions_file.py b/docs/create_versions_file.py
new file mode 100644
index 0000000..4c31c24
--- /dev/null
+++ b/docs/create_versions_file.py
@@ -0,0 +1,21 @@
+import json
+import pathlib
+
+import requests
+
+
+def create():
+ versions = [{"name": "main", "url": "https://heizer.claudezss.com/docs/main"}]
+
+ rsp = requests.get("https://api.github.com/repos/claudezss/heizer/releases?per_page=100")
+ d = rsp.json()
+ for item in d:
+ release_name = item["name"]
+ versions.append({"name": release_name, "url": f"https://heizer.claudezss.com/docs/{release_name}/"})
+
+ with open(pathlib.Path(__file__).parent / "versions.json", "+w") as f:
+ json.dump(versions, f, indent=4)
+
+
+if __name__ == "__main__":
+ create()
diff --git a/docs/source/_templates/versions.html b/docs/source/_templates/versions.html
index b3c63b5..8624712 100644
--- a/docs/source/_templates/versions.html
+++ b/docs/source/_templates/versions.html
@@ -1,28 +1,29 @@
-{%- if current_version %}
+
+
Other Versions
- v: {{ current_version.name }}
+ v: {{ release }}
-
- {%- if versions.tags %}
-
+
+
- Tags
- {%- for item in versions.tags %}
- - {{ item.name }}
- {%- endfor %}
- {%- endif %}
- {%- if versions.branches %}
-
- - Branches
- {%- for item in versions.branches %}
- - {{ item.name }}
- {%- endfor %}
-
- {%- endif %}
-{%- endif %}
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 89adc47..e4f5213 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -9,11 +9,14 @@
# import sys
# sys.path.insert(0, os.path.abspath("../.."))
+from pathlib import Path
+
+root_folder = Path(__file__).parent.parent.parent
project = "heizer"
copyright = "2023, Yan Zhang"
author = "Yan Zhang"
-release = "main"
+release = open(root_folder / "heizer" / "VERSION").read()
# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
@@ -32,6 +35,7 @@
"_templates",
]
+html_context = {"release": release}
autodoc_typehints = "description"
diff --git a/docs/source/tutorial/consumer.rst b/docs/source/tutorial/consumer.rst
index 3529610..8dc4feb 100644
--- a/docs/source/tutorial/consumer.rst
+++ b/docs/source/tutorial/consumer.rst
@@ -8,60 +8,68 @@ Basic Producer and Consumer
.. ipython:: python
- from heizer import HeizerConfig, HeizerTopic, consumer, producer, HeizerMessage
+ from heizer import Topic, consumer, Producer, Message, ProducerConfig, ConsumerConfig, create_new_topics
import json
+ import uuid
+ import asyncio
- producer_config = HeizerConfig(
- {
- "bootstrap.servers": "localhost:9092",
- }
- )
+ producer_config = ProducerConfig(bootstrap_servers="localhost:9092")
- consumer_config = HeizerConfig(
- {
- "bootstrap.servers": "localhost:9092",
- "group.id": "default",
- "auto.offset.reset": "earliest",
- }
- )
+ consumer_config = ConsumerConfig(bootstrap_servers="localhost:9092", group_id="default")
-2. Create the topic
+2. Create the topic with 2 partitions
.. ipython:: python
- topics = [HeizerTopic(name="my.topic1.consumer.example")]
+ topics = [Topic(name=f"my.topic1.consumer.example.{uuid.uuid4()}", num_partitions=2)]
+ create_new_topics(config=producer_config, topics=topics)
3. Create producer
.. ipython:: python
- @producer(
- topics=topics,
- config=producer_config,
- key_alias="key",
- headers_alias="headers",
- )
- def produce_data(status: str, result: str):
- return {
- "status": status,
- "result": result,
- "key": "my_key",
- "headers": {"my_header": "my_header_value"},
- }
+ pd = Producer(config=producer_config)
-4. Publish messages
+4. Publish messages synchronously to partition 0
.. ipython:: python
- produce_data("start", "1")
-
- produce_data("loading", "2")
+ for status, val in [("start", "1"), ("loading", "2"), ("success", "3"), ("postprocess", "4")]:
+ pd.produce(
+ topic=topics[0],
+ key="my_key",
+ value={"status": status, "result": val},
+ headers={"k": "v"},
+ partition=0,
+ auto_flush=False
+ )
+ pd.flush()
- produce_data("success", "3")
+5. Publish messages asynchronously to partition 1 ( it's faster than sync produce in most cases)
- produce_data("postprocess", "4")
+.. ipython:: python
-5. Create consumer
+ jobs = []
+ async def produce():
+ for status, val in [("start", "1"), ("loading", "2"), ("success", "3"), ("postprocess", "4")]:
+ jobs.append(
+ asyncio.ensure_future(
+ pd.async_produce(
+ topic=topics[0],
+ key="my_key",
+ value={"status": status, "result": val},
+ headers={"k": "v"},
+ partition=1,
+ auto_flush=False
+ )
+ )
+ )
+ await asyncio.gather(*jobs)
+ pd.flush()
+
+ asyncio.run(produce())
+
+6. Create consumer
.. ipython:: python
@@ -70,7 +78,8 @@ Basic Producer and Consumer
# `status` is `success` in msg
# If there is no stopper func, consumer will keep running forever
- def stopper(msg: HeizerMessage):
+ def stopper(msg: Message, C: consumer, *arg, **kargs):
+ print(f"Consumer name: {C.name}")
data = json.loads(msg.value)
if data["status"] == "success":
return True
@@ -81,12 +90,18 @@ Basic Producer and Consumer
config=consumer_config,
stopper=stopper,
)
- def consume_data(message: HeizerMessage):
+ def consume_data(message: Message, *arg, **kwargs):
data = json.loads(message.value)
- print(data)
- print(message.key)
- print(message.headers)
+ print(f"message data: {data}")
+ print(f"message key: {message.key}")
+ print(f"message headers: {message.headers}")
return data["result"]
result = consume_data()
- print("Expected Result:", result)
+ print("Expected Result (should be 3):", result)
+
+
+7. More samples:
+
+.. literalinclude :: ./../../../tests/test_consumer.py
+ :language: python
diff --git a/heizer/__init__.py b/heizer/__init__.py
index 5f3b41b..5060e68 100644
--- a/heizer/__init__.py
+++ b/heizer/__init__.py
@@ -1,16 +1,24 @@
-from heizer._source.admin import create_new_topic, get_admin_client
-from heizer._source.consumer import consumer
-from heizer._source.message import HeizerMessage
-from heizer._source.producer import producer
-from heizer._source.topic import HeizerTopic
-from heizer.config import HeizerConfig
+from heizer._source.admin import create_new_topics, delete_topics, get_admin_client, list_topics
+from heizer._source.consumer import ConsumerSignal, consumer
+from heizer._source.message import Message
+from heizer._source.producer import Producer
+from heizer._source.status_manager import read_consumer_status, write_consumer_status
+from heizer._source.topic import Topic
+from heizer.config import BaseConfig, ConsumerConfig, ProducerConfig
__all__ = [
"consumer",
- "producer",
- "HeizerConfig",
- "HeizerTopic",
- "HeizerMessage",
- "create_new_topic",
+ "ConsumerSignal",
+ "Producer",
+ "BaseConfig",
+ "ProducerConfig",
+ "ConsumerConfig",
+ "Message",
+ "Topic",
+ "create_new_topics",
"get_admin_client",
+ "write_consumer_status",
+ "read_consumer_status",
+ "delete_topics",
+ "list_topics",
]
diff --git a/heizer/_source/__init__.py b/heizer/_source/__init__.py
index ec5d26b..a51ef24 100644
--- a/heizer/_source/__init__.py
+++ b/heizer/_source/__init__.py
@@ -1,18 +1,20 @@
import logging
-import os
import sys
+from heizer.env_vars import HEIZER_LOG_LEVEL
+
FORMAT = "heizer %(asctime)s %(levelname)-8s %(message)s"
def get_logger(name: str) -> logging.Logger:
+ """Get a logger with the given name."""
formatter = logging.Formatter(fmt=FORMAT, datefmt="%Y-%m-%d %H:%M:%S")
handler = logging.FileHandler("/tmp/heizer_log.txt", mode="w")
handler.setFormatter(formatter)
screen_handler = logging.StreamHandler(stream=sys.stdout)
screen_handler.setFormatter(formatter)
logger = logging.getLogger(name)
- logger.setLevel(os.environ.get("HEIZER_LOG_LEVEL", "INFO"))
+ logger.setLevel(HEIZER_LOG_LEVEL)
logger.addHandler(handler)
logger.addHandler(screen_handler)
return logger
diff --git a/heizer/_source/admin.py b/heizer/_source/admin.py
index 11419b7..a597c5e 100644
--- a/heizer/_source/admin.py
+++ b/heizer/_source/admin.py
@@ -1,19 +1,30 @@
-from confluent_kafka.admin import AdminClient, NewTopic
+from confluent_kafka.admin import AdminClient
from heizer._source import get_logger
-from heizer._source.topic import HeizerTopic
-from heizer.config import HeizerConfig
-from heizer.types import List
+from heizer._source.topic import Topic
+from heizer.config import BaseConfig
+from heizer.types import KafkaConfig, List, Union
logger = get_logger(__name__)
-def get_admin_client(config: HeizerConfig) -> AdminClient:
- return AdminClient({"bootstrap.servers": config.value.get("bootstrap.servers")})
+def get_admin_client(config: Union[BaseConfig, KafkaConfig]) -> AdminClient:
+ """
+ Create an admin client using the provided configuration.
+ """
+ if isinstance(config, BaseConfig):
+ config_dict = {"bootstrap.servers": config.bootstrap_servers}
+ else:
+ config_dict = config
+ return AdminClient(config_dict)
-def create_new_topic(admin_client: AdminClient, topics: List[HeizerTopic]) -> None:
- new_topics = [NewTopic(topic.name, num_partitions=len(topic._partitions)) for topic in topics]
+def create_new_topics(config: Union[BaseConfig, KafkaConfig], topics: List[Topic]) -> None:
+ """
+ Create new topics using the provided configuration.
+ """
+ admin_client = get_admin_client(config)
+ new_topics = [topic._new_topic for topic in topics]
fs = admin_client.create_topics(new_topics)
# Wait for each operation to finish.
@@ -27,3 +38,23 @@ def create_new_topic(admin_client: AdminClient, topics: List[HeizerTopic]) -> No
continue
else:
logger.exception(f"Failed to create topic {topic}", exc_info=e)
+
+
+def delete_topics(config: Union[BaseConfig, KafkaConfig], topics: List[Topic]) -> None:
+ """Delete topics using the provided configuration."""
+ admin_client = get_admin_client(config)
+ fs = admin_client.delete_topics([tp.name for tp in topics])
+
+ # Wait for each operation to finish.
+ for topic, f in fs.items():
+ try:
+ f.result() # The result itself is None
+ logger.info("Topic {} deleted".format(topic))
+ except Exception as e:
+ logger.exception(f"Failed to delete topic {topic}", exc_info=e)
+
+
+def list_topics(config: Union[BaseConfig, KafkaConfig]) -> List[str]:
+ """List topics using the provided configuration."""
+ admin_client = get_admin_client(config)
+ return list(admin_client.list_topics().topics.keys())
diff --git a/heizer/_source/consumer.py b/heizer/_source/consumer.py
index e42d772..1b9e157 100644
--- a/heizer/_source/consumer.py
+++ b/heizer/_source/consumer.py
@@ -1,28 +1,33 @@
import asyncio
+import atexit
import functools
import logging
+import os
import signal
+from dataclasses import dataclass
from uuid import uuid4
-from confluent_kafka import Consumer
-from pydantic import BaseModel
+import confluent_kafka as ck
from heizer._source import get_logger
-from heizer._source.admin import create_new_topic, get_admin_client
-from heizer._source.message import HeizerMessage
-from heizer._source.topic import HeizerTopic
-from heizer.config import HeizerConfig
+from heizer._source.admin import create_new_topics
+from heizer._source.enums import ConsumerStatusEnum
+from heizer._source.message import Message
+from heizer._source.producer import Producer
+from heizer._source.status_manager import write_consumer_status
+from heizer._source.topic import Topic
+from heizer.config import ConsumerConfig
from heizer.types import (
Any,
Awaitable,
Callable,
Concatenate,
Coroutine,
+ KafkaConfig,
List,
Optional,
ParamSpec,
Stopper,
- Type,
TypeVar,
Union,
)
@@ -35,86 +40,121 @@
logger = get_logger(__name__)
-def signal_handler(signal: Any, frame: Any) -> None:
- global interrupted
- interrupted = True
+def _make_consumer_config_dict(config: ConsumerConfig) -> KafkaConfig:
+ config_dict = {
+ "bootstrap.servers": config.bootstrap_servers,
+ "group.id": config.group_id,
+ "auto.offset.reset": config.auto_offset_reset,
+ "enable.auto.commit": config.enable_auto_commit,
+ }
+ if config.other_configs:
+ config_dict.update(config.other_configs)
-signal.signal(signal.SIGINT, signal_handler)
+ return config_dict
-interrupted = False
+
+@dataclass
+class ConsumerSignal:
+ is_running: bool = True
+
+ def stop(self) -> None:
+ self.is_running = False
class consumer(object):
- """A decorator to create a consumer"""
+ """
+ A decorator to create a consumer
+ """
- __id__: str
+ id: str
name: Optional[str]
- topics: List[HeizerTopic]
- config: HeizerConfig = HeizerConfig()
+ topics: List[Topic]
+ config: KafkaConfig
call_once: bool = False
stopper: Optional[Stopper] = None
- deserializer: Optional[Type[BaseModel]] = None
-
+ deserializer: Optional[Callable] = None
+ consumer_signal: ConsumerSignal
is_async: bool = False
poll_timeout: int = 1
-
init_topics: bool = False
+ enable_retry: bool = False
+ retry_topic: Optional[Topic] = None
+ retry_times: int = 1
+ retry_times_header_key: str = "!!-heizer_func_retried_times-!!"
+
+ # private attr
+ _consumer_instance: ck.Consumer
+ _producer_instance: Optional[Producer] = None # for retry
+
def __init__(
self,
*,
- topics: List[HeizerTopic],
- config: HeizerConfig = HeizerConfig(),
+ topics: List[Topic],
+ config: Union[KafkaConfig, ConsumerConfig],
call_once: bool = False,
stopper: Optional[Stopper] = None,
- deserializer: Optional[Type[BaseModel]] = None,
+ deserializer: Optional[Callable] = None,
is_async: bool = False,
+ id: Optional[str] = None,
name: Optional[str] = None,
poll_timeout: Optional[int] = None,
init_topics: bool = False,
+ consumer_signal: Optional[ConsumerSignal] = None,
+ enable_retry: bool = False,
+ retry_topic: Optional[Topic] = None,
+ retry_times: int = 1,
):
self.topics = topics
- self.config = config
+ self.config = _make_consumer_config_dict(config) if isinstance(config, ConsumerConfig) else config
self.call_once = call_once
self.stopper = stopper
self.deserializer = deserializer
- self.__id__ = str(uuid4())
- self.name = name or self.__id__
+ self.id = id or str(uuid4())
+ self.name = name or self.id
self.is_async = is_async
self.poll_timeout = poll_timeout if poll_timeout is not None else 1
self.init_topics = init_topics
+ self.consumer_signal = consumer_signal or ConsumerSignal()
+ self.enable_retry = enable_retry
+ self.retry_topic = retry_topic
+ self.retry_times = retry_times
def __repr__(self) -> str:
- return self.name or self.__id__
+ return self.name or self.id
- async def __run__( # type: ignore
- self,
- func: Callable[Concatenate[HeizerMessage, P], Union[T, Awaitable[T]]],
- is_async: bool,
- *args: P.args,
- **kwargs: P.kwargs,
- ) -> Union[Optional[T]]:
- """Run the consumer"""
+ @property
+ def ck_consumer(self) -> ck.Consumer:
+ return self._consumer_instance
- logger.debug(f"[{self}] Creating consumer ")
- c = Consumer(self.config.value)
+ @property
+ def ck_producer(self) -> Optional[ck.Producer]:
+ return self._producer_instance
- logger.debug(f"[{self}] Subscribing to topics {[topic.name for topic in self.topics]}")
- c.subscribe([topic.name for topic in self.topics])
+ def commit(self, asynchronous: bool = False, *arg, **kwargs):
+ self.ck_consumer.commit(asynchronous=asynchronous, *arg, **kwargs)
- if self.init_topics:
- logger.debug(f"[{self}] Initializing topics")
- admin_client = get_admin_client(self.config)
- create_new_topic(admin_client, self.topics)
+ async def _long_run(self, f: Callable, is_async: bool, *args, **kwargs):
+ target_topics: List[Topic] = [topic for topic in self.topics]
- while True:
- if interrupted:
- logger.debug(f"[{self}] Interrupted by keyboard")
- break
+ if self.enable_retry and self.retry_topic:
+ target_topics.append(self.retry_topic)
+
+ logger.info(f"[{self}] Subscribing to topics {[topic.name for topic in target_topics]}")
+ self.ck_consumer.subscribe([topic.name for topic in target_topics])
+
+ result = None
+
+ while self.consumer_signal.is_running:
+ write_consumer_status(
+ consumer_id=self.id, consumer_name=self.name, pid=os.getpid(), status=ConsumerStatusEnum.RUNNING
+ )
- msg = c.poll(self.poll_timeout)
+ result = None
+
+ msg = self.ck_consumer.poll(self.poll_timeout)
if msg is None:
continue
@@ -124,37 +164,44 @@ async def __run__( # type: ignore
continue
logger.debug(f"[{self}] Received message")
- hmessage = HeizerMessage(msg)
+
+ h_message = Message(msg)
if self.deserializer is not None:
logger.debug(f"[{self}] Parsing message")
try:
- hmessage.formatted_value = self.deserializer.parse_raw(hmessage.value)
+ h_message.formatted_value = self.deserializer(h_message.value)
except Exception as e:
logger.exception(f"[{self}] Failed to deserialize message", exc_info=e)
- result = None
+ # If heizer retry times header key exists, remove it when passing to user func
+ has_retired: Optional[int] = None
+ if h_message.headers and self.retry_times_header_key in h_message.headers:
+ has_retired = int(h_message.headers.pop(self.retry_times_header_key))
+
+ logger.debug(f"[{self}] Executing function {f.__name__}")
- logger.debug(f"[{self}] Executing function {func.__name__}")
try:
if is_async:
- result = await func(hmessage, *args, **kwargs) # type: ignore
+ result = await f(h_message, self, *args, **kwargs) # type: ignore
else:
- result = func(hmessage, *args, **kwargs)
+ result = f(h_message, self, *args, **kwargs)
except Exception as e:
logger.exception(
- f"[{self}] Failed to execute function {func.__name__}",
+ f"[{self}] Failed to execute function {f.__name__}",
exc_info=e,
)
+ if self.enable_retry:
+ logger.info(f"[{self}] Start producing retry message for function {f.__name__}")
+ await self.produce_retry_message(message=h_message, has_retired=has_retired, func_name=f.__name__)
finally:
- # TODO: add failed message to a retry queue
logger.debug(f"[{self}] Committing message")
- c.commit()
+ self.commit()
if self.stopper is not None:
logger.debug(f"[{self}] Executing stopper function")
try:
- should_stop = self.stopper(hmessage)
+ should_stop = self.stopper(h_message, self)
except Exception as e:
logger.warning(
f"[{self}] Failed to execute stopper function {self.stopper.__name__}.",
@@ -163,25 +210,123 @@ async def __run__( # type: ignore
should_stop = False
if should_stop:
- return result
+ self.consumer_signal.stop()
+ break
if self.call_once is True:
logger.debug(f"[{self}] Call Once is on, returning result")
- return result
+ self.consumer_signal.stop()
+ break
+
+ return result
+
+ async def _run( # type: ignore
+ self,
+ func: Callable[Concatenate[Message, P], Union[T, Awaitable[T]]],
+ is_async: bool,
+ *args: P.args,
+ **kwargs: P.kwargs,
+ ) -> Union[Optional[T]]:
+ """Run the consumer"""
+
+ if self.init_topics:
+ logger.info(f"[{self}] Initializing topics")
+ create_new_topics({"bootstrap.servers": self.config["bootstrap.servers"]}, self.topics)
+
+ logger.info(f"[{self}] Creating consumer")
+ self._consumer_instance = ck.Consumer(self.config)
+
+ if self.enable_retry:
+ self.retry_topic = self.retry_topic or Topic(name=f"{func.__name__}_heizer_retry")
+ self._producer_instance = Producer(config={"bootstrap.servers": self.config["bootstrap.servers"]})
+
+ logger.info(f"[{self}] Creating retry topic {self.retry_topic.name}")
+ create_new_topics({"bootstrap.servers": self.config["bootstrap.servers"]}, [self.retry_topic])
+
+ atexit.register(self._atexit)
+ signal.signal(signal.SIGTERM, self._exit)
+ signal.signal(signal.SIGINT, self._exit)
+
+ result = None
+
+ try:
+ result = await self._long_run(func, is_async, *args, **kwargs)
+ except KeyboardInterrupt:
+ logger.info(
+ f"[{self}] Stopped because of keyboard interruption",
+ )
+ except Exception as e:
+ logger.exception(f"[{self}] stopped", exc_info=e)
+ finally:
+ self.consumer_signal.stop()
+ self.ck_consumer.close()
+ write_consumer_status(
+ consumer_id=self.id, consumer_name=self.name, pid=os.getpid(), status=ConsumerStatusEnum.CLOSED
+ )
+ logger.info(
+ f"[{self}] Closed",
+ )
+ return result
def __call__(
- self, func: Callable[Concatenate[HeizerMessage, P], T]
+ self, func: Callable[Concatenate[Message, P], T]
) -> Callable[P, Union[Coroutine[Any, Any, Optional[T]], T, None]]:
@functools.wraps(func)
async def async_decorator(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
"""Async decorator"""
logging.debug(f"[{self}] Running async decorator")
- return await self.__run__(func, self.is_async, *args, **kwargs)
+ return await self._run(func, self.is_async, *args, **kwargs)
@functools.wraps(func)
def decorator(*args: P.args, **kwargs: P.kwargs) -> Optional[T]:
"""Sync decorator"""
logging.debug(f"[{self}] Running sync decorator")
- return asyncio.run(self.__run__(func, self.is_async, *args, **kwargs))
+ return asyncio.run(self._run(func, self.is_async, *args, **kwargs))
return async_decorator if self.is_async else decorator
+
+ def _exit(self, sig, frame, *args, **kwargs) -> None:
+ self.consumer_signal.stop()
+ self.ck_consumer.close()
+
+ def _atexit(self) -> None:
+ self.ck_consumer.close()
+
+ async def produce_retry_message(self, message: Message, has_retired: Optional[int], func_name: str) -> None:
+ """
+ Produces a retry message for a given function.
+
+ :param message: The original message that needs to be retried.
+ :type message: Message
+ :param has_retired: The number of times the function has been retried before. Defaults to None.
+ :type has_retired: Optional[int]
+ :param func_name: The name of the function.
+ :type func_name: str
+ :return: None
+ """
+ if not self._producer_instance:
+ logger.error(
+ f"[{self}] Confluent producer instance not found,"
+ f" failed to produce retry message for function {func_name}"
+ )
+ return
+
+ if not self.retry_topic:
+ logger.error(f"[{self}] Retry topic not found, failed to produce retry message for function {func_name}")
+ return
+
+ if has_retired == self.retry_times:
+ logger.info(f"[{self}] Function {func_name} reached retry limit ({self.retry_times}), will give up")
+ return
+ elif has_retired is None:
+ has_retired = 1
+ else:
+ has_retired += 1
+
+ new_headers = message.headers or {}
+ new_headers[self.retry_times_header_key] = str(has_retired)
+
+ await self._producer_instance.async_produce(
+ topic=self.retry_topic, key=message.key, value=message.value or "", headers=new_headers, auto_flush=True
+ )
+ logger.info(f"[{self}] Produced retry message for function {func_name}, retry time(s): {has_retired}")
diff --git a/heizer/_source/enums.py b/heizer/_source/enums.py
new file mode 100644
index 0000000..d496189
--- /dev/null
+++ b/heizer/_source/enums.py
@@ -0,0 +1,6 @@
+import enum
+
+
+class ConsumerStatusEnum(str, enum.Enum):
+ RUNNING = "running"
+ CLOSED = "closed"
diff --git a/heizer/_source/message.py b/heizer/_source/message.py
index 90afe86..db1d485 100644
--- a/heizer/_source/message.py
+++ b/heizer/_source/message.py
@@ -1,9 +1,9 @@
-from confluent_kafka import Message
+import confluent_kafka as ck
from heizer.types import Any, Dict, List, Optional, Tuple, Union
-class HeizerMessage:
+class Message:
"""
:class: HeizerMessage
@@ -20,7 +20,7 @@ class HeizerMessage:
"""
# initialized properties
- message: Message
+ _message: ck.Message
topic: Optional[str]
partition: int
headers: Optional[Dict[str, str]]
@@ -29,8 +29,8 @@ class HeizerMessage:
formatted_value: Optional[Any] = None
- def __init__(self, message: Message):
- self.message = message
+ def __init__(self, message: ck.Message):
+ self._message = message
self.topic = self._parse_topic(message.topic())
self.partition = message.partition()
self.headers = self._parse_headers(message.headers())
diff --git a/heizer/_source/producer.py b/heizer/_source/producer.py
index f2f4d97..b0de939 100644
--- a/heizer/_source/producer.py
+++ b/heizer/_source/producer.py
@@ -1,18 +1,24 @@
-import functools
+import asyncio
import json
from uuid import uuid4
-from confluent_kafka import Message, Producer
+import confluent_kafka as ck
from heizer._source import get_logger
-from heizer._source.topic import HeizerTopic
-from heizer.config import HeizerConfig
-from heizer.types import Any, Callable, Dict, List, Optional, TypeVar, Union, cast
+from heizer._source.topic import Topic
+from heizer.config import ProducerConfig
+from heizer.types import Any, Callable, Dict, KafkaConfig, Optional, Union
logger = get_logger(__name__)
-R = TypeVar("R")
-F = TypeVar("F", bound=Callable[..., Dict[str, str]])
+
+def _create_producer_config_dict(config: ProducerConfig) -> KafkaConfig:
+ config_dict = {"bootstrap.servers": config.bootstrap_servers}
+
+ if config.other_configs:
+ config_dict.update(config.other_configs)
+
+ return config_dict
def _default_serializer(value: Union[Dict[Any, Any], str, bytes]) -> bytes:
@@ -23,7 +29,7 @@ def _default_serializer(value: Union[Dict[Any, Any], str, bytes]) -> bytes:
Default Kafka message serializer, which will encode inputs to bytes
"""
- if isinstance(value, dict):
+ if isinstance(value, dict) or isinstance(value, list):
return json.dumps(value).encode("utf-8")
elif isinstance(value, str):
return value.encode("utf-8")
@@ -33,125 +39,154 @@ def _default_serializer(value: Union[Dict[Any, Any], str, bytes]) -> bytes:
raise ValueError(f"Input type is not supported: {type(value).__name__}")
-class producer(object):
+class Producer(object):
"""
- A decorator to create a producer
+ Kafka Message Producer
"""
__id__: str
name: str
- # args
- config: HeizerConfig = HeizerConfig()
- topics: List[HeizerTopic]
+ # attrs need be initialized
+ config: KafkaConfig
serializer: Callable[..., bytes] = _default_serializer
call_back: Optional[Callable[..., Any]] = None
- default_key: Optional[str] = None
- default_headers: Optional[Dict[str, str]] = None
-
- key_alias: str = "key"
- headers_alias: str = "headers"
-
# private properties
- _producer_instance: Optional[Producer] = None
+ _producer_instance: Optional[ck.Producer] = None
def __init__(
self,
- topics: List[HeizerTopic],
- config: HeizerConfig = HeizerConfig(),
+ config: Union[ProducerConfig, KafkaConfig],
serializer: Callable[..., bytes] = _default_serializer,
call_back: Optional[Callable[..., Any]] = None,
- default_key: Optional[str] = None,
- default_headers: Optional[Dict[str, str]] = None,
- key_alias: Optional[str] = None,
- headers_alias: Optional[str] = None,
name: Optional[str] = None,
- init_topics: bool = True,
):
- self.topics = topics
- self.config = config
+ self.config = _create_producer_config_dict(config) if isinstance(config, ProducerConfig) else config
self.serializer = serializer
- self.call_back = call_back
- self.default_key = default_key
- self.default_headers = default_headers
-
- if key_alias:
- self.key_alias = key_alias
- if headers_alias:
- self.headers_alias = headers_alias
+ self.call_back = call_back or self.default_delivery_report
self.__id__ = str(uuid4())
self.name = name or self.__id__
+ def __repr__(self) -> str:
+ return self.name
+
+ def default_delivery_report(self, err: str, msg: ck.Message) -> None:
+ """
+ Called once for each message produced to indicate delivery result.
+ Triggered by poll() or flush().
+
+ :param msg:
+ :param err:
+ :return:
+ """
+ if err is not None:
+ logger.error(f"[{self}] Message delivery failed: {err}")
+ else:
+ logger.debug(f"[{self}] Message delivered to {msg.topic()} [{msg.partition()}]")
+
@property
- def _producer(self) -> Producer:
+ def _producer(self) -> ck.Producer:
if not self._producer_instance:
- self._producer_instance = Producer(self.config.value)
+ self._producer_instance = ck.Producer(self.config)
return self._producer_instance
- def __call__(self, func: F) -> F:
- @functools.wraps(func)
- def decorator(*args: Any, **kwargs: Any) -> Any:
- try:
- result = func(*args, **kwargs)
- except Exception as e:
- raise e
-
- try:
- key = result.pop(self.key_alias, self.default_key)
- except Exception as e:
- logger.debug(f"Failed to get key from result. {str(e)}")
- key = self.default_key
-
- try:
- headers = result.pop(self.headers_alias, self.default_headers)
- except Exception as e:
- logger.debug(
- f"Failed to get headers from result. {str(e)}",
- )
- headers = self.default_headers
-
- headers = cast(Dict[str, str], headers)
-
- msg = self.serializer(result)
-
- self._produce_message(message=msg, key=key, headers=headers)
-
- return result
-
- return cast(F, decorator)
-
- def _produce_message(self, message: bytes, key: Optional[str], headers: Optional[Dict[str, str]]) -> None:
- for topic in self.topics:
- for partition in topic.partitions:
- try:
- self._producer.poll(0)
- self._producer.produce(
- topic=topic.name,
- value=message,
- partition=partition,
- key=key,
- headers=headers,
- on_delivery=self.call_back,
- )
- self._producer.flush()
- except Exception as e:
- logger.exception(f"Failed to produce msg to topic: {topic.name}", exc_info=e)
- raise e
-
-
-def delivery_report(err: str, msg: Message) -> None:
- """
- Called once for each message produced to indicate delivery result.
- Triggered by poll() or flush().
-
- :param msg:
- :param err:
- :return:
- """
- if err is not None:
- logger.error("Message delivery failed: {}".format(err))
- else:
- logger.debug("Message delivered to {} [{}]".format(msg.topic(), msg.partition()))
+ def produce(
+ self,
+ topic: Union[str, Topic],
+ value: Union[bytes, str, Dict[Any, Any]],
+ key: Optional[str] = None,
+ headers: Optional[Dict[str, str]] = None,
+ partition: Optional[int] = -1,
+ auto_flush: bool = True,
+ ) -> None:
+ """
+ This method is used to produce messages to a Kafka topic using the confluent_kafka library.
+
+ :param topic: The topic to publish the message to. Can be either a string representing the topic name,
+ or an instance of the Topic class.
+ :param value: The message body to be produced. Can be either bytes, a string, or a dictionary.
+ :param key: Optional. The key for the message.
+ :param headers: Optional. Additional headers for the message.
+ :param auto_flush: Optional. Default is True. If True, the producer will automatically flush messages after
+ producing them.
+
+ :return: None
+
+ Example usage:
+
+ >>> producer = Producer(config={"xx": "xx"})
+ >>> producer.produce(topic="my_topic", value={"k":"v"})
+
+ """
+ return asyncio.run(
+ self._produce(
+ topic=topic, value=value, key=key, headers=headers, partition=partition, auto_flush=auto_flush
+ )
+ )
+
+ async def async_produce(
+ self,
+ topic: Union[str, Topic],
+ value: Union[bytes, str, Dict[Any, Any]],
+ key: Optional[str] = None,
+ headers: Optional[Dict[str, str]] = None,
+ partition: Optional[int] = -1,
+ auto_flush: bool = True,
+ ) -> None:
+ """
+ Produce a message asynchronously to a Kafka topic.
+
+ :param topic: The topic to produce the message to. Can be either a string or a `Topic` object.
+ :param value: The message body to be produced. Can be either bytes, a string, or a dictionary.
+ :param key: The key to associate with the message. Optional.
+ :param headers: Additional headers to include with the message. Optional.
+ :param auto_flush: If set to True, automatically flush the producer after producing the message.
+ Default is True.
+ :return: None
+
+ This method asynchronously produces a message to a Kafka topic. It accepts the topic, message, key, headers,
+ and auto_flush parameters.
+
+ Example usage:
+
+ >>> producer = Producer(config={"xx": "xx"})
+ >>> asyncio.run(producer.async_produce(topic="my_topic", value={"k":"v"}))
+ """
+ return await self._produce(
+ topic=topic, value=value, key=key, headers=headers, partition=partition, auto_flush=auto_flush
+ )
+
+ async def _produce(
+ self,
+ topic: Union[str, Topic],
+ value: Union[bytes, str, Dict[Any, Any]],
+ key: Optional[str] = None,
+ headers: Optional[Dict[str, str]] = None,
+ partition: Optional[int] = -1,
+ auto_flush: bool = True,
+ ) -> None:
+ topic = Topic(name=topic) if isinstance(topic, str) else topic
+
+ try:
+ self._producer.produce(
+ topic=topic.name,
+ value=self.serializer(value),
+ key=key,
+ headers=headers,
+ on_delivery=self.call_back,
+ partition=partition,
+ )
+ if auto_flush:
+ self.flush()
+ except ck.error.KafkaError.QUEUE_FULL:
+ logger.warning(f"[{self}] Local queue is full, force to flush")
+ self.flush()
+ except Exception as e:
+ logger.exception(f"[{self}] Failed to produce msg to topic: {topic.name}", exc_info=e)
+ raise e
+
+ def flush(self) -> None:
+ self._producer.flush()
diff --git a/heizer/_source/status_manager.py b/heizer/_source/status_manager.py
new file mode 100644
index 0000000..62fa66d
--- /dev/null
+++ b/heizer/_source/status_manager.py
@@ -0,0 +1,48 @@
+import json
+import os
+from collections import defaultdict
+from datetime import datetime
+from typing import Any, Dict, Optional
+
+from heizer._source.enums import ConsumerStatusEnum
+from heizer.env_vars import CONSUMER_STATUS_FILE_PATH
+
+
+def write_consumer_status(
+ consumer_id: str,
+ status: ConsumerStatusEnum,
+ pid: int,
+ consumer_name: Optional[str] = None,
+) -> None:
+ if os.path.exists(CONSUMER_STATUS_FILE_PATH):
+ data = json.loads(open(CONSUMER_STATUS_FILE_PATH).read())
+ data.update(
+ {
+ consumer_id: {
+ "name": consumer_name or consumer_id,
+ "status": status,
+ "pid": pid,
+ "timestamp": datetime.utcnow().isoformat(),
+ }
+ }
+ )
+ else:
+ data = defaultdict(Dict[str, Any])
+ data[consumer_id] = {
+ "name": consumer_name or consumer_id,
+ "status": status,
+ "pid": pid,
+ "timestamp": datetime.utcnow().isoformat(),
+ }
+
+ with open(CONSUMER_STATUS_FILE_PATH, "+w") as f:
+ json.dump(data, f)
+
+
+def read_consumer_status(consumer_id: Optional[str] = None) -> Dict[str, Any]:
+ with open(CONSUMER_STATUS_FILE_PATH) as f:
+ data = json.loads(f.read())
+ if consumer_id:
+ return data.get(consumer_id, {})
+ else:
+ return data
diff --git a/heizer/_source/topic.py b/heizer/_source/topic.py
index 380c85f..ee0bdd6 100644
--- a/heizer/_source/topic.py
+++ b/heizer/_source/topic.py
@@ -1,32 +1,64 @@
-from typing import List, Optional
+from confluent_kafka import TopicPartition, admin
-from confluent_kafka import TopicPartition
+from heizer.types import Any, Dict, List, Optional
-class HeizerTopic(TopicPartition):
+class Topic:
name: str
- _partitions: List[int]
- _replication_factor: int
- _topic_partitions: List[TopicPartition]
+ partition: int
+ offset: Optional[int] = None
+ metadata: Optional[str] = None
+ leader_epoch: Optional[int] = None
+ num_partitions: Optional[int] = None
+ replication_factor: Optional[int] = None
+ replica_assignment: Optional[List[Any]] = None
+ config: Optional[Dict[Any, Any]] = None
+
+ _topic_partition: TopicPartition
+ _new_topic: admin.NewTopic
def __init__(
self,
name: str,
- partitions: Optional[List[int]] = None,
- replication_factor: int = 1,
+ partition: Optional[int] = None,
+ offset: Optional[int] = None,
+ metadata: Optional[str] = None,
+ leader_epoch: Optional[int] = None,
+ num_partitions: Optional[int] = None,
+ replication_factor: Optional[int] = None,
+ replica_assignment: Optional[List[Any]] = None,
+ config: Optional[Dict[Any, Any]] = None,
):
- self._partitions = partitions or [-1]
- self._topic_partitions = []
- self._replication_factor = replication_factor
self.name = name
+ self.partition = partition if partition is not None else -1
+ self.offset = offset if offset is not None else -1
+ self.metadata = metadata
+ self.leader_epoch = leader_epoch
+ self.num_partitions = num_partitions if num_partitions is not None else 1
+ self.replication_factor = replication_factor
+ self.replica_assignment = replica_assignment
+ self.config = config
+
+ topic_partition_args = {
+ "topic": self.name,
+ "partition": self.partition,
+ }
+ if self.offset is not None:
+ topic_partition_args["offset"] = self.offset
+ if self.metadata is not None:
+ topic_partition_args["metadata"] = self.metadata
+ if self.leader_epoch is not None:
+ topic_partition_args["leader_epoch"] = self.leader_epoch
+
+ self._topic_partition = TopicPartition(**topic_partition_args)
- for partition in self._partitions:
- self._topic_partitions.append(TopicPartition(topic=name, partition=partition))
+ new_topic_args = {"topic": self.name, "num_partitions": self.num_partitions}
- @property
- def partitions(self) -> List[int]:
- return self._partitions
+ if self.replication_factor is not None:
+ new_topic_args["replication_factor"] = self.replication_factor
+ if self.replica_assignment is not None:
+ new_topic_args["replica_assignment"] = self.replica_assignment
+ if self.config is not None:
+ new_topic_args["config"] = self.config
- @property
- def topic_partitions(self) -> List[TopicPartition]:
- return self._topic_partitions
+ self._new_topic = admin.NewTopic(**new_topic_args)
diff --git a/heizer/config.py b/heizer/config.py
index a7e6716..5240af8 100644
--- a/heizer/config.py
+++ b/heizer/config.py
@@ -1,31 +1,37 @@
-import os
-from typing import Optional
+from collections import defaultdict
+from dataclasses import dataclass, field
-from heizer.types import KafkaConfig
+from heizer.types import Any, Dict
-DEFAULT_KAFKA_BOOTSTRAP_SERVER = os.environ.get("KAFKA_BOOTSTRAP_SERVER", "localhost:9094")
-DEFAULT_KAFKA_GROUP = os.environ.get("KAFKA_GROUP", "default")
+# confluent kafka configs
+# https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#kafka-client-configuration
-class HeizerConfig(object):
- # confluent kafka configs
- # https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#kafka-client-configuration
- __kafka_config: KafkaConfig
+@dataclass
+class BaseConfig:
+ """
+ Base configurations
+ """
- def __init__(self, config: Optional[KafkaConfig] = None):
- self.__kafka_config = config or {
- "bootstrap.servers": DEFAULT_KAFKA_BOOTSTRAP_SERVER,
- "group.id": DEFAULT_KAFKA_GROUP,
- }
+ bootstrap_servers: str
- @property
- def value(self) -> KafkaConfig:
- """
- Return kafka configurations
- Returns
- -------
- config : dict
- The dict contains kafka configurations
- """
- return self.__kafka_config
+@dataclass
+class ProducerConfig(BaseConfig):
+ """
+ Configuration class for producer.
+ """
+
+ other_configs: Dict[str, Any] = field(default_factory=defaultdict)
+
+
+@dataclass
+class ConsumerConfig(BaseConfig):
+ """
+ Configuration class for consumer.
+ """
+
+ group_id: str
+ auto_offset_reset: str = "earliest"
+ other_configs: Dict[str, Any] = field(default_factory=defaultdict)
+ enable_auto_commit: bool = False
diff --git a/heizer/env_vars.py b/heizer/env_vars.py
new file mode 100644
index 0000000..4e5b7bb
--- /dev/null
+++ b/heizer/env_vars.py
@@ -0,0 +1,5 @@
+import os
+
+CONSUMER_STATUS_FILE_PATH = os.environ.get("HEIZER_CONSUMER_STATUS_FILE_PATH", "/tmp/heizer_consumers_status.json")
+
+HEIZER_LOG_LEVEL = os.environ.get("HEIZER_LOG_LEVEL", "INFO")
diff --git a/heizer/types.py b/heizer/types.py
index e269f7e..e920819 100644
--- a/heizer/types.py
+++ b/heizer/types.py
@@ -3,7 +3,7 @@
import sys
from typing import Any, Awaitable, Callable, Coroutine, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast
-if sys.version_info.minor <= 10:
+if sys.version_info.minor < 10:
from typing_extensions import Concatenate, ParamSpec
else:
from typing import Concatenate, ParamSpec
diff --git a/pyproject.toml b/pyproject.toml
index c79cdc5..0f516e6 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -33,27 +33,29 @@ classifiers = [
dependencies = [
"confluent_kafka",
- "pydantic",
"typing; python_version<'3.10'",
"typing_extensions; python_version<'3.10'",
]
[project.optional-dependencies]
-all = ["websockets"]
+
+
socket = ["websockets"]
dev = ["pre-commit"]
doc = [
- "sphinx==5.3.0",
+ "sphinx",
"sphinx_rtd_theme",
"myst-parser",
"wheel",
"ipython",
"sphinx-multiversion",
+ "requests",
]
test = [
"pytest",
- "pytest-cov"
+ "pytest-cov",
+ "pydantic"
]
@@ -77,7 +79,7 @@ exclude =[".git", "__pycache__", "venv", "env", 'docs/*']
include = ["heizer/*"]
[tool.mypy]
-strict=true
+strict=false
ignore_missing_imports=true
disallow_subclassing_any=false
exclude = ['venv', '.venv', 'tests/*', 'docs/*', 'samples/*']
diff --git a/samples/websockets/server.py b/samples/websockets/server.py
index 9a64bb5..8745d67 100644
--- a/samples/websockets/server.py
+++ b/samples/websockets/server.py
@@ -2,27 +2,30 @@
import websockets
-from heizer import HeizerConfig, HeizerTopic, consumer
-
-consumer_config = HeizerConfig(
- {
- "bootstrap.servers": "0.0.0.0:9092",
- "group.id": "test",
- "auto.offset.reset": "earliest",
- }
+from heizer import ConsumerConfig, Message, Topic, consumer
+
+consumer_config = ConsumerConfig(
+ bootstrap_servers="localhost:9092",
+ group_id="websockets_sample",
+ auto_offset_reset="earliest",
+ enable_auto_commit=False,
)
-topics = [HeizerTopic(name="my.topic1")]
+topics = [Topic(name="my.topic1")]
-@consumer(topics=topics, config=consumer_config, is_async=True)
-async def handler(message, websocket, *args, **kwargs):
+@consumer(topics=topics, config=consumer_config, is_async=True, init_topics=True, name="websocket_sample")
+async def handler(message: Message, C: consumer, websocket, *args, **kwargs):
+ print(C.name)
await websocket.send(message.value)
async def main():
async with websockets.serve(handler, "", 8001):
- await asyncio.Future()
+ try:
+ await asyncio.Future()
+ except KeyboardInterrupt:
+ return
if __name__ == "__main__":
diff --git a/tests/conftest.py b/tests/conftest.py
index 605a353..e9dc089 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -3,12 +3,24 @@
import pytest
from confluent_kafka.admin import AdminClient
+from heizer import Topic, delete_topics, list_topics
+
+BOOTSTRAP_SERVERS = os.environ.get("KAFKA_SERVER", "localhost:9092")
+
@pytest.fixture()
def bootstrap_server() -> str:
- return os.environ.get("KAFKA_SERVER", "localhost:9092")
+ return BOOTSTRAP_SERVERS
@pytest.fixture()
def admin_client(bootstrap_server) -> AdminClient:
return AdminClient({"bootstrap.servers": bootstrap_server})
+
+
+@pytest.fixture(autouse=True, scope="session")
+def clean_topics():
+ yield
+ config = {"bootstrap.servers": BOOTSTRAP_SERVERS}
+ topics = list_topics(config)
+ delete_topics(config, [Topic(t) for t in topics])
diff --git a/tests/test_consumer.py b/tests/test_consumer.py
index 3093857..517d612 100644
--- a/tests/test_consumer.py
+++ b/tests/test_consumer.py
@@ -1,18 +1,23 @@
import json
import logging
import os
-from typing import Any, Dict, cast
+from typing import cast
+from uuid import uuid4
import pytest
from pydantic import BaseModel
-from heizer import HeizerConfig, HeizerMessage, HeizerTopic, consumer, create_new_topic, get_admin_client, producer
-
-Producer_Config = HeizerConfig(
- {
- "bootstrap.servers": os.environ.get("KAFKA_SERVER", "localhost:9092"),
- }
+from heizer import (
+ ConsumerSignal,
+ Message,
+ Producer,
+ ProducerConfig,
+ Topic,
+ consumer,
+ create_new_topics,
+ read_consumer_status,
)
+from heizer.env_vars import CONSUMER_STATUS_FILE_PATH
@pytest.fixture
@@ -20,140 +25,150 @@ def group_id():
return "test_group"
+@pytest.fixture(autouse=True)
+def clean_logs():
+ yield
+ if os.path.exists(CONSUMER_STATUS_FILE_PATH):
+ os.remove(CONSUMER_STATUS_FILE_PATH)
+
+
@pytest.fixture
-def consumer_config(group_id):
- return HeizerConfig(
- {
- "bootstrap.servers": os.environ.get("KAFKA_SERVER", "localhost:9092"),
- "group.id": group_id,
- "auto.offset.reset": "earliest",
- }
- )
+def producer_config(bootstrap_server):
+ return ProducerConfig(bootstrap_servers=bootstrap_server)
+
+
+@pytest.fixture
+def consumer_config(group_id, bootstrap_server):
+ return {
+ "bootstrap.servers": bootstrap_server,
+ "group.id": group_id,
+ "auto.offset.reset": "earliest",
+ }
@pytest.mark.parametrize("group_id", ["test_consumer_stopper"])
-def test_consumer_stopper(group_id, consumer_config) -> None:
- toppics = [HeizerTopic(name="heizer.test.result", partitions=[0, 1, 2], replication_factor=2)]
- admin = get_admin_client(consumer_config)
- create_new_topic(admin, toppics)
-
- @producer(
- topics=toppics,
- config=Producer_Config,
- key_alias="myKey",
- headers_alias="myHeaders",
- )
- def produce_data(status: str, result_value: str) -> Dict[str, Any]:
- return {
- "key1": 1,
- "key2": "2",
- "key3": True,
- "status": status,
- "result": result_value,
- "myKey": "id1",
- "myHeaders": {"header1": "value1", "header2": "value2"},
- }
-
- def stopper(msg: HeizerMessage) -> bool:
+def test_consumer_stopper(group_id, consumer_config, producer_config, caplog, bootstrap_server) -> None:
+ topics = [Topic(name=f"heizer.test.result.{uuid4()}", num_partitions=3)]
+ create_new_topics({"bootstrap.servers": bootstrap_server}, topics)
+
+ pd = Producer(config=producer_config)
+
+ for status, result in [("start", 1), ("loading", 2), ("success", 3), ("postprocess", 4)]:
+ pd.produce(
+ topic=topics[0],
+ key="key1",
+ value={"status": status, "result": result},
+ headers={"header1": "value1", "header2": "value2"},
+ auto_flush=False,
+ )
+
+ pd.flush()
+
+ def stopper(msg: Message, *args, **kwargs) -> bool:
data = json.loads(msg.value)
if data["status"] == "success":
return True
return False
@consumer(
- topics=[HeizerTopic(name="heizer.test.result")],
+ topics=topics,
config=consumer_config,
stopper=stopper,
)
def consume_data(msg, *args, **kwargs) -> str:
data = json.loads(msg.value)
- assert msg.key == "id1"
+ assert msg.key == "key1"
assert msg.headers == {"header1": "value1", "header2": "value2"}
- assert "myKey" not in data
- assert "myHeaders" not in data
-
return cast(str, data["result"])
- produce_data("start", "waiting")
- produce_data("loading", "waiting")
- produce_data("success", "finished")
- produce_data("postprocess", "postprocess")
-
result = consume_data() # type: ignore
- assert result == "finished"
+ assert result == 3
@pytest.mark.parametrize("group_id", ["test_consumer_call_once"])
-def test_consumer_call_once(consumer_config) -> None:
+def test_consumer_call_once(group_id, producer_config, consumer_config, caplog) -> None:
+ caplog.set_level(logging.DEBUG)
topic_name = "heizer.test.test_consumer_call_once"
+ topic = Topic(name=f"{topic_name}.{uuid4()}")
- @producer(
- topics=[HeizerTopic(name=topic_name)],
- config=Producer_Config,
- key_alias="non_existing_key",
- headers_alias="non_existing_headers",
- )
- def produce_data(status: str, result: str) -> Dict[str, Any]:
- return {
- "key1": 1,
- "key2": "2",
- "key3": True,
- "status": status,
- "result": result,
- }
+ producer = Producer(config=producer_config)
- @consumer(
- topics=[HeizerTopic(name=topic_name)],
- config=consumer_config,
- call_once=True,
- )
+ for status, result in [("start", 1), ("loading", 2), ("success", 3), ("postprocess", 4)]:
+ producer.produce(
+ topic=topic,
+ key="key1",
+ value={"status": status, "result": result},
+ headers={"header1": "value1", "header2": "value2"},
+ auto_flush=True,
+ )
+
+ @consumer(topics=[topic], config=consumer_config, call_once=True)
def consume_data(msg, *args, **kwargs) -> str:
data = json.loads(msg.value)
- return cast(str, data["result"])
+ return data["result"]
+
+ result = consume_data()
- produce_data("start", "waiting")
- produce_data("loading", "waiting")
- produce_data("success", "finished")
+ assert result == 1
+
+
+@pytest.mark.parametrize("group_id", ["test_stop_consumer_by_signal"])
+def test_stop_consumer_by_signal(group_id, producer_config, consumer_config, caplog) -> None:
+ caplog.set_level(logging.DEBUG)
+ topic_name = "heizer.test.test_stop_consumer_by_signal"
+ topic = Topic(name=f"{topic_name}.{uuid4()}")
+
+ producer = Producer(config=producer_config)
+
+ for status, result in [("start", 1), ("loading", 2)]:
+ producer.produce(
+ topic=topic,
+ key="key1",
+ value={"status": status, "result": result},
+ headers={"header1": "value1", "header2": "value2"},
+ auto_flush=True,
+ )
+ sg = ConsumerSignal()
+
+ @consumer(topics=[topic], config=consumer_config, consumer_signal=sg)
+ def consume_data(msg, *args, **kwargs) -> str:
+ data = json.loads(msg.value)
+ sg.stop()
+ return data["result"]
result = consume_data()
- assert result == "waiting"
+ assert result == 1
@pytest.mark.parametrize("group_id", ["test_consumer_deserializer"])
-def test_consumer_deserializer(caplog, consumer_config) -> None:
+def test_consumer_deserializer(caplog, consumer_config, group_id, producer_config) -> None:
caplog.set_level(logging.DEBUG)
- topic_name = "heizer.test.test_consumer_deserializer"
+ topic = Topic(f"heizer.test.test_consumer_deserializer.{uuid4()}")
class TestModel(BaseModel):
name: str
age: int
- deserializer = TestModel
+ deserializer = TestModel.parse_raw
- @producer(
- topics=[HeizerTopic(name=topic_name)],
- config=Producer_Config,
- )
- def produce_data() -> Dict[str, Any]:
- return {
+ producer = Producer(config=producer_config)
+
+ producer.produce(
+ topic=topic,
+ value={
"name": "mike",
"age": 20,
- }
-
- @consumer(
- topics=[HeizerTopic(name=topic_name)],
- config=consumer_config,
- call_once=True,
- deserializer=deserializer,
+ },
)
- def consume_data(message: HeizerMessage, *args, **kwargs):
- return message.formatted_value
- produce_data()
+ @consumer(topics=[topic], config=consumer_config, call_once=True, deserializer=deserializer, id="test_consumer_x")
+ def consume_data(message: Message, C, *args, **kwargs):
+ C.consumer_signal.stop()
+ return message.formatted_value
result = consume_data()
@@ -161,3 +176,63 @@ def consume_data(message: HeizerMessage, *args, **kwargs):
assert result.name == "mike"
assert result.age == 20
+
+ status = read_consumer_status(consumer_id="test_consumer_x")
+ assert status["status"] == "closed"
+
+
+@pytest.mark.parametrize("group_id", ["test_consumer_retry_failed_func"])
+def test_consumer_retry_failed_func(caplog, consumer_config, group_id, producer_config) -> None:
+ caplog.set_level(logging.DEBUG)
+ topic = Topic(f"heizer.test.test_consumer_retry_failed_func.{uuid4()}")
+ retry_topic = Topic(f"heizer.test.test_consumer_retry_failed_func.retry.{uuid4()}")
+
+ class TestModel(BaseModel):
+ name: str
+ age: int
+
+ deserializer = TestModel.parse_raw
+
+ producer = Producer(config=producer_config)
+
+ producer.produce(
+ topic=topic,
+ headers={"k": "v"},
+ value={
+ "name": "mike",
+ "age": 20,
+ },
+ )
+
+ def stopper(message, C, *args, **kwargs) -> bool:
+ if not getattr(C, "msg_count", None):
+ setattr(C, "msg_count", 1)
+
+ C.msg_count += 1
+
+ if C.msg_count > 4:
+ return True
+ else:
+ return False
+
+ @consumer(
+ topics=[topic],
+ config=consumer_config,
+ deserializer=deserializer,
+ enable_retry=True,
+ retry_times=3,
+ id="failed_to_consume_data_consumer",
+ name="test_consumer",
+ retry_topic=retry_topic,
+ stopper=stopper,
+ )
+ def failed_to_consume_data(message: Message, C, *args, **kwargs):
+ assert C.retry_times_header_key not in message.headers
+ raise ValueError
+
+ failed_to_consume_data()
+
+ assert "[test_consumer] Function failed_to_consume_data reached retry limit (3), will give up" in caplog.messages
+
+ status = read_consumer_status()
+ assert status["failed_to_consume_data_consumer"]["status"] == "closed"
diff --git a/tests/test_producer.py b/tests/test_producer.py
new file mode 100644
index 0000000..724a84b
--- /dev/null
+++ b/tests/test_producer.py
@@ -0,0 +1,46 @@
+import asyncio
+import uuid
+
+import pytest
+
+from heizer import Producer, ProducerConfig
+
+
+@pytest.fixture
+def message_count() -> int:
+ return 10_000
+
+
+@pytest.fixture
+def messages(message_count):
+ return [
+ {
+ "key": "test_key",
+ "value": ["test"] * 1000,
+ "headers": {"k": "v"},
+ }
+ ] * message_count
+
+
+def test_produce_message(messages, bootstrap_server, caplog):
+ pd = Producer(config=ProducerConfig(bootstrap_servers=bootstrap_server), name="test_producer")
+
+ topic = f"test_producer_topic_{uuid.uuid4()}"
+
+ for msg in messages:
+ pd.produce(topic=topic, auto_flush=False, **msg)
+ pd.flush()
+
+
+def test_async_produce_message(messages, bootstrap_server, caplog):
+ pd = Producer(config=ProducerConfig(bootstrap_servers=bootstrap_server), name="test_producer")
+ topic = f"test_async_producer_topic_{uuid.uuid4()}"
+ jobs = []
+
+ async def produce():
+ for msg in messages:
+ jobs.append(asyncio.ensure_future(pd.async_produce(topic=topic, auto_flush=False, **msg)))
+ await asyncio.gather(*jobs)
+ pd.flush()
+
+ asyncio.run(produce())