Skip to content

Commit

Permalink
Add types to functions.py (#1328)
Browse files Browse the repository at this point in the history
* add types to functions.py
* add asserts for hidden attributes
* add .py mypy check

---------

Co-authored-by: Elias Freider <[email protected]>
  • Loading branch information
savarin and freider authored Feb 19, 2024
1 parent 83e90f2 commit 2cb7fae
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 7 deletions.
1 change: 1 addition & 0 deletions modal/cli/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ async def list(json: Optional[bool] = False):
# Catch-all for other exceptions, like incorrect server url
workspace = "Unknown (profile misconfigured)"
else:
assert hasattr(resp, "username")
workspace = resp.username
content = ["•" if active else "", profile, workspace]
rows.append((active, content))
Expand Down
17 changes: 10 additions & 7 deletions modal/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Optional,
Sequence,
Set,
Sized,
Type,
Union,
)
Expand Down Expand Up @@ -82,7 +83,7 @@
from .secret import _Secret
from .volume import _Volume

OUTPUTS_TIMEOUT = 55 # seconds
OUTPUTS_TIMEOUT = 55.0 # seconds
ATTEMPT_TIMEOUT_GRACE_PERIOD = 5 # seconds


Expand Down Expand Up @@ -563,13 +564,13 @@ def from_args(
stub,
image=None,
secret: Optional[_Secret] = None,
secrets: Collection[_Secret] = (),
secrets: Sequence[_Secret] = (),
schedule: Optional[Schedule] = None,
is_generator=False,
gpu: GPU_T = None,
# TODO: maybe break this out into a separate decorator for notebooks.
mounts: Collection[_Mount] = (),
network_file_systems: Dict[Union[str, os.PathLike], _NetworkFileSystem] = {},
network_file_systems: Dict[Union[str, PurePosixPath], _NetworkFileSystem] = {},
allow_cross_region_volumes: bool = False,
volumes: Dict[Union[str, os.PathLike], Union[_Volume, _S3Mount]] = {},
webhook_config: Optional[api_pb2.WebhookConfig] = None,
Expand Down Expand Up @@ -926,7 +927,7 @@ def from_parametrized(
obj,
from_other_workspace: bool,
options: Optional[api_pb2.FunctionOptions],
args: Iterable[Any],
args: Sized,
kwargs: Dict[str, Any],
) -> "_Function":
async def _load(provider: _Function, resolver: Resolver, existing_object_id: Optional[str]):
Expand Down Expand Up @@ -1017,6 +1018,7 @@ async def lookup(
@property
def tag(self):
"""mdmd:hidden"""
assert hasattr(self, "_tag")
return self._tag

@property
Expand All @@ -1033,6 +1035,7 @@ def env(self) -> FunctionEnv:

def get_build_def(self) -> str:
"""mdmd:hidden"""
assert hasattr(self, "_raw_f") and hasattr(self, "_build_args")
return f"{inspect.getsource(self._raw_f)}\n{repr(self._build_args)}"

# Live handle methods
Expand Down Expand Up @@ -1262,7 +1265,7 @@ async def remote_gen(self, *args, **kwargs) -> AsyncGenerator[Any, None]:
async for item in self._call_generator(args, kwargs): # type: ignore
yield item

def call(self, *args, **kwargs) -> Awaitable[Any]:
def call(self, *args, **kwargs) -> None:
"""Deprecated. Use `f.remote` or `f.remote_gen` instead."""
# TODO: Generics/TypeVars
if self._is_generator:
Expand Down Expand Up @@ -1462,8 +1465,8 @@ async def _gather(*function_calls: _FunctionCall):
gather = synchronize_api(_gather)


_current_input_id = ContextVar("_current_input_id")
_current_function_call_id = ContextVar("_current_function_call_id")
_current_input_id: ContextVar = ContextVar("_current_input_id")
_current_function_call_id: ContextVar = ContextVar("_current_function_call_id")


def current_input_id() -> Optional[str]:
Expand Down
5 changes: 5 additions & 0 deletions tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,12 @@ def lint(ctx):

@task
def mypy(ctx):
mypy_allowlist = [
"modal/functions.py",
]

ctx.run("mypy .", pty=True)
ctx.run(f"mypy {' '.join(mypy_allowlist)} --follow-imports=skip", pty=True)


@task
Expand Down

0 comments on commit 2cb7fae

Please sign in to comment.