-
Notifications
You must be signed in to change notification settings - Fork 299
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
fix: fix JSONParamType to handle serialization of custom objects #2931
base: master
Are you sure you want to change the base?
fix: fix JSONParamType to handle serialization of custom objects #2931
Conversation
a155053
to
36f98ed
Compare
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #2931 +/- ##
===========================================
- Coverage 76.33% 46.75% -29.59%
===========================================
Files 199 199
Lines 20840 20789 -51
Branches 2681 2684 +3
===========================================
- Hits 15908 9719 -6189
- Misses 4214 10594 +6380
+ Partials 718 476 -242 ☔ View full report in Codecov by Sentry. |
Signed-off-by: Tsung-Han Ho (dalaoqi) <[email protected]>
Signed-off-by: Tsung-Han Ho (dalaoqi) <[email protected]>
…intainability Signed-off-by: Tsung-Han Ho (dalaoqi) <[email protected]>
- Added a fallback check for `lv.scalar.union.value.collection.literals` in case `lv.collection` is None. - Prevents potential errors when `lv.collection` is not properly initialized. Signed-off-by: Tsung-Han Ho (dalaoqi) <[email protected]>
- Added validation to check for duplicate input names in the function's input interface. - Raised a ValueError if duplicate input names are detected to prevent issues during argument assignment. - Simplified the conversion of args to kwargs by removing the redundant multiple values check. Signed-off-by: Tsung-Han Ho (dalaoqi) <[email protected]>
2ad9c6c
to
1ce0704
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tested 4 cases and all of them worked!
can you help me add some unit tests?
pyflyte run build/taiwan/2931/example1.py wf
pyflyte run build/taiwan/2931/example1.py dict_wf
pyflyte run build/taiwan/2931/example1.py list_wf
pyflyte run build/taiwan/2931/example1.py dc_wf
from textwrap import shorten
from flytekit import task, workflow
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, List
from flytekit.core.type_engine import TypeEngine
from dataclasses_json import dataclass_json
from flyteidl.core.execution_pb2 import TaskExecution
from flytekit.core.context_manager import FlyteContextManager
from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.core.interface import Interface
from flytekit.extend.backend.base_agent import (
AgentRegistry,
Resource,
SyncAgentBase,
SyncAgentExecutorMixin,
)
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate
from flytekit.types.file import FlyteFile
from flytekit.types.directory import FlyteDirectory
from flytekit.types.structured import StructuredDataset
from flytekit.types.schema import FlyteSchema
import pandas as pd
import os
image = None
@dataclass_json
@dataclass
class Foo:
val: str
class FooAgent(SyncAgentBase):
def __init__(self) -> None:
super().__init__(task_type_name="foo")
def do(
self,
task_template: TaskTemplate,
inputs: Optional[LiteralMap] = None,
**kwargs: Any,
) -> Resource:
return Resource(
phase=TaskExecution.SUCCEEDED, outputs={"foos": [Foo(val="a"), Foo(val="b")], "has_foos": True}
)
AgentRegistry.register(FooAgent())
class FooTask(SyncAgentExecutorMixin, PythonTask): # type: ignore
_TASK_TYPE = "foo"
def __init__(self, name: str, **kwargs: Any) -> None:
task_config: dict[str, Any] = {}
outputs = {"has_foos": bool, "foos": Optional[List[Foo]]}
super().__init__(
task_type=self._TASK_TYPE,
name=name,
task_config=task_config,
interface=Interface(outputs=outputs),
**kwargs,
)
def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
return {}
foo_task = FooTask(name="foo_task")
@task
def foos_task(foos: list[Foo]) -> None:
print(f"hi {foos}")
@workflow
def dc_wf(foos: list[Foo] = [Foo(val="a"), Foo(val="b")]) -> None:
has_foos, foos = foo_task()
foos_task(foos=foos)
@dataclass
class DC:
ff: FlyteFile = field(default_factory=lambda: FlyteFile(os.path.realpath(__file__)))
sd: StructuredDataset = field(default_factory=lambda: StructuredDataset(
uri="/Users/future-outlier/code/dev/flytekit/build/debugyt/user/FE/src/data/df.parquet",
file_format="parquet"
))
fd: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(
"/Users/future-outlier/code/dev/flytekit/build/debugyt/user/FE/src/data/"
))
@task(container_image=image)
def t1(dc: DC = DC()) -> DC:
with open(dc.ff, "r") as f:
print("File Content: ", f.read())
print("sd:", dc.sd.open(pd.DataFrame).all())
df_path = os.path.join(dc.fd.path, "df.parquet")
print("fd: ", os.path.isdir(df_path))
return dc
@workflow
def wf(dc: DC = DC()):
t1(dc=dc)
@task(container_image=image)
def list_t1(list_dc: list[DC] = [DC(), DC()]) -> list[DC]:
for dc in list_dc:
with open(dc.ff, "r") as f:
print("File Content: ", f.read())
print("sd:", dc.sd.open(pd.DataFrame).all())
df_path = os.path.join(dc.fd.path, "df.parquet")
print("fd: ", os.path.isdir(df_path))
return list_dc
@workflow
def list_wf(list_dc: list[DC] = [DC(), DC()]):
list_t1(list_dc=list_dc)
@task(container_image=image)
def dict_t1(dict_dc: dict[str, DC] = {"a": DC(), "b": DC()}) -> dict[str, DC]:
for _, dc in dict_dc.items():
with open(dc.ff, "r") as f:
print("File Content: ", f.read())
print("sd:", dc.sd.open(pd.DataFrame).all())
df_path = os.path.join(dc.fd.path, "df.parquet")
print("fd: ", os.path.isdir(df_path))
return dict_dc
@workflow
def dict_wf(dict_dc: dict[str, DC] = {"a": DC(), "b": DC()}):
dict_t1(dict_dc=dict_dc)
if __name__ == "__main__":
from flytekit.clis.sdk_in_container import pyflyte
from click.testing import CliRunner
import os
# wf()
runner = CliRunner()
path = os.path.realpath(__file__)
result = runner.invoke(pyflyte.main, ["run", path, "dict_wf", ])
print("Remote Execution: ", result.output)
if lv and lv.scalar and lv.scalar.union.value.collection.literals: | ||
lits = lv.scalar.union.value.collection.literals |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you help me add more comments to let others know what this is for?
I think you might be right, but why this is related to union
but not literalmap
something like that?
if isinstance(value, dict) or isinstance(value, list): | ||
return value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why change type
to isinstance
?
return json.loads(value) | ||
if isinstance(value, str): | ||
return json.loads(value) | ||
return json.loads(json.dumps(value, default=lambda o: o.__dict__)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the type
of value
will be here?
is it only dataclass?
if only dataclass, can we use
import dataclasses; dataclasses.is_dataclass(something)
to check the value first?
https://stackoverflow.com/questions/56106116/check-if-a-class-is-a-dataclass-in-python
Tracking issue
Closes flyteorg/flyte#5985
Why are the changes needed?
The issue arises when trying to serialize lists or dictionaries that contain custom objects, as non-serializable custom objects cause a TypeError by default. This prevents successful serialization of data structures with instances of user-defined classes, necessitating a solution to convert these objects into a JSON-compatible format.
What changes were proposed in this pull request?
This pull request fixes JSON serialization issues for lists and dictionaries containing custom objects by adding support to convert non-serializable objects to JSON-compatible dictionaries.
Added default Parameter to json.dumps: Utilized
default=lambda o: o.dict
in json.dumps calls to convert non-serializable custom objects into a JSON-compatible dictionary format.How was this patch tested?
Setup process
Screenshots
Check all the applicable boxes
Related PRs
Docs link