Skip to content

Commit

Permalink
Reuse sd _set_literal with a public wrapper method
Browse files Browse the repository at this point in the history
Signed-off-by: JiaWei Jiang <[email protected]>
  • Loading branch information
JiangJiaWei1103 committed Nov 26, 2024
1 parent 6bf3dd0 commit d5d8620
Showing 1 changed file with 27 additions and 18 deletions.
45 changes: 27 additions & 18 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,12 +197,25 @@ def return_sd() -> StructuredDataset:
return df
For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/5954.
2. Need access to self._literal_sd when converting task output back to flyteidl, please see:
https://github.com/flyteorg/flytekit/blob/master/flytekit/bin/entrypoint.py#L326
For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/5956.
"""
to_literal = loop_manager.synced(flyte_dataset_transformer.async_to_literal)
self._literal_sd = to_literal(ctx, self, StructuredDataset, expected).scalar.structured_dataset
if self.metadata is None:
self._metadata = self._literal_sd.metadata

def set_literal(self, ctx: FlyteContext, expected: LiteralType) -> None:
"""
A public wrapper method to set the StructuredDataset Literal.
This method provides external access to the internal _set_literal method.
"""
return self._set_literal(ctx, expected)

def iter(self) -> Generator[DF, None, None]:
if self._dataframe_type is None:
raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.")
Expand Down Expand Up @@ -795,24 +808,18 @@ def dict_to_structured_dataset(
if uri is None:
raise ValueError("StructuredDataset's uri and file format should not be None")

# Construct models.literal.StructuredDataset
py_sd = StructuredDataset(
metadata=StructuredDatasetMetadata(
structured_dataset_type=StructuredDatasetType(format=file_format)
),
uri=uri,
)
to_literal = loop_manager.synced(flyte_dataset_transformer.async_to_literal)
# Construct python StructuredDataset
sdt = StructuredDatasetType(format=file_format)
metad = StructuredDatasetMetadata(structured_dataset_type=sdt)
sd = StructuredDataset(uri=uri, metadata=metad)

# Explicitly set StructuredDataset Literal
expected = TypeEngine.to_literal_type(StructuredDataset)
lit_sd = to_literal(ctx, py_sd, StructuredDataset, expected).scalar.structured_dataset
sd.set_literal(ctx, expected)

return StructuredDatasetTransformerEngine().to_python_value(
FlyteContextManager.current_context(),
Literal(
scalar=Scalar(
structured_dataset=lit_sd
)
),
Literal(scalar=Scalar(structured_dataset=sd._literal_sd)),
expected_python_type,
)

Expand Down Expand Up @@ -845,12 +852,14 @@ def wf(dc: DC):
"""
if binary_idl_object.tag == MESSAGEPACK:
python_val = msgpack.loads(binary_idl_object.value)
return self.dict_to_structured_dataset(ctx=ctx, dict_obj=python_val, expected_python_type=expected_python_type)
return self.dict_to_structured_dataset(
ctx=ctx, dict_obj=python_val, expected_python_type=expected_python_type
)
else:
raise TypeTransformerFailedError(f"Unsupported binary format: `{binary_idl_object.tag}`")

def from_generic_idl(
self, generic: Struct, expected_python_type: Type[T] | StructuredDataset
self, ctx: FlyteContext, generic: Struct, expected_python_type: Type[T] | StructuredDataset
) -> T | StructuredDataset:
"""
If the input is from Flyte Console, the Life Cycle will be as follows:
Expand All @@ -877,7 +886,7 @@ def wf(dc: DC):
"""
json_str = _json_format.MessageToJson(generic)
python_val = json.loads(json_str)
return self.dict_to_structured_dataset(dict_obj=python_val, expected_python_type=expected_python_type)
return self.dict_to_structured_dataset(ctx=ctx, dict_obj=python_val, expected_python_type=expected_python_type)

async def async_to_python_value(
self, ctx: FlyteContext, lv: Literal, expected_python_type: Type[T] | StructuredDataset
Expand Down Expand Up @@ -917,7 +926,7 @@ def t2(in_a: Annotated[StructuredDataset, kwtypes(col_b=float)]): ...
if lv.scalar.binary:
return self.from_binary_idl(ctx, lv.scalar.binary, expected_python_type)
if lv.scalar.generic:
return self.from_generic_idl(lv.scalar.generic, expected_python_type)
return self.from_generic_idl(ctx, lv.scalar.generic, expected_python_type)

# Detect annotations and extract out all the relevant information that the user might supply
expected_python_type, column_dict, storage_fmt, pa_schema = extract_cols_and_format(expected_python_type)
Expand Down

0 comments on commit d5d8620

Please sign in to comment.