Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fix JSONParamType to handle serialization of custom objects #2931

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

dalaoqi
Copy link

@dalaoqi dalaoqi commented Nov 14, 2024

Tracking issue

Closes flyteorg/flyte#5985

Why are the changes needed?

The issue arises when trying to serialize lists or dictionaries that contain custom objects, as non-serializable custom objects cause a TypeError by default. This prevents successful serialization of data structures with instances of user-defined classes, necessitating a solution to convert these objects into a JSON-compatible format.

What changes were proposed in this pull request?

This pull request fixes JSON serialization issues for lists and dictionaries containing custom objects by adding support to convert non-serializable objects to JSON-compatible dictionaries.

Added default Parameter to json.dumps: Utilized default=lambda o: o.dict in json.dumps calls to convert non-serializable custom objects into a JSON-compatible dictionary format.

How was this patch tested?

Setup process

Screenshots

Check all the applicable boxes

  • I updated the documentation accordingly.
  • All new and existing tests passed.
  • All commits are signed-off.

Related PRs

Docs link

@dalaoqi dalaoqi force-pushed the 5985-bug-with-customObject-in-dict-and-list branch from a155053 to 36f98ed Compare November 14, 2024 16:46
Copy link

codecov bot commented Nov 15, 2024

Codecov Report

Attention: Patch coverage is 0% with 4 lines in your changes missing coverage. Please review.

Project coverage is 46.75%. Comparing base (3f0ab84) to head (36f98ed).
Report is 4 commits behind head on master.

Files with missing lines Patch % Lines
flytekit/interaction/click_types.py 0.00% 4 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master    #2931       +/-   ##
===========================================
- Coverage   76.33%   46.75%   -29.59%     
===========================================
  Files         199      199               
  Lines       20840    20789       -51     
  Branches     2681     2684        +3     
===========================================
- Hits        15908     9719     -6189     
- Misses       4214    10594     +6380     
+ Partials      718      476      -242     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

- Added a fallback check for `lv.scalar.union.value.collection.literals` in case `lv.collection` is None.
- Prevents potential errors when `lv.collection` is not properly initialized.

Signed-off-by: Tsung-Han Ho (dalaoqi) <[email protected]>
- Added validation to check for duplicate input names in the function's input interface.
- Raised a ValueError if duplicate input names are detected to prevent issues during argument assignment.
- Simplified the conversion of args to kwargs by removing the redundant multiple values check.

Signed-off-by: Tsung-Han Ho (dalaoqi) <[email protected]>
@dalaoqi dalaoqi force-pushed the 5985-bug-with-customObject-in-dict-and-list branch from 2ad9c6c to 1ce0704 Compare November 16, 2024 15:13
Copy link
Member

@Future-Outlier Future-Outlier left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've tested 4 cases and all of them worked!
can you help me add some unit tests?

pyflyte run build/taiwan/2931/example1.py wf
pyflyte run build/taiwan/2931/example1.py dict_wf
pyflyte run build/taiwan/2931/example1.py list_wf
pyflyte run build/taiwan/2931/example1.py dc_wf
from textwrap import shorten

from flytekit import task, workflow
from dataclasses import dataclass, field
from typing import Any, Dict, Optional, List
from flytekit.core.type_engine import TypeEngine
from dataclasses_json import dataclass_json
from flyteidl.core.execution_pb2 import TaskExecution
from flytekit.core.context_manager import FlyteContextManager
from flytekit.configuration import SerializationSettings
from flytekit.core.base_task import PythonTask
from flytekit.core.interface import Interface
from flytekit.extend.backend.base_agent import (
    AgentRegistry,
    Resource,
    SyncAgentBase,
    SyncAgentExecutorMixin,
)
from flytekit.models.literals import LiteralMap
from flytekit.models.task import TaskTemplate

from flytekit.types.file import FlyteFile
from flytekit.types.directory import FlyteDirectory
from flytekit.types.structured import StructuredDataset
from flytekit.types.schema import FlyteSchema
import pandas as pd
import os

image = None

@dataclass_json
@dataclass
class Foo:
    val: str


class FooAgent(SyncAgentBase):
    def __init__(self) -> None:
        super().__init__(task_type_name="foo")

    def do(
        self,
        task_template: TaskTemplate,
        inputs: Optional[LiteralMap] = None,
        **kwargs: Any,
    ) -> Resource:
        return Resource(
            phase=TaskExecution.SUCCEEDED, outputs={"foos": [Foo(val="a"), Foo(val="b")], "has_foos": True}
        )


AgentRegistry.register(FooAgent())


class FooTask(SyncAgentExecutorMixin, PythonTask):  # type: ignore
    _TASK_TYPE = "foo"

    def __init__(self, name: str, **kwargs: Any) -> None:
        task_config: dict[str, Any] = {}

        outputs = {"has_foos": bool, "foos": Optional[List[Foo]]}

        super().__init__(
            task_type=self._TASK_TYPE,
            name=name,
            task_config=task_config,
            interface=Interface(outputs=outputs),
            **kwargs,
        )

    def get_custom(self, settings: SerializationSettings) -> Dict[str, Any]:
        return {}


foo_task = FooTask(name="foo_task")


@task
def foos_task(foos: list[Foo]) -> None:
    print(f"hi {foos}")


@workflow
def dc_wf(foos: list[Foo] = [Foo(val="a"), Foo(val="b")]) -> None:
    has_foos, foos = foo_task()
    foos_task(foos=foos)



@dataclass
class DC:
    ff: FlyteFile = field(default_factory=lambda: FlyteFile(os.path.realpath(__file__)))
    sd: StructuredDataset = field(default_factory=lambda: StructuredDataset(
        uri="/Users/future-outlier/code/dev/flytekit/build/debugyt/user/FE/src/data/df.parquet",
        file_format="parquet"
    ))
    fd: FlyteDirectory = field(default_factory=lambda: FlyteDirectory(
        "/Users/future-outlier/code/dev/flytekit/build/debugyt/user/FE/src/data/"
    ))


@task(container_image=image)
def t1(dc: DC = DC()) -> DC:
    with open(dc.ff, "r") as f:
        print("File Content: ", f.read())

    print("sd:", dc.sd.open(pd.DataFrame).all())

    df_path = os.path.join(dc.fd.path, "df.parquet")
    print("fd: ", os.path.isdir(df_path))

    return dc

@workflow
def wf(dc: DC = DC()):
    t1(dc=dc)

@task(container_image=image)
def list_t1(list_dc: list[DC] = [DC(), DC()]) -> list[DC]:
    for dc in list_dc:
        with open(dc.ff, "r") as f:
            print("File Content: ", f.read())

        print("sd:", dc.sd.open(pd.DataFrame).all())

        df_path = os.path.join(dc.fd.path, "df.parquet")
        print("fd: ", os.path.isdir(df_path))

    return list_dc

@workflow
def list_wf(list_dc: list[DC] = [DC(), DC()]):
    list_t1(list_dc=list_dc)

@task(container_image=image)
def dict_t1(dict_dc: dict[str, DC] = {"a": DC(), "b": DC()}) -> dict[str, DC]:
    for _, dc in dict_dc.items():
        with open(dc.ff, "r") as f:
            print("File Content: ", f.read())

        print("sd:", dc.sd.open(pd.DataFrame).all())

        df_path = os.path.join(dc.fd.path, "df.parquet")
        print("fd: ", os.path.isdir(df_path))

    return dict_dc

@workflow
def dict_wf(dict_dc: dict[str, DC] = {"a": DC(), "b": DC()}):
    dict_t1(dict_dc=dict_dc)

if __name__ == "__main__":
    from flytekit.clis.sdk_in_container import pyflyte
    from click.testing import CliRunner
    import os

    # wf()

    runner = CliRunner()
    path = os.path.realpath(__file__)
    result = runner.invoke(pyflyte.main, ["run", path, "dict_wf", ])
    print("Remote Execution: ", result.output)

Comment on lines +1676 to +1677
if lv and lv.scalar and lv.scalar.union.value.collection.literals:
lits = lv.scalar.union.value.collection.literals
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you help me add more comments to let others know what this is for?
I think you might be right, but why this is related to union but not literalmap something like that?

Comment on lines +308 to 309
if isinstance(value, dict) or isinstance(value, list):
return value
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change type to isinstance?

return json.loads(value)
if isinstance(value, str):
return json.loads(value)
return json.loads(json.dumps(value, default=lambda o: o.__dict__))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the type of value will be here?
is it only dataclass?
if only dataclass, can we use
import dataclasses; dataclasses.is_dataclass(something) to check the value first?

https://stackoverflow.com/questions/56106116/check-if-a-class-is-a-dataclass-in-python

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] dataclass in list default input error
2 participants