Skip to content

Commit

Permalink
Merge branch 'main' of github.com:modal-labs/modal-client into kramst…
Browse files Browse the repository at this point in the history
…rom/cli-227-remove-context_mount-from
  • Loading branch information
kramstrom committed Nov 27, 2024
2 parents eb21079 + 04ce2fa commit ad38730
Show file tree
Hide file tree
Showing 7 changed files with 163 additions and 76 deletions.
148 changes: 81 additions & 67 deletions modal/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
Optional,
Sequence,
Set,
Tuple,
Union,
cast,
get_args,
Expand All @@ -36,6 +35,7 @@
from ._utils.blob_utils import MAX_OBJECT_SIZE_BYTES
from ._utils.function_utils import FunctionInfo
from ._utils.grpc_utils import RETRYABLE_GRPC_STATUS_CODES, retry_transient_errors
from .client import _Client
from .cloud_bucket_mount import _CloudBucketMount
from .config import config, logger, user_config_path
from .environments import _get_environment_cached
Expand All @@ -52,7 +52,6 @@
if typing.TYPE_CHECKING:
import modal.functions


# This is used for both type checking and runtime validation
ImageBuilderVersion = Literal["2023.12", "2024.04", "2024.10"]

Expand Down Expand Up @@ -147,8 +146,8 @@ def _get_modal_requirements_command(version: ImageBuilderVersion) -> str:
return f"{prefix} -r {CONTAINER_REQUIREMENTS_PATH}"


def _flatten_str_args(function_name: str, arg_name: str, args: Tuple[Union[str, List[str]], ...]) -> List[str]:
"""Takes a tuple of strings, or string lists, and flattens it.
def _flatten_str_args(function_name: str, arg_name: str, args: Sequence[Union[str, List[str]]]) -> List[str]:
"""Takes a sequence of strings, or string lists, and flattens it.
Raises an error if any of the elements are not strings or string lists.
"""
Expand Down Expand Up @@ -304,7 +303,7 @@ class _ImageRegistryConfig:
def __init__(
self,
# TODO: change to _PUBLIC after worker starts handling it.
registry_auth_type: int = api_pb2.REGISTRY_AUTH_TYPE_UNSPECIFIED,
registry_auth_type: "api_pb2.RegistryAuthType.ValueType" = api_pb2.REGISTRY_AUTH_TYPE_UNSPECIFIED,
secret: Optional[_Secret] = None,
):
self.registry_auth_type = registry_auth_type
Expand All @@ -313,7 +312,7 @@ def __init__(
def get_proto(self) -> api_pb2.ImageRegistryConfig:
return api_pb2.ImageRegistryConfig(
registry_auth_type=self.registry_auth_type,
secret_id=(self.secret.object_id if self.secret else None),
secret_id=(self.secret.object_id if self.secret else ""),
)


Expand All @@ -324,6 +323,45 @@ class DockerfileSpec:
context_files: Dict[str, str]


async def _image_await_build_result(image_id: str, client: _Client) -> api_pb2.ImageJoinStreamingResponse:
last_entry_id: str = ""
result_response: Optional[api_pb2.ImageJoinStreamingResponse] = None

async def join():
nonlocal last_entry_id, result_response

request = api_pb2.ImageJoinStreamingRequest(image_id=image_id, timeout=55, last_entry_id=last_entry_id)
async for response in client.stub.ImageJoinStreaming.unary_stream(request):
if response.entry_id:
last_entry_id = response.entry_id
if response.result.status:
result_response = response
# can't return yet, since there may still be logs streaming back in subsequent responses
for task_log in response.task_logs:
if task_log.task_progress.pos or task_log.task_progress.len:
assert task_log.task_progress.progress_type == api_pb2.IMAGE_SNAPSHOT_UPLOAD
if output_mgr := _get_output_manager():
output_mgr.update_snapshot_progress(image_id, task_log.task_progress)
elif task_log.data:
if output_mgr := _get_output_manager():
await output_mgr.put_log_content(task_log)
if output_mgr := _get_output_manager():
output_mgr.flush_lines()

# Handle up to n exceptions while fetching logs
retry_count = 0
while result_response is None:
try:
await join()
except (StreamTerminatedError, GRPCError) as exc:
if isinstance(exc, GRPCError) and exc.status not in RETRYABLE_GRPC_STATUS_CODES:
raise exc
retry_count += 1
if retry_count >= 3:
raise exc
return result_response


class _Image(_Object, type_prefix="im"):
"""Base class for container images to run functions in.
Expand Down Expand Up @@ -352,17 +390,17 @@ def _initialize_from_other(self, other: "_Image"):
self._serve_mounts = other._serve_mounts
self._deferred_mounts = other._deferred_mounts

def _hydrate_metadata(self, message: Optional[Message]):
def _hydrate_metadata(self, metadata: Optional[Message]):
env_image_id = config.get("image_id") # set as an env var in containers
if env_image_id == self.object_id:
for exc in self.inside_exceptions:
# This raises exceptions from `with image.imports()` blocks
# if the hydrated image is the one used by the container
raise exc

if message:
assert isinstance(message, api_pb2.ImageMetadata)
self._metadata = message
if metadata:
assert isinstance(metadata, api_pb2.ImageMetadata)
self._metadata = metadata

def _add_mount_layer_or_copy(self, mount: _Mount, copy: bool = False):
if copy:
Expand All @@ -378,7 +416,7 @@ async def _load(self2: "_Image", resolver: Resolver, existing_object_id: Optiona
return _Image._from_loader(_load, "Image(local files)", deps=lambda: [base_image, mount])

@property
def _mount_layers(self) -> typing.Tuple[_Mount]:
def _mount_layers(self) -> typing.Sequence[_Mount]:
"""Non-evaluated mount layers on the image
When the image is used by a Modal container, these mounts need to be attached as well to
Expand Down Expand Up @@ -422,7 +460,7 @@ def _from_args(
context_mount_function: Optional[Callable[[], _Mount]] = None,
force_build: bool = False,
# For internal use only.
_namespace: int = api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE,
_namespace: "api_pb2.DeploymentNamespace.ValueType" = api_pb2.DEPLOYMENT_NAMESPACE_WORKSPACE,
_do_assert_no_mount_layers: bool = True,
):
context_mount = context_mount_function() if modal.is_local() and context_mount_function else None
Expand All @@ -444,14 +482,14 @@ def _from_args(
if build_function and len(base_images) != 1:
raise InvalidError("Cannot run a build function with multiple base images!")

def _deps() -> List[_Object]:
deps: List[_Object] = list(base_images.values()) + list(secrets)
def _deps() -> Sequence[_Object]:
deps = tuple(base_images.values()) + tuple(secrets)
if build_function:
deps.append(build_function)
deps += (build_function,)
if context_mount:
deps.append(context_mount)
if image_registry_config.secret:
deps.append(image_registry_config.secret)
deps += (context_mount,)
if image_registry_config and image_registry_config.secret:
deps += (image_registry_config.secret,)
return deps

async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[str]):
Expand All @@ -460,6 +498,7 @@ async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[s
# base images can't have
image._assert_no_mount_layers()

assert resolver.app_id # type narrowing
environment = await _get_environment_cached(resolver.environment_name or "", resolver.client)
# A bit hacky,but assume that the environment provides a valid builder version
image_builder_version = cast(ImageBuilderVersion, environment._settings.image_builder_version)
Expand Down Expand Up @@ -494,7 +533,6 @@ async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[s

if build_function:
build_function_id = build_function.object_id

globals = build_function._get_info().get_globals()
attrs = build_function._get_info().get_cls_var_attrs()
globals = {**globals, **attrs}
Expand All @@ -516,14 +554,14 @@ async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[s

# Cloudpickle function serialization produces unstable values.
# TODO: better way to filter out types that don't have a stable hash?
build_function_globals = serialize(filtered_globals) if filtered_globals else None
build_function_globals = serialize(filtered_globals) if filtered_globals else b""
_build_function = api_pb2.BuildFunction(
definition=build_function.get_build_def(),
globals=build_function_globals,
input=build_function_input,
)
else:
build_function_id = None
build_function_id = ""
_build_function = None

image_definition = api_pb2.Image(
Expand All @@ -532,7 +570,7 @@ async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[s
context_files=context_file_pb2s,
secret_ids=[secret.object_id for secret in secrets],
gpu=bool(gpu_config.type), # Note: as of 2023-01-27, server still uses this
context_mount_id=(context_mount.object_id if context_mount else None),
context_mount_id=(context_mount.object_id if context_mount else ""),
gpu_config=gpu_config, # Note: as of 2023-01-27, server ignores this
image_registry_config=image_registry_config.get_proto(),
runtime=config.get("function_runtime"),
Expand All @@ -543,7 +581,7 @@ async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[s
req = api_pb2.ImageGetOrCreateRequest(
app_id=resolver.app_id,
image=image_definition,
existing_image_id=existing_object_id, # TODO: ignored
existing_image_id=existing_object_id or "", # TODO: ignored
build_function_id=build_function_id,
force_build=config.get("force_build") or force_build,
namespace=_namespace,
Expand All @@ -554,46 +592,22 @@ async def _load(self: _Image, resolver: Resolver, existing_object_id: Optional[s
)
resp = await retry_transient_errors(resolver.client.stub.ImageGetOrCreate, req)
image_id = resp.image_id
result: api_pb2.GenericResult
metadata: Optional[api_pb2.ImageMetadata] = None

if resp.result.status:
# image already built
result = resp.result
if resp.HasField("metadata"):
metadata = resp.metadata
else:
# not built or in the process of building - wait for build
logger.debug("Waiting for image %s" % image_id)
resp = await _image_await_build_result(image_id, resolver.client)
result = resp.result
if resp.HasField("metadata"):
metadata = resp.metadata

logger.debug("Waiting for image %s" % image_id)
last_entry_id: Optional[str] = None
result_response: Optional[api_pb2.ImageJoinStreamingResponse] = None

async def join():
nonlocal last_entry_id, result_response

request = api_pb2.ImageJoinStreamingRequest(image_id=image_id, timeout=55, last_entry_id=last_entry_id)

async for response in resolver.client.stub.ImageJoinStreaming.unary_stream(request):
if response.entry_id:
last_entry_id = response.entry_id
if response.result.status:
result_response = response
# can't return yet, since there may still be logs streaming back in subsequent responses
for task_log in response.task_logs:
if task_log.task_progress.pos or task_log.task_progress.len:
assert task_log.task_progress.progress_type == api_pb2.IMAGE_SNAPSHOT_UPLOAD
if output_mgr := _get_output_manager():
output_mgr.update_snapshot_progress(image_id, task_log.task_progress)
elif task_log.data:
if output_mgr := _get_output_manager():
await output_mgr.put_log_content(task_log)
if output_mgr := _get_output_manager():
output_mgr.flush_lines()

# Handle up to n exceptions while fetching logs
retry_count = 0
while result_response is None:
try:
await join()
except (StreamTerminatedError, GRPCError) as exc:
if isinstance(exc, GRPCError) and exc.status not in RETRYABLE_GRPC_STATUS_CODES:
raise exc
retry_count += 1
if retry_count >= 3:
raise exc

result = result_response.result
if result.status == api_pb2.GenericResult.GENERIC_STATUS_FAILURE:
raise RemoteError(f"Image build for {image_id} failed with the exception:\n{result.exception}")
elif result.status == api_pb2.GenericResult.GENERIC_STATUS_TERMINATED:
Expand All @@ -607,7 +621,7 @@ async def join():
else:
raise RemoteError("Unknown status %s!" % result.status)

self._hydrate(image_id, resolver.client, result_response.metadata)
self._hydrate(image_id, resolver.client, metadata)
local_mounts = set()
for base in base_images.values():
local_mounts |= base._serve_mounts
Expand Down Expand Up @@ -727,7 +741,7 @@ def build_dockerfile(version: ImageBuilderVersion) -> DockerfileSpec:
context_mount_function=lambda: _Mount.from_local_file(local_path, remote_path=f"/{basename}"),
)

def _add_local_python_packages(self, *packages: Union[str, Path], copy: bool = False) -> "_Image":
def _add_local_python_packages(self, *packages: str, copy: bool = False) -> "_Image":
"""Adds Python package files to containers
Adds all files from the specified Python packages to containers running the Image.
Expand Down Expand Up @@ -1685,7 +1699,7 @@ def my_build_function():
function = _Function.from_args(
info,
app=None,
image=self,
image=self, # type: ignore[reportArgumentType] # TODO: probably conflict with type stub?
secrets=secrets,
gpu=gpu,
mounts=mounts,
Expand Down Expand Up @@ -1753,7 +1767,7 @@ def workdir(self, path: Union[str, PurePosixPath]) -> "_Image":
"""

def build_dockerfile(version: ImageBuilderVersion) -> DockerfileSpec:
commands = ["FROM base", f"WORKDIR {shlex.quote(path)}"]
commands = ["FROM base", f"WORKDIR {shlex.quote(str(path))}"]
return DockerfileSpec(commands=commands, context_files={})

return _Image._from_args(
Expand Down Expand Up @@ -1797,7 +1811,7 @@ async def _logs(self) -> typing.AsyncGenerator[str, None]:
This method is considered private since its interface may change - use it at your own risk!
"""
last_entry_id: Optional[str] = None
last_entry_id: str = ""

request = api_pb2.ImageJoinStreamingRequest(
image_id=self._object_id, timeout=55, last_entry_id=last_entry_id, include_logs_for_finished=True
Expand Down
4 changes: 2 additions & 2 deletions modal/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

_BLOCKING_O = synchronize_api(O)

EPHEMERAL_OBJECT_HEARTBEAT_SLEEP = 300
EPHEMERAL_OBJECT_HEARTBEAT_SLEEP: int = 300


def _get_environment_name(environment_name: Optional[str] = None, resolver: Optional[Resolver] = None) -> Optional[str]:
Expand Down Expand Up @@ -205,7 +205,7 @@ def local_uuid(self):
return self._local_uuid

@property
def object_id(self):
def object_id(self) -> str:
"""mdmd:hidden"""
return self._object_id

Expand Down
19 changes: 13 additions & 6 deletions modal/volume.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,21 @@ def g():
```
"""

_lock: asyncio.Lock
_lock: Optional[asyncio.Lock] = None

def _initialize_from_empty(self):
async def _get_lock(self):
# To (mostly*) prevent multiple concurrent operations on the same volume, which can cause problems under
# some unlikely circumstances.
# *: You can bypass this by creating multiple handles to the same volume, e.g. via lookup. But this
# covers the typical case = good enough.
self._lock = asyncio.Lock()

# Note: this function runs no async code but is marked as async to ensure it's
# being run inside the synchronicity event loop and binds the lock to the
# correct event loop on Python 3.9 which eagerly assigns event loops on
# constructions of locks
if self._lock is None:
self._lock = asyncio.Lock()
return self._lock

@staticmethod
def new():
Expand Down Expand Up @@ -188,7 +195,7 @@ async def ephemeral(
environment_name: Optional[str] = None,
version: "typing.Optional[modal_proto.api_pb2.VolumeFsVersion.ValueType]" = None,
_heartbeat_sleep: float = EPHEMERAL_OBJECT_HEARTBEAT_SLEEP,
) -> AsyncIterator["_Volume"]:
) -> AsyncGenerator["_Volume", None]:
"""Creates a new ephemeral volume within a context manager:
Usage:
Expand Down Expand Up @@ -269,7 +276,7 @@ async def create_deployed(

@live_method
async def _do_reload(self, lock=True):
async with self._lock if lock else asyncnullcontext():
async with (await self._get_lock()) if lock else asyncnullcontext():
req = api_pb2.VolumeReloadRequest(volume_id=self.object_id)
_ = await retry_transient_errors(self._client.stub.VolumeReload, req)

Expand All @@ -280,7 +287,7 @@ async def commit(self):
If successful, the changes made are now persisted in durable storage and available to other containers accessing
the volume.
"""
async with self._lock:
async with await self._get_lock():
req = api_pb2.VolumeCommitRequest(volume_id=self.object_id)
try:
# TODO(gongy): only apply indefinite retries on 504 status.
Expand Down
2 changes: 1 addition & 1 deletion modal_version/_version_generated.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright Modal Labs 2024

# Note: Reset this value to -1 whenever you make a minor `0.X` release of the client.
build_number = 50 # git: 2903bbe
build_number = 52 # git: a1760e2
Loading

0 comments on commit ad38730

Please sign in to comment.