Skip to content

Commit

Permalink
Store the interactive_mode_enabled to avoid using ipython_check
Browse files Browse the repository at this point in the history
Signed-off-by: Mecoli1219 <[email protected]>
  • Loading branch information
Mecoli1219 committed Jul 29, 2024
1 parent bf95827 commit 8bd5e81
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 3 deletions.
11 changes: 11 additions & 0 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@ class FlyteContext(object):
origin_stackframe: Optional[traceback.FrameSummary] = None
output_metadata_tracker: Optional[OutputMetadataTracker] = None
fast_register_file_uploader: Optional[typing.Callable] = None
interactive_mode_enabled: bool = False

@property
def user_space_params(self) -> Optional[ExecutionParameters]:
Expand All @@ -663,6 +664,7 @@ def new_builder(self) -> Builder:
execution_state=self.execution_state,
in_a_condition=self.in_a_condition,
output_metadata_tracker=self.output_metadata_tracker,
interactive_mode_enabled=self.interactive_mode_enabled,
)

def enter_conditional_section(self) -> Builder:
Expand Down Expand Up @@ -690,6 +692,9 @@ def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> Builder:
def with_fast_register_file_uploader(self, f: typing.Callable) -> Builder:
return self.new_builder().with_fast_register_file_uploader(f)

def with_interactive_mode_enabled(self, i: bool) -> Builder:
return self.new_builder().with_interactive_mode_enabled(i)

def new_compilation_state(self, prefix: str = "") -> CompilationState:
"""
Creates and returns a default compilation state. For most of the code this should be the entrypoint
Expand Down Expand Up @@ -752,6 +757,7 @@ class Builder(object):
in_a_condition: bool = False
output_metadata_tracker: Optional[OutputMetadataTracker] = None
fast_register_file_uploader: Optional[typing.Callable] = None
interactive_mode_enabled: bool = False

def build(self) -> FlyteContext:
return FlyteContext(
Expand All @@ -764,6 +770,7 @@ def build(self) -> FlyteContext:
in_a_condition=self.in_a_condition,
output_metadata_tracker=self.output_metadata_tracker,
fast_register_file_uploader=self.fast_register_file_uploader,
interactive_mode_enabled=self.interactive_mode_enabled,
)

def enter_conditional_section(self) -> FlyteContext.Builder:
Expand Down Expand Up @@ -812,6 +819,10 @@ def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> FlyteContext
self.output_metadata_tracker = t
return self

def with_interactive_mode_enabled(self, i: bool) -> FlyteContext.Builder:
self.interactive_mode_enabled = i
return self

def new_compilation_state(self, prefix: str = "") -> CompilationState:
"""
Creates and returns a default compilation state. For most of the code this should be the entrypoint
Expand Down
1 change: 1 addition & 0 deletions flytekit/remote/init_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def init_remote(
default_project=default_project,
default_domain=default_domain,
data_upload_location=data_upload_location,
interactive_mode_enabled=True,
**kwargs,
)
# TODO: This should be merged into the FlyteRemote in the future
Expand Down
17 changes: 16 additions & 1 deletion flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ def __init__(
default_project: typing.Optional[str] = None,
default_domain: typing.Optional[str] = None,
data_upload_location: str = "flyte://my-s3-bucket/",
interactive_mode_enabled: bool = False,
**kwargs,
):
"""Initialize a FlyteRemote object.
Expand All @@ -209,6 +210,7 @@ def __init__(
:param default_domain: default domain to use when fetching or executing flyte entities.
:param data_upload_location: this is where all the default data will be uploaded when providing inputs.
The default location - `s3://my-s3-bucket/data` works for sandbox/demo environment. Please override this for non-sandbox cases.
:param interactive_mode_enabled: If set to True, the FlyteRemote will pickle the task/workflow objects if not found
"""
if config is None or config.platform is None or config.platform.endpoint is None:
raise user_exceptions.FlyteAssertion("Flyte endpoint should be provided.")
Expand All @@ -232,6 +234,7 @@ def __init__(

# Save the file access object locally, build a context for it and save that as well.
self._ctx = FlyteContextManager.current_context().with_file_access(self._file_access).build()
self._interactive_mode_enabled = interactive_mode_enabled

@property
def context(self) -> FlyteContext:
Expand Down Expand Up @@ -265,6 +268,11 @@ def file_access(self) -> FileAccessProvider:
"""File access provider to use for offloading non-literal inputs/outputs."""
return self._file_access

@property
def interactive_mode_enabled(self) -> bool:
"""If set to True, the FlyteRemote will pickle the task/workflow objects if not found."""
return self._interactive_mode_enabled

def get(
self, flyte_uri: typing.Optional[str] = None
) -> typing.Optional[typing.Union[LiteralsResolver, Literal, HTML, bytes]]:
Expand Down Expand Up @@ -315,6 +323,12 @@ def remote_context(self):
FlyteContextManager.current_context().with_file_access(self.file_access)
)

def interactive_context(self):
"""Context manager with interactive-specific configuration."""
return FlyteContextManager.with_context(
FlyteContextManager.current_context().with_interactive_mode_enabled(self.interactive_mode_enabled)
)

def fetch_task_lazy(
self, project: str = None, domain: str = None, name: str = None, version: str = None
) -> LazyEntity:
Expand Down Expand Up @@ -760,7 +774,8 @@ async def _serialize_and_register(
if options.file_uploader is None:
options.file_uploader = self.upload_file

_ = get_serializable(m, settings=serialization_settings, entity=entity, options=options)
with self.interactive_context():
_ = get_serializable(m, settings=serialization_settings, entity=entity, options=options)
# concurrent register
cp_task_entity_map = OrderedDict(filter(lambda x: isinstance(x[1], task_models.TaskSpec), m.items()))
tasks = []
Expand Down
5 changes: 3 additions & 2 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from flytekit.core.base_task import PythonTask
from flytekit.core.condition import BranchNode
from flytekit.core.container_task import ContainerTask
from flytekit.core.context_manager import FlyteContextManager
from flytekit.core.gate import Gate
from flytekit.core.launch_plan import LaunchPlan, ReferenceLaunchPlan
from flytekit.core.legacy_map_task import MapPythonTask
Expand Down Expand Up @@ -201,10 +202,10 @@ def _update_serialization_settings_for_ipython(
if not isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)):
return serialization_settings

from flytekit.tools.interactive import ipython_check
ctx = FlyteContextManager.current_context()

# Let's check if we are in an interactive environment like Jupyter notebook
if ipython_check():
if ctx.interactive_mode_enabled:
# We are in an interactive environment, let's check if the task is a PythonFunctionTask and the task function
# is defined in the main module. If so, we will serialize the task as a pickled object and upload it to remote
# storage. The main module check is to ensure that the task function is not defined in a notebook cell.
Expand Down

0 comments on commit 8bd5e81

Please sign in to comment.