Skip to content

Commit

Permalink
feat: support for Unpack[TypedDict] in task kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
dimastbk committed Dec 25, 2024
1 parent e815b95 commit 10ed15c
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 15 deletions.
27 changes: 21 additions & 6 deletions pyzeebe/function_tools/parameter_tools.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,40 @@
from __future__ import annotations

import inspect
from typing import Any
from typing import Any, get_type_hints

from typing_extensions import ( # type: ignore[attr-defined]
_is_unpack,
get_args,
is_typeddict,
)

from pyzeebe.function_tools import Function
from pyzeebe.job.job import Job


def get_parameters_from_function(task_function: Function[..., Any]) -> list[str] | None:
function_signature = inspect.signature(task_function)
for _, parameter in function_signature.parameters.items():
if parameter.kind in (inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD):
return []
variables_to_fetch: list[str] = []

function_signature = inspect.signature(task_function)
if not function_signature.parameters:
return None

for parameter in function_signature.parameters.values():
if parameter.kind == inspect.Parameter.VAR_POSITIONAL:
return []
elif parameter.kind == inspect.Parameter.VAR_KEYWORD:
if _is_unpack(parameter.annotation) and is_typeddict(get_args(parameter.annotation)[0]):
variables_to_fetch.extend(get_type_hints(get_args(parameter.annotation)[0]).keys())
else:
return []
elif parameter.annotation != Job:
variables_to_fetch.append(parameter.name)

if all(param.annotation == Job for param in function_signature.parameters.values()):
return []

return [param.name for param in function_signature.parameters.values() if param.annotation != Job]
return variables_to_fetch


def get_job_parameter_name(function: Function[..., Any]) -> str | None:
Expand Down
11 changes: 4 additions & 7 deletions pyzeebe/grpc_internals/zeebe_process_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,13 +206,10 @@ def _create_form_from_raw_form(response: FormMetadata) -> DeployResourceResponse

_METADATA_PARSERS: dict[
str,
Callable[
[ProcessMetadata | DecisionMetadata | DecisionRequirementsMetadata | FormMetadata],
DeployResourceResponse.ProcessMetadata
| DeployResourceResponse.DecisionMetadata
| DeployResourceResponse.DecisionRequirementsMetadata
| DeployResourceResponse.FormMetadata,
],
Callable[[ProcessMetadata], DeployResourceResponse.ProcessMetadata]
| Callable[[DecisionMetadata], DeployResourceResponse.DecisionMetadata]
| Callable[[DecisionRequirementsMetadata], DeployResourceResponse.DecisionRequirementsMetadata]
| Callable[[FormMetadata], DeployResourceResponse.FormMetadata],
] = {
"process": ZeebeProcessAdapter._create_process_from_raw_process,
"decision": ZeebeProcessAdapter._create_decision_from_raw_decision,
Expand Down
4 changes: 2 additions & 2 deletions pyzeebe/worker/task_router.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import logging
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from typing import Any, Callable, Literal, Optional, TypeVar, overload

from typing_extensions import ParamSpec
Expand All @@ -16,7 +16,7 @@

P = ParamSpec("P")
R = TypeVar("R")
RD = TypeVar("RD", bound=Optional[dict[str, Any]])
RD = TypeVar("RD", bound=Optional[Mapping[str, Any]])

logger = logging.getLogger(__name__)

Expand Down
2 changes: 2 additions & 0 deletions tests/unit/function_tools/parameter_tools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ class TestGetFunctionParameters:
(dummy_functions.positional_and_keyword_params, ["x", "y"]),
(dummy_functions.args_param, []),
(dummy_functions.kwargs_param, []),
(dummy_functions.kwargs_typed_dict_param, ["z"]),
(dummy_functions.positional_and_kwargs_typed_dict_param, ["x", "y", "z"]),
(dummy_functions.standard_named_params, ["args", "kwargs"]),
(dummy_functions.with_job_parameter, []),
(dummy_functions.with_job_parameter_and_param, ["x"]),
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/utils/dummy_functions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from typing import TypedDict

from typing_extensions import Unpack

from pyzeebe.job.job import Job


Expand Down Expand Up @@ -33,6 +37,18 @@ def kwargs_param(**kwargs):
pass


class Kwargs(TypedDict):
z: int


def kwargs_typed_dict_param(**kwargs: Unpack[Kwargs]):
pass


def positional_and_kwargs_typed_dict_param(x, y=1, **kwargs: Unpack[Kwargs]):
pass


def standard_named_params(args, kwargs):
pass

Expand Down

0 comments on commit 10ed15c

Please sign in to comment.