diff --git a/.github/workflows/deploy-latest.yml b/.github/workflows/deploy-latest.yml deleted file mode 100644 index 19af55b..0000000 --- a/.github/workflows/deploy-latest.yml +++ /dev/null @@ -1,31 +0,0 @@ - -name: "Deploy latest" - -on: - push: - branches: - - main - -jobs: - deploy_latest: - - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Set up Python - uses: actions/setup-python@v4 - with: - python-version: '3.x' - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install build - - name: Build package - run: | - echo "999999-dev-$(date +%s)" > heizer/VERSION - python -m build - - name: Publish package - uses: pypa/gh-action-pypi-publish@27b31702a0e7fc50959f5ad993c78deac1bdfc29 - with: - user: __token__ - password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 323a9f2..482e4df 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -83,7 +83,7 @@ jobs: run: | python -m pytest -s tests --cov=heizer --cov-report=xml --junit-xml=report.xml - test_build_doc: + build_doc_and_publish_on_main: runs-on: ubuntu-latest services: zookeeper: @@ -116,13 +116,16 @@ jobs: sudo apt-get install -y zip gzip tar python -m pip install --upgrade pip pip install -e .[all,doc] - sphinx-multiversion ./docs/source ./doc/ - zip -r doc.zip doc/ + echo ${{ github.ref_name }} > heizer/VERSION + python docs/create_versions_file.py + sphinx-build ./docs/source "./public/${{ github.ref_name }}" + mv ./docs/versions.json ./public/versions.json + zip -r public.zip ./public/ - name: Upload Artifact uses: actions/upload-artifact@v3 with: - name: doc_zip - path: doc.zip + name: public_zip + path: public.zip retention-days: 5 - name: deploy if: github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/')) @@ -131,7 +134,7 @@ jobs: host: ${{ secrets.SERVER_IP }} username: "root" password: ${{ secrets.SERVER_PASS }} - source: "doc/" + source: "public/" target: "/var/www/html/docs/" strip_components: 1 overwrite: true diff --git a/.gitignore b/.gitignore index 07d18ad..207576e 100644 --- a/.gitignore +++ b/.gitignore @@ -156,3 +156,5 @@ cython_debug/ # Editors .idea .vscode +.fleet +public diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..9fde319 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,40 @@ +# Main/Dev + + +### Features + +### Improvements + +### Fixes + +--- + +# Releases + +## 0.2.0 + +### Breaking Changes + +- rewrite heizer producer to be a class instead of decorator. There is no clear benefit to create + producer as decorator +- updated class names + +### New Features: + +- support retry in consumer +- add helpers to list, list topics +- add consumer signal +- write consumer status log file for health checking +- support async produce in producer + +### Improvements + +- add more logs + +--- + +## 0.1.5 + +### Improvements + +- improve consumer diff --git a/docs/create_versions_file.py b/docs/create_versions_file.py new file mode 100644 index 0000000..4c31c24 --- /dev/null +++ b/docs/create_versions_file.py @@ -0,0 +1,21 @@ +import json +import pathlib + +import requests + + +def create(): + versions = [{"name": "main", "url": "https://heizer.claudezss.com/docs/main"}] + + rsp = requests.get("https://api.github.com/repos/claudezss/heizer/releases?per_page=100") + d = rsp.json() + for item in d: + release_name = item["name"] + versions.append({"name": release_name, "url": f"https://heizer.claudezss.com/docs/{release_name}/"}) + + with open(pathlib.Path(__file__).parent / "versions.json", "+w") as f: + json.dump(versions, f, indent=4) + + +if __name__ == "__main__": + create() diff --git a/docs/source/_templates/versions.html b/docs/source/_templates/versions.html index b3c63b5..8624712 100644 --- a/docs/source/_templates/versions.html +++ b/docs/source/_templates/versions.html @@ -1,28 +1,29 @@ -{%- if current_version %}
+ + Other Versions - v: {{ current_version.name }} + v: {{ release }} -
- {%- if versions.tags %} -
+
+
Tags
- {%- for item in versions.tags %} -
{{ item.name }}
- {%- endfor %}
- {%- endif %} - {%- if versions.branches %} -
-
Branches
- {%- for item in versions.branches %} -
{{ item.name }}
- {%- endfor %} -
- {%- endif %}
-{%- endif %} diff --git a/docs/source/conf.py b/docs/source/conf.py index 89adc47..e4f5213 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -9,11 +9,14 @@ # import sys # sys.path.insert(0, os.path.abspath("../..")) +from pathlib import Path + +root_folder = Path(__file__).parent.parent.parent project = "heizer" copyright = "2023, Yan Zhang" author = "Yan Zhang" -release = "main" +release = open(root_folder / "heizer" / "VERSION").read() # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration @@ -32,6 +35,7 @@ "_templates", ] +html_context = {"release": release} autodoc_typehints = "description" diff --git a/docs/source/tutorial/consumer.rst b/docs/source/tutorial/consumer.rst index 3529610..8dc4feb 100644 --- a/docs/source/tutorial/consumer.rst +++ b/docs/source/tutorial/consumer.rst @@ -8,60 +8,68 @@ Basic Producer and Consumer .. ipython:: python - from heizer import HeizerConfig, HeizerTopic, consumer, producer, HeizerMessage + from heizer import Topic, consumer, Producer, Message, ProducerConfig, ConsumerConfig, create_new_topics import json + import uuid + import asyncio - producer_config = HeizerConfig( - { - "bootstrap.servers": "localhost:9092", - } - ) + producer_config = ProducerConfig(bootstrap_servers="localhost:9092") - consumer_config = HeizerConfig( - { - "bootstrap.servers": "localhost:9092", - "group.id": "default", - "auto.offset.reset": "earliest", - } - ) + consumer_config = ConsumerConfig(bootstrap_servers="localhost:9092", group_id="default") -2. Create the topic +2. Create the topic with 2 partitions .. ipython:: python - topics = [HeizerTopic(name="my.topic1.consumer.example")] + topics = [Topic(name=f"my.topic1.consumer.example.{uuid.uuid4()}", num_partitions=2)] + create_new_topics(config=producer_config, topics=topics) 3. Create producer .. ipython:: python - @producer( - topics=topics, - config=producer_config, - key_alias="key", - headers_alias="headers", - ) - def produce_data(status: str, result: str): - return { - "status": status, - "result": result, - "key": "my_key", - "headers": {"my_header": "my_header_value"}, - } + pd = Producer(config=producer_config) -4. Publish messages +4. Publish messages synchronously to partition 0 .. ipython:: python - produce_data("start", "1") - - produce_data("loading", "2") + for status, val in [("start", "1"), ("loading", "2"), ("success", "3"), ("postprocess", "4")]: + pd.produce( + topic=topics[0], + key="my_key", + value={"status": status, "result": val}, + headers={"k": "v"}, + partition=0, + auto_flush=False + ) + pd.flush() - produce_data("success", "3") +5. Publish messages asynchronously to partition 1 ( it's faster than sync produce in most cases) - produce_data("postprocess", "4") +.. ipython:: python -5. Create consumer + jobs = [] + async def produce(): + for status, val in [("start", "1"), ("loading", "2"), ("success", "3"), ("postprocess", "4")]: + jobs.append( + asyncio.ensure_future( + pd.async_produce( + topic=topics[0], + key="my_key", + value={"status": status, "result": val}, + headers={"k": "v"}, + partition=1, + auto_flush=False + ) + ) + ) + await asyncio.gather(*jobs) + pd.flush() + + asyncio.run(produce()) + +6. Create consumer .. ipython:: python @@ -70,7 +78,8 @@ Basic Producer and Consumer # `status` is `success` in msg # If there is no stopper func, consumer will keep running forever - def stopper(msg: HeizerMessage): + def stopper(msg: Message, C: consumer, *arg, **kargs): + print(f"Consumer name: {C.name}") data = json.loads(msg.value) if data["status"] == "success": return True @@ -81,12 +90,18 @@ Basic Producer and Consumer config=consumer_config, stopper=stopper, ) - def consume_data(message: HeizerMessage): + def consume_data(message: Message, *arg, **kwargs): data = json.loads(message.value) - print(data) - print(message.key) - print(message.headers) + print(f"message data: {data}") + print(f"message key: {message.key}") + print(f"message headers: {message.headers}") return data["result"] result = consume_data() - print("Expected Result:", result) + print("Expected Result (should be 3):", result) + + +7. More samples: + +.. literalinclude :: ./../../../tests/test_consumer.py + :language: python diff --git a/heizer/__init__.py b/heizer/__init__.py index 5f3b41b..5060e68 100644 --- a/heizer/__init__.py +++ b/heizer/__init__.py @@ -1,16 +1,24 @@ -from heizer._source.admin import create_new_topic, get_admin_client -from heizer._source.consumer import consumer -from heizer._source.message import HeizerMessage -from heizer._source.producer import producer -from heizer._source.topic import HeizerTopic -from heizer.config import HeizerConfig +from heizer._source.admin import create_new_topics, delete_topics, get_admin_client, list_topics +from heizer._source.consumer import ConsumerSignal, consumer +from heizer._source.message import Message +from heizer._source.producer import Producer +from heizer._source.status_manager import read_consumer_status, write_consumer_status +from heizer._source.topic import Topic +from heizer.config import BaseConfig, ConsumerConfig, ProducerConfig __all__ = [ "consumer", - "producer", - "HeizerConfig", - "HeizerTopic", - "HeizerMessage", - "create_new_topic", + "ConsumerSignal", + "Producer", + "BaseConfig", + "ProducerConfig", + "ConsumerConfig", + "Message", + "Topic", + "create_new_topics", "get_admin_client", + "write_consumer_status", + "read_consumer_status", + "delete_topics", + "list_topics", ] diff --git a/heizer/_source/__init__.py b/heizer/_source/__init__.py index ec5d26b..a51ef24 100644 --- a/heizer/_source/__init__.py +++ b/heizer/_source/__init__.py @@ -1,18 +1,20 @@ import logging -import os import sys +from heizer.env_vars import HEIZER_LOG_LEVEL + FORMAT = "heizer %(asctime)s %(levelname)-8s %(message)s" def get_logger(name: str) -> logging.Logger: + """Get a logger with the given name.""" formatter = logging.Formatter(fmt=FORMAT, datefmt="%Y-%m-%d %H:%M:%S") handler = logging.FileHandler("/tmp/heizer_log.txt", mode="w") handler.setFormatter(formatter) screen_handler = logging.StreamHandler(stream=sys.stdout) screen_handler.setFormatter(formatter) logger = logging.getLogger(name) - logger.setLevel(os.environ.get("HEIZER_LOG_LEVEL", "INFO")) + logger.setLevel(HEIZER_LOG_LEVEL) logger.addHandler(handler) logger.addHandler(screen_handler) return logger diff --git a/heizer/_source/admin.py b/heizer/_source/admin.py index 11419b7..a597c5e 100644 --- a/heizer/_source/admin.py +++ b/heizer/_source/admin.py @@ -1,19 +1,30 @@ -from confluent_kafka.admin import AdminClient, NewTopic +from confluent_kafka.admin import AdminClient from heizer._source import get_logger -from heizer._source.topic import HeizerTopic -from heizer.config import HeizerConfig -from heizer.types import List +from heizer._source.topic import Topic +from heizer.config import BaseConfig +from heizer.types import KafkaConfig, List, Union logger = get_logger(__name__) -def get_admin_client(config: HeizerConfig) -> AdminClient: - return AdminClient({"bootstrap.servers": config.value.get("bootstrap.servers")}) +def get_admin_client(config: Union[BaseConfig, KafkaConfig]) -> AdminClient: + """ + Create an admin client using the provided configuration. + """ + if isinstance(config, BaseConfig): + config_dict = {"bootstrap.servers": config.bootstrap_servers} + else: + config_dict = config + return AdminClient(config_dict) -def create_new_topic(admin_client: AdminClient, topics: List[HeizerTopic]) -> None: - new_topics = [NewTopic(topic.name, num_partitions=len(topic._partitions)) for topic in topics] +def create_new_topics(config: Union[BaseConfig, KafkaConfig], topics: List[Topic]) -> None: + """ + Create new topics using the provided configuration. + """ + admin_client = get_admin_client(config) + new_topics = [topic._new_topic for topic in topics] fs = admin_client.create_topics(new_topics) # Wait for each operation to finish. @@ -27,3 +38,23 @@ def create_new_topic(admin_client: AdminClient, topics: List[HeizerTopic]) -> No continue else: logger.exception(f"Failed to create topic {topic}", exc_info=e) + + +def delete_topics(config: Union[BaseConfig, KafkaConfig], topics: List[Topic]) -> None: + """Delete topics using the provided configuration.""" + admin_client = get_admin_client(config) + fs = admin_client.delete_topics([tp.name for tp in topics]) + + # Wait for each operation to finish. + for topic, f in fs.items(): + try: + f.result() # The result itself is None + logger.info("Topic {} deleted".format(topic)) + except Exception as e: + logger.exception(f"Failed to delete topic {topic}", exc_info=e) + + +def list_topics(config: Union[BaseConfig, KafkaConfig]) -> List[str]: + """List topics using the provided configuration.""" + admin_client = get_admin_client(config) + return list(admin_client.list_topics().topics.keys()) diff --git a/heizer/_source/consumer.py b/heizer/_source/consumer.py index e42d772..1b9e157 100644 --- a/heizer/_source/consumer.py +++ b/heizer/_source/consumer.py @@ -1,28 +1,33 @@ import asyncio +import atexit import functools import logging +import os import signal +from dataclasses import dataclass from uuid import uuid4 -from confluent_kafka import Consumer -from pydantic import BaseModel +import confluent_kafka as ck from heizer._source import get_logger -from heizer._source.admin import create_new_topic, get_admin_client -from heizer._source.message import HeizerMessage -from heizer._source.topic import HeizerTopic -from heizer.config import HeizerConfig +from heizer._source.admin import create_new_topics +from heizer._source.enums import ConsumerStatusEnum +from heizer._source.message import Message +from heizer._source.producer import Producer +from heizer._source.status_manager import write_consumer_status +from heizer._source.topic import Topic +from heizer.config import ConsumerConfig from heizer.types import ( Any, Awaitable, Callable, Concatenate, Coroutine, + KafkaConfig, List, Optional, ParamSpec, Stopper, - Type, TypeVar, Union, ) @@ -35,86 +40,121 @@ logger = get_logger(__name__) -def signal_handler(signal: Any, frame: Any) -> None: - global interrupted - interrupted = True +def _make_consumer_config_dict(config: ConsumerConfig) -> KafkaConfig: + config_dict = { + "bootstrap.servers": config.bootstrap_servers, + "group.id": config.group_id, + "auto.offset.reset": config.auto_offset_reset, + "enable.auto.commit": config.enable_auto_commit, + } + if config.other_configs: + config_dict.update(config.other_configs) -signal.signal(signal.SIGINT, signal_handler) + return config_dict -interrupted = False + +@dataclass +class ConsumerSignal: + is_running: bool = True + + def stop(self) -> None: + self.is_running = False class consumer(object): - """A decorator to create a consumer""" + """ + A decorator to create a consumer + """ - __id__: str + id: str name: Optional[str] - topics: List[HeizerTopic] - config: HeizerConfig = HeizerConfig() + topics: List[Topic] + config: KafkaConfig call_once: bool = False stopper: Optional[Stopper] = None - deserializer: Optional[Type[BaseModel]] = None - + deserializer: Optional[Callable] = None + consumer_signal: ConsumerSignal is_async: bool = False poll_timeout: int = 1 - init_topics: bool = False + enable_retry: bool = False + retry_topic: Optional[Topic] = None + retry_times: int = 1 + retry_times_header_key: str = "!!-heizer_func_retried_times-!!" + + # private attr + _consumer_instance: ck.Consumer + _producer_instance: Optional[Producer] = None # for retry + def __init__( self, *, - topics: List[HeizerTopic], - config: HeizerConfig = HeizerConfig(), + topics: List[Topic], + config: Union[KafkaConfig, ConsumerConfig], call_once: bool = False, stopper: Optional[Stopper] = None, - deserializer: Optional[Type[BaseModel]] = None, + deserializer: Optional[Callable] = None, is_async: bool = False, + id: Optional[str] = None, name: Optional[str] = None, poll_timeout: Optional[int] = None, init_topics: bool = False, + consumer_signal: Optional[ConsumerSignal] = None, + enable_retry: bool = False, + retry_topic: Optional[Topic] = None, + retry_times: int = 1, ): self.topics = topics - self.config = config + self.config = _make_consumer_config_dict(config) if isinstance(config, ConsumerConfig) else config self.call_once = call_once self.stopper = stopper self.deserializer = deserializer - self.__id__ = str(uuid4()) - self.name = name or self.__id__ + self.id = id or str(uuid4()) + self.name = name or self.id self.is_async = is_async self.poll_timeout = poll_timeout if poll_timeout is not None else 1 self.init_topics = init_topics + self.consumer_signal = consumer_signal or ConsumerSignal() + self.enable_retry = enable_retry + self.retry_topic = retry_topic + self.retry_times = retry_times def __repr__(self) -> str: - return self.name or self.__id__ + return self.name or self.id - async def __run__( # type: ignore - self, - func: Callable[Concatenate[HeizerMessage, P], Union[T, Awaitable[T]]], - is_async: bool, - *args: P.args, - **kwargs: P.kwargs, - ) -> Union[Optional[T]]: - """Run the consumer""" + @property + def ck_consumer(self) -> ck.Consumer: + return self._consumer_instance - logger.debug(f"[{self}] Creating consumer ") - c = Consumer(self.config.value) + @property + def ck_producer(self) -> Optional[ck.Producer]: + return self._producer_instance - logger.debug(f"[{self}] Subscribing to topics {[topic.name for topic in self.topics]}") - c.subscribe([topic.name for topic in self.topics]) + def commit(self, asynchronous: bool = False, *arg, **kwargs): + self.ck_consumer.commit(asynchronous=asynchronous, *arg, **kwargs) - if self.init_topics: - logger.debug(f"[{self}] Initializing topics") - admin_client = get_admin_client(self.config) - create_new_topic(admin_client, self.topics) + async def _long_run(self, f: Callable, is_async: bool, *args, **kwargs): + target_topics: List[Topic] = [topic for topic in self.topics] - while True: - if interrupted: - logger.debug(f"[{self}] Interrupted by keyboard") - break + if self.enable_retry and self.retry_topic: + target_topics.append(self.retry_topic) + + logger.info(f"[{self}] Subscribing to topics {[topic.name for topic in target_topics]}") + self.ck_consumer.subscribe([topic.name for topic in target_topics]) + + result = None + + while self.consumer_signal.is_running: + write_consumer_status( + consumer_id=self.id, consumer_name=self.name, pid=os.getpid(), status=ConsumerStatusEnum.RUNNING + ) - msg = c.poll(self.poll_timeout) + result = None + + msg = self.ck_consumer.poll(self.poll_timeout) if msg is None: continue @@ -124,37 +164,44 @@ async def __run__( # type: ignore continue logger.debug(f"[{self}] Received message") - hmessage = HeizerMessage(msg) + + h_message = Message(msg) if self.deserializer is not None: logger.debug(f"[{self}] Parsing message") try: - hmessage.formatted_value = self.deserializer.parse_raw(hmessage.value) + h_message.formatted_value = self.deserializer(h_message.value) except Exception as e: logger.exception(f"[{self}] Failed to deserialize message", exc_info=e) - result = None + # If heizer retry times header key exists, remove it when passing to user func + has_retired: Optional[int] = None + if h_message.headers and self.retry_times_header_key in h_message.headers: + has_retired = int(h_message.headers.pop(self.retry_times_header_key)) + + logger.debug(f"[{self}] Executing function {f.__name__}") - logger.debug(f"[{self}] Executing function {func.__name__}") try: if is_async: - result = await func(hmessage, *args, **kwargs) # type: ignore + result = await f(h_message, self, *args, **kwargs) # type: ignore else: - result = func(hmessage, *args, **kwargs) + result = f(h_message, self, *args, **kwargs) except Exception as e: logger.exception( - f"[{self}] Failed to execute function {func.__name__}", + f"[{self}] Failed to execute function {f.__name__}", exc_info=e, ) + if self.enable_retry: + logger.info(f"[{self}] Start producing retry message for function {f.__name__}") + await self.produce_retry_message(message=h_message, has_retired=has_retired, func_name=f.__name__) finally: - # TODO: add failed message to a retry queue logger.debug(f"[{self}] Committing message") - c.commit() + self.commit() if self.stopper is not None: logger.debug(f"[{self}] Executing stopper function") try: - should_stop = self.stopper(hmessage) + should_stop = self.stopper(h_message, self) except Exception as e: logger.warning( f"[{self}] Failed to execute stopper function {self.stopper.__name__}.", @@ -163,25 +210,123 @@ async def __run__( # type: ignore should_stop = False if should_stop: - return result + self.consumer_signal.stop() + break if self.call_once is True: logger.debug(f"[{self}] Call Once is on, returning result") - return result + self.consumer_signal.stop() + break + + return result + + async def _run( # type: ignore + self, + func: Callable[Concatenate[Message, P], Union[T, Awaitable[T]]], + is_async: bool, + *args: P.args, + **kwargs: P.kwargs, + ) -> Union[Optional[T]]: + """Run the consumer""" + + if self.init_topics: + logger.info(f"[{self}] Initializing topics") + create_new_topics({"bootstrap.servers": self.config["bootstrap.servers"]}, self.topics) + + logger.info(f"[{self}] Creating consumer") + self._consumer_instance = ck.Consumer(self.config) + + if self.enable_retry: + self.retry_topic = self.retry_topic or Topic(name=f"{func.__name__}_heizer_retry") + self._producer_instance = Producer(config={"bootstrap.servers": self.config["bootstrap.servers"]}) + + logger.info(f"[{self}] Creating retry topic {self.retry_topic.name}") + create_new_topics({"bootstrap.servers": self.config["bootstrap.servers"]}, [self.retry_topic]) + + atexit.register(self._atexit) + signal.signal(signal.SIGTERM, self._exit) + signal.signal(signal.SIGINT, self._exit) + + result = None + + try: + result = await self._long_run(func, is_async, *args, **kwargs) + except KeyboardInterrupt: + logger.info( + f"[{self}] Stopped because of keyboard interruption", + ) + except Exception as e: + logger.exception(f"[{self}] stopped", exc_info=e) + finally: + self.consumer_signal.stop() + self.ck_consumer.close() + write_consumer_status( + consumer_id=self.id, consumer_name=self.name, pid=os.getpid(), status=ConsumerStatusEnum.CLOSED + ) + logger.info( + f"[{self}] Closed", + ) + return result def __call__( - self, func: Callable[Concatenate[HeizerMessage, P], T] + self, func: Callable[Concatenate[Message, P], T] ) -> Callable[P, Union[Coroutine[Any, Any, Optional[T]], T, None]]: @functools.wraps(func) async def async_decorator(*args: P.args, **kwargs: P.kwargs) -> Optional[T]: """Async decorator""" logging.debug(f"[{self}] Running async decorator") - return await self.__run__(func, self.is_async, *args, **kwargs) + return await self._run(func, self.is_async, *args, **kwargs) @functools.wraps(func) def decorator(*args: P.args, **kwargs: P.kwargs) -> Optional[T]: """Sync decorator""" logging.debug(f"[{self}] Running sync decorator") - return asyncio.run(self.__run__(func, self.is_async, *args, **kwargs)) + return asyncio.run(self._run(func, self.is_async, *args, **kwargs)) return async_decorator if self.is_async else decorator + + def _exit(self, sig, frame, *args, **kwargs) -> None: + self.consumer_signal.stop() + self.ck_consumer.close() + + def _atexit(self) -> None: + self.ck_consumer.close() + + async def produce_retry_message(self, message: Message, has_retired: Optional[int], func_name: str) -> None: + """ + Produces a retry message for a given function. + + :param message: The original message that needs to be retried. + :type message: Message + :param has_retired: The number of times the function has been retried before. Defaults to None. + :type has_retired: Optional[int] + :param func_name: The name of the function. + :type func_name: str + :return: None + """ + if not self._producer_instance: + logger.error( + f"[{self}] Confluent producer instance not found," + f" failed to produce retry message for function {func_name}" + ) + return + + if not self.retry_topic: + logger.error(f"[{self}] Retry topic not found, failed to produce retry message for function {func_name}") + return + + if has_retired == self.retry_times: + logger.info(f"[{self}] Function {func_name} reached retry limit ({self.retry_times}), will give up") + return + elif has_retired is None: + has_retired = 1 + else: + has_retired += 1 + + new_headers = message.headers or {} + new_headers[self.retry_times_header_key] = str(has_retired) + + await self._producer_instance.async_produce( + topic=self.retry_topic, key=message.key, value=message.value or "", headers=new_headers, auto_flush=True + ) + logger.info(f"[{self}] Produced retry message for function {func_name}, retry time(s): {has_retired}") diff --git a/heizer/_source/enums.py b/heizer/_source/enums.py new file mode 100644 index 0000000..d496189 --- /dev/null +++ b/heizer/_source/enums.py @@ -0,0 +1,6 @@ +import enum + + +class ConsumerStatusEnum(str, enum.Enum): + RUNNING = "running" + CLOSED = "closed" diff --git a/heizer/_source/message.py b/heizer/_source/message.py index 90afe86..db1d485 100644 --- a/heizer/_source/message.py +++ b/heizer/_source/message.py @@ -1,9 +1,9 @@ -from confluent_kafka import Message +import confluent_kafka as ck from heizer.types import Any, Dict, List, Optional, Tuple, Union -class HeizerMessage: +class Message: """ :class: HeizerMessage @@ -20,7 +20,7 @@ class HeizerMessage: """ # initialized properties - message: Message + _message: ck.Message topic: Optional[str] partition: int headers: Optional[Dict[str, str]] @@ -29,8 +29,8 @@ class HeizerMessage: formatted_value: Optional[Any] = None - def __init__(self, message: Message): - self.message = message + def __init__(self, message: ck.Message): + self._message = message self.topic = self._parse_topic(message.topic()) self.partition = message.partition() self.headers = self._parse_headers(message.headers()) diff --git a/heizer/_source/producer.py b/heizer/_source/producer.py index f2f4d97..b0de939 100644 --- a/heizer/_source/producer.py +++ b/heizer/_source/producer.py @@ -1,18 +1,24 @@ -import functools +import asyncio import json from uuid import uuid4 -from confluent_kafka import Message, Producer +import confluent_kafka as ck from heizer._source import get_logger -from heizer._source.topic import HeizerTopic -from heizer.config import HeizerConfig -from heizer.types import Any, Callable, Dict, List, Optional, TypeVar, Union, cast +from heizer._source.topic import Topic +from heizer.config import ProducerConfig +from heizer.types import Any, Callable, Dict, KafkaConfig, Optional, Union logger = get_logger(__name__) -R = TypeVar("R") -F = TypeVar("F", bound=Callable[..., Dict[str, str]]) + +def _create_producer_config_dict(config: ProducerConfig) -> KafkaConfig: + config_dict = {"bootstrap.servers": config.bootstrap_servers} + + if config.other_configs: + config_dict.update(config.other_configs) + + return config_dict def _default_serializer(value: Union[Dict[Any, Any], str, bytes]) -> bytes: @@ -23,7 +29,7 @@ def _default_serializer(value: Union[Dict[Any, Any], str, bytes]) -> bytes: Default Kafka message serializer, which will encode inputs to bytes """ - if isinstance(value, dict): + if isinstance(value, dict) or isinstance(value, list): return json.dumps(value).encode("utf-8") elif isinstance(value, str): return value.encode("utf-8") @@ -33,125 +39,154 @@ def _default_serializer(value: Union[Dict[Any, Any], str, bytes]) -> bytes: raise ValueError(f"Input type is not supported: {type(value).__name__}") -class producer(object): +class Producer(object): """ - A decorator to create a producer + Kafka Message Producer """ __id__: str name: str - # args - config: HeizerConfig = HeizerConfig() - topics: List[HeizerTopic] + # attrs need be initialized + config: KafkaConfig serializer: Callable[..., bytes] = _default_serializer call_back: Optional[Callable[..., Any]] = None - default_key: Optional[str] = None - default_headers: Optional[Dict[str, str]] = None - - key_alias: str = "key" - headers_alias: str = "headers" - # private properties - _producer_instance: Optional[Producer] = None + _producer_instance: Optional[ck.Producer] = None def __init__( self, - topics: List[HeizerTopic], - config: HeizerConfig = HeizerConfig(), + config: Union[ProducerConfig, KafkaConfig], serializer: Callable[..., bytes] = _default_serializer, call_back: Optional[Callable[..., Any]] = None, - default_key: Optional[str] = None, - default_headers: Optional[Dict[str, str]] = None, - key_alias: Optional[str] = None, - headers_alias: Optional[str] = None, name: Optional[str] = None, - init_topics: bool = True, ): - self.topics = topics - self.config = config + self.config = _create_producer_config_dict(config) if isinstance(config, ProducerConfig) else config self.serializer = serializer - self.call_back = call_back - self.default_key = default_key - self.default_headers = default_headers - - if key_alias: - self.key_alias = key_alias - if headers_alias: - self.headers_alias = headers_alias + self.call_back = call_back or self.default_delivery_report self.__id__ = str(uuid4()) self.name = name or self.__id__ + def __repr__(self) -> str: + return self.name + + def default_delivery_report(self, err: str, msg: ck.Message) -> None: + """ + Called once for each message produced to indicate delivery result. + Triggered by poll() or flush(). + + :param msg: + :param err: + :return: + """ + if err is not None: + logger.error(f"[{self}] Message delivery failed: {err}") + else: + logger.debug(f"[{self}] Message delivered to {msg.topic()} [{msg.partition()}]") + @property - def _producer(self) -> Producer: + def _producer(self) -> ck.Producer: if not self._producer_instance: - self._producer_instance = Producer(self.config.value) + self._producer_instance = ck.Producer(self.config) return self._producer_instance - def __call__(self, func: F) -> F: - @functools.wraps(func) - def decorator(*args: Any, **kwargs: Any) -> Any: - try: - result = func(*args, **kwargs) - except Exception as e: - raise e - - try: - key = result.pop(self.key_alias, self.default_key) - except Exception as e: - logger.debug(f"Failed to get key from result. {str(e)}") - key = self.default_key - - try: - headers = result.pop(self.headers_alias, self.default_headers) - except Exception as e: - logger.debug( - f"Failed to get headers from result. {str(e)}", - ) - headers = self.default_headers - - headers = cast(Dict[str, str], headers) - - msg = self.serializer(result) - - self._produce_message(message=msg, key=key, headers=headers) - - return result - - return cast(F, decorator) - - def _produce_message(self, message: bytes, key: Optional[str], headers: Optional[Dict[str, str]]) -> None: - for topic in self.topics: - for partition in topic.partitions: - try: - self._producer.poll(0) - self._producer.produce( - topic=topic.name, - value=message, - partition=partition, - key=key, - headers=headers, - on_delivery=self.call_back, - ) - self._producer.flush() - except Exception as e: - logger.exception(f"Failed to produce msg to topic: {topic.name}", exc_info=e) - raise e - - -def delivery_report(err: str, msg: Message) -> None: - """ - Called once for each message produced to indicate delivery result. - Triggered by poll() or flush(). - - :param msg: - :param err: - :return: - """ - if err is not None: - logger.error("Message delivery failed: {}".format(err)) - else: - logger.debug("Message delivered to {} [{}]".format(msg.topic(), msg.partition())) + def produce( + self, + topic: Union[str, Topic], + value: Union[bytes, str, Dict[Any, Any]], + key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + partition: Optional[int] = -1, + auto_flush: bool = True, + ) -> None: + """ + This method is used to produce messages to a Kafka topic using the confluent_kafka library. + + :param topic: The topic to publish the message to. Can be either a string representing the topic name, + or an instance of the Topic class. + :param value: The message body to be produced. Can be either bytes, a string, or a dictionary. + :param key: Optional. The key for the message. + :param headers: Optional. Additional headers for the message. + :param auto_flush: Optional. Default is True. If True, the producer will automatically flush messages after + producing them. + + :return: None + + Example usage: + + >>> producer = Producer(config={"xx": "xx"}) + >>> producer.produce(topic="my_topic", value={"k":"v"}) + + """ + return asyncio.run( + self._produce( + topic=topic, value=value, key=key, headers=headers, partition=partition, auto_flush=auto_flush + ) + ) + + async def async_produce( + self, + topic: Union[str, Topic], + value: Union[bytes, str, Dict[Any, Any]], + key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + partition: Optional[int] = -1, + auto_flush: bool = True, + ) -> None: + """ + Produce a message asynchronously to a Kafka topic. + + :param topic: The topic to produce the message to. Can be either a string or a `Topic` object. + :param value: The message body to be produced. Can be either bytes, a string, or a dictionary. + :param key: The key to associate with the message. Optional. + :param headers: Additional headers to include with the message. Optional. + :param auto_flush: If set to True, automatically flush the producer after producing the message. + Default is True. + :return: None + + This method asynchronously produces a message to a Kafka topic. It accepts the topic, message, key, headers, + and auto_flush parameters. + + Example usage: + + >>> producer = Producer(config={"xx": "xx"}) + >>> asyncio.run(producer.async_produce(topic="my_topic", value={"k":"v"})) + """ + return await self._produce( + topic=topic, value=value, key=key, headers=headers, partition=partition, auto_flush=auto_flush + ) + + async def _produce( + self, + topic: Union[str, Topic], + value: Union[bytes, str, Dict[Any, Any]], + key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + partition: Optional[int] = -1, + auto_flush: bool = True, + ) -> None: + topic = Topic(name=topic) if isinstance(topic, str) else topic + + try: + self._producer.produce( + topic=topic.name, + value=self.serializer(value), + key=key, + headers=headers, + on_delivery=self.call_back, + partition=partition, + ) + if auto_flush: + self.flush() + except ck.error.KafkaError.QUEUE_FULL: + logger.warning(f"[{self}] Local queue is full, force to flush") + self.flush() + except Exception as e: + logger.exception(f"[{self}] Failed to produce msg to topic: {topic.name}", exc_info=e) + raise e + + def flush(self) -> None: + self._producer.flush() diff --git a/heizer/_source/status_manager.py b/heizer/_source/status_manager.py new file mode 100644 index 0000000..62fa66d --- /dev/null +++ b/heizer/_source/status_manager.py @@ -0,0 +1,48 @@ +import json +import os +from collections import defaultdict +from datetime import datetime +from typing import Any, Dict, Optional + +from heizer._source.enums import ConsumerStatusEnum +from heizer.env_vars import CONSUMER_STATUS_FILE_PATH + + +def write_consumer_status( + consumer_id: str, + status: ConsumerStatusEnum, + pid: int, + consumer_name: Optional[str] = None, +) -> None: + if os.path.exists(CONSUMER_STATUS_FILE_PATH): + data = json.loads(open(CONSUMER_STATUS_FILE_PATH).read()) + data.update( + { + consumer_id: { + "name": consumer_name or consumer_id, + "status": status, + "pid": pid, + "timestamp": datetime.utcnow().isoformat(), + } + } + ) + else: + data = defaultdict(Dict[str, Any]) + data[consumer_id] = { + "name": consumer_name or consumer_id, + "status": status, + "pid": pid, + "timestamp": datetime.utcnow().isoformat(), + } + + with open(CONSUMER_STATUS_FILE_PATH, "+w") as f: + json.dump(data, f) + + +def read_consumer_status(consumer_id: Optional[str] = None) -> Dict[str, Any]: + with open(CONSUMER_STATUS_FILE_PATH) as f: + data = json.loads(f.read()) + if consumer_id: + return data.get(consumer_id, {}) + else: + return data diff --git a/heizer/_source/topic.py b/heizer/_source/topic.py index 380c85f..ee0bdd6 100644 --- a/heizer/_source/topic.py +++ b/heizer/_source/topic.py @@ -1,32 +1,64 @@ -from typing import List, Optional +from confluent_kafka import TopicPartition, admin -from confluent_kafka import TopicPartition +from heizer.types import Any, Dict, List, Optional -class HeizerTopic(TopicPartition): +class Topic: name: str - _partitions: List[int] - _replication_factor: int - _topic_partitions: List[TopicPartition] + partition: int + offset: Optional[int] = None + metadata: Optional[str] = None + leader_epoch: Optional[int] = None + num_partitions: Optional[int] = None + replication_factor: Optional[int] = None + replica_assignment: Optional[List[Any]] = None + config: Optional[Dict[Any, Any]] = None + + _topic_partition: TopicPartition + _new_topic: admin.NewTopic def __init__( self, name: str, - partitions: Optional[List[int]] = None, - replication_factor: int = 1, + partition: Optional[int] = None, + offset: Optional[int] = None, + metadata: Optional[str] = None, + leader_epoch: Optional[int] = None, + num_partitions: Optional[int] = None, + replication_factor: Optional[int] = None, + replica_assignment: Optional[List[Any]] = None, + config: Optional[Dict[Any, Any]] = None, ): - self._partitions = partitions or [-1] - self._topic_partitions = [] - self._replication_factor = replication_factor self.name = name + self.partition = partition if partition is not None else -1 + self.offset = offset if offset is not None else -1 + self.metadata = metadata + self.leader_epoch = leader_epoch + self.num_partitions = num_partitions if num_partitions is not None else 1 + self.replication_factor = replication_factor + self.replica_assignment = replica_assignment + self.config = config + + topic_partition_args = { + "topic": self.name, + "partition": self.partition, + } + if self.offset is not None: + topic_partition_args["offset"] = self.offset + if self.metadata is not None: + topic_partition_args["metadata"] = self.metadata + if self.leader_epoch is not None: + topic_partition_args["leader_epoch"] = self.leader_epoch + + self._topic_partition = TopicPartition(**topic_partition_args) - for partition in self._partitions: - self._topic_partitions.append(TopicPartition(topic=name, partition=partition)) + new_topic_args = {"topic": self.name, "num_partitions": self.num_partitions} - @property - def partitions(self) -> List[int]: - return self._partitions + if self.replication_factor is not None: + new_topic_args["replication_factor"] = self.replication_factor + if self.replica_assignment is not None: + new_topic_args["replica_assignment"] = self.replica_assignment + if self.config is not None: + new_topic_args["config"] = self.config - @property - def topic_partitions(self) -> List[TopicPartition]: - return self._topic_partitions + self._new_topic = admin.NewTopic(**new_topic_args) diff --git a/heizer/config.py b/heizer/config.py index a7e6716..5240af8 100644 --- a/heizer/config.py +++ b/heizer/config.py @@ -1,31 +1,37 @@ -import os -from typing import Optional +from collections import defaultdict +from dataclasses import dataclass, field -from heizer.types import KafkaConfig +from heizer.types import Any, Dict -DEFAULT_KAFKA_BOOTSTRAP_SERVER = os.environ.get("KAFKA_BOOTSTRAP_SERVER", "localhost:9094") -DEFAULT_KAFKA_GROUP = os.environ.get("KAFKA_GROUP", "default") +# confluent kafka configs +# https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#kafka-client-configuration -class HeizerConfig(object): - # confluent kafka configs - # https://docs.confluent.io/platform/current/clients/confluent-kafka-python/html/index.html#kafka-client-configuration - __kafka_config: KafkaConfig +@dataclass +class BaseConfig: + """ + Base configurations + """ - def __init__(self, config: Optional[KafkaConfig] = None): - self.__kafka_config = config or { - "bootstrap.servers": DEFAULT_KAFKA_BOOTSTRAP_SERVER, - "group.id": DEFAULT_KAFKA_GROUP, - } + bootstrap_servers: str - @property - def value(self) -> KafkaConfig: - """ - Return kafka configurations - Returns - ------- - config : dict - The dict contains kafka configurations - """ - return self.__kafka_config +@dataclass +class ProducerConfig(BaseConfig): + """ + Configuration class for producer. + """ + + other_configs: Dict[str, Any] = field(default_factory=defaultdict) + + +@dataclass +class ConsumerConfig(BaseConfig): + """ + Configuration class for consumer. + """ + + group_id: str + auto_offset_reset: str = "earliest" + other_configs: Dict[str, Any] = field(default_factory=defaultdict) + enable_auto_commit: bool = False diff --git a/heizer/env_vars.py b/heizer/env_vars.py new file mode 100644 index 0000000..4e5b7bb --- /dev/null +++ b/heizer/env_vars.py @@ -0,0 +1,5 @@ +import os + +CONSUMER_STATUS_FILE_PATH = os.environ.get("HEIZER_CONSUMER_STATUS_FILE_PATH", "/tmp/heizer_consumers_status.json") + +HEIZER_LOG_LEVEL = os.environ.get("HEIZER_LOG_LEVEL", "INFO") diff --git a/heizer/types.py b/heizer/types.py index e269f7e..e920819 100644 --- a/heizer/types.py +++ b/heizer/types.py @@ -3,7 +3,7 @@ import sys from typing import Any, Awaitable, Callable, Coroutine, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast -if sys.version_info.minor <= 10: +if sys.version_info.minor < 10: from typing_extensions import Concatenate, ParamSpec else: from typing import Concatenate, ParamSpec diff --git a/pyproject.toml b/pyproject.toml index c79cdc5..0f516e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,27 +33,29 @@ classifiers = [ dependencies = [ "confluent_kafka", - "pydantic", "typing; python_version<'3.10'", "typing_extensions; python_version<'3.10'", ] [project.optional-dependencies] -all = ["websockets"] + + socket = ["websockets"] dev = ["pre-commit"] doc = [ - "sphinx==5.3.0", + "sphinx", "sphinx_rtd_theme", "myst-parser", "wheel", "ipython", "sphinx-multiversion", + "requests", ] test = [ "pytest", - "pytest-cov" + "pytest-cov", + "pydantic" ] @@ -77,7 +79,7 @@ exclude =[".git", "__pycache__", "venv", "env", 'docs/*'] include = ["heizer/*"] [tool.mypy] -strict=true +strict=false ignore_missing_imports=true disallow_subclassing_any=false exclude = ['venv', '.venv', 'tests/*', 'docs/*', 'samples/*'] diff --git a/samples/websockets/server.py b/samples/websockets/server.py index 9a64bb5..8745d67 100644 --- a/samples/websockets/server.py +++ b/samples/websockets/server.py @@ -2,27 +2,30 @@ import websockets -from heizer import HeizerConfig, HeizerTopic, consumer - -consumer_config = HeizerConfig( - { - "bootstrap.servers": "0.0.0.0:9092", - "group.id": "test", - "auto.offset.reset": "earliest", - } +from heizer import ConsumerConfig, Message, Topic, consumer + +consumer_config = ConsumerConfig( + bootstrap_servers="localhost:9092", + group_id="websockets_sample", + auto_offset_reset="earliest", + enable_auto_commit=False, ) -topics = [HeizerTopic(name="my.topic1")] +topics = [Topic(name="my.topic1")] -@consumer(topics=topics, config=consumer_config, is_async=True) -async def handler(message, websocket, *args, **kwargs): +@consumer(topics=topics, config=consumer_config, is_async=True, init_topics=True, name="websocket_sample") +async def handler(message: Message, C: consumer, websocket, *args, **kwargs): + print(C.name) await websocket.send(message.value) async def main(): async with websockets.serve(handler, "", 8001): - await asyncio.Future() + try: + await asyncio.Future() + except KeyboardInterrupt: + return if __name__ == "__main__": diff --git a/tests/conftest.py b/tests/conftest.py index 605a353..e9dc089 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,12 +3,24 @@ import pytest from confluent_kafka.admin import AdminClient +from heizer import Topic, delete_topics, list_topics + +BOOTSTRAP_SERVERS = os.environ.get("KAFKA_SERVER", "localhost:9092") + @pytest.fixture() def bootstrap_server() -> str: - return os.environ.get("KAFKA_SERVER", "localhost:9092") + return BOOTSTRAP_SERVERS @pytest.fixture() def admin_client(bootstrap_server) -> AdminClient: return AdminClient({"bootstrap.servers": bootstrap_server}) + + +@pytest.fixture(autouse=True, scope="session") +def clean_topics(): + yield + config = {"bootstrap.servers": BOOTSTRAP_SERVERS} + topics = list_topics(config) + delete_topics(config, [Topic(t) for t in topics]) diff --git a/tests/test_consumer.py b/tests/test_consumer.py index 3093857..517d612 100644 --- a/tests/test_consumer.py +++ b/tests/test_consumer.py @@ -1,18 +1,23 @@ import json import logging import os -from typing import Any, Dict, cast +from typing import cast +from uuid import uuid4 import pytest from pydantic import BaseModel -from heizer import HeizerConfig, HeizerMessage, HeizerTopic, consumer, create_new_topic, get_admin_client, producer - -Producer_Config = HeizerConfig( - { - "bootstrap.servers": os.environ.get("KAFKA_SERVER", "localhost:9092"), - } +from heizer import ( + ConsumerSignal, + Message, + Producer, + ProducerConfig, + Topic, + consumer, + create_new_topics, + read_consumer_status, ) +from heizer.env_vars import CONSUMER_STATUS_FILE_PATH @pytest.fixture @@ -20,140 +25,150 @@ def group_id(): return "test_group" +@pytest.fixture(autouse=True) +def clean_logs(): + yield + if os.path.exists(CONSUMER_STATUS_FILE_PATH): + os.remove(CONSUMER_STATUS_FILE_PATH) + + @pytest.fixture -def consumer_config(group_id): - return HeizerConfig( - { - "bootstrap.servers": os.environ.get("KAFKA_SERVER", "localhost:9092"), - "group.id": group_id, - "auto.offset.reset": "earliest", - } - ) +def producer_config(bootstrap_server): + return ProducerConfig(bootstrap_servers=bootstrap_server) + + +@pytest.fixture +def consumer_config(group_id, bootstrap_server): + return { + "bootstrap.servers": bootstrap_server, + "group.id": group_id, + "auto.offset.reset": "earliest", + } @pytest.mark.parametrize("group_id", ["test_consumer_stopper"]) -def test_consumer_stopper(group_id, consumer_config) -> None: - toppics = [HeizerTopic(name="heizer.test.result", partitions=[0, 1, 2], replication_factor=2)] - admin = get_admin_client(consumer_config) - create_new_topic(admin, toppics) - - @producer( - topics=toppics, - config=Producer_Config, - key_alias="myKey", - headers_alias="myHeaders", - ) - def produce_data(status: str, result_value: str) -> Dict[str, Any]: - return { - "key1": 1, - "key2": "2", - "key3": True, - "status": status, - "result": result_value, - "myKey": "id1", - "myHeaders": {"header1": "value1", "header2": "value2"}, - } - - def stopper(msg: HeizerMessage) -> bool: +def test_consumer_stopper(group_id, consumer_config, producer_config, caplog, bootstrap_server) -> None: + topics = [Topic(name=f"heizer.test.result.{uuid4()}", num_partitions=3)] + create_new_topics({"bootstrap.servers": bootstrap_server}, topics) + + pd = Producer(config=producer_config) + + for status, result in [("start", 1), ("loading", 2), ("success", 3), ("postprocess", 4)]: + pd.produce( + topic=topics[0], + key="key1", + value={"status": status, "result": result}, + headers={"header1": "value1", "header2": "value2"}, + auto_flush=False, + ) + + pd.flush() + + def stopper(msg: Message, *args, **kwargs) -> bool: data = json.loads(msg.value) if data["status"] == "success": return True return False @consumer( - topics=[HeizerTopic(name="heizer.test.result")], + topics=topics, config=consumer_config, stopper=stopper, ) def consume_data(msg, *args, **kwargs) -> str: data = json.loads(msg.value) - assert msg.key == "id1" + assert msg.key == "key1" assert msg.headers == {"header1": "value1", "header2": "value2"} - assert "myKey" not in data - assert "myHeaders" not in data - return cast(str, data["result"]) - produce_data("start", "waiting") - produce_data("loading", "waiting") - produce_data("success", "finished") - produce_data("postprocess", "postprocess") - result = consume_data() # type: ignore - assert result == "finished" + assert result == 3 @pytest.mark.parametrize("group_id", ["test_consumer_call_once"]) -def test_consumer_call_once(consumer_config) -> None: +def test_consumer_call_once(group_id, producer_config, consumer_config, caplog) -> None: + caplog.set_level(logging.DEBUG) topic_name = "heizer.test.test_consumer_call_once" + topic = Topic(name=f"{topic_name}.{uuid4()}") - @producer( - topics=[HeizerTopic(name=topic_name)], - config=Producer_Config, - key_alias="non_existing_key", - headers_alias="non_existing_headers", - ) - def produce_data(status: str, result: str) -> Dict[str, Any]: - return { - "key1": 1, - "key2": "2", - "key3": True, - "status": status, - "result": result, - } + producer = Producer(config=producer_config) - @consumer( - topics=[HeizerTopic(name=topic_name)], - config=consumer_config, - call_once=True, - ) + for status, result in [("start", 1), ("loading", 2), ("success", 3), ("postprocess", 4)]: + producer.produce( + topic=topic, + key="key1", + value={"status": status, "result": result}, + headers={"header1": "value1", "header2": "value2"}, + auto_flush=True, + ) + + @consumer(topics=[topic], config=consumer_config, call_once=True) def consume_data(msg, *args, **kwargs) -> str: data = json.loads(msg.value) - return cast(str, data["result"]) + return data["result"] + + result = consume_data() - produce_data("start", "waiting") - produce_data("loading", "waiting") - produce_data("success", "finished") + assert result == 1 + + +@pytest.mark.parametrize("group_id", ["test_stop_consumer_by_signal"]) +def test_stop_consumer_by_signal(group_id, producer_config, consumer_config, caplog) -> None: + caplog.set_level(logging.DEBUG) + topic_name = "heizer.test.test_stop_consumer_by_signal" + topic = Topic(name=f"{topic_name}.{uuid4()}") + + producer = Producer(config=producer_config) + + for status, result in [("start", 1), ("loading", 2)]: + producer.produce( + topic=topic, + key="key1", + value={"status": status, "result": result}, + headers={"header1": "value1", "header2": "value2"}, + auto_flush=True, + ) + sg = ConsumerSignal() + + @consumer(topics=[topic], config=consumer_config, consumer_signal=sg) + def consume_data(msg, *args, **kwargs) -> str: + data = json.loads(msg.value) + sg.stop() + return data["result"] result = consume_data() - assert result == "waiting" + assert result == 1 @pytest.mark.parametrize("group_id", ["test_consumer_deserializer"]) -def test_consumer_deserializer(caplog, consumer_config) -> None: +def test_consumer_deserializer(caplog, consumer_config, group_id, producer_config) -> None: caplog.set_level(logging.DEBUG) - topic_name = "heizer.test.test_consumer_deserializer" + topic = Topic(f"heizer.test.test_consumer_deserializer.{uuid4()}") class TestModel(BaseModel): name: str age: int - deserializer = TestModel + deserializer = TestModel.parse_raw - @producer( - topics=[HeizerTopic(name=topic_name)], - config=Producer_Config, - ) - def produce_data() -> Dict[str, Any]: - return { + producer = Producer(config=producer_config) + + producer.produce( + topic=topic, + value={ "name": "mike", "age": 20, - } - - @consumer( - topics=[HeizerTopic(name=topic_name)], - config=consumer_config, - call_once=True, - deserializer=deserializer, + }, ) - def consume_data(message: HeizerMessage, *args, **kwargs): - return message.formatted_value - produce_data() + @consumer(topics=[topic], config=consumer_config, call_once=True, deserializer=deserializer, id="test_consumer_x") + def consume_data(message: Message, C, *args, **kwargs): + C.consumer_signal.stop() + return message.formatted_value result = consume_data() @@ -161,3 +176,63 @@ def consume_data(message: HeizerMessage, *args, **kwargs): assert result.name == "mike" assert result.age == 20 + + status = read_consumer_status(consumer_id="test_consumer_x") + assert status["status"] == "closed" + + +@pytest.mark.parametrize("group_id", ["test_consumer_retry_failed_func"]) +def test_consumer_retry_failed_func(caplog, consumer_config, group_id, producer_config) -> None: + caplog.set_level(logging.DEBUG) + topic = Topic(f"heizer.test.test_consumer_retry_failed_func.{uuid4()}") + retry_topic = Topic(f"heizer.test.test_consumer_retry_failed_func.retry.{uuid4()}") + + class TestModel(BaseModel): + name: str + age: int + + deserializer = TestModel.parse_raw + + producer = Producer(config=producer_config) + + producer.produce( + topic=topic, + headers={"k": "v"}, + value={ + "name": "mike", + "age": 20, + }, + ) + + def stopper(message, C, *args, **kwargs) -> bool: + if not getattr(C, "msg_count", None): + setattr(C, "msg_count", 1) + + C.msg_count += 1 + + if C.msg_count > 4: + return True + else: + return False + + @consumer( + topics=[topic], + config=consumer_config, + deserializer=deserializer, + enable_retry=True, + retry_times=3, + id="failed_to_consume_data_consumer", + name="test_consumer", + retry_topic=retry_topic, + stopper=stopper, + ) + def failed_to_consume_data(message: Message, C, *args, **kwargs): + assert C.retry_times_header_key not in message.headers + raise ValueError + + failed_to_consume_data() + + assert "[test_consumer] Function failed_to_consume_data reached retry limit (3), will give up" in caplog.messages + + status = read_consumer_status() + assert status["failed_to_consume_data_consumer"]["status"] == "closed" diff --git a/tests/test_producer.py b/tests/test_producer.py new file mode 100644 index 0000000..724a84b --- /dev/null +++ b/tests/test_producer.py @@ -0,0 +1,46 @@ +import asyncio +import uuid + +import pytest + +from heizer import Producer, ProducerConfig + + +@pytest.fixture +def message_count() -> int: + return 10_000 + + +@pytest.fixture +def messages(message_count): + return [ + { + "key": "test_key", + "value": ["test"] * 1000, + "headers": {"k": "v"}, + } + ] * message_count + + +def test_produce_message(messages, bootstrap_server, caplog): + pd = Producer(config=ProducerConfig(bootstrap_servers=bootstrap_server), name="test_producer") + + topic = f"test_producer_topic_{uuid.uuid4()}" + + for msg in messages: + pd.produce(topic=topic, auto_flush=False, **msg) + pd.flush() + + +def test_async_produce_message(messages, bootstrap_server, caplog): + pd = Producer(config=ProducerConfig(bootstrap_servers=bootstrap_server), name="test_producer") + topic = f"test_async_producer_topic_{uuid.uuid4()}" + jobs = [] + + async def produce(): + for msg in messages: + jobs.append(asyncio.ensure_future(pd.async_produce(topic=topic, auto_flush=False, **msg))) + await asyncio.gather(*jobs) + pd.flush() + + asyncio.run(produce())