-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
🎨 Unify typecheck and shapecheck (#3)
- Loading branch information
Showing
8 changed files
with
368 additions
and
78 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,7 @@ build-backend = "poetry.core.masonry.api" | |
|
||
[tool.poetry] | ||
name = "safecheck" | ||
version = "0.0.3" | ||
version = "0.1.0" | ||
description = "Utilities for typechecking, shapechecking and dispatch." | ||
readme = "README.md" | ||
authors = ["David Muhr <[email protected]>"] | ||
|
@@ -99,6 +99,8 @@ force-exclude = true | |
ignore = [ | ||
"D203", # one blank line required before class docstring | ||
"D213", # multi line summary should start at second line | ||
"ANN101", # missing type annotation for `self` in method | ||
"B905", # `zip()` without an explicit `strict=` parameter | ||
] | ||
|
||
[tool.ruff.isort] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
from collections.abc import Callable | ||
from inspect import Parameter, _empty, signature # type: ignore[reportPrivateUsage] | ||
from typing import Any | ||
|
||
from ._typecheck import typecheck | ||
|
||
__all__ = [ | ||
"implements", | ||
"protocol", | ||
] | ||
|
||
CallableAny = Callable[..., Any] | ||
|
||
|
||
class FunctionProtocol: | ||
def __init__( | ||
self, | ||
return_annotation: type, | ||
parameters: list[Parameter], | ||
) -> None: | ||
super().__init__() | ||
self.return_annotation = return_annotation | ||
self.parameters = parameters | ||
|
||
|
||
class InvalidProtocolError(Exception): | ||
def __init__(self, msg: str) -> None: | ||
super().__init__(msg) | ||
|
||
|
||
class ProtocolImplementationError(Exception): | ||
def __init__(self, msg: str) -> None: | ||
super().__init__(msg) | ||
|
||
|
||
def protocol(func: CallableAny) -> FunctionProtocol: | ||
sig = signature(func) | ||
params = list(sig.parameters.values()) | ||
if sig.return_annotation is _empty: | ||
msg = "Cannot construct a protocol with missing return type annotation." | ||
raise InvalidProtocolError(msg) | ||
|
||
for parameter in params: | ||
if parameter.annotation is _empty: | ||
msg = f"Cannot construct a protocol with missing type annotation, found {parameter}." | ||
raise InvalidProtocolError(msg) | ||
|
||
if parameter.default is not _empty: | ||
msg = f"Unexpected default value found in protocol definition, found {parameter}." | ||
raise InvalidProtocolError(msg) | ||
|
||
return FunctionProtocol(sig.return_annotation, params) | ||
|
||
|
||
def implements(protocol: FunctionProtocol) -> Callable[[CallableAny], CallableAny]: | ||
if not isinstance(protocol, FunctionProtocol): # type: ignore[reportUnnecessaryIsInstance] | ||
msg = ( | ||
f"A protocol implementation using `implements` expects a FunctionProtocol parameter, " | ||
f"but found {type(protocol)}. Did you use `@implements` without parameters? Use " | ||
f"@implements(protocol) instead." | ||
) | ||
raise ProtocolImplementationError(msg) | ||
|
||
def decorator(func: CallableAny) -> CallableAny: | ||
sig = signature(func) | ||
size = len(protocol.parameters) | ||
|
||
# check if the updated return annotation matches the protocol return annotation | ||
return_annotation = protocol.return_annotation if sig.return_annotation is _empty else sig.return_annotation | ||
if return_annotation != (proto_return := protocol.return_annotation): | ||
msg = ( | ||
f"Cannot implement a protocol without matching return types, but found return type " | ||
f"{return_annotation} for a protocol with return type {proto_return}." | ||
) | ||
raise ProtocolImplementationError(msg) | ||
|
||
# check if the updated shared parameters exactly match the protocol parameters | ||
sig_params = list(sig.parameters.values()) | ||
shared_params = update_annotations(protocol.parameters, sig_params) | ||
if strip_defaults(shared_params[:size]) != (proto_params := protocol.parameters): | ||
msg = ( | ||
f"Cannot implement a protocol without matching parameter types, but found parameters " | ||
f"{sig_params} for a protocol with parameters {proto_params}." | ||
) | ||
raise ProtocolImplementationError(msg) | ||
|
||
# check if the other parameters all have default values | ||
other_params = sig_params[size:] | ||
if any(p.default is _empty for p in other_params): | ||
msg = ( | ||
f"Cannot implement a protocol that requires substitution, if any parameters not " | ||
f"included in the protocol do not have a default value, found: {other_params}." | ||
) | ||
raise ProtocolImplementationError(msg) | ||
|
||
# replace the function signature | ||
final_parameters = shared_params + other_params | ||
func.__signature__ = sig.replace( # type: ignore[reportFunctionMemberAccess] | ||
parameters=final_parameters, | ||
return_annotation=return_annotation, | ||
) | ||
|
||
# replace the function annotations (used by runtime type checker) | ||
param_annotations = {p.name: p.annotation for p in final_parameters if p.annotation is not _empty} | ||
return_annotation = {} if return_annotation is _empty else {"return": return_annotation} | ||
func.__annotations__ = param_annotations | return_annotation | ||
return typecheck(func) | ||
|
||
return decorator | ||
|
||
|
||
def strip_defaults(params: list[Parameter]) -> list[Parameter]: | ||
params = params.copy() | ||
"""Strip the default values for the parameters in the list, which are irrelevant for the comparison.""" | ||
for param in params: | ||
setattr(param, "_default", _empty) # noqa[B010] | ||
|
||
return params | ||
|
||
|
||
def update_annotations(reference: list[Parameter], params: list[Parameter]) -> list[Parameter]: | ||
for ref, param in zip(reference, params): | ||
if param.annotation is _empty: | ||
setattr(param, "_annotation", ref.annotation) # noqa[B010] | ||
return params |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from beartype import beartype as _typecheck | ||
from beartype._data.hint.datahinttyping import BeartypeableT, BeartypeReturn | ||
from jaxtyping import jaxtyped as _shapecheck | ||
|
||
|
||
def typecheck(fn: BeartypeableT) -> BeartypeReturn: | ||
"""Typecheck a function without jaxtyping annotations, otherwise additionally shapecheck the function. | ||
:param fn: Any function or method. | ||
:return: Typechecked function or method. | ||
:raises: BeartypeException if a call to the function does not satisfy the typecheck. | ||
""" | ||
# check if there is any annotation requiring a shapecheck, i.e. any jaxtyping annotation that is not "..." | ||
# this check is significantly slower than the string-based check implemented below (~+50%), but this should | ||
# only be relevant in tight loops. | ||
# for annotation in fn.__annotations__.values(): | ||
# if getattr(annotation, "dim_str", "") != "...": | ||
|
||
# simply check if there is any mention of jaxtyping in the annotations, this adds barely any overhead to | ||
# a base call of beartype's @beartype | ||
if "jaxtyping" in str(fn.__annotations__): | ||
# shapecheck implies typecheck | ||
return _shapecheck(_typecheck(fn)) | ||
|
||
return _typecheck(fn) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import numpy | ||
from beartype import beartype | ||
|
||
from safecheck import * | ||
|
||
args = list(range(10)) | ||
args_shaped = numpy.random.randn(10, 100) # dim0=number of args, dim1=size of arg | ||
|
||
|
||
def decorate(f): | ||
return f | ||
|
||
|
||
def f(*_: int) -> None: | ||
... | ||
|
||
|
||
def f_shaped(*_: Shaped[NumpyArray, "n"]) -> None: | ||
... | ||
|
||
|
||
def test_no_overhead(benchmark): | ||
benchmark(f, *args) | ||
|
||
|
||
def test_no_overhead_shaped(benchmark): | ||
benchmark(f_shaped, *args_shaped) | ||
|
||
|
||
def test_minimal_overhead(benchmark): | ||
benchmark(decorate(f), *args) | ||
|
||
|
||
def test_minimal_overhead_shaped(benchmark): | ||
benchmark(decorate(f_shaped), *args_shaped) | ||
|
||
|
||
def test_beartype(benchmark): | ||
benchmark(beartype(f), *args) | ||
|
||
|
||
def test_beartype_shaped(benchmark): | ||
benchmark(beartype(f_shaped), *args_shaped) | ||
|
||
|
||
def test_typecheck(benchmark): | ||
benchmark(typecheck(f), *args) | ||
|
||
|
||
def test_typecheck_shaped(benchmark): | ||
benchmark(typecheck(f_shaped), *args_shaped) | ||
|
||
|
||
def test_dispatch(benchmark): | ||
dispatch = Dispatcher() | ||
benchmark(dispatch(f), *args) | ||
|
||
|
||
def test_dispatch_shaped(benchmark): | ||
benchmark(dispatch(f_shaped), *args_shaped) | ||
|
||
|
||
def test_protocol(benchmark): | ||
benchmark(implements(protocol(f))(f), *args) | ||
|
||
|
||
def test_protocol_shaped(benchmark): | ||
benchmark(implements(protocol(f_shaped))(f_shaped), *args_shaped) |
Oops, something went wrong.