Skip to content

Commit

Permalink
Implement customizable serializer (#214)
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov authored Oct 10, 2023
1 parent 34db231 commit cf87480
Show file tree
Hide file tree
Showing 13 changed files with 306 additions and 57 deletions.
114 changes: 62 additions & 52 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ freezegun = "^1.2.2"
pytest-mock = "^3.11.1"
tzlocal = "^5.0.1"
types-tzlocal = "^5.0.1.1"
types-pytz = "^2023.3.1.1"

[tool.poetry.extras]
zmq = ["pyzmq"]
Expand Down
20 changes: 18 additions & 2 deletions taskiq/abc/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
from typing_extensions import ParamSpec, Self, TypeAlias

from taskiq.abc.middleware import TaskiqMiddleware
from taskiq.abc.serializer import TaskiqSerializer
from taskiq.acks import AckableMessage
from taskiq.decor import AsyncTaskiqDecoratedTask
from taskiq.events import TaskiqEvents
from taskiq.formatters.json_formatter import JSONFormatter
from taskiq.formatters.proxy_formatter import ProxyFormatter
from taskiq.message import BrokerMessage
from taskiq.result_backends.dummy import DummyResultBackend
from taskiq.serializers.json_serializer import JSONSerializer
from taskiq.state import TaskiqState
from taskiq.utils import maybe_awaitable, remove_suffix
from taskiq.warnings import TaskiqDeprecationWarning
Expand Down Expand Up @@ -97,7 +99,8 @@ def __init__(
self.middlewares: "List[TaskiqMiddleware]" = []
self.result_backend = result_backend
self.decorator_class = AsyncTaskiqDecoratedTask
self.formatter: "TaskiqFormatter" = JSONFormatter()
self.serializer: TaskiqSerializer = JSONSerializer()
self.formatter: "TaskiqFormatter" = ProxyFormatter(self)
self.id_generator = task_id_generator
self.local_task_registry: Dict[str, AsyncTaskiqDecoratedTask[Any, Any]] = {}
# Every event has a list of handlers.
Expand Down Expand Up @@ -479,6 +482,19 @@ def with_event_handlers(
self.event_handlers[event].extend(handlers)
return self

def with_serializer(
self,
serializer: TaskiqSerializer,
) -> "Self": # pragma: no cover
"""
Set a new serializer and return an updated broker.
:param serializer: new serializer.
:return: self
"""
self.serializer = serializer
return self

def _register_task(
self,
task_name: str,
Expand Down
24 changes: 24 additions & 0 deletions taskiq/abc/serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from abc import ABC, abstractmethod
from typing import Any


class TaskiqSerializer(ABC):
"""Custom serializer for brokers."""

@abstractmethod
def dumpb(self, value: Any) -> bytes:
"""
Dump value to bytes for sending through the wire.
:param value: value to encode.
:return: encoded value.
"""

@abstractmethod
def loadb(self, value: bytes) -> Any:
"""
Parse byte-encoded value received from the wire.
:param message: value to parse.
:return: decoded value.
"""
18 changes: 18 additions & 0 deletions taskiq/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,15 @@
def parse_obj_as(annot: T, obj: Any) -> T:
return pydantic.TypeAdapter(annot).validate_python(obj)

def model_validate(
model_class: Type[Model],
message: Dict[str, Any],
) -> Model:
return model_class.model_validate(message)

def model_dump(instance: Model) -> Dict[str, Any]:
return instance.model_dump()

def model_validate_json(
model_class: Type[Model],
message: Union[str, bytes, bytearray],
Expand All @@ -37,6 +46,15 @@ def model_copy(
else:
parse_obj_as = pydantic.parse_obj_as # type: ignore

def model_validate(
model_class: Type[Model],
message: Dict[str, Any],
) -> Model:
return model_class.parse_obj(message)

def model_dump(instance: Model) -> Dict[str, Any]:
return instance.dict()

def model_validate_json(
model_class: Type[Model],
message: Union[str, bytes, bytearray],
Expand Down
2 changes: 1 addition & 1 deletion taskiq/formatters/json_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class JSONFormatter(TaskiqFormatter):
"""Default taskiq formatter."""
"""JSON taskiq formatter."""

def dumps(self, message: TaskiqMessage) -> BrokerMessage:
"""
Expand Down
38 changes: 38 additions & 0 deletions taskiq/formatters/proxy_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from typing import TYPE_CHECKING

from taskiq.abc.formatter import TaskiqFormatter
from taskiq.compat import model_dump, model_validate
from taskiq.message import BrokerMessage, TaskiqMessage

if TYPE_CHECKING:
from taskiq.abc.broker import AsyncBroker


class ProxyFormatter(TaskiqFormatter):
"""Default taskiq formatter."""

def __init__(self, broker: "AsyncBroker") -> None:
self.broker = broker

def dumps(self, message: TaskiqMessage) -> BrokerMessage:
"""
Dumps taskiq message to some broker message format.
:param message: message to send.
:return: Dumped message.
"""
return BrokerMessage(
task_id=message.task_id,
task_name=message.task_name,
message=self.broker.serializer.dumpb(model_dump(message)),
labels=message.labels,
)

def loads(self, message: bytes) -> TaskiqMessage:
"""
Loads json from message.
:param message: broker's message.
:return: parsed taskiq message.
"""
return model_validate(TaskiqMessage, self.broker.serializer.loadb(message))
4 changes: 2 additions & 2 deletions taskiq/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class TaskiqMessage(BaseModel):

task_id: str
task_name: str
labels: Dict[str, str]
labels: Dict[str, Any]
args: List[Any]
kwargs: Dict[str, Any]

Expand All @@ -25,4 +25,4 @@ class BrokerMessage(BaseModel):
task_id: str
task_name: str
message: bytes
labels: Dict[str, str]
labels: Dict[str, Any]
1 change: 1 addition & 0 deletions taskiq/serializers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Taskiq serializers."""
26 changes: 26 additions & 0 deletions taskiq/serializers/json_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from json import dumps, loads
from typing import Any

from taskiq.abc.serializer import TaskiqSerializer


class JSONSerializer(TaskiqSerializer):
"""Default taskiq serizalizer."""

def dumpb(self, value: Any) -> bytes:
"""
Dumps taskiq message to some broker message format.
:param message: message to send.
:return: Dumped message.
"""
return dumps(value).encode()

def loadb(self, value: bytes) -> Any:
"""
Parse byte-encoded value received from the wire.
:param message: value to parse.
:return: decoded value.
"""
return loads(value.decode())
45 changes: 45 additions & 0 deletions tests/formatters/test_json_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import pytest

from taskiq.formatters.json_formatter import JSONFormatter
from taskiq.message import BrokerMessage, TaskiqMessage


@pytest.mark.anyio
async def test_json_dumps() -> None:
fmt = JSONFormatter()
msg = TaskiqMessage(
task_id="task-id",
task_name="task.name",
labels={"label1": 1, "label2": "text"},
args=[1, "a"],
kwargs={"p1": "v1"},
)
expected = BrokerMessage(
task_id="task-id",
task_name="task.name",
message=(
b'{"task_id":"task-id","task_name":"task.name",'
b'"labels":{"label1":1,"label2":"text"},'
b'"args":[1,"a"],"kwargs":{"p1":"v1"}}'
),
labels={"label1": 1, "label2": "text"},
)
assert fmt.dumps(msg) == expected


@pytest.mark.anyio
async def test_json_loads() -> None:
fmt = JSONFormatter()
msg = (
b'{"task_id":"task-id","task_name":"task.name",'
b'"labels":{"label1":1,"label2":"text"},'
b'"args":[1,"a"],"kwargs":{"p1":"v1"}}'
)
expected = TaskiqMessage(
task_id="task-id",
task_name="task.name",
labels={"label1": 1, "label2": "text"},
args=[1, "a"],
kwargs={"p1": "v1"},
)
assert fmt.loads(msg) == expected
47 changes: 47 additions & 0 deletions tests/formatters/test_proxy_formatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest

from taskiq.brokers.inmemory_broker import InMemoryBroker
from taskiq.message import BrokerMessage, TaskiqMessage


@pytest.mark.anyio
async def test_proxy_dumps() -> None:
# uses json serializer by default
broker = InMemoryBroker()
msg = TaskiqMessage(
task_id="task-id",
task_name="task.name",
labels={"label1": 1, "label2": "text"},
args=[1, "a"],
kwargs={"p1": "v1"},
)
expected = BrokerMessage(
task_id="task-id",
task_name="task.name",
message=(
b'{"task_id": "task-id", "task_name": "task.name", '
b'"labels": {"label1": 1, "label2": "text"}, '
b'"args": [1, "a"], "kwargs": {"p1": "v1"}}'
),
labels={"label1": 1, "label2": "text"},
)
assert broker.formatter.dumps(msg) == expected


@pytest.mark.anyio
async def test_proxy_loads() -> None:
# uses json serializer by default
broker = InMemoryBroker()
msg = (
b'{"task_id":"task-id","task_name":"task.name",'
b'"labels":{"label1":1,"label2":"text"},'
b'"args":[1,"a"],"kwargs":{"p1":"v1"}}'
)
expected = TaskiqMessage(
task_id="task-id",
task_name="task.name",
labels={"label1": 1, "label2": "text"},
args=[1, "a"],
kwargs={"p1": "v1"},
)
assert broker.formatter.loads(msg) == expected
23 changes: 23 additions & 0 deletions tests/serializers/test_json_serializer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import pytest

from taskiq.serializers.json_serializer import JSONSerializer


@pytest.mark.anyio
async def test_json_dumpb() -> None:
serizalizer = JSONSerializer()
assert serizalizer.dumpb(None) == b"null" # noqa: PLR2004
assert serizalizer.dumpb(1) == b"1" # noqa: PLR2004
assert serizalizer.dumpb("a") == b'"a"' # noqa: PLR2004
assert serizalizer.dumpb(["a"]) == b'["a"]' # noqa: PLR2004
assert serizalizer.dumpb({"a": "b"}) == b'{"a": "b"}' # noqa: PLR2004


@pytest.mark.anyio
async def test_json_loadb() -> None:
serizalizer = JSONSerializer()
assert serizalizer.loadb(b"null") is None
assert serizalizer.loadb(b"1") == 1
assert serizalizer.loadb(b'"a"') == "a"
assert serizalizer.loadb(b'["a"]') == ["a"]
assert serizalizer.loadb(b'{"a": "b"}') == {"a": "b"}

0 comments on commit cf87480

Please sign in to comment.