From d154af85e8b4caecb88c451a482148757be89418 Mon Sep 17 00:00:00 2001 From: Future-Outlier Date: Sun, 8 Sep 2024 14:28:31 +0800 Subject: [PATCH] support dataclass with flyte types for local execution Signed-off-by: Future-Outlier --- flytekit/core/type_engine.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/flytekit/core/type_engine.py b/flytekit/core/type_engine.py index e91d2a0aed..84a30fd46a 100644 --- a/flytekit/core/type_engine.py +++ b/flytekit/core/type_engine.py @@ -390,10 +390,24 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T): f"The original fields have the following extra keys that are not in dataclass fields: {list(extra_keys)}" ) + from flytekit import StructuredDataset + from flytekit.types.directory import FlyteDirectory + from flytekit.types.file import FlyteFile + from flytekit.types.schema.types import FlyteSchema + + FLYTE_TYPES = [FlyteFile, FlyteDirectory, StructuredDataset, FlyteSchema] + for k, v in original_dict.items(): if k in expected_fields_dict: if isinstance(v, dict): - self.assert_type(expected_fields_dict[k], v) + # todo: 1. check if expected_fields_dict[k] is a flyte type, if yes, then use v to construct a flyet types and assert them + expected_type = expected_fields_dict[k] + if expected_type in FLYTE_TYPES: + new_v = copy.deepcopy(v) + new_v = expected_type(**new_v) + self.assert_type(expected_fields_dict[k], new_v) + else: + self.assert_type(expected_fields_dict[k], v) else: expected_type = expected_fields_dict[k] original_type = type(v) @@ -498,6 +512,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp from flytekit.models.literals import Json if isinstance(python_val, dict): + # There will be bug for local execution dataclass attribute access msgpack_bytes = msgpack.dumps(python_val) return Literal(scalar=Scalar(json=Json(value=msgpack_bytes)))