diff --git a/tests/flytekit/integration/remote/test_remote.py b/tests/flytekit/integration/remote/test_remote.py index 0d9047294c..9cf21ae8b4 100644 --- a/tests/flytekit/integration/remote/test_remote.py +++ b/tests/flytekit/integration/remote/test_remote.py @@ -15,6 +15,7 @@ import uuid import pytest +import flytekit from flytekit import LaunchPlan, kwtypes from flytekit.configuration import Config, ImageConfig, SerializationSettings from flytekit.core.launch_plan import reference_launch_plan @@ -22,7 +23,7 @@ from flytekit.core.workflow import reference_workflow from flytekit.exceptions.user import FlyteAssertion, FlyteEntityNotExistException from flytekit.extras.sqlite3.task import SQLite3Config, SQLite3Task -from flytekit.remote.remote import FlyteRemote +from flytekit.remote.remote import FlyteRemote, _get_git_root from flytekit.types.schema import FlyteSchema MODULE_PATH = pathlib.Path(__file__).parent / "workflows/basic" @@ -608,3 +609,8 @@ def test_register_wf_fast(register): subworkflow_node_executions = execution.node_executions["n1"].subworkflow_node_executions subworkflow_node_executions["n1-0-n0"].inputs == {"a": 103} subworkflow_node_executions["n1-0-n1"].outputs == {"t1_int_output": 107, "c": "world"} + + +def test_get_git_root(): + flytekit_module = pathlib.Path(flytekit.__file__).parent + assert _get_git_root(str(flytekit_module)) == str(flytekit_module.parent) diff --git a/tests/flytekit/unit/test_translator.py b/tests/flytekit/unit/test_translator.py index 792ca0b131..3ab3f488e4 100644 --- a/tests/flytekit/unit/test_translator.py +++ b/tests/flytekit/unit/test_translator.py @@ -1,9 +1,12 @@ import typing from collections import OrderedDict +from pathlib import Path + +import mock import flytekit.configuration from flytekit import ContainerTask, Resources -from flytekit.configuration import FastSerializationSettings, Image, ImageConfig +from flytekit.configuration import FastSerializationSettings, Image, ImageConfig, SerializationSettings from flytekit.core.base_task import kwtypes from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan from flytekit.core.reference_entity import ReferenceSpec, ReferenceTemplate @@ -11,7 +14,7 @@ from flytekit.core.workflow import ReferenceWorkflow, workflow from flytekit.models.core import identifier as identifier_models from flytekit.models.task import Resources as resource_model -from flytekit.tools.translator import get_serializable +from flytekit.tools.translator import get_serializable, _get_git_link, _is_file_pushed default_img = Image(name="default", fqn="test", tag="tag") serialization_settings = flytekit.configuration.SerializationSettings( @@ -166,3 +169,32 @@ def morning_greeter_caller(day_of_week: str) -> str: assert len(task_spec.template.interface.outputs) == 1 assert len(task_spec.template.nodes) == 1 assert len(task_spec.template.nodes[0].inputs) == 2 + + +@mock.patch("flytekit.remote.remote._get_git_root") +@mock.patch("flytekit.tools.translator._is_file_pushed") +def test_get_git_link(mock_is_file_pushed, mock_get_git_root): + ss = SerializationSettings( + project="test_proj", + domain="test_domain", + version="abc", + image_config=ImageConfig(Image(name="name", fqn="image", tag="name")), + env={}, + ) + + assert _get_git_link(module="flytekit/workflow.py", settings=ss) is None + + mock_is_file_pushed.return_value = True + mock_get_git_root.return_value = "flytekit" + ss.git_repo = "https://github.com/flyteorg/flytekit/blob/master" + + assert _get_git_link(module="flytekit/workflow.py", settings=ss) == "https://github.com/flyteorg/flytekit/blob/master/workflow.py" + + mock_is_file_pushed.return_value = False + assert _get_git_link(module="flytekit/workflow.py", settings=ss) is None + + +def test_is_file_pushed(): + module = str(Path(flytekit.__file__).parent.parent) + assert _is_file_pushed(module, "workflow.py") is False + assert _is_file_pushed(module, "NOTICE") is True