Skip to content

Commit

Permalink
Async/tasks remote (#2964)
Browse files Browse the repository at this point in the history
auto detect flyte remote connection

Signed-off-by: Yee Hing Tong <[email protected]>
  • Loading branch information
wild-endeavor authored Nov 27, 2024
1 parent 45f07b0 commit eea2605
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 24 deletions.
17 changes: 15 additions & 2 deletions flytekit/configuration/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
```
"""

import os
from typing import Optional, Protocol, runtime_checkable

from click import Group
Expand Down Expand Up @@ -59,10 +60,22 @@ def get_remote(
config: Optional[str], project: str, domain: str, data_upload_location: Optional[str] = None
) -> FlyteRemote:
"""Get FlyteRemote object for CLI session."""

cfg_file = get_config_file(config)

# The assumption here (if there's no config file that means we want sandbox) is too broad.
# todo: can improve this in the future, rather than just checking one env var, auto() with
# nothing configured should probably not return sandbox but can consider
if cfg_file is None:
cfg_obj = Config.for_sandbox()
logger.info("No config files found, creating remote with sandbox config")
# We really are just looking for endpoint, client_id, and client_secret. These correspond to the env vars
# FLYTE_PLATFORM_URL, FLYTE_CREDENTIALS_CLIENT_ID, FLYTE_CREDENTIALS_CLIENT_SECRET
# auto() should pick these up.
if "FLYTE_PLATFORM_URL" in os.environ:
cfg_obj = Config.auto(None)
logger.warning(f"Auto-created config object to pick up env vars {cfg_obj}")

Check warning on line 75 in flytekit/configuration/plugin.py

View check run for this annotation

Codecov / codecov/patch

flytekit/configuration/plugin.py#L74-L75

Added lines #L74 - L75 were not covered by tests
else:
cfg_obj = Config.for_sandbox()
logger.info("No config files found, creating remote with sandbox config")

Check warning on line 78 in flytekit/configuration/plugin.py

View check run for this annotation

Codecov / codecov/patch

flytekit/configuration/plugin.py#L77-L78

Added lines #L77 - L78 were not covered by tests
else: # pragma: no cover
cfg_obj = Config.auto(config)
logger.debug(f"Creating remote with config {cfg_obj}" + (f" with file {config}" if config else ""))
Expand Down
2 changes: 2 additions & 0 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ class TaskMetadata(object):
retries: int = 0
timeout: Optional[Union[datetime.timedelta, int]] = None
pod_template_name: Optional[str] = None
is_eager: bool = False

def __post_init__(self):
if self.timeout:
Expand Down Expand Up @@ -180,6 +181,7 @@ def to_taskmetadata_model(self) -> _task_model.TaskMetadata:
cache_serializable=self.cache_serialize,
pod_template_name=self.pod_template_name,
cache_ignore_input_vars=self.cache_ignore_input_vars,
is_eager=self.is_eager,
)


Expand Down
43 changes: 25 additions & 18 deletions flytekit/core/python_function_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from flytekit.configuration import ImageConfig, SerializationSettings
from flytekit.core import launch_plan as _annotated_launch_plan
from flytekit.core.base_task import Task, TaskResolverMixin
from flytekit.core.base_task import Task, TaskMetadata, TaskResolverMixin
from flytekit.core.constants import EAGER_ROOT_ENV_NAME
from flytekit.core.context_manager import ExecutionState, FlyteContext, FlyteContextManager
from flytekit.core.docstring import Docstring
Expand Down Expand Up @@ -455,6 +455,11 @@ def __init__(
if "execution_mode" in kwargs:
del kwargs["execution_mode"]

Check warning on line 456 in flytekit/core/python_function_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/python_function_task.py#L456

Added line #L456 was not covered by tests

if "metadata" in kwargs:
kwargs["metadata"].is_eager = True

Check warning on line 459 in flytekit/core/python_function_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/python_function_task.py#L459

Added line #L459 was not covered by tests
else:
kwargs["metadata"] = TaskMetadata(is_eager=True)

Check warning on line 461 in flytekit/core/python_function_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/python_function_task.py#L461

Added line #L461 was not covered by tests

super().__init__(

Check warning on line 463 in flytekit/core/python_function_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/python_function_task.py#L463

Added line #L463 was not covered by tests
task_config,
task_function,
Expand Down Expand Up @@ -516,25 +521,38 @@ async def async_execute(self, *args, **kwargs) -> Any:
return await self._task_function(**kwargs)

Check warning on line 521 in flytekit/core/python_function_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/python_function_task.py#L521

Added line #L521 was not covered by tests

def execute(self, **kwargs) -> Any:
from flytekit.experimental.eager_function import _internal_demo_remote
from flytekit.remote.remote import FlyteRemote

remote = FlyteRemote.for_sandbox(default_project="flytesnacks", default_domain="development")
remote = _internal_demo_remote(remote)

ctx = FlyteContextManager.current_context()
is_local_execution = cast(ExecutionState, ctx.execution_state).is_local_execution()
builder = ctx.new_builder()

Check warning on line 526 in flytekit/core/python_function_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/python_function_task.py#L524-L526

Added lines #L524 - L526 were not covered by tests
if not is_local_execution:
# ensure that the worker queue is in context
if not ctx.worker_queue:
from flytekit.configuration.plugin import get_plugin

Check warning on line 530 in flytekit/core/python_function_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/python_function_task.py#L530

Added line #L530 was not covered by tests

# This should be read from transport at real runtime if available, but if not, we should either run
# remote in interactive mode, or let users configure the version to use.
ss = ctx.serialization_settings

Check warning on line 534 in flytekit/core/python_function_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/python_function_task.py#L534

Added line #L534 was not covered by tests
if not ss:
ss = SerializationSettings(

Check warning on line 536 in flytekit/core/python_function_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/python_function_task.py#L536

Added line #L536 was not covered by tests
image_config=ImageConfig.auto_default_image(),
)

# In order to build the controller, we really just need a remote.
project = (

Check warning on line 541 in flytekit/core/python_function_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/python_function_task.py#L541

Added line #L541 was not covered by tests
ctx.user_space_params.execution_id.project
if ctx.user_space_params and ctx.user_space_params.execution_id
else "flytesnacks"
)
domain = (

Check warning on line 546 in flytekit/core/python_function_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/python_function_task.py#L546

Added line #L546 was not covered by tests
ctx.user_space_params.execution_id.domain
if ctx.user_space_params and ctx.user_space_params.execution_id
else "development"
)
raw_output = ctx.user_space_params.raw_output_prefix
remote = get_plugin().get_remote(

Check warning on line 552 in flytekit/core/python_function_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/python_function_task.py#L551-L552

Added lines #L551 - L552 were not covered by tests
config=None, project=project, domain=domain, data_upload_location=raw_output
)

# tag is the current execution id
# root tag is read from the environment variable if it exists, if not, it's the current execution id
if not ctx.user_space_params or not ctx.user_space_params.execution_id:
Expand Down Expand Up @@ -589,14 +607,3 @@ async def run_with_backend(self, **kwargs):
# now have to fail this eager task, because we don't want it to show up as succeeded.
raise FlyteNonRecoverableSystemException(base_error)
return result

Check warning on line 609 in flytekit/core/python_function_task.py

View check run for this annotation

Codecov / codecov/patch

flytekit/core/python_function_task.py#L608-L609

Added lines #L608 - L609 were not covered by tests


"""
update code comments and remove int test for now
verify auth env var and start auto loading
- figure out how remotes can be different.
pure watch informer pattern
priority for flytekit - fix naming, depending on src
"""
11 changes: 10 additions & 1 deletion flytekit/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def __init__(
cache_serializable,
pod_template_name,
cache_ignore_input_vars,
is_eager: bool = False,
):
"""
Information needed at runtime to determine behavior such as whether or not outputs are discoverable, timeouts,
Expand All @@ -200,6 +201,7 @@ def __init__(
single instance over identical inputs is executed, other concurrent executions wait for the cached results.
:param pod_template_name: The name of the existing PodTemplate resource which will be used in this task.
:param cache_ignore_input_vars: Input variables that should not be included when calculating hash for cache.
:param is_eager:
"""
self._discoverable = discoverable
self._runtime = runtime
Expand All @@ -211,6 +213,11 @@ def __init__(
self._cache_serializable = cache_serializable
self._pod_template_name = pod_template_name
self._cache_ignore_input_vars = cache_ignore_input_vars
self._is_eager = is_eager

@property
def is_eager(self):
return self._is_eager

@property
def discoverable(self):
Expand Down Expand Up @@ -310,13 +317,14 @@ def to_flyte_idl(self):
cache_serializable=self.cache_serializable,
pod_template_name=self.pod_template_name,
cache_ignore_input_vars=self.cache_ignore_input_vars,
is_eager=self.is_eager,
)
if self.timeout:
tm.timeout.FromTimedelta(self.timeout)
return tm

@classmethod
def from_flyte_idl(cls, pb2_object):
def from_flyte_idl(cls, pb2_object: _core_task.TaskMetadata):
"""
:param flyteidl.core.task_pb2.TaskMetadata pb2_object:
:rtype: TaskMetadata
Expand All @@ -332,6 +340,7 @@ def from_flyte_idl(cls, pb2_object):
cache_serializable=pb2_object.cache_serializable,
pod_template_name=pb2_object.pod_template_name,
cache_ignore_input_vars=pb2_object.cache_ignore_input_vars,
is_eager=pb2_object.is_eager,
)


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ dependencies = [
"diskcache>=5.2.1",
"docker>=4.0.0",
"docstring-parser>=0.9.0",
"flyteidl>=1.13.7",
# "flyteidl>=1.13.7", # todo:async bump after releasing idl
"fsspec>=2023.3.0",
"gcsfs>=2023.3.0",
"googleapis-common-protos>=1.57",
Expand Down
21 changes: 19 additions & 2 deletions tests/flytekit/unit/core/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,27 @@
from flytekit.core.worker_queue import Controller
from flytekit.utils.asyn import loop_manager
from flytekit.core.context_manager import FlyteContextManager
from flytekit.configuration import Config, DataConfig, S3Config, FastSerializationSettings
from flytekit.configuration import Config, DataConfig, S3Config, FastSerializationSettings, ImageConfig, SerializationSettings, Image
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.tools.translator import get_serializable
from collections import OrderedDict

default_img = Image(name="default", fqn="test", tag="tag")
serialization_settings = SerializationSettings(
project="project",
domain="domain",
version="version",
env=None,
image_config=ImageConfig(default_image=default_img, images=[default_img]),
)


@task
def add_one(x: int) -> int:
return x + 1


@eager
@eager(environment={"a": "b"})
async def simple_eager_workflow(x: int) -> int:
# This is the normal way of calling tasks. Call normal tasks in an effectively async way by hanging and waiting for
# the result.
Expand Down Expand Up @@ -41,3 +52,9 @@ def test_easy_2():
):
res = loop_manager.run_sync(simple_eager_workflow.run_with_backend, x=1)
assert res == 2


def test_serialization():
se_spec = get_serializable(OrderedDict(), serialization_settings, simple_eager_workflow)
assert se_spec.template.metadata.is_eager
assert len(se_spec.template.container.env) == 2

0 comments on commit eea2605

Please sign in to comment.