From 442103d45acd39b36984aa30188df1c5574c479b Mon Sep 17 00:00:00 2001 From: Elias Freider Date: Wed, 27 Nov 2024 09:35:54 +0100 Subject: [PATCH 1/4] Support ImageGetOrCreate returning build results and skipping ImageJoinStreaming (#2495) ## Changelog * Redeploying/running apps where images are already built is now slightly faster --- modal/image.py | 146 +++++++++++++++++++++++++-------------------- modal/object.py | 2 +- tasks.py | 1 + test/image_test.py | 45 ++++++++++++++ 4 files changed, 127 insertions(+), 67 deletions(-) diff --git a/modal/image.py b/modal/image.py index e5d7957d1..dfcbc540c 100644 --- a/modal/image.py +++ b/modal/image.py @@ -19,7 +19,6 @@ Optional, Sequence, Set, - Tuple, Union, cast, get_args, @@ -36,6 +35,7 @@ from ._utils.blob_utils import MAX_OBJECT_SIZE_BYTES from ._utils.function_utils import FunctionInfo from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors +from .client import _Client from .cloud_bucket_mount import _CloudBucketMount from .config import config, logger, user_config_path from .environments import _get_environment_cached @@ -52,7 +52,6 @@ if typing.TYPE_CHECKING: import modal.functions - # This is used for both type checking and runtime validation ImageBuilderVersion = Literal["2023.12", "2024.04", "2024.10"] @@ -147,8 +146,8 @@ def _get_modal_requirements_command(version: ImageBuilderVersion) -> str: return f"{prefix} -r {CONTAINER_REQUIREMENTS_PATH}" -def _flatten_str_args(function_name: str, arg_name: str, args: Tuple[Union[str, List[str]], ...]) -> List[str]: - """Takes a tuple of strings, or string lists, and flattens it. +def _flatten_str_args(function_name: str, arg_name: str, args: Sequence[Union[str, List[str]]]) -> List[str]: + """Takes a sequence of strings, or string lists, and flattens it. Raises an error if any of the elements are not strings or string lists. """ @@ -244,7 +243,7 @@ class _ImageRegistryConfig: def __init__( self, # TODO: change to _PUBLIC after worker starts handling it. - registry_auth_type: int = api_pb2.REGISTRY_AUTH_TYPE_UNSPECIFIED, + registry_auth_type: "api_pb2.RegistryAuthType.ValueType" = api_pb2.REGISTRY_AUTH_TYPE_UNSPECIFIED, secret: Optional[_Secret] = None, ): self.registry_auth_type = registry_auth_type @@ -253,7 +252,7 @@ def __init__( def get_proto(self) -> api_pb2.ImageRegistryConfig: return api_pb2.ImageRegistryConfig( registry_auth_type=self.registry_auth_type, - secret_id=(self.secret.object_id if self.secret else None), + secret_id=(self.secret.object_id if self.secret else ""), ) @@ -264,6 +263,45 @@ class DockerfileSpec: context_files: Dict[str, str] +async def _image_await_build_result(image_id: str, client: _Client) -> api_pb2.ImageJoinStreamingResponse: + last_entry_id: str = "" + result_response: Optional[api_pb2.ImageJoinStreamingResponse] = None + + async def join(): + nonlocal last_entry_id, result_response + + request = api_pb2.ImageJoinStreamingRequest(image_id=image_id, timeout=55, last_entry_id=last_entry_id) + async for response in client.stub.ImageJoinStreaming.unary_stream(request): + if response.entry_id: + last_entry_id = response.entry_id + if response.result.status: + result_response = response + # can't return yet, since there may still be logs streaming back in subsequent responses + for task_log in response.task_logs: + if task_log.task_progress.pos or task_log.task_progress.len: + assert task_log.task_progress.progress_type == api_pb2.IMAGE_SNAPSHOT_UPLOAD + if output_mgr := _get_output_manager(): + output_mgr.update_snapshot_progress(image_id, task_log.task_progress) + elif task_log.data: + if output_mgr := _get_output_manager(): + await output_mgr.put_log_content(task_log) + if output_mgr := _get_output_manager(): + output_mgr.flush_lines() + + # Handle up to n exceptions while fetching logs + retry_count = 0 + while result_response is None: + try: + await join() + except (StreamTerminatedError, GRPCError) as exc: + if isinstance(exc, GRPCError) and exc.status not in RETRYABLE_GRPC_STATUS_CODES: + raise exc + retry_count += 1 + if retry_count >= 3: + raise exc + return result_response + + class _Image(_Object, type_prefix="im"): """Base class for container images to run functions in. @@ -292,7 +330,7 @@ def _initialize_from_other(self, other: "_Image"): self._serve_mounts = other._serve_mounts self._deferred_mounts = other._deferred_mounts - def _hydrate_metadata(self, message: Optional[Message]): + def _hydrate_metadata(self, metadata: Optional[Message]): env_image_id = config.get("image_id") # set as an env var in containers if env_image_id == self.object_id: for exc in self.inside_exceptions: @@ -300,9 +338,9 @@ def _hydrate_metadata(self, message: Optional[Message]): # if the hydrated image is the one used by the container raise exc - if message: - assert isinstance(message, api_pb2.ImageMetadata) - self._metadata = message + if metadata: + assert isinstance(metadata, api_pb2.ImageMetadata) + self._metadata = metadata def _add_mount_layer_or_copy(self, mount: _Mount, copy: bool = False): if copy: @@ -318,7 +356,7 @@ async def _load(self2: "_Image", resolver: Resolver, existing_object_id: Optiona return _Image._from_loader(_load, "Image(local files)", deps=lambda: [base_image, mount]) @property - def _mount_layers(self) -> typing.Tuple[_Mount]: + def _mount_layers(self) -> typing.Sequence[_Mount]: """Non-evaluated mount layers on the image When the image is used by a Modal container, these mounts need to be attached as well to @@ -362,7 +400,7 @@ def _from_args( context_mount: Optional[_Mount] = None, force_build: bool = False, # For internal use only. - _namespace: int = api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE, + _namespace: "api_pb2.DeploymentNamespace.ValueType" = api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE, _do_assert_no_mount_layers: bool = True, ): if base_images is None: @@ -382,14 +420,14 @@ def _from_args( if build_function and len(base_images) != 1: raise InvalidError("Cannot run a build function with multiple base images!") - def _deps() -> List[_Object]: - deps: List[_Object] = list(base_images.values()) + list(secrets) + def _deps() -> Sequence[_Object]: + deps = tuple(base_images.values()) + tuple(secrets) if build_function: - deps.append(build_function) + deps += (build_function,) if context_mount: - deps.append(context_mount) - if image_registry_config.secret: - deps.append(image_registry_config.secret) + deps += (context_mount,) + if image_registry_config and image_registry_config.secret: + deps += (image_registry_config.secret,) return deps async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[str]): @@ -398,6 +436,7 @@ async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[s # base images can't have image._assert_no_mount_layers() + assert resolver.app_id # type narrowing environment = await _get_environment_cached(resolver.environment_name or "", resolver.client) # A bit hacky,but assume that the environment provides a valid builder version image_builder_version = cast(ImageBuilderVersion, environment._settings.image_builder_version) @@ -432,7 +471,6 @@ async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[s if build_function: build_function_id = build_function.object_id - globals = build_function._get_info().get_globals() attrs = build_function._get_info().get_cls_var_attrs() globals = {**globals, **attrs} @@ -454,14 +492,14 @@ async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[s # Cloudpickle function serialization produces unstable values. # TODO: better way to filter out types that don't have a stable hash? - build_function_globals = serialize(filtered_globals) if filtered_globals else None + build_function_globals = serialize(filtered_globals) if filtered_globals else b"" _build_function = api_pb2.BuildFunction( definition=build_function.get_build_def(), globals=build_function_globals, input=build_function_input, ) else: - build_function_id = None + build_function_id = "" _build_function = None image_definition = api_pb2.Image( @@ -470,7 +508,7 @@ async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[s context_files=context_file_pb2s, secret_ids=[secret.object_id for secret in secrets], gpu=bool(gpu_config.type), # Note: as of 2023-01-27, server still uses this - context_mount_id=(context_mount.object_id if context_mount else None), + context_mount_id=(context_mount.object_id if context_mount else ""), gpu_config=gpu_config, # Note: as of 2023-01-27, server ignores this image_registry_config=image_registry_config.get_proto(), runtime=config.get("function_runtime"), @@ -481,7 +519,7 @@ async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[s req = api_pb2.ImageGetOrCreateRequest( app_id=resolver.app_id, image=image_definition, - existing_image_id=existing_object_id, # TODO: ignored + existing_image_id=existing_object_id or "", # TODO: ignored build_function_id=build_function_id, force_build=config.get("force_build") or force_build, namespace=_namespace, @@ -492,46 +530,22 @@ async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[s ) resp = await retry_transient_errors(resolver.client.stub.ImageGetOrCreate, req) image_id = resp.image_id + result: api_pb2.GenericResult + metadata: Optional[api_pb2.ImageMetadata] = None + + if resp.result.status: + # image already built + result = resp.result + if resp.HasField("metadata"): + metadata = resp.metadata + else: + # not built or in the process of building - wait for build + logger.debug("Waiting for image %s" % image_id) + resp = await _image_await_build_result(image_id, resolver.client) + result = resp.result + if resp.HasField("metadata"): + metadata = resp.metadata - logger.debug("Waiting for image %s" % image_id) - last_entry_id: Optional[str] = None - result_response: Optional[api_pb2.ImageJoinStreamingResponse] = None - - async def join(): - nonlocal last_entry_id, result_response - - request = api_pb2.ImageJoinStreamingRequest(image_id=image_id, timeout=55, last_entry_id=last_entry_id) - - async for response in resolver.client.stub.ImageJoinStreaming.unary_stream(request): - if response.entry_id: - last_entry_id = response.entry_id - if response.result.status: - result_response = response - # can't return yet, since there may still be logs streaming back in subsequent responses - for task_log in response.task_logs: - if task_log.task_progress.pos or task_log.task_progress.len: - assert task_log.task_progress.progress_type == api_pb2.IMAGE_SNAPSHOT_UPLOAD - if output_mgr := _get_output_manager(): - output_mgr.update_snapshot_progress(image_id, task_log.task_progress) - elif task_log.data: - if output_mgr := _get_output_manager(): - await output_mgr.put_log_content(task_log) - if output_mgr := _get_output_manager(): - output_mgr.flush_lines() - - # Handle up to n exceptions while fetching logs - retry_count = 0 - while result_response is None: - try: - await join() - except (StreamTerminatedError, GRPCError) as exc: - if isinstance(exc, GRPCError) and exc.status not in RETRYABLE_GRPC_STATUS_CODES: - raise exc - retry_count += 1 - if retry_count >= 3: - raise exc - - result = result_response.result if result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE: raise RemoteError(f"Image build for {image_id} failed with the exception:\n{result.exception}") elif result.status == api_pb2.GenericResult.GENERIC_STATUS_TERMINATED: @@ -545,7 +559,7 @@ async def join(): else: raise RemoteError("Unknown status %s!" % result.status) - self._hydrate(image_id, resolver.client, result_response.metadata) + self._hydrate(image_id, resolver.client, metadata) local_mounts = set() for base in base_images.values(): local_mounts |= base._serve_mounts @@ -666,7 +680,7 @@ def build_dockerfile(version: ImageBuilderVersion) -> DockerfileSpec: context_mount=mount, ) - def _add_local_python_packages(self, *packages: Union[str, Path], copy: bool = False) -> "_Image": + def _add_local_python_packages(self, *packages: str, copy: bool = False) -> "_Image": """Adds Python package files to containers Adds all files from the specified Python packages to containers running the Image. @@ -1632,7 +1646,7 @@ def my_build_function(): function = _Function.from_args( info, app=None, - image=self, + image=self, # type: ignore[reportArgumentType] # TODO: probably conflict with type stub? secrets=secrets, gpu=gpu, mounts=mounts, @@ -1744,7 +1758,7 @@ async def _logs(self) -> typing.AsyncGenerator[str, None]: This method is considered private since its interface may change - use it at your own risk! """ - last_entry_id: Optional[str] = None + last_entry_id: str = "" request = api_pb2.ImageJoinStreamingRequest( image_id=self._object_id, timeout=55, last_entry_id=last_entry_id, include_logs_for_finished=True diff --git a/modal/object.py b/modal/object.py index f23e93a87..4d6d02a8c 100644 --- a/modal/object.py +++ b/modal/object.py @@ -205,7 +205,7 @@ def local_uuid(self): return self._local_uuid @property - def object_id(self): + def object_id(self) -> str: """mdmd:hidden""" return self._object_id diff --git a/tasks.py b/tasks.py index 06fdea87c..c0ef4526e 100644 --- a/tasks.py +++ b/tasks.py @@ -150,6 +150,7 @@ def type_check(ctx): "test/cls_test.py", # see mypy bug above - but this works with pyright, so we run that instead "modal/_runtime/container_io_manager.py", "modal/io_streams.py", + "modal/image.py", ] ctx.run(f"pyright {' '.join(pyright_allowlist)}", pty=True) diff --git a/test/image_test.py b/test/image_test.py index f7af141ca..57314b317 100644 --- a/test/image_test.py +++ b/test/image_test.py @@ -1431,3 +1431,48 @@ def test_add_locals_build_function(servicer, client, supports_on_path): # TODO: test modal shell w/ lazy mounts # this works since the image is passed on as is to a sandbox which will load it and # transfer any virtual mount layers from the image as mounts to the sandbox + + +def test_image_only_joins_unfinished_steps(servicer, client): + app = App() + deb_slim = Image.debian_slim() + image = deb_slim.pip_install("foobarbaz") + app.function(image=image)(dummy) + with servicer.intercept() as ctx: + # default - image not built, should stream + with app.run(client=client): + pass + image_gets = ctx.get_requests("ImageGetOrCreate") + assert len(image_gets) == 2 + image_joins = ctx.get_requests("ImageJoinStreaming") + assert len(image_joins) == 2 + + with servicer.intercept() as ctx: + # lets mock that deb_slim has been built already + + async def custom_responder(servicer, stream): + image_get_or_create_request = await stream.recv_message() + is_base_image = any("FROM python:" in cmd for cmd in image_get_or_create_request.image.dockerfile_commands) + if is_base_image: + # base image done + await stream.send_message( + api_pb2.ImageGetOrCreateResponse( + image_id="im-123", + result=api_pb2.GenericResult(status=api_pb2.GenericResult.GENERIC_STATUS_SUCCESS), + ) + ) + else: + await stream.send_message( + api_pb2.ImageGetOrCreateResponse( + image_id="im-124", + ) + ) + + ctx.set_responder("ImageGetOrCreate", custom_responder) + with app.run(client=client): + pass + image_gets = ctx.get_requests("ImageGetOrCreate") + assert len(image_gets) == 2 + image_joins = ctx.get_requests("ImageJoinStreaming") + assert len(image_joins) == 1 # should now skip building of second build step + assert image_joins[0].image_id == "im-124" From 36c4fdcae26d87c123dbce7bf27d98cee2a74df8 Mon Sep 17 00:00:00 2001 From: github-actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 27 Nov 2024 08:36:21 +0000 Subject: [PATCH 2/4] [auto-commit] [skip ci] Bump the build number --- modal_version/_version_generated.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal_version/_version_generated.py b/modal_version/_version_generated.py index a3f77b61b..88ce76f83 100644 --- a/modal_version/_version_generated.py +++ b/modal_version/_version_generated.py @@ -1,4 +1,4 @@ # Copyright Modal Labs 2024 # Note: Reset this value to -1 whenever you make a minor `0.X` release of the client. -build_number = 50 # git: 2903bbe +build_number = 51 # git: 442103d From a1760e20cd660e793aec2c5cbd5ad8463424370a Mon Sep 17 00:00:00 2001 From: Elias Freider Date: Wed, 27 Nov 2024 13:19:01 +0100 Subject: [PATCH 3/4] Fixes issue where Volume.from_name eagerly assigns event loop (#2581) * Fixes issue where Volume.from_name eagerly assigns the main thread event loop to internal lock --- modal/image.py | 2 +- modal/object.py | 2 +- modal/volume.py | 19 +++++++++++++------ test/volume_test.py | 20 ++++++++++++++++++++ 4 files changed, 35 insertions(+), 8 deletions(-) diff --git a/modal/image.py b/modal/image.py index dfcbc540c..24538d437 100644 --- a/modal/image.py +++ b/modal/image.py @@ -1714,7 +1714,7 @@ def workdir(self, path: Union[str, PurePosixPath]) -> "_Image": """ def build_dockerfile(version: ImageBuilderVersion) -> DockerfileSpec: - commands = ["FROM base", f"WORKDIR {shlex.quote(path)}"] + commands = ["FROM base", f"WORKDIR {shlex.quote(str(path))}"] return DockerfileSpec(commands=commands, context_files={}) return _Image._from_args( diff --git a/modal/object.py b/modal/object.py index 4d6d02a8c..22c6cfc0f 100644 --- a/modal/object.py +++ b/modal/object.py @@ -17,7 +17,7 @@ _BLOCKING_O = synchronize_api(O) -EPHEMERAL_OBJECT_HEARTBEAT_SLEEP = 300 +EPHEMERAL_OBJECT_HEARTBEAT_SLEEP: int = 300 def _get_environment_name(environment_name: Optional[str] = None, resolver: Optional[Resolver] = None) -> Optional[str]: diff --git a/modal/volume.py b/modal/volume.py index 015f77a09..cc67aee12 100644 --- a/modal/volume.py +++ b/modal/volume.py @@ -122,14 +122,21 @@ def g(): ``` """ - _lock: asyncio.Lock + _lock: Optional[asyncio.Lock] = None - def _initialize_from_empty(self): + async def _get_lock(self): # To (mostly*) prevent multiple concurrent operations on the same volume, which can cause problems under # some unlikely circumstances. # *: You can bypass this by creating multiple handles to the same volume, e.g. via lookup. But this # covers the typical case = good enough. - self._lock = asyncio.Lock() + + # Note: this function runs no async code but is marked as async to ensure it's + # being run inside the synchronicity event loop and binds the lock to the + # correct event loop on Python 3.9 which eagerly assigns event loops on + # constructions of locks + if self._lock is None: + self._lock = asyncio.Lock() + return self._lock @staticmethod def new(): @@ -188,7 +195,7 @@ async def ephemeral( environment_name: Optional[str] = None, version: "typing.Optional[modal_proto.api_pb2.VolumeFsVersion.ValueType]" = None, _heartbeat_sleep: float = EPHEMERAL_OBJECT_HEARTBEAT_SLEEP, - ) -> AsyncIterator["_Volume"]: + ) -> AsyncGenerator["_Volume", None]: """Creates a new ephemeral volume within a context manager: Usage: @@ -269,7 +276,7 @@ async def create_deployed( @live_method async def _do_reload(self, lock=True): - async with self._lock if lock else asyncnullcontext(): + async with (await self._get_lock()) if lock else asyncnullcontext(): req = api_pb2.VolumeReloadRequest(volume_id=self.object_id) _ = await retry_transient_errors(self._client.stub.VolumeReload, req) @@ -280,7 +287,7 @@ async def commit(self): If successful, the changes made are now persisted in durable storage and available to other containers accessing the volume. """ - async with self._lock: + async with await self._get_lock(): req = api_pb2.VolumeCommitRequest(volume_id=self.object_id) try: # TODO(gongy): only apply indefinite retries on 504 status. diff --git a/test/volume_test.py b/test/volume_test.py index 6a00862f8..4e449adca 100644 --- a/test/volume_test.py +++ b/test/volume_test.py @@ -381,3 +381,23 @@ async def test_open_files_error_annotation(tmp_path): def test_invalid_name(servicer, client, name): with pytest.raises(InvalidError, match="Invalid Volume name"): modal.Volume.lookup(name) + + +@pytest.fixture() +def unset_main_thread_event_loop(): + try: + event_loop = asyncio.get_event_loop() + except RuntimeError: + event_loop = None + + asyncio.set_event_loop(None) + try: + yield + finally: + asyncio.set_event_loop(event_loop) # reset so we don't break other tests + + +@pytest.mark.usefixtures("unset_main_thread_event_loop") +def test_lock_is_py39_safe(set_env_client): + vol = modal.Volume.from_name("my_vol", create_if_missing=True) + vol.reload() From 04ce2fa7f18fbaa901f228555b2d14620df66e42 Mon Sep 17 00:00:00 2001 From: github-actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 27 Nov 2024 12:19:34 +0000 Subject: [PATCH 4/4] [auto-commit] [skip ci] Bump the build number --- modal_version/_version_generated.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modal_version/_version_generated.py b/modal_version/_version_generated.py index 88ce76f83..f8ed129fb 100644 --- a/modal_version/_version_generated.py +++ b/modal_version/_version_generated.py @@ -1,4 +1,4 @@ # Copyright Modal Labs 2024 # Note: Reset this value to -1 whenever you make a minor `0.X` release of the client. -build_number = 51 # git: 442103d +build_number = 52 # git: a1760e2