Skip to content

Commit

Permalink
save from merge
Browse files Browse the repository at this point in the history
Signed-off-by: Mecoli1219 <[email protected]>
  • Loading branch information
Mecoli1219 committed Aug 14, 2024
1 parent 7c8232f commit 6f51bce
Show file tree
Hide file tree
Showing 7 changed files with 100 additions and 104 deletions.
6 changes: 2 additions & 4 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -802,21 +802,19 @@ def execute(self, **kwargs) -> Any:
"""
pass

def remote(self, version: Optional[str] = None, options: Optional[Options] = None, **kwargs) -> FlyteFuture:
def remote(self, options: Optional[Options] = None, **kwargs) -> FlyteFuture:
"""
This method will be invoked to execute the task remotely. This will return a FlyteFuture object that can be
used to track the progress of the task execution.
This method should be executed after specifying the remote configuration via `flytekit.remote.init_remote()`.
:param version: an optional version string to fetch or register the task. If not specified, it will randomly
generate a version string.
:param options: an optional options that can be used to override the default options of the task. If not
specified, the default options provided by `init_remote()` will be used.
:param kwargs: Dict[str, Any] the inputs to the task. The inputs should match the signature of the task.
:return: FlyteFuture
"""
return FlyteFuture(self, version=version, options=options, **kwargs)
return FlyteFuture(self, options=options, **kwargs)

def post_execute(self, user_params: Optional[ExecutionParameters], rval: Any) -> Any:
"""
Expand Down
79 changes: 60 additions & 19 deletions flytekit/core/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,14 @@
from datetime import timedelta
from typing import TYPE_CHECKING

import click
import cloudpickle

if TYPE_CHECKING:
from IPython.display import IFrame

from flytekit.core.base_task import PythonTask
from flytekit.core.type_engine import LiteralsResolver
from flytekit.core.workflow import WorkflowBase
from flytekit.remote.executions import FlyteWorkflowExecution
from flytekit.tools.translator import Options
Expand All @@ -22,7 +26,6 @@ class FlyteFuture:
def __init__(
self,
entity: typing.Union[PythonTask, WorkflowBase],
version: typing.Optional[str] = None,
options: typing.Optional[Options] = None,
**kwargs,
):
Expand All @@ -31,6 +34,7 @@ def __init__(
This object requires the FlyteRemote client to be initialized before it can be used. The FlyteRemote client
can be initialized by calling `flytekit.remote.init_remote()`.
"""
from flytekit.core.base_task import PythonTask
from flytekit.remote.init_remote import REMOTE_DEFAULT_OPTIONS, REMOTE_ENTRY
from flytekit.tools.script_mode import hash_file

Expand All @@ -40,23 +44,30 @@ def __init__(
)
self._remote_entry = REMOTE_ENTRY

if version is None:
with tempfile.TemporaryDirectory() as tmp_dir:
dest = pathlib.Path(tmp_dir, "pkl.gz")
with gzip.GzipFile(filename=dest, mode="wb", mtime=0) as gzipped:
cloudpickle.dump(entity, gzipped)
md5_bytes, _, _ = hash_file(dest)
with tempfile.TemporaryDirectory() as tmp_dir:
dest = pathlib.Path(tmp_dir, "pkl.gz")
with gzip.GzipFile(filename=dest, mode="wb", mtime=0) as gzipped:
cloudpickle.dump(entity, gzipped)
md5_bytes, _, _ = hash_file(dest)

h = hashlib.md5(md5_bytes)
version = base64.urlsafe_b64encode(h.digest()).decode("ascii").rstrip("=")
h = hashlib.md5(md5_bytes)
version = base64.urlsafe_b64encode(h.digest()).decode("ascii").rstrip("=")

if options is None:
options = REMOTE_DEFAULT_OPTIONS

self._version = version
self._exe = self._remote_entry.execute(entity, version=version, inputs=kwargs, options=options)
self._is_task = isinstance(entity, PythonTask)
console_url = self._remote_entry.generate_console_url(self._exe)
s = (
click.style("\n[✔] ", fg="green")
+ "Go to "
+ click.style(console_url, fg="cyan")
+ " to see execution in the console."
)
click.echo(s)

def wait(
def __wait(
self,
timeout: typing.Optional[timedelta] = None,
poll_interval: typing.Optional[timedelta] = None,
Expand All @@ -75,12 +86,42 @@ def wait(
sync_nodes=sync_nodes,
)

@property
def version(self) -> str:
"""The version of the task or workflow being executed."""
return self._version
def get(
self,
timeout: typing.Optional[timedelta] = None,
) -> typing.Optional[LiteralsResolver]:
"""Wait for the execution to complete and return the output.
:param timeout: maximum amount of time to wait
"""
out = self.__wait(
timeout=timeout,
poll_interval=timedelta(seconds=5),
)
return out.outputs

def get_deck(
self,
timeout: typing.Optional[timedelta] = None,
) -> typing.Optional[IFrame]:
"""Wait for the execution to complete and return the deck for the task.
:param timeout: maximum amount of time to wait
"""
if not self._is_task:
raise ValueError("Deck can only be retrieved for task executions.")
from IPython.display import IFrame

self.__wait(
timeout=timeout,
poll_interval=timedelta(seconds=5),
)
for node_execution in self._exe.node_executions.values():
uri = node_execution._closure.deck_uri
if uri:
break
if uri == "":
raise ValueError("Deck not found for task execution")

@property
def exe(self) -> FlyteWorkflowExecution:
"""The executing FlyteWorkflowExecution object."""
return self._exe
deck_uri = self._remote_entry.client.get_download_signed_url(uri)
return IFrame(src=deck_uri.signed_url, width=800, height=600)
6 changes: 2 additions & 4 deletions flytekit/core/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,21 +306,19 @@ def __call__(self, *args, **kwargs) -> Union[Tuple[Promise], Promise, VoidPromis
def execute(self, **kwargs):
raise Exception("Should not be called")

def remote(self, version: Optional[str] = None, options: Optional[Options] = None, **kwargs) -> FlyteFuture:
def remote(self, options: Optional[Options] = None, **kwargs) -> FlyteFuture:
"""
This method will be invoked to execute the workflow remotely. This will return a FlyteFuture object that can be
used to track the progress of the workflow execution.
This method should be executed after specifying the remote configuration via `flytekit.remote.init_remote()`.
:param version: an optional version string to fetch or register the workflow. If not specified, it will randomly
generate a version string.
:param options: an optional options that can be used to override the default options of the workflow. If not
specified, the default options provided by `init_remote()` will be used.
:param kwargs: Dict[str, Any] the inputs to the workflow. The inputs should match the signature of the workflow.
:return: FlyteFuture
"""
return FlyteFuture(self, version=version, options=options, **kwargs)
return FlyteFuture(self, options=options, **kwargs)

def compile(self, **kwargs):
pass
Expand Down
23 changes: 10 additions & 13 deletions flytekit/remote/init_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,19 +37,16 @@ def init_remote(
"""
global REMOTE_ENTRY, REMOTE_DEFAULT_OPTIONS
with REMOTE_ENTRY_LOCK:
if REMOTE_ENTRY is None:
REMOTE_ENTRY = FlyteRemote(
config=config,
default_project=default_project,
default_domain=default_domain,
data_upload_location=data_upload_location,
interactive_mode_enabled=interactive_mode_enabled,
**kwargs,
)
# TODO: This should be merged into the FlyteRemote in the future
REMOTE_DEFAULT_OPTIONS = default_options
else:
raise AssertionError("Remote client already initialized")
REMOTE_ENTRY = FlyteRemote(
config=config,
default_project=default_project,
default_domain=default_domain,
data_upload_location=data_upload_location,
interactive_mode_enabled=interactive_mode_enabled,
**kwargs,
)
# TODO: This should be merged into the FlyteRemote in the future
REMOTE_DEFAULT_OPTIONS = default_options

# Set the log level
log_level = get_level_from_cli_verbosity(verbosity)
Expand Down
12 changes: 6 additions & 6 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
from flytekit.remote.remote_fs import get_flyte_fs
from flytekit.tools.fast_registration import FastPackageOptions, fast_package
from flytekit.tools.interactive import ipython_check
from flytekit.tools.script_mode import _find_project_root, compress_scripts, hash_file
from flytekit.tools.script_mode import _find_project_root, compress_scripts, get_all_modules, hash_file
from flytekit.tools.translator import (
FlyteControlPlaneEntity,
FlyteLocalEntity,
Expand Down Expand Up @@ -210,7 +210,7 @@ def __init__(
:param default_domain: default domain to use when fetching or executing flyte entities.
:param data_upload_location: this is where all the default data will be uploaded when providing inputs.
The default location - `s3://my-s3-bucket/data` works for sandbox/demo environment. Please override this for non-sandbox cases.
:param interactive_mode_enabled: If set to True, the FlyteRemote will pickle the task/workflow objects if not found
:param interactive_mode_enabled: If set to True, the FlyteRemote will pickle the task/workflow.
"""
if config is None or config.platform is None or config.platform.endpoint is None:
raise user_exceptions.FlyteAssertion("Flyte endpoint should be provided.")
Expand Down Expand Up @@ -270,7 +270,7 @@ def file_access(self) -> FileAccessProvider:

@property
def interactive_mode_enabled(self) -> bool:
"""If set to True, the FlyteRemote will pickle the task/workflow objects if not found."""
"""If set to True, the FlyteRemote will pickle the task/workflow."""
return self._interactive_mode_enabled

def get(
Expand Down Expand Up @@ -1049,7 +1049,7 @@ def register_script(
)
else:
archive_fname = pathlib.Path(os.path.join(tmp_dir, "script_mode.tar.gz"))
compress_scripts(source_path, str(archive_fname), module_name)
compress_scripts(source_path, str(archive_fname), get_all_modules(source_path, module_name))
md5_bytes, upload_native_url = self.upload_file(
archive_fname, project or self.default_project, domain or self.default_domain
)
Expand Down Expand Up @@ -2124,7 +2124,7 @@ def sync_node_execution(
if node_id in node_mapping:
execution._node = node_mapping[node_id]
else:
raise Exception(f"Missing node from mapping: {node_id}")
raise ValueError(f"Missing node from mapping: {node_id}")

# Get the node execution data
node_execution_get_data_response = self.client.get_node_execution_data(execution.id)
Expand Down Expand Up @@ -2219,7 +2219,7 @@ def sync_node_execution(
return execution
else:
logger.error(f"NE {execution} undeterminable, {type(execution._node)}, {execution._node}")
raise Exception(f"Node execution undeterminable, entity has type {type(execution._node)}")
raise ValueError(f"Node execution undeterminable, entity has type {type(execution._node)}")

# Handle the case for gate nodes
elif execution._node.gate_node is not None:
Expand Down
11 changes: 4 additions & 7 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,16 +163,13 @@ def get_command_prefix_for_fast_execute(settings: SerializationSettings) -> List
if settings.fast_serialization_settings and settings.fast_serialization_settings.distribution_location
else "{{ .remote_package_path }}"
),
"--dest-dir",
(
settings.fast_serialization_settings.destination_dir
if settings.fast_serialization_settings and settings.fast_serialization_settings.destination_dir
else "{{ .dest_dir }}"
),
]
# If pickling is enabled, we will add a pickled bit
if settings.fast_serialization_settings and settings.fast_serialization_settings.pickled:
prefix = prefix + ["--pickled"]
elif settings.fast_serialization_settings and settings.fast_serialization_settings.destination_dir:
prefix = prefix + ["--dest-dir", settings.fast_serialization_settings.destination_dir]
else:
prefix = prefix + ["--dest-dir", "{{ .dest_dir }}"]

return prefix + ["--"]

Expand Down
67 changes: 16 additions & 51 deletions tests/flytekit/integration/remote/test_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,54 +586,19 @@ def test_flyteremote_uploads_large_file(gigabytes):
def test_workflow_remote_func(mock_ipython_check):
"""Test the logic of the remote execution of workflows and tasks."""
mock_ipython_check.return_value = True
with pytest.raises(AssertionError):
init_remote(Config.auto(config_file=CONFIG), PROJECT, DOMAIN, interactive_mode_enabled=True)
from .workflows.basic.child_workflow import parent_wf, double

# child_workflow.parent_wf asynchronously register a parent wf1 with child lp from another wf2.
future0 = double.remote(a=3)
future1 = parent_wf.remote(a=3)
future2 = parent_wf.remote(a=2)
assert future0.version != VERSION
assert future1.version != VERSION
assert future2.version != VERSION
# It should generate a new version for each execution
assert future1.version != future2.version

out0 = future0.wait()
assert out0.outputs["o0"] == 6
out1 = future1.wait()
assert out1.outputs["o0"] == 18
out2 = future2.wait()
assert out2.outputs["o0"] == 12


def test_fetch_python_task_remote_func(register):
"""Test remote execution of a @task-decorated python function that is already registered."""
with patch("flytekit.tools.interactive.ipython_check") as mock_ipython_check:
mock_ipython_check.return_value = True

from .workflows.basic.basic_workflow import t1

future = t1.remote(a=10, version=VERSION)
out = future.wait()
assert future.version == VERSION

assert out.outputs["t1_int_output"] == 12
assert out.outputs["c"] == "world"


@pytest.mark.skip(reason="Waiting for supporting the `name` parameter in the remote function")
def test_fetch_python_workflow_remote_func(register):
"""Test remote execution of a @workflow-decorated python function that is already registered."""
with patch("flytekit.tools.interactive.ipython_check") as mock_ipython_check:
mock_ipython_check.return_value = True
from .workflows.basic.basic_workflow import my_basic_wf

future = my_basic_wf.remote(a=10, b="xyz", version=VERSION)
out = future.wait()
assert out.outputs["o0"] == 12
assert out.outputs["o1"] == "xyzworld"
out0 = future0.get()
assert out0["o0"] == 6
out1 = future1.get()
assert out1["o0"] == 18
out2 = future2.get()
assert out2["o0"] == 12


@mock.patch("flytekit.tools.interactive.ipython_check")
Expand All @@ -644,8 +609,8 @@ def test_execute_task_remote_func_list_of_floats(mock_ipython_check):

xs: typing.List[float] = [0.1, 0.2, 0.3, 0.4, -99999.7]
future = concat_list.remote(xs=xs)
out = future.wait()
assert out.outputs["o0"] == "[0.1, 0.2, 0.3, 0.4, -99999.7]"
out = future.get()
assert out["o0"] == "[0.1, 0.2, 0.3, 0.4, -99999.7]"


@mock.patch("flytekit.tools.interactive.ipython_check")
Expand All @@ -656,8 +621,8 @@ def test_execute_task_remote_func_convert_dict(mock_ipython_check):

d: typing.Dict[str, str] = {"key1": "value1", "key2": "value2"}
future = convert_to_string.remote(d=d)
out = future.wait()
assert json.loads(out.outputs["o0"]) == {"key1": "value1", "key2": "value2"}
out = future.get()
assert json.loads(out["o0"]) == {"key1": "value1", "key2": "value2"}


@mock.patch("flytekit.tools.interactive.ipython_check")
Expand All @@ -668,8 +633,8 @@ def test_execute_python_workflow_remote_func_dict_of_string_to_string(mock_ipyth

d: typing.Dict[str, str] = {"k1": "v1", "k2": "v2"}
future = my_dict_str_wf.remote(d=d)
out = future.wait()
assert json.loads(out.outputs["o0"]) == {"k1": "v1", "k2": "v2"}
out = future.get()
assert json.loads(out["o0"]) == {"k1": "v1", "k2": "v2"}


@mock.patch("flytekit.tools.interactive.ipython_check")
Expand All @@ -681,8 +646,8 @@ def test_execute_python_workflow_remote_func_list_of_floats(mock_ipython_check):

xs: typing.List[float] = [42.24, 999.1, 0.0001]
future = my_list_float_wf.remote(xs=xs)
out = future.wait()
assert out.outputs["o0"] == "[42.24, 999.1, 0.0001]"
out = future.get()
assert out["o0"] == "[42.24, 999.1, 0.0001]"

@mock.patch("flytekit.tools.interactive.ipython_check")
def test_execute_workflow_remote_fn_with_maptask(mock_ipython_check):
Expand All @@ -692,5 +657,5 @@ def test_execute_workflow_remote_fn_with_maptask(mock_ipython_check):

d: typing.List[int] = [1, 2, 3]
future = workflow_with_maptask.remote(data=d, y=3)
out = future.wait()
assert out.outputs["o0"] == [4, 5, 6]
out = future.get()
assert out["o0"] == [4, 5, 6]

0 comments on commit 6f51bce

Please sign in to comment.