Skip to content

Commit

Permalink
support union of types
Browse files Browse the repository at this point in the history
  • Loading branch information
yxlao committed Jul 9, 2024
1 parent 6ae814e commit 990dc6e
Show file tree
Hide file tree
Showing 3 changed files with 108 additions and 115 deletions.
144 changes: 46 additions & 98 deletions camtools/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,31 +167,47 @@ def _is_shape_compatible(

@lru_cache(maxsize=1024)
def _assert_tensor_hint(
hint: AbstractArray,
hint: Union[AbstractArray, Union[AbstractArray]],
arg_shape: Tuple[int, ...],
arg_dtype: Any,
arg_name: str,
):
"""
Args:
hint: A type hint for a tensor, must be javtyping.AbstractArray.
hint: A type hint for a tensor, must be javtyping.AbstractArray or
a Union of javtyping.AbstractArray.
arg: An argument to check, typically a tensor.
arg_name: The name of the argument, for error messages.
"""
# Check shapes.
gt_shape = _shape_from_dim_str(hint.dim_str)
if not _is_shape_compatible(arg_shape, gt_shape):
if typing.get_origin(hint) is Union:
for sub_hint in typing.get_args(hint):
gt_shape = _shape_from_dim_str(sub_hint.dim_str)
gt_dtypes = sub_hint.dtypes
if (
_is_shape_compatible(arg_shape, gt_shape)
and _dtype_to_str(arg_dtype) in gt_dtypes
):
return
raise TypeError(
f"{arg_name} must be of shape {gt_shape}, but got shape {arg_shape}."
)

# Check dtype.
gt_dtypes = hint.dtypes
if _dtype_to_str(arg_dtype) not in gt_dtypes:
raise TypeError(
f"{arg_name} must be of dtype {gt_dtypes}, "
f"but got dtype {_dtype_to_str(arg_dtype)}."
f"{arg_name} must be of shape and dtype "
f"compatible with any of {hint}, "
f"but got shape {arg_shape} and got dtype {_dtype_to_str(arg_dtype)}."
)
else:
# Check shapes.
gt_shape = _shape_from_dim_str(hint.dim_str)
if not _is_shape_compatible(arg_shape, gt_shape):
raise TypeError(
f"{arg_name} must be of shape {gt_shape}, "
f"but got shape {arg_shape}."
)
# Check dtype.
gt_dtypes = hint.dtypes
if _dtype_to_str(arg_dtype) not in gt_dtypes:
raise TypeError(
f"{arg_name} must be of dtype {gt_dtypes}, "
f"but got dtype {_dtype_to_str(arg_dtype)}."
)


@lru_cache(maxsize=1)
Expand Down Expand Up @@ -243,6 +259,21 @@ class Tensor:
"""


def _is_pure_tensor_hint(hint):
"""
Returns True if the type hint is a tensor hint or a Union of purely
tensor hints, and False otherwise.
"""
if typing.get_origin(hint) is Union:
is_tensor_hints = [
inspect.isclass(h) and issubclass(h, AbstractArray)
for h in typing.get_args(hint)
]
return all(is_tensor_hints)
else:
return inspect.isclass(hint) and issubclass(hint, AbstractArray)


def tensor_backend_auto(func, force_backend=None):
"""
Automatic backend selection based on the backend of type-annotated input
Expand Down Expand Up @@ -291,9 +322,7 @@ def tensor_backend_auto(func, force_backend=None):
if name in arg_names
}
tensor_names = [
name
for name, hint in arg_name_to_hint.items()
if inspect.isclass(hint) and issubclass(hint, AbstractArray)
name for name, hint in arg_name_to_hint.items() if _is_pure_tensor_hint(hint)
]

def _convert_tensor_to_backend(arg, backend):
Expand Down Expand Up @@ -477,84 +506,3 @@ def wrapper(*args, **kwargs):
return wrapped_func(*args, **kwargs)

return wrapper


def tensor_type_check(func):
"""
Function decorator to enforce tensor type and shape checks on input
arguments based on their type annotations using AbstractArray.
This decorator ensures that tensors passed to the function conform to the
specified dimensions and data types declared in the function's type hints.
If tensor type check is globally enabled, it verifies each tensor's shape
and dtype against the expected types specified in the function's
annotations.
This function is extracted from `tensor_backend_auto`. It shall be used
independently if the backend conversion is not required. If the backend
conversion is needed, us `tensor_backend_auto` instead as
`tensor_type_check` is a subset check of `tensor_backend_auto`.
Args:
func (callable): The function to decorate.
Behavior:
- Only processes input arguments that are explicitly typed as
AbstractArray.
- If tensor type checking is enabled, it verifies that each tensor matches
the declared shape and dtype in the type hints.
- Raises TypeError if any tensor does not conform to the expected shape or
dtype.
"""
# Pre-compute the function signature and type hints
sig = inspect.signature(func)
arg_names = [
param.name
for param in sig.parameters.values()
if param.kind
in (
inspect.Parameter.POSITIONAL_ONLY,
inspect.Parameter.POSITIONAL_OR_KEYWORD,
inspect.Parameter.KEYWORD_ONLY,
)
]
arg_name_to_hint = {
name: hint
for name, hint in typing.get_type_hints(func).items()
if name in arg_names
and inspect.isclass(hint)
and issubclass(hint, AbstractArray)
}

@wraps(func)
def wrapper(*args, **kwargs):
if is_tensor_check_enabled():
# Bind args and kwargs
arg_name_to_arg = dict(zip(arg_names, args))
arg_name_to_arg.update(kwargs)

# Fill in missing arguments with their default values
for arg_name, param in sig.parameters.items():
if arg_name not in arg_name_to_arg and param.default is not param.empty:
arg_name_to_arg[arg_name] = param.default

# Check tensor dtype and shape if enabled
for tensor_name, tensor_arg in arg_name_to_arg.items():
if tensor_name in arg_name_to_hint:
hint = arg_name_to_hint[tensor_name]
if isinstance(tensor_arg, _get_valid_array_types()):
_assert_tensor_hint(
hint=hint,
arg_shape=tensor_arg.shape,
arg_dtype=tensor_arg.dtype,
arg_name=tensor_name,
)

# Call the original function with updated arguments
return func(**arg_name_to_arg)

else:
# Call the original function without tensor type checks
return func(*args, **kwargs)

return wrapper
17 changes: 15 additions & 2 deletions camtools/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,21 @@
from . import sanity
from . import convert


def pad_0001(array):
from .backend import (
Tensor,
tensor_backend_numpy,
tensor_backend_auto,
ivy,
)
from jaxtyping import Float
from typing import List, Tuple, Dict, Union
from matplotlib import pyplot as plt


@tensor_backend_auto
def pad_0001(
array: Union[Float[Tensor, "3 4"], Float[Tensor, "N 3 4"]]
) -> Union[Float[Tensor, "4 4"], Float[Tensor, "N 4 4"]]:
"""
Pad [0, 0, 0, 1] to the bottom row.
Expand Down
62 changes: 47 additions & 15 deletions test/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import numpy as np
import pytest
from jaxtyping import Float
from jaxtyping import Float, Int

import camtools as ct
from camtools.backend import Tensor, ivy, is_torch_available, torch
import warnings
from typing import Union


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -173,22 +174,22 @@ def add(
assert np.allclose(result, expected, atol=1e-5)

# Incorrect shapes
with pytest.raises(TypeError, match=r".*but got shape.*"):
with pytest.raises(TypeError, match=r".*got shape.*"):
y_wrong = np.array([[1, 1, 1, 1]], dtype=np.float32)
add(x, y_wrong)

# Incorrect shape with lists
with pytest.raises(TypeError, match=r".*but got shape.*"):
with pytest.raises(TypeError, match=r".*got shape.*"):
y_wrong = [[1.0, 1.0, 1.0, 1.0]]
add(x, y_wrong)

# Incorrect dtype
with pytest.raises(TypeError, match=r".*but got dtype.*"):
with pytest.raises(TypeError, match=r".*got dtype.*"):
y_wrong = np.array([[1, 1, 1]], dtype=np.int64)
add(x, y_wrong)

# Incorrect dtype with lists
with pytest.raises(TypeError, match=r".*but got dtype.*"):
with pytest.raises(TypeError, match=r".*got dtype.*"):
y_wrong = [[1, 1, 1]]
add(x, y_wrong)

Expand Down Expand Up @@ -216,22 +217,22 @@ def add(
assert torch.allclose(result, expected, atol=1e-5)

# Incorrect shapes
with pytest.raises(TypeError, match=r".*but got shape.*"):
with pytest.raises(TypeError, match=r".*got shape.*"):
y_wrong = torch.tensor([[1, 1, 1, 1]], dtype=torch.float32)
add(x, y_wrong)

# Incorrect shape with lists
with pytest.raises(TypeError, match=r".*but got shape.*"):
with pytest.raises(TypeError, match=r".*got shape.*"):
y_wrong = [[1.0, 1.0, 1.0, 1.0]]
add(x, y_wrong)

# Incorrect dtype
with pytest.raises(TypeError, match=r".*but got dtype.*"):
with pytest.raises(TypeError, match=r".*got dtype.*"):
y_wrong = torch.tensor([[1, 1, 1]], dtype=torch.int64)
add(x, y_wrong)

# Incorrect dtype with lists
with pytest.raises(TypeError, match=r".*but got dtype.*"):
with pytest.raises(TypeError, match=r".*got dtype.*"):
y_wrong = [[1, 1, 1]]
add(x, y_wrong)

Expand Down Expand Up @@ -262,7 +263,7 @@ def add(
assert np.allclose(result, expected, atol=1e-5)

# Test for a shape mismatch where y does not conform to "n 3"
with pytest.raises(TypeError, match=r".*but got shape.*"):
with pytest.raises(TypeError, match=r".*got shape.*"):
y_wrong = np.array([4.0, 5.0, 6.0], dtype=np.float32) # Shape (3,)
add(x, y_wrong)

Expand All @@ -273,7 +274,7 @@ def add(
assert np.allclose(result, expected, atol=1e-5)

# Incorrect dtype with lists, expect dtype error
with pytest.raises(TypeError, match=r".*but got dtype.*"):
with pytest.raises(TypeError, match=r".*got dtype.*"):
y_wrong = [[4, 5, 6], [7, 8, 9]] # int type elements in list
add(x, y_wrong)

Expand Down Expand Up @@ -305,7 +306,7 @@ def add(
assert torch.allclose(result, expected, atol=1e-5)

# Test for a shape mismatch where y does not conform to "n 3"
with pytest.raises(TypeError, match=r".*but got shape.*"):
with pytest.raises(TypeError, match=r".*got shape.*"):
y_wrong = torch.tensor([4.0, 5.0, 6.0], dtype=torch.float32) # Shape (3,)
add(x, y_wrong)

Expand All @@ -316,7 +317,7 @@ def add(
assert torch.allclose(result, expected, atol=1e-5)

# Incorrect dtype with lists, expect dtype error
with pytest.raises(TypeError, match=r".*but got dtype.*"):
with pytest.raises(TypeError, match=r".*got dtype.*"):
y_wrong = [[4, 5, 6], [7, 8, 9]] # int type elements in list
add(x, y_wrong)

Expand Down Expand Up @@ -492,7 +493,7 @@ def test_disable_tensor_check():
result = sum_xyz(x, y=y)
assert np.allclose(result, np.array([9.0, 9.0, 9.0]))
ct.backend.enable_tensor_check()
with pytest.raises(TypeError, match=r".*but got shape.*"):
with pytest.raises(TypeError, match=r".*got shape.*"):
sum_xyz(x, y=y)

# Wrong dtype
Expand All @@ -502,5 +503,36 @@ def test_disable_tensor_check():
result = sum_xyz(x, y=y)
assert np.allclose(result, np.array([9.0, 9.0, 9.0]))
ct.backend.enable_tensor_check()
with pytest.raises(TypeError, match=r".*but got dtype.*"):
with pytest.raises(TypeError, match=r".*got dtype.*"):
sum_xyz(x, y=y)


def test_union_of_hints():

@ct.backend.tensor_backend_auto
def identity(
arr: Union[
Float[Tensor, "3 4"],
Float[Tensor, "N 3 4"],
Int[Tensor, "N 3 4"],
]
):
return arr

# Supported dtype and shape
x = np.random.rand(3, 4).astype(np.float32)
assert np.allclose(identity(x), x)
x = np.random.rand(5, 3, 4).astype(np.float32)
assert np.allclose(identity(x), x)
x = np.random.randint(0, 10, (5, 3, 4)).astype(np.int32)
assert np.allclose(identity(x), x)

# Wrong dtype
with pytest.raises(TypeError, match=r".*got dtype.*"):
x = np.random.randint(0, 10, (3, 4)).astype(np.int32)
identity(x)

# Wrong shape
with pytest.raises(TypeError, match=r".*got shape.*"):
x = np.random.rand(5, 4).astype(np.float32)
identity(x)

0 comments on commit 990dc6e

Please sign in to comment.