Skip to content

Commit

Permalink
Use msgpack for control message passing (#27)
Browse files Browse the repository at this point in the history
* Use yaml FullLoader that handle UUID type as well

* Use UnsafeLoader explicitly for bundles

* Using msgpack for passing control messages
  • Loading branch information
unkcpz authored Jan 30, 2025
1 parent eca3ae3 commit 4edd4df
Show file tree
Hide file tree
Showing 6 changed files with 106 additions and 5 deletions.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ keywords = ['workflow', 'multithreaded', 'rabbitmq']
requires-python = '>=3.10'
dependencies = [
'kiwipy[rmq]~=0.8.5',
'msgpack~=1.1',
'nest_asyncio~=1.5,>=1.5.1',
'pyyaml~=6.0',
'typing-extensions~=4.12'
Expand Down Expand Up @@ -131,6 +132,7 @@ module = [
'aiocontextvars.*',
'frozendict.*',
'kiwipy.*',
'msgpack.*',
'nest_asyncio.*',
'tblib.*',
]
Expand Down
3 changes: 1 addition & 2 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import yaml
from typing_extensions import Self, override
from yaml.loader import Loader

from plumpy.loaders import ObjectLoader
from plumpy.message import MessageBuilder, MessageType
Expand Down Expand Up @@ -547,7 +546,7 @@ def recreate_from(cls, saved_state: SAVED_STATE_TYPE, load_context: LoadSaveCont
load_context = ensure_object_loader(load_context, saved_state)
obj = auto_load(cls, saved_state, load_context)

obj.exception = yaml.load(saved_state[obj.EXC_VALUE], Loader=Loader)
obj.exception = yaml.load(saved_state[obj.EXC_VALUE], Loader=yaml.UnsafeLoader)
if _HAS_TBLIB:
try:
obj.traceback = tblib.Traceback.from_string(saved_state[obj.TRACEBACK], strict=False)
Expand Down
44 changes: 44 additions & 0 deletions src/plumpy/rmq/process_control.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,13 @@
"""Module for process level communication functions and classes"""

import asyncio
import functools
import uuid
from typing import Any, Dict, Hashable, Optional, Sequence, Union

import kiwipy
import msgpack
import yaml

from plumpy import loaders
from plumpy.coordinator import Coordinator
Expand All @@ -27,6 +31,46 @@
ProcessStatus = Any


# Define yaml type represender and constructor for UUID type handling in message passing
# NOTE: it is recommend to use msgpack for sending message, the yaml is only here for reference.
def uuid_representer(dumper, data): # type: ignore
return dumper.represent_scalar('!uuid', str(data))


def uuid_constructor(loader, node): # type: ignore
value = loader.construct_scalar(node)
return uuid.UUID(value)


yaml.add_representer(uuid.UUID, uuid_representer)
yaml.add_constructor('!uuid', uuid_constructor)

YAML_ENCODER = functools.partial(yaml.dump, encoding='utf-8')
YAML_DECODER = functools.partial(yaml.load, Loader=yaml.FullLoader)

# Define ext hook for msgpack to handle UUID type in message passing

UUID_EXT_CODE = 42 # Just pick any integer < 128


def default_uuid_ext(obj: Any) -> msgpack.ExtType:
"""Convert UUID objects into a custom msgpack.ExtType."""
if isinstance(obj, uuid.UUID):
return msgpack.ExtType(UUID_EXT_CODE, obj.bytes)
raise TypeError(f'Cannot serialize type {type(obj)}')


def ext_hook(code: Any, data: bytes | None) -> Any:
"""Recreate the object from the custom msgpack.ExtType."""
if code == UUID_EXT_CODE:
return uuid.UUID(bytes=data)
return msgpack.ExtType(code, data)


MSGPACK_ENCODER = functools.partial(msgpack.packb, default=default_uuid_ext)
MSGPACK_DECODER = functools.partial(msgpack.unpackb, ext_hook=ext_hook)


# FIXME: the class not fit typing of ProcessController protocol
class RemoteProcessController:
"""
Expand Down
6 changes: 4 additions & 2 deletions tests/rmq/test_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import uuid
import pytest
import shortuuid
import yaml
# import yaml
import msgpack

from kiwipy.rmq import RmqThreadCommunicator

Expand Down Expand Up @@ -41,7 +42,8 @@ def _coordinator():
message_exchange=message_exchange,
task_exchange=task_exchange,
task_queue=task_queue,
decoder=functools.partial(yaml.load, Loader=yaml.Loader),
encoder=process_control.MSGPACK_ENCODER,
decoder=process_control.MSGPACK_DECODER,
)

loop = asyncio.get_event_loop()
Expand Down
2 changes: 1 addition & 1 deletion tests/test_persistence.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,6 @@ def test_bundle_yaml(self):
bundle = plumpy.Bundle(Save1())
represent = yaml.dump({'bundle': bundle})

bundle_loaded = yaml.load(represent, Loader=yaml.Loader)['bundle']
bundle_loaded = yaml.load(represent, Loader=yaml.UnsafeLoader)['bundle']
assert isinstance(bundle_loaded, plumpy.Bundle)
assert bundle_loaded == Save1().save()
54 changes: 54 additions & 0 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 4edd4df

Please sign in to comment.