Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Prevent UnionTransformer type ambiguity in combination with PyTorchTypeTransformer #2726

Merged
merged 2 commits into from
Sep 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions flytekit/extras/pytorch/native.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
python_type: Type[T],
expected: LiteralType,
) -> Literal:
if not isinstance(python_val, torch.Tensor) and not isinstance(python_val, torch.nn.Module):
raise TypeTransformerFailedError("Expected a torch.Tensor or nn.Module")

Check warning on line 32 in flytekit/extras/pytorch/native.py

View check run for this annotation

Codecov / codecov/patch

flytekit/extras/pytorch/native.py#L32

Added line #L32 was not covered by tests

meta = BlobMetadata(
type=_core_types.BlobType(
format=self.PYTORCH_FORMAT,
Expand Down
44 changes: 43 additions & 1 deletion tests/flytekit/unit/extras/pytorch/test_transformations.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from collections import OrderedDict
from typing import Union

import pytest
import torch

import flytekit
from flytekit import task
from flytekit import task, workflow
from flytekit.configuration import Image, ImageConfig
from flytekit.core import context_manager
from flytekit.core.type_engine import TypeTransformerFailedError
from flytekit.extras.pytorch import (
PyTorchCheckpoint,
PyTorchCheckpointTransformer,
Expand All @@ -18,6 +20,7 @@
from flytekit.models.types import LiteralType
from flytekit.tools.translator import get_serializable


default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = flytekit.configuration.SerializationSettings(
project="project",
Expand Down Expand Up @@ -130,3 +133,42 @@ def t1() -> PyTorchCheckpoint:
task_spec.template.interface.outputs["o0"].type.blob.format
is PyTorchCheckpointTransformer.PYTORCH_CHECKPOINT_FORMAT
)


def test_to_literal_unambiguity():
"""Test that the pytorch type transformers raise an error when the input is a list of tensors or modules.

The PyTorchTypeTransformer uses `torch.save` for serialization which is able to serialize a list of tensors
or modules but this causes ambiguity in the Union type transformer as it cannot distinguish whether the
ListTransformer should invoke the PyTorchTypeTransformer for every element in the list or the
PyTorchTypeTransformer for the entire list.
"""
ctx = context_manager.FlyteContext.current_context()

with pytest.raises(TypeTransformerFailedError):
test_inp = torch.tensor([1, 2, 3])
trans = PyTorchTensorTransformer()
trans.to_literal(ctx, [test_inp], torch.Tensor, trans.get_literal_type(torch.Tensor))


with pytest.raises(TypeTransformerFailedError):
model = torch.nn.Linear(2, 2)
trans = PyTorchModuleTransformer()
trans.to_literal(ctx, [model], torch.nn.Module, trans.get_literal_type(torch.nn.Module))


def test_torch_tensor_list_union():
"""Test that a task can return a union of list of tensor and tensor.

See test_to_literal_unambiguity for more details why this failed.
"""

@task
def foo() -> Union[list[torch.Tensor], torch.Tensor]:
return [torch.tensor([1, 2, 3])]

@workflow
def wf():
foo()

wf()
Loading