Skip to content

Commit

Permalink
Add interactive mode enable for init_remote & fix unit test
Browse files Browse the repository at this point in the history
Signed-off-by: Mecoli1219 <[email protected]>
  • Loading branch information
Mecoli1219 committed Aug 2, 2024
1 parent 093ab35 commit f665f4d
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 394 deletions.
1 change: 1 addition & 0 deletions dev-requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,4 @@ prometheus-client

orjson
kubernetes>=12.0.1
nest-asyncio
5 changes: 5 additions & 0 deletions flytekit/configuration/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,6 +830,7 @@ class SerializationSettings(DataClassJsonMixin):
can be fast registered (and thus omit building a Docker image) this object contains additional parameters
for serialization.
source_root (Optional[str]): The root directory of the source code.
interactive_mode_enabled (bool): Whether or not the task is being serialized in interactive mode.
"""

image_config: ImageConfig
Expand All @@ -842,6 +843,7 @@ class SerializationSettings(DataClassJsonMixin):
flytekit_virtualenv_root: Optional[str] = None
fast_serialization_settings: Optional[FastSerializationSettings] = None
source_root: Optional[str] = None
interactive_mode_enabled: bool = False

def __post_init__(self):
if self.flytekit_virtualenv_root is None:
Expand Down Expand Up @@ -916,6 +918,7 @@ def new_builder(self) -> Builder:
python_interpreter=self.python_interpreter,
fast_serialization_settings=self.fast_serialization_settings,
source_root=self.source_root,
interactive_mode_enabled=self.interactive_mode_enabled,
)

def should_fast_serialize(self) -> bool:
Expand Down Expand Up @@ -967,6 +970,7 @@ class Builder(object):
python_interpreter: Optional[str] = None
fast_serialization_settings: Optional[FastSerializationSettings] = None
source_root: Optional[str] = None
interactive_mode_enabled: bool = False

def with_fast_serialization_settings(self, fss: fast_serialization_settings) -> SerializationSettings.Builder:
self.fast_serialization_settings = fss
Expand All @@ -984,4 +988,5 @@ def build(self) -> SerializationSettings:
python_interpreter=self.python_interpreter,
fast_serialization_settings=self.fast_serialization_settings,
source_root=self.source_root,
interactive_mode_enabled=self.interactive_mode_enabled,
)
11 changes: 0 additions & 11 deletions flytekit/core/context_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,6 @@ class FlyteContext(object):
origin_stackframe: Optional[traceback.FrameSummary] = None
output_metadata_tracker: Optional[OutputMetadataTracker] = None
fast_register_file_uploader: Optional[typing.Callable] = None
interactive_mode_enabled: bool = False

@property
def user_space_params(self) -> Optional[ExecutionParameters]:
Expand All @@ -664,7 +663,6 @@ def new_builder(self) -> Builder:
execution_state=self.execution_state,
in_a_condition=self.in_a_condition,
output_metadata_tracker=self.output_metadata_tracker,
interactive_mode_enabled=self.interactive_mode_enabled,
)

def enter_conditional_section(self) -> Builder:
Expand Down Expand Up @@ -692,9 +690,6 @@ def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> Builder:
def with_fast_register_file_uploader(self, f: typing.Callable) -> Builder:
return self.new_builder().with_fast_register_file_uploader(f)

def with_interactive_mode_enabled(self, i: bool) -> Builder:
return self.new_builder().with_interactive_mode_enabled(i)

def new_compilation_state(self, prefix: str = "") -> CompilationState:
"""
Creates and returns a default compilation state. For most of the code this should be the entrypoint
Expand Down Expand Up @@ -757,7 +752,6 @@ class Builder(object):
in_a_condition: bool = False
output_metadata_tracker: Optional[OutputMetadataTracker] = None
fast_register_file_uploader: Optional[typing.Callable] = None
interactive_mode_enabled: bool = False

def build(self) -> FlyteContext:
return FlyteContext(
Expand All @@ -770,7 +764,6 @@ def build(self) -> FlyteContext:
in_a_condition=self.in_a_condition,
output_metadata_tracker=self.output_metadata_tracker,
fast_register_file_uploader=self.fast_register_file_uploader,
interactive_mode_enabled=self.interactive_mode_enabled,
)

def enter_conditional_section(self) -> FlyteContext.Builder:
Expand Down Expand Up @@ -819,10 +812,6 @@ def with_output_metadata_tracker(self, t: OutputMetadataTracker) -> FlyteContext
self.output_metadata_tracker = t
return self

def with_interactive_mode_enabled(self, i: bool) -> FlyteContext.Builder:
self.interactive_mode_enabled = i
return self

def new_compilation_state(self, prefix: str = "") -> CompilationState:
"""
Creates and returns a default compilation state. For most of the code this should be the entrypoint
Expand Down
2 changes: 1 addition & 1 deletion flytekit/core/future.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
from flytekit.tools.script_mode import hash_file

if REMOTE_ENTRY is None:
raise Exception(
raise RuntimeError(
"Remote Flyte client has not been initialized. Please call flytekit.remote.init_remote() before executing tasks."
)
self._remote_entry = REMOTE_ENTRY
Expand Down
4 changes: 3 additions & 1 deletion flytekit/remote/init_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from flytekit.configuration import Config
from flytekit.remote.remote import FlyteRemote
from flytekit.tools.translator import Options
from flytekit.tools.interactive import ipython_check

REMOTE_ENTRY: typing.Optional[FlyteRemote] = None
# TODO: This should be merged into the FlyteRemote in the future
Expand All @@ -17,6 +18,7 @@ def init_remote(
default_domain: typing.Optional[str] = None,
data_upload_location: str = "flyte://my-s3-bucket/",
default_options: typing.Optional[Options] = None,
interactive_mode: bool = ipython_check(),
**kwargs,
):
"""
Expand All @@ -37,7 +39,7 @@ def init_remote(
default_project=default_project,
default_domain=default_domain,
data_upload_location=data_upload_location,
interactive_mode_enabled=True,
interactive_mode_enabled=interactive_mode,
**kwargs,
)
# TODO: This should be merged into the FlyteRemote in the future
Expand Down
19 changes: 3 additions & 16 deletions flytekit/remote/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,7 @@ def __init__(
)

# Save the file access object locally, build a context for it and save that as well.
self._ctx = (
FlyteContextManager.current_context()
.with_file_access(self._file_access)
.build()
)
self._ctx = FlyteContextManager.current_context().with_file_access(self._file_access).build()
self._interactive_mode_enabled = interactive_mode_enabled

@property
Expand Down Expand Up @@ -327,12 +323,6 @@ def remote_context(self):
FlyteContextManager.current_context().with_file_access(self.file_access)
)

def interactive_context(self):
"""Context manager with interactive-specific configuration."""
return FlyteContextManager.with_context(
FlyteContextManager.current_context().with_interactive_mode_enabled(self.interactive_mode_enabled)
)

def fetch_task_lazy(
self, project: str = None, domain: str = None, name: str = None, version: str = None
) -> LazyEntity:
Expand Down Expand Up @@ -772,17 +762,14 @@ async def _serialize_and_register(
)
if serialization_settings.version is None:
serialization_settings.version = version
serialization_settings.interactive_mode_enabled = self.interactive_mode_enabled

if options is None:
options = Options()
if options.file_uploader is None:
options.file_uploader = self.upload_file

if self.interactive_mode_enabled:
with self.interactive_context():
_ = get_serializable(m, settings=serialization_settings, entity=entity, options=options)
else:
_ = get_serializable(m, settings=serialization_settings, entity=entity, options=options)
_ = get_serializable(m, settings=serialization_settings, entity=entity, options=options)
# concurrent register
cp_task_entity_map = OrderedDict(filter(lambda x: isinstance(x[1], task_models.TaskSpec), m.items()))
tasks = []
Expand Down
14 changes: 3 additions & 11 deletions flytekit/tools/translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,8 @@ def _update_serialization_settings_for_ipython(
if not isinstance(entity, (PythonAutoContainerTask, ArrayNodeMapTask)):
return serialization_settings

# If the context is not interactive, we don't need to do anything
ctx = context_manager.FlyteContextManager.current_context()

# # Let's check if we are in an interactive environment like Jupyter notebook
if ctx.interactive_mode_enabled:
# Let's check if we are in an interactive environment like Jupyter notebook
if serialization_settings.interactive_mode_enabled is True:
# We are in an interactive environment, let's check if the task is a PythonFunctionTask and the task function
# is defined in the main module. If so, we will serialize the task as a pickled object and upload it to remote
# storage. The main module check is to ensure that the task function is not defined in a notebook cell.
Expand Down Expand Up @@ -238,13 +235,8 @@ def _update_serialization_settings_for_ipython(
with gzip.GzipFile(filename=dest, mode="wb", mtime=0) as gzipped:
cloudpickle.dump(entity, gzipped)
rich.get_console().print("[yellow]Uploading Pickled representation of Task to remote storage...[/ yellow]")
md5_bytes, native_url = options.file_uploader(dest)
# if not serialization_settings.version and md5_bytes:
# import base64
_, native_url = options.file_uploader(dest)

# h = hashlib.md5(md5_bytes)
# version = base64.urlsafe_b64encode(h.digest()).decode("ascii").rstrip("=")
# serialization_settings.version = version
serialization_settings.fast_serialization_settings = FastSerializationSettings(
enabled=True, pickled=True, distribution_location=native_url
)
Expand Down
Loading

0 comments on commit f665f4d

Please sign in to comment.