Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: improved backend, union handling, tensor creation APIs #68

Merged
merged 1 commit into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 138 additions & 31 deletions camtools/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from functools import lru_cache, wraps
from typing import Any, Tuple, Union

import jaxtyping
from jaxtyping import AbstractArray
import numpy as np
from typing import List, Tuple, Dict, Literal


@lru_cache(maxsize=1)
Expand Down Expand Up @@ -167,31 +168,47 @@ def _is_shape_compatible(

@lru_cache(maxsize=1024)
def _assert_tensor_hint(
hint: jaxtyping.AbstractArray,
hint: Union[AbstractArray, Union[AbstractArray]],
arg_shape: Tuple[int, ...],
arg_dtype: Any,
arg_name: str,
):
"""
Args:
hint: A type hint for a tensor, must be javtyping.AbstractArray.
hint: A type hint for a tensor, must be javtyping.AbstractArray or
a Union of javtyping.AbstractArray.
arg: An argument to check, typically a tensor.
arg_name: The name of the argument, for error messages.
"""
# Check shapes.
gt_shape = _shape_from_dim_str(hint.dim_str)
if not _is_shape_compatible(arg_shape, gt_shape):
raise TypeError(
f"{arg_name} must be of shape {gt_shape}, but got shape {arg_shape}."
)

# Check dtype.
gt_dtypes = hint.dtypes
if _dtype_to_str(arg_dtype) not in gt_dtypes:
raise TypeError(
f"{arg_name} must be of dtype {gt_dtypes}, "
f"but got dtype {_dtype_to_str(arg_dtype)}."
if typing.get_origin(hint) is Union:
for sub_hint in typing.get_args(hint):
gt_shape = _shape_from_dim_str(sub_hint.dim_str)
gt_dtypes = sub_hint.dtypes
if (
_is_shape_compatible(arg_shape, gt_shape)
and _dtype_to_str(arg_dtype) in gt_dtypes
):
return
raise ValueError(
f"{arg_name} must be of shape and dtype "
f"compatible with any of {hint}, "
f"but got shape {arg_shape} and got dtype {_dtype_to_str(arg_dtype)}."
)
else:
# Check shapes.
gt_shape = _shape_from_dim_str(hint.dim_str)
if not _is_shape_compatible(arg_shape, gt_shape):
raise ValueError(
f"{arg_name} must be of shape {gt_shape}, "
f"but got shape {arg_shape}."
)
# Check dtype.
gt_dtypes = hint.dtypes
if _dtype_to_str(arg_dtype) not in gt_dtypes:
raise ValueError(
f"{arg_name} must be of dtype {gt_dtypes}, "
f"but got dtype {_dtype_to_str(arg_dtype)}."
)


@lru_cache(maxsize=1)
Expand Down Expand Up @@ -230,8 +247,8 @@ def disable_tensor_check():
def is_tensor_check_enabled():
"""
Returns True if the tensor dtype and shape check is enabled, and False
otherwise. This will be used when @tensor_to_auto_backend,
@tensor_to_numpy_backend, or @tensor_to_torch_backend is called.
otherwise. This will be used when @tensor_backend_auto,
@tensor_backend_numpy, or @tensor_backend_torch is called.
"""
return _tensor_check_enabled

Expand All @@ -242,26 +259,42 @@ class Tensor:
Typically np.ndarray or torch.Tensor is supported.
"""

pass

def _is_pure_tensor_hint(hint):
"""
Returns True if the type hint is a tensor hint or a Union of purely
tensor hints, and False otherwise.
"""
if typing.get_origin(hint) is Union:
is_tensor_hints = [
inspect.isclass(h) and issubclass(h, AbstractArray)
for h in typing.get_args(hint)
]
return all(is_tensor_hints)
else:
return inspect.isclass(hint) and issubclass(hint, AbstractArray)


def tensor_to_auto_backend(func, force_backend=None):
def tensor_backend_auto(func, force_backend: Literal["numpy", "torch"] = None):
"""
Automatic backend selection based on the backend of type-annotated input
tensors, and run tensor type and shape checks if is_tensor_check_enabled().
If there are no tensors, or if the tensors do not have the necessary type
annotations, the default backend is used. The function targets specifically
jaxtyping.AbstractArray annotations to determine tensor treatment and
AbstractArray annotations to determine tensor treatment and
backend usage.

Raises:
ValueError: If the tensor's shape and dtype do not match the type hints.

Detailed behaviors:
1. Only processes input arguments that are explicitly typed as
jaxtyping.AbstractArray. Arguments without this type hint or with
AbstractArray. Arguments without this type hint or with
different annotations maintain their default behavior without backend
modification.
2. Supports handling of numpy.ndarray, torch.Tensor, and Python lists that
should be converted to tensors based on their type hints.
3. If the type hint is jaxtyping.AbstractArray and the argument is a list,
3. If the type hint is AbstractArray and the argument is a list,
the list will be converted to a tensor using the native array
functionality of the active backend.
4. Ensures all tensor arguments must be from the same backend to avoid
Expand Down Expand Up @@ -293,9 +326,7 @@ def tensor_to_auto_backend(func, force_backend=None):
if name in arg_names
}
tensor_names = [
name
for name, hint in arg_name_to_hint.items()
if inspect.isclass(hint) and issubclass(hint, jaxtyping.AbstractArray)
name for name, hint in arg_name_to_hint.items() if _is_pure_tensor_hint(hint)
]

def _convert_tensor_to_backend(arg, backend):
Expand Down Expand Up @@ -367,7 +398,7 @@ def wrapper(*args, **kwargs):
elif is_torch_available() and tensor_types_used == {torch.Tensor}:
backend = "torch"
else:
raise TypeError(
raise ValueError(
f"All tensors must be from the same backend, "
f"but got {tensor_types_used}."
)
Expand Down Expand Up @@ -403,7 +434,7 @@ def wrapper(*args, **kwargs):
return wrapper


def tensor_to_numpy_backend(func):
def tensor_backend_numpy(func):
"""
Run this function by first converting its input tensors to numpy arrays.
Only jaxtyping-annotated tensors will be processed. This wrapper shall be
Expand Down Expand Up @@ -433,7 +464,7 @@ def tensor_to_numpy_backend(func):
force_backend argument set to "numpy".
"""
# Wrap the original function with tensor_auto_backend enforcing numpy.
wrapped_func = tensor_to_auto_backend(func, force_backend="numpy")
wrapped_func = tensor_backend_auto(func, force_backend="numpy")

@wraps(func)
def wrapper(*args, **kwargs):
Expand All @@ -442,7 +473,7 @@ def wrapper(*args, **kwargs):
return wrapper


def tensor_to_torch_backend(func):
def tensor_backend_torch(func):
"""
Run this function by first converting its input tensors to torch tensors.
Only jaxtyping-annotated tensors will be processed. This wrapper shall be
Expand Down Expand Up @@ -472,10 +503,86 @@ def tensor_to_torch_backend(func):
force_backend argument set to "torch".
"""
# Wrap the original function with tensor_auto_backend enforcing torch.
wrapped_func = tensor_to_auto_backend(func, force_backend="torch")
wrapped_func = tensor_backend_auto(func, force_backend="torch")

@wraps(func)
def wrapper(*args, **kwargs):
return wrapped_func(*args, **kwargs)

return wrapper


def create_array(
arr: List,
dtype: Any,
backend: Literal["numpy", "torch"],
) -> Tensor:
"""
Call np.array() or torch.tensor() depending on the backend.
"""
if backend == "numpy":
return np.array(arr, dtype=dtype)
elif backend == "torch":
if not is_torch_available():
raise ValueError("Torch is not available.")
return torch.tensor(arr, dtype=dtype)


def create_ones(
shape: Tuple[int, ...],
dtype: Any,
backend: Literal["numpy", "torch"],
) -> Tensor:
"""
Call np.ones() or torch.ones() depending on the backend.
"""
if backend == "numpy":
return np.ones(shape, dtype=dtype)
elif backend == "torch":
if not is_torch_available():
raise ValueError("Torch is not available.")
return torch.ones(shape, dtype=dtype)


def create_zeros(
shape: Tuple[int, ...],
dtype: Any,
backend: Literal["numpy", "torch"],
) -> Tensor:
"""
Call np.zeros() or torch.zeros() depending on the backend.
"""
if backend == "numpy":
return np.zeros(shape, dtype=dtype)
elif backend == "torch":
if not is_torch_available():
raise ValueError("Torch is not available.")
return torch.zeros(shape, dtype=dtype)


def create_empty(
shape: Tuple[int, ...],
dtype: Any,
backend: Literal["numpy", "torch"],
) -> Tensor:
"""
Call np.empty() or torch.empty() depending on the backend.
"""
if backend == "numpy":
return np.empty(shape, dtype=dtype)
elif backend == "torch":
if not is_torch_available():
raise ValueError("Torch is not available.")
return torch.empty(shape, dtype=dtype)


def get_tensor_backend(arr: Tensor) -> Literal["numpy", "torch"]:
"""
Get the backend of a tensor.
"""
if isinstance(arr, np.ndarray):
return "numpy"
elif is_torch_available() and isinstance(arr, torch.Tensor):
return "torch"
else:
raise ValueError(f"Unsupported tensor type {type(arr)}.")
Loading
Loading