Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Use models literal StructuredDataset to enable sd bypass task #2954

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,25 @@ def return_sd() -> StructuredDataset:
return df
For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/5954.
2. Need access to self._literal_sd when converting task output back to flyteidl, please see:
https://github.com/flyteorg/flytekit/blob/master/flytekit/bin/entrypoint.py#L326
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use github permanent link?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, let me fix it. Thanks!

For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/5956.
"""
to_literal = loop_manager.synced(flyte_dataset_transformer.async_to_literal)
self._literal_sd = to_literal(ctx, self, StructuredDataset, expected).scalar.structured_dataset
if self.metadata is None:
self._metadata = self._literal_sd.metadata

def set_literal(self, ctx: FlyteContext, expected: LiteralType) -> None:
"""
A public wrapper method to set the StructuredDataset Literal.
This method provides external access to the internal _set_literal method.
"""
return self._set_literal(ctx, expected)

def iter(self) -> Generator[DF, None, None]:
if self._dataframe_type is None:
raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.")
Expand Down Expand Up @@ -787,31 +800,31 @@ def encode(
return lit

def dict_to_structured_dataset(
self, dict_obj: typing.Dict[str, str], expected_python_type: Type[T] | StructuredDataset
self, ctx: FlyteContext, dict_obj: typing.Dict[str, str], expected_python_type: Type[T] | StructuredDataset
) -> T | StructuredDataset:
uri = dict_obj.get("uri", None)
file_format = dict_obj.get("file_format", None)

if uri is None:
raise ValueError("StructuredDataset's uri and file format should not be None")

# Construct python StructuredDataset
sdt = StructuredDatasetType(format=file_format)
metad = StructuredDatasetMetadata(structured_dataset_type=sdt)
sd = StructuredDataset(uri=uri, metadata=metad)

# Explicitly set StructuredDataset Literal
expected = TypeEngine.to_literal_type(StructuredDataset)
sd.set_literal(ctx, expected)

return StructuredDatasetTransformerEngine().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
structured_dataset=StructuredDataset(
metadata=StructuredDatasetMetadata(
structured_dataset_type=StructuredDatasetType(format=file_format)
),
uri=uri,
)
)
),
Literal(scalar=Scalar(structured_dataset=sd._literal_sd)),
expected_python_type,
)

def from_binary_idl(
self, binary_idl_object: Binary, expected_python_type: Type[T] | StructuredDataset
self, ctx: FlyteContext, binary_idl_object: Binary, expected_python_type: Type[T] | StructuredDataset
) -> T | StructuredDataset:
"""
If the input is from flytekit, the Life Cycle will be as follows:
Expand Down Expand Up @@ -839,12 +852,14 @@ def wf(dc: DC):
"""
if binary_idl_object.tag == MESSAGEPACK:
python_val = msgpack.loads(binary_idl_object.value)
return self.dict_to_structured_dataset(dict_obj=python_val, expected_python_type=expected_python_type)
return self.dict_to_structured_dataset(
ctx=ctx, dict_obj=python_val, expected_python_type=expected_python_type
)
else:
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")

def from_generic_idl(
self, generic: Struct, expected_python_type: Type[T] | StructuredDataset
self, ctx: FlyteContext, generic: Struct, expected_python_type: Type[T] | StructuredDataset
) -> T | StructuredDataset:
"""
If the input is from Flyte Console, the Life Cycle will be as follows:
Expand All @@ -871,7 +886,7 @@ def wf(dc: DC):
"""
json_str = _json_format.MessageToJson(generic)
python_val = json.loads(json_str)
return self.dict_to_structured_dataset(dict_obj=python_val, expected_python_type=expected_python_type)
return self.dict_to_structured_dataset(ctx=ctx, dict_obj=python_val, expected_python_type=expected_python_type)

async def async_to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] | StructuredDataset
Expand Down Expand Up @@ -909,9 +924,9 @@ def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ...
# Handle dataclass attribute access
if lv.scalar:
if lv.scalar.binary:
return self.from_binary_idl(lv.scalar.binary, expected_python_type)
return self.from_binary_idl(ctx, lv.scalar.binary, expected_python_type)
if lv.scalar.generic:
return self.from_generic_idl(lv.scalar.generic, expected_python_type)
return self.from_generic_idl(ctx, lv.scalar.generic, expected_python_type)

# Detect annotations and extract out all the relevant information that the user might supply
expected_python_type, column_dict, storage_fmt, pa_schema = extract_cols_and_format(expected_python_type)
Expand Down
Loading