Skip to content

Commit

Permalink
put {'format':'msgpack'} in literal's metadata
Browse files Browse the repository at this point in the history
Signed-off-by: Future-Outlier <[email protected]>
  • Loading branch information
Future-Outlier committed Sep 8, 2024
1 parent d154af8 commit 86c24ae
Showing 1 changed file with 45 additions and 39 deletions.
84 changes: 45 additions & 39 deletions flytekit/core/type_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from functools import lru_cache
from typing import Dict, List, NamedTuple, Optional, Type, cast

import msgpack
from dataclasses_json import DataClassJsonMixin, dataclass_json
from flyteidl.core import literals_pb2
from google.protobuf import json_format as _json_format
Expand Down Expand Up @@ -336,7 +337,7 @@ class Test(DataClassJsonMixin):

def __init__(self):
super().__init__("Object-Dataclass-Transformer", object)
self._json_decoder: Dict[Type, JSONDecoder] = {} # This will deprecated in the future
self._json_decoder: Dict[Type, JSONDecoder] = {} # This will be deprecated in the future
self._msgpack_encoder: Dict[Type, MessagePackEncoder] = {}
self._msgpack_decoder: Dict[Type, MessagePackDecoder] = {}

Expand Down Expand Up @@ -507,14 +508,12 @@ def get_literal_type(self, t: Type[T]) -> LiteralType:
return _type_models.LiteralType(simple=_type_models.SimpleType.JSON, metadata=schema, structure=ts)

def to_literal(self, ctx: FlyteContext, python_val: T, python_type: Type[T], expected: LiteralType) -> Literal:
import msgpack

from flytekit.models.literals import Json

if isinstance(python_val, dict):
# There will be bug for local execution dataclass attribute access
msgpack_bytes = msgpack.dumps(python_val)
return Literal(scalar=Scalar(json=Json(value=msgpack_bytes)))
return Literal(scalar=Scalar(json=Json(value=msgpack_bytes)), metadata={"format": "msgpack"})

Check warning on line 516 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L515-L516

Added lines #L515 - L516 were not covered by tests

if not dataclasses.is_dataclass(python_val):
raise TypeTransformerFailedError(
Expand Down Expand Up @@ -692,7 +691,6 @@ def _fix_dataclass_int(self, dc_type: Type[dataclasses.dataclass], dc: typing.An
return dc

def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T]) -> T:
import msgpack
from mashumaro.codecs.msgpack import MessagePackDecoder

if not dataclasses.is_dataclass(expected_python_type):
Expand All @@ -701,21 +699,35 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
"user defined datatypes in Flytekit"
)

dc = None
scalar = lv.scalar
metadata = lv.metadata
if scalar.json:
# We should check if the metadata's format is supported or not,
# if so, we should use it to deserialize the dataclass.
# Now we only support msgpack format.
# We can add other formats in the future, and metadata can ensure backward compatibility.
# We can't use hasattr(expected_python_type, "from_json") here because we rely on
# mashumaro's API to customize the deserialization behavior for Flyte types.
if issubclass(expected_python_type, DataClassJSONMixin):
dict_obj = msgpack.loads(scalar.json.value)
json_str = json.dumps(dict_obj)
dc = expected_python_type.from_json(json_str) # type: ignore
else:
try:
decoder = self._msgpack_decoder[expected_python_type]
except KeyError:
decoder = MessagePackDecoder(expected_python_type)
self._msgpack_decoder[expected_python_type] = decoder
dc = decoder.decode(scalar.json.value)
if metadata and metadata.get("format", None) == "msgpack":
if issubclass(expected_python_type, DataClassJSONMixin):
dict_obj = msgpack.loads(scalar.json.value)
json_str = json.dumps(dict_obj)
dc = expected_python_type.from_json(json_str) # type: ignore

Check warning on line 716 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L714-L716

Added lines #L714 - L716 were not covered by tests
else:
try:
decoder = self._msgpack_decoder[expected_python_type]
except KeyError:
decoder = MessagePackDecoder(expected_python_type)
self._msgpack_decoder[expected_python_type] = decoder
dc = decoder.decode(scalar.json.value)

if dc is None:
raise TypeTransformerFailedError(

Check warning on line 726 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L726

Added line #L726 was not covered by tests
f"Failed to convert {scalar.json.value} to {expected_python_type}. "
f"Please check if the literal's metadata: {metadata} has a 'format' field, and ensure flytekit support it."
)

elif scalar.generic:
# The `from_json` function is provided from mashumaro's `DataClassJSONMixin`.
# It deserializes a JSON string into a data class, and supports additional functionality over JSONDecoder
Expand All @@ -739,7 +751,7 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
if scalar.generic:
self._fix_dataclass_int(expected_python_type, dc)

Check warning on line 752 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L752

Added line #L752 was not covered by tests

return dc
return dc # type: ignore

# This ensures that calls with the same literal type returns the same dataclass. For example, `pyflyte run``
# command needs to call guess_python_type to get the TypeEngine-derived dataclass. Without caching here, separate
Expand Down Expand Up @@ -1709,19 +1721,22 @@ def dict_to_json_literal(ctx: FlyteContext, v: dict, allow_pickle: bool) -> Lite
"""
Creates a flyte-specific ``Literal`` value from a native python dictionary.
"""
import msgpack

from flytekit.models.literals import Json
from flytekit.types.pickle import FlytePickle

Check warning on line 1725 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1724-L1725

Added lines #L1724 - L1725 were not covered by tests

try:
msgpack_bytes = msgpack.dumps(v)
return Literal(scalar=Scalar(json=Json(value=msgpack_bytes)), metadata={"format": "msgpack"})
except TypeError as e:

Check warning on line 1730 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1727-L1730

Added lines #L1727 - L1730 were not covered by tests
# There's no need to use Json type in this case, as key and value in the dict are all strings.
if allow_pickle:
remote_path = FlytePickle.to_pickle(ctx, v)
msgpack_bytes = msgpack.dumps({"pickle_file": remote_path})
return Literal(scalar=Scalar(json=Json(value=msgpack_bytes)), metadata={"format": "pickle"})
return Literal(

Check warning on line 1734 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1733-L1734

Added lines #L1733 - L1734 were not covered by tests
scalar=Scalar(
generic=_json_format.Parse(json.dumps({"pickle_file": remote_path}), _struct.Struct())
),
metadata={"format": "pickle"},
)
raise e

Check warning on line 1740 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1740

Added line #L1740 was not covered by tests

@staticmethod
Expand Down Expand Up @@ -1812,32 +1827,23 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
# for empty generic we have to explicitly test for lv.scalar.generic is not None as empty dict
# evaluates to false
if lv and lv.scalar:
if lv.scalar.generic is not None:
if lv.metadata and lv.metadata.get("format", None) == "pickle":
from flytekit.types.pickle import FlytePickle
metadata = lv.metadata

Check warning on line 1830 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1830

Added line #L1830 was not covered by tests
if metadata and metadata.get("format", None) == "pickle":
from flytekit.types.pickle import FlytePickle

uri = json.loads(_json_format.MessageToJson(lv.scalar.generic)).get("pickle_file")
return FlytePickle.from_pickle(uri)
uri = json.loads(_json_format.MessageToJson(lv.scalar.generic)).get("pickle_file")
return FlytePickle.from_pickle(uri)

if lv.scalar.generic is not None:
try:
return json.loads(_json_format.MessageToJson(lv.scalar.generic))
except TypeError:
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

Check warning on line 1841 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1838-L1841

Added lines #L1838 - L1841 were not covered by tests
elif lv.scalar.json is not None:
import msgpack

if lv.metadata and lv.metadata.get("format", None) == "pickle":
from flytekit.types.pickle import FlytePickle

msgpack_bytes = lv.scalar.json.value
dict_obj = msgpack.loads(msgpack_bytes)
uri = dict_obj.get("pickle_file")
return FlytePickle.from_pickle(uri)

try:

Check warning on line 1843 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1843

Added line #L1843 was not covered by tests
msgpack_bytes = lv.scalar.json.value
return msgpack.loads(msgpack_bytes)

if metadata.get("format", None) == "msgpack":
msgpack_bytes = lv.scalar.json.value
return msgpack.loads(msgpack_bytes)
except TypeError:
raise TypeTransformerFailedError(f"Cannot convert from {lv} to {expected_python_type}")

Check warning on line 1848 in flytekit/core/type_engine.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/type_engine.py#L1845-L1848

Added lines #L1845 - L1848 were not covered by tests

Expand Down

0 comments on commit 86c24ae

Please sign in to comment.