Skip to content

Commit

Permalink
support dataclass with flyte types for local execution
Browse files Browse the repository at this point in the history
Signed-off-by: Future-Outlier <[email protected]>
  • Loading branch information
Future-Outlier committed Sep 8, 2024
1 parent d3c9dd1 commit d154af8
Showing 1 changed file with 16 additions and 1 deletion.
17 changes: 16 additions & 1 deletion flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,24 @@ def assert_type(self, expected_type: Type[DataClassJsonMixin], v: T):
f"The original fields have the following extra keys that are not in dataclass fields: {list(extra_keys)}"
)

from flytekit import StructuredDataset
from flytekit.types.directory import FlyteDirectory
from flytekit.types.file import FlyteFile
from flytekit.types.schema.types import FlyteSchema

FLYTE_TYPES = [FlyteFile, FlyteDirectory, StructuredDataset, FlyteSchema]

for k, v in original_dict.items():
if k in expected_fields_dict:
if isinstance(v, dict):
self.assert_type(expected_fields_dict[k], v)
# todo: 1. check if expected_fields_dict[k] is a flyte type, if yes, then use v to construct a flyet types and assert them
expected_type = expected_fields_dict[k]
if expected_type in FLYTE_TYPES:
new_v = copy.deepcopy(v)
new_v = expected_type(**new_v)
self.assert_type(expected_fields_dict[k], new_v)
else:
self.assert_type(expected_fields_dict[k], v)
else:
expected_type = expected_fields_dict[k]
original_type = type(v)
Expand Down Expand Up @@ -498,6 +512,7 @@ def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], exp
from flytekit.models.literals import Json

if isinstance(python_val, dict):
# There will be bug for local execution dataclass attribute access
msgpack_bytes = msgpack.dumps(python_val)
return Literal(scalar=Scalar(json=Json(value=msgpack_bytes)))

Expand Down

0 comments on commit d154af8

Please sign in to comment.