From 04beaea8741576cd8ff5cec5646dcec7f92d3199 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Tue, 9 Jul 2024 08:16:38 -0700 Subject: [PATCH] feat: flexible numpy/torch backend and tensor type checking (#69) This PR is a combination of: * feat: automatic tensor backend and type checks (#64) * perf: performance improvements for backend function wrappers (#66) * feat: improved backend, union handling, tensor creation APIs (#68) --- .github/workflows/unit_test.yml | 8 +- .gitignore | 3 + camtools/__init__.py | 2 +- camtools/backend.py | 588 ++++++++++++++++++++++++++++++++ camtools/metric.py | 2 +- pyproject.toml | 3 + setup.py | 1 - test/test_backend.py | 538 +++++++++++++++++++++++++++++ 8 files changed, 1141 insertions(+), 4 deletions(-) create mode 100644 camtools/backend.py create mode 100644 test/test_backend.py diff --git a/.github/workflows/unit_test.yml b/.github/workflows/unit_test.yml index 7028b807..1634eb1c 100644 --- a/.github/workflows/unit_test.yml +++ b/.github/workflows/unit_test.yml @@ -25,6 +25,12 @@ jobs: run: | python -m pip install --upgrade pip pip install -e .[dev] - - name: Run unit tests + - name: Run unit tests (numpy) + run: | + pytest + - name: Install dependencies with torch + run: | + pip install -e .[torch] + - name: Run unit tests (numpy + torch) run: | pytest diff --git a/.gitignore b/.gitignore index 130234ba..cb172652 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,9 @@ camtools/assets/bbox_box.png camtools/assets/bbox_box_blender.png +# cProfile +*.prof + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/camtools/__init__.py b/camtools/__init__.py index 3e40db69..2b0c3e6d 100644 --- a/camtools/__init__.py +++ b/camtools/__init__.py @@ -1,4 +1,5 @@ from . import artifact +from . import backend from . import camera from . import colmap from . import colormap @@ -16,7 +17,6 @@ from . import transform from . import util - try: # Python >= 3.8 from importlib.metadata import version diff --git a/camtools/backend.py b/camtools/backend.py new file mode 100644 index 00000000..c19f3eab --- /dev/null +++ b/camtools/backend.py @@ -0,0 +1,588 @@ +import inspect +import typing +import warnings +from functools import lru_cache, wraps +from typing import Any, Tuple, Union + +from jaxtyping import AbstractArray +import numpy as np +from typing import List, Tuple, Dict, Literal + + +@lru_cache(maxsize=1) +def _safely_import_torch(): + """ + Open3D has an issue where it must be imported before torch. If Open3D is + installed, this function will import Open3D before torch. Otherwise, it + will return simply import and return torch. + + Use this function to import torch within camtools to handle the Open3D + import order issue. That is, within camtools, we shall avoid `import torch`, + and instead use `from camtools.backend import torch`. As torch is an + optional dependency for camtools, this function will return None if torch + is not available. + + Returns: + module: The torch module if available, otherwise None. + """ + try: + __import__("open3d") + except ImportError: + pass + + try: + _torch = __import__("torch") + return _torch + except ImportError: + return None + + +torch = _safely_import_torch() + + +@lru_cache(maxsize=1) +def is_torch_available(): + return _safely_import_torch() is not None + + +@lru_cache(maxsize=1) +def _safely_import_ivy(): + """ + This function sets up the warnings filter to suppress the deprecation + before importing ivy. This is a temporary workaround to suppress the + deprecation warning from numpy 2.0. + + Within camtools, we shall avoid `import ivy`, and instead use + `from camtools.backend import ivy`. + """ + warnings.filterwarnings( + "ignore", + message=".*numpy.core.numeric is deprecated.*", + category=DeprecationWarning, + module="ivy", + ) + warnings.filterwarnings( + "ignore", + message=".*Compositional function.*array_mode is set to False.*", + category=UserWarning, + module="ivy", + ) + ivy = __import__("ivy") + ivy.set_array_mode(False) + return ivy + + +ivy = _safely_import_ivy() + + +@lru_cache(maxsize=64) +def _dtype_to_str(dtype): + """ + Convert numpy or torch dtype to string + + - "bool" + - "bool_" + - "uint4" + - "uint8" + - "uint16" + - "uint32" + - "uint64" + - "int4" + - "int8" + - "int16" + - "int32" + - "int64" + - "bfloat16" + - "float16" + - "float32" + - "float64" + - "complex64" + - "complex128" + """ + if isinstance(dtype, np.dtype): + return dtype.name + + if is_torch_available(): + if isinstance(dtype, torch.dtype): + return str(dtype).split(".")[1] + + return ValueError(f"Unknown dtype {dtype}.") + + +@lru_cache(maxsize=1024) +def _shape_from_dim_str(dim_str: str) -> Tuple[Union[int, None, str], ...]: + shape = [] + elements = dim_str.split() + for elem in elements: + if elem == "...": + shape.append("...") + elif elem.isdigit(): + shape.append(int(elem)) + else: + shape.append(None) + return tuple(shape) + + +@lru_cache(maxsize=1024) +def _is_shape_compatible( + arg_shape: Tuple[Union[int, None, str], ...], + gt_shape: Tuple[Union[int, None, str], ...], +) -> bool: + if "..." in gt_shape: + pre_ellipsis = None + post_ellipsis = None + + for i, dim in enumerate(gt_shape): + if dim == "...": + pre_ellipsis = i + post_ellipsis = len(gt_shape) - i - 1 + break + + if pre_ellipsis is None or gt_shape.count("...") > 1: + raise ValueError( + "Only one ellipsis is supported in the shape hint for now." + ) + + if len(arg_shape) < len(gt_shape) - 1: + return False + + for i in range(pre_ellipsis): + if arg_shape[i] != gt_shape[i] and gt_shape[i] is not None: + return False + + for i in range(1, post_ellipsis + 1): + if arg_shape[-i] != gt_shape[-i] and gt_shape[-i] is not None: + return False + + return True + else: + if len(arg_shape) != len(gt_shape): + return False + + for arg_dim, gt_dim in zip(arg_shape, gt_shape): + if arg_dim != gt_dim and gt_dim is not None: + return False + + return True + + +@lru_cache(maxsize=1024) +def _assert_tensor_hint( + hint: Union[AbstractArray, Union[AbstractArray]], + arg_shape: Tuple[int, ...], + arg_dtype: Any, + arg_name: str, +): + """ + Args: + hint: A type hint for a tensor, must be javtyping.AbstractArray or + a Union of javtyping.AbstractArray. + arg: An argument to check, typically a tensor. + arg_name: The name of the argument, for error messages. + """ + if typing.get_origin(hint) is Union: + for sub_hint in typing.get_args(hint): + gt_shape = _shape_from_dim_str(sub_hint.dim_str) + gt_dtypes = sub_hint.dtypes + if ( + _is_shape_compatible(arg_shape, gt_shape) + and _dtype_to_str(arg_dtype) in gt_dtypes + ): + return + raise ValueError( + f"{arg_name} must be of shape and dtype " + f"compatible with any of {hint}, " + f"but got shape {arg_shape} and got dtype {_dtype_to_str(arg_dtype)}." + ) + else: + # Check shapes. + gt_shape = _shape_from_dim_str(hint.dim_str) + if not _is_shape_compatible(arg_shape, gt_shape): + raise ValueError( + f"{arg_name} must be of shape {gt_shape}, " + f"but got shape {arg_shape}." + ) + # Check dtype. + gt_dtypes = hint.dtypes + if _dtype_to_str(arg_dtype) not in gt_dtypes: + raise ValueError( + f"{arg_name} must be of dtype {gt_dtypes}, " + f"but got dtype {_dtype_to_str(arg_dtype)}." + ) + + +@lru_cache(maxsize=1) +def _get_valid_array_types(): + if is_torch_available(): + valid_array_types = (np.ndarray, torch.Tensor) + else: + valid_array_types = (np.ndarray,) + return valid_array_types + + +# Global variable to keep track of the tensor type check status +_tensor_check_enabled = True + + +def enable_tensor_check(): + """ + Enable the tensor type check globally. This function activates type checking + for tensors, which is useful for ensuring that tensor operations are + performed correctly, especially during debugging and development. + """ + global _tensor_check_enabled + _tensor_check_enabled = True + + +def disable_tensor_check(): + """ + Disable the tensor type check globally. This function deactivates type checking + for tensors, which can be useful for performance optimizations or when + type checks are known to be unnecessary or problematic. + """ + global _tensor_check_enabled + _tensor_check_enabled = False + + +def is_tensor_check_enabled(): + """ + Returns True if the tensor dtype and shape check is enabled, and False + otherwise. This will be used when @tensor_backend_auto, + @tensor_backend_numpy, or @tensor_backend_torch is called. + """ + return _tensor_check_enabled + + +class Tensor: + """ + An abstract tensor type for type hinting only. + Typically np.ndarray or torch.Tensor is supported. + """ + + +def _is_pure_tensor_hint(hint): + """ + Returns True if the type hint is a tensor hint or a Union of purely + tensor hints, and False otherwise. + """ + if typing.get_origin(hint) is Union: + is_tensor_hints = [ + inspect.isclass(h) and issubclass(h, AbstractArray) + for h in typing.get_args(hint) + ] + return all(is_tensor_hints) + else: + return inspect.isclass(hint) and issubclass(hint, AbstractArray) + + +def tensor_backend_auto(func, force_backend: Literal["numpy", "torch"] = None): + """ + Automatic backend selection based on the backend of type-annotated input + tensors, and run tensor type and shape checks if is_tensor_check_enabled(). + If there are no tensors, or if the tensors do not have the necessary type + annotations, the default backend is used. The function targets specifically + AbstractArray annotations to determine tensor treatment and + backend usage. + + Raises: + ValueError: If the tensor's shape and dtype do not match the type hints. + + Detailed behaviors: + 1. Only processes input arguments that are explicitly typed as + AbstractArray. Arguments without this type hint or with + different annotations maintain their default behavior without backend + modification. + 2. Supports handling of numpy.ndarray, torch.Tensor, and Python lists that + should be converted to tensors based on their type hints. + 3. If the type hint is AbstractArray and the argument is a list, + the list will be converted to a tensor using the native array + functionality of the active backend. + 4. Ensures all tensor arguments must be from the same backend to avoid + conflicts. + 5. Uses the default backend if no tensors are present or if the tensors do + not require specific backend handling based on their annotations. + 6. If force_backend is specified, the inferred backend from arguments + and type hints will be ignored, and the specified backend will be used + instead. Don't confuse this with the default backend as this takes + higher precedence. + """ + + # Pre-compute the function signature and type hints + # This is called per function declaration and not per function call + sig = inspect.signature(func) + arg_names = [ + param.name + for param in sig.parameters.values() + if param.kind + in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ) + ] + arg_name_to_hint = { + name: hint + for name, hint in typing.get_type_hints(func).items() + if name in arg_names + } + tensor_names = [ + name for name, hint in arg_name_to_hint.items() if _is_pure_tensor_hint(hint) + ] + + def _convert_tensor_to_backend(arg, backend): + """ + Convert the tensor to the specified backend. It shall already be checked + that the arg is a tensor-like object, the tensor is type-annotated, and + the backend is valid. + """ + if backend == "numpy": + if isinstance(arg, np.ndarray): + return arg + elif is_torch_available() and isinstance(arg, torch.Tensor): + return arg.detach().cpu().numpy() + elif isinstance(arg, (list, tuple)): + return np.array(arg) + else: + raise ValueError( + f"Unsupported type {type(arg)} for conversion to numpy." + ) + elif backend == "torch": + if not is_torch_available(): + raise ValueError("Torch is not available.") + elif isinstance(arg, torch.Tensor): + return arg + elif isinstance(arg, np.ndarray): + return torch.from_numpy(arg) + elif isinstance(arg, (list, tuple)): + return torch.tensor(arg) + else: + raise ValueError( + f"Unsupported type {type(arg)} for conversion to torch." + ) + else: + raise ValueError(f"Unsupported backend {backend}.") + + @wraps(func) + def wrapper(*args, **kwargs): + # Bind args and kwargs + # This is faster than sig.bind() but less flexible + arg_name_to_arg = dict(zip(arg_names, args)) + arg_name_to_arg.update(kwargs) + + # Fill in missing arguments with their default values + for arg_name, param in sig.parameters.items(): + if arg_name not in arg_name_to_arg and param.default is not param.empty: + arg_name_to_arg[arg_name] = param.default + + # Determine backend + if force_backend is None: + # Recursively collect np.ndarray and torch.Tensor objects + # Other types including lists are ignored + if is_torch_available(): + tensor_types = (np.ndarray, torch.Tensor) + else: + tensor_types = (np.ndarray,) + tensors = [ + arg_name_to_arg[tensor_name] + for tensor_name in tensor_names + if isinstance(arg_name_to_arg[tensor_name], tensor_types) + ] + + # Determine the backend based on tensor types present + if not tensors: + backend = "numpy" + else: + tensor_types_used = {type(t) for t in tensors} + if tensor_types_used == {np.ndarray}: + backend = "numpy" + elif is_torch_available() and tensor_types_used == {torch.Tensor}: + backend = "torch" + else: + raise ValueError( + f"All tensors must be from the same backend, " + f"but got {tensor_types_used}." + ) + elif force_backend in ("numpy", "torch"): + backend = force_backend + else: + raise ValueError(f"Unsupported forced backend {force_backend}.") + + # Convert tensors to the appropriate backend + for tensor_name in tensor_names: + arg_name_to_arg[tensor_name] = _convert_tensor_to_backend( + arg_name_to_arg[tensor_name], backend + ) + + # Check tensor dtype and shape if enabled + if is_tensor_check_enabled(): + for tensor_name in tensor_names: + hint = arg_name_to_hint[tensor_name] + tensor_arg = arg_name_to_arg[tensor_name] + if isinstance(tensor_arg, _get_valid_array_types()): + _assert_tensor_hint( + hint=hint, + arg_shape=tensor_arg.shape, + arg_dtype=tensor_arg.dtype, + arg_name=tensor_name, + ) + + # Call the original function with updated arguments + result = func(**arg_name_to_arg) + + return result + + return wrapper + + +def tensor_backend_numpy(func): + """ + Run this function by first converting its input tensors to numpy arrays. + Only jaxtyping-annotated tensors will be processed. This wrapper shall be + used if the internal implementation is numpy-only or if we expect to return + numpy arrays. + + Behavior: + 1. Only converts arguments that are annotated explicitly with a jaxtyping + tensor type. If the type hint is a container of tensors, the conversion + will not be performed. + 2. Supports conversion of lists into numpy arrays if they are intended to be + tensors, according to the function's type annotations. + 3. The conversion is applied to top-level arguments and does not recursively + convert tensors within nested custom types (e.g., custom classes + containing tensors). + 4. This decorator is particularly useful for functions requiring consistent + tensor handling specifically with numpy, ensuring compatibility and + simplifying operations that depend on numpy's functionality. + + Note: + - The decorator inspects type annotations and applies conversions where + specified. + - Lists of tensors or tensors within lists annotated as tensors + will be converted to numpy arrays if not already in that format. + + This function simply wraps the tensor_auto_backend function with the + force_backend argument set to "numpy". + """ + # Wrap the original function with tensor_auto_backend enforcing numpy. + wrapped_func = tensor_backend_auto(func, force_backend="numpy") + + @wraps(func) + def wrapper(*args, **kwargs): + return wrapped_func(*args, **kwargs) + + return wrapper + + +def tensor_backend_torch(func): + """ + Run this function by first converting its input tensors to torch tensors. + Only jaxtyping-annotated tensors will be processed. This wrapper shall be + used if the internal implementation is torch-only or if we expect to return + torch tensors. + + Behavior: + 1. Only converts arguments that are annotated explicitly with a jaxtyping + tensor type. If the type hint is a container of tensors, the conversion + will not be performed. + 2. Supports conversion of lists into torch tensors if they are intended to be + tensors, according to the function's type annotations. + 3. The conversion is applied to top-level arguments and does not recursively + convert tensors within nested custom types (e.g., custom classes + containing tensors). + 4. This decorator is particularly useful for functions requiring consistent + tensor handling specifically with torch, ensuring compatibility and + simplifying operations that depend on torch's functionality. + + Note: + - The decorator inspects type annotations and applies conversions where + specified. + - Lists of tensors or tensors within lists annotated as tensors + will be converted to torch tensors if not already in that format. + + This function simply wraps the tensor_auto_backend function with the + force_backend argument set to "torch". + """ + # Wrap the original function with tensor_auto_backend enforcing torch. + wrapped_func = tensor_backend_auto(func, force_backend="torch") + + @wraps(func) + def wrapper(*args, **kwargs): + return wrapped_func(*args, **kwargs) + + return wrapper + + +def create_array( + arr: List, + dtype: Any, + backend: Literal["numpy", "torch"], +) -> Tensor: + """ + Call np.array() or torch.tensor() depending on the backend. + """ + if backend == "numpy": + return np.array(arr, dtype=dtype) + elif backend == "torch": + if not is_torch_available(): + raise ValueError("Torch is not available.") + return torch.tensor(arr, dtype=dtype) + + +def create_ones( + shape: Tuple[int, ...], + dtype: Any, + backend: Literal["numpy", "torch"], +) -> Tensor: + """ + Call np.ones() or torch.ones() depending on the backend. + """ + if backend == "numpy": + return np.ones(shape, dtype=dtype) + elif backend == "torch": + if not is_torch_available(): + raise ValueError("Torch is not available.") + return torch.ones(shape, dtype=dtype) + + +def create_zeros( + shape: Tuple[int, ...], + dtype: Any, + backend: Literal["numpy", "torch"], +) -> Tensor: + """ + Call np.zeros() or torch.zeros() depending on the backend. + """ + if backend == "numpy": + return np.zeros(shape, dtype=dtype) + elif backend == "torch": + if not is_torch_available(): + raise ValueError("Torch is not available.") + return torch.zeros(shape, dtype=dtype) + + +def create_empty( + shape: Tuple[int, ...], + dtype: Any, + backend: Literal["numpy", "torch"], +) -> Tensor: + """ + Call np.empty() or torch.empty() depending on the backend. + """ + if backend == "numpy": + return np.empty(shape, dtype=dtype) + elif backend == "torch": + if not is_torch_available(): + raise ValueError("Torch is not available.") + return torch.empty(shape, dtype=dtype) + + +def get_tensor_backend(arr: Tensor) -> Literal["numpy", "torch"]: + """ + Get the backend of a tensor. + """ + if isinstance(arr, np.ndarray): + return "numpy" + elif is_torch_available() and isinstance(arr, torch.Tensor): + return "torch" + else: + raise ValueError(f"Unsupported tensor type {type(arr)}.") diff --git a/camtools/metric.py b/camtools/metric.py index 2d675c4f..7cc8badb 100644 --- a/camtools/metric.py +++ b/camtools/metric.py @@ -9,6 +9,7 @@ from . import image from . import io from . import sanity +from .backend import torch def image_psnr( @@ -94,7 +95,6 @@ def image_lpips( Returns: LPIPS value in float. """ - import torch import lpips if im_mask is None: diff --git a/pyproject.toml b/pyproject.toml index d185fbb1..ba235db4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,8 @@ dependencies = [ "matplotlib>=3.3.4", "scikit-image>=0.16.2", "tqdm>=4.60.0", + "ivy>=0.0.8.0", + "jaxtyping>=0.2.12", ] description = "CamTools: Camera Tools for Computer Vision." license = {text = "MIT"} @@ -34,6 +36,7 @@ Homepage = "https://github.com/yxlao/camtools" dev = [ "black>=22.1.0", "pytest>=6.2.2", + "pytest-benchmark>=4.0.0", "ipdb", ] torch = [ diff --git a/setup.py b/setup.py index b024da80..60684932 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,3 @@ from setuptools import setup - setup() diff --git a/test/test_backend.py b/test/test_backend.py new file mode 100644 index 00000000..e76bd431 --- /dev/null +++ b/test/test_backend.py @@ -0,0 +1,538 @@ +import numpy as np +import pytest +from jaxtyping import Float, Int + +import camtools as ct +from camtools.backend import Tensor, ivy, is_torch_available, torch +import warnings +from typing import Union + + +@pytest.fixture(autouse=True) +def ignore_ivy_warnings(): + warnings.filterwarnings( + "ignore", + message=".*Compositional function.*array_mode is set to False.*", + category=UserWarning, + ) + yield + + +@ct.backend.tensor_backend_auto +def concat(x: Float[Tensor, "..."], y: Float[Tensor, "..."]): + return ivy.concat([x, y], axis=0) + + +def test_concat_numpy(): + """ + Test the default backend when no tensors are provided. + """ + x = np.array([1.0, 2.0, 3.0]) + y = np.array([4.0, 5.0, 6.0]) + result = concat(x, y) + assert isinstance(result, np.ndarray) + assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + +@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") +def test_concat_torch(): + """ + Test the default backend when no tensors are provided. + """ + x = torch.tensor([1.0, 2.0, 3.0]) + y = torch.tensor([4.0, 5.0, 6.0]) + result = concat(x, y) + assert isinstance(result, torch.Tensor) + assert torch.equal(result, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + +def test_concat_list_to_numpy(): + """ + Test the default backend when no tensors are provided. + """ + x = [1.0, 2.0, 3.0] + y = [4.0, 5.0, 6.0] + result = concat(x, y) + assert isinstance(result, np.ndarray) + assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + +def test_concat_mix_list_and_numpy(): + """ + Test handling of mixed list and tensor types. + """ + x = [1.0, 2.0, 3.0] + y = np.array([4.0, 5.0, 6.0]) + result = concat(x, y) + assert isinstance(result, np.ndarray) + assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + +@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") +def test_concat_mix_list_and_torch(): + """ + Test handling of mixed list and tensor types. + """ + x = [1.0, 2.0, 3.0] + y = torch.tensor([4.0, 5.0, 6.0]) + result = concat(x, y) + assert isinstance(result, torch.Tensor) + assert torch.equal(result, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + +@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") +def test_concat_mix_numpy_and_torch(): + """ + Test error handling with mixed tensor types across arguments. + """ + x = np.array([1.0, 2.0, 3.0]) + y = torch.tensor([4.0, 5.0, 6.0]) + with pytest.raises(ValueError, match=r".*must be from the same backend.*"): + concat(x, y) + + +def test_concat_list_of_numpy(): + """ + Test handling of containers holding tensors from different backends. + """ + + x = [np.array(1.0), np.array(2.0), np.array(3.0)] + y = np.array([4.0, 5.0, 6.0]) + result = concat(x, y) + assert isinstance(result, np.ndarray) + assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + +@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") +def test_concat_list_of_torch(): + """ + Test handling of containers holding tensors from different backends. + """ + x = [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)] + y = torch.tensor([4.0, 5.0, 6.0]) + result = concat(x, y) + + assert isinstance(result, torch.Tensor) + assert torch.equal(result, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + +@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") +def test_concat_list_of_numpy_and_torch(): + """ + Test handling with mixed tensor types across containers. + + In this case as lists are not type-checked, we both x and y will be + converted to default backend's arrays internally. That is, + x <- np.array(x) and y <- np.array(y) are both valid operation. In this + case, even though y contains tensors from both numpy and torch, as + np.asarray(y) is valid, the function should work. + + However, this can be very slow. As creating a torch tensor from a list of + np.ndarray is very slow and likewise for creating np.ndarray from a list of + torch tensors. Therefore, you shall avoid doing this in practice. + """ + x = [np.array(1.0), np.array(2.0), np.array(3.0)] + y = [torch.tensor(4.0), torch.tensor(5.0), torch.tensor(6.0)] + result = concat(x, y) + assert isinstance(result, np.ndarray) + assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + +def test_creation(): + @ct.backend.tensor_backend_auto + def creation(): + zeros = ivy.zeros([2, 3]) + return zeros + + # Default backend is numpy + tensor = creation() + assert isinstance(tensor, np.ndarray) + assert tensor.shape == (2, 3) + assert tensor.dtype == np.float32 + + +def test_type_hint_arguments_numpy(): + @ct.backend.tensor_backend_auto + def add( + x: Float[Tensor, "2 3"], + y: Float[Tensor, "1 3"], + ) -> Float[Tensor, "2 3"]: + return x + y + + # Default backend is numpy + x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) + y = np.array([[1, 1, 1]], dtype=np.float32) + result = add(x, y) + expected = np.array([[2, 3, 4], [5, 6, 7]], dtype=np.float32) + assert isinstance(result, np.ndarray) + assert np.allclose(result, expected, atol=1e-5) + + # List can be converted to numpy automatically + x = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + result = add(x, y) + assert isinstance(result, np.ndarray) + assert np.allclose(result, expected, atol=1e-5) + + # Incorrect shapes + with pytest.raises(ValueError, match=r".*got shape.*"): + y_wrong = np.array([[1, 1, 1, 1]], dtype=np.float32) + add(x, y_wrong) + + # Incorrect shape with lists + with pytest.raises(ValueError, match=r".*got shape.*"): + y_wrong = [[1.0, 1.0, 1.0, 1.0]] + add(x, y_wrong) + + # Incorrect dtype + with pytest.raises(ValueError, match=r".*got dtype.*"): + y_wrong = np.array([[1, 1, 1]], dtype=np.int64) + add(x, y_wrong) + + # Incorrect dtype with lists + with pytest.raises(ValueError, match=r".*got dtype.*"): + y_wrong = [[1, 1, 1]] + add(x, y_wrong) + + +@pytest.mark.skipif(not ct.backend.is_torch_available(), reason="Skip torch") +def test_type_hint_arguments_torch(): + @ct.backend.tensor_backend_auto + def add( + x: Float[Tensor, "2 3"], + y: Float[Tensor, "1 3"], + ) -> Float[Tensor, "2 3"]: + return x + y + + x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32) + y = torch.tensor([[1, 1, 1]], dtype=torch.float32) + result = add(x, y) + expected = torch.tensor([[2, 3, 4], [5, 6, 7]], dtype=torch.float32) + assert isinstance(result, torch.Tensor) + assert torch.allclose(result, expected, atol=1e-5) + + # List can be converted to torch automatically + x = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] + result = add(x, y) + assert isinstance(result, torch.Tensor) + assert torch.allclose(result, expected, atol=1e-5) + + # Incorrect shapes + with pytest.raises(ValueError, match=r".*got shape.*"): + y_wrong = torch.tensor([[1, 1, 1, 1]], dtype=torch.float32) + add(x, y_wrong) + + # Incorrect shape with lists + with pytest.raises(ValueError, match=r".*got shape.*"): + y_wrong = [[1.0, 1.0, 1.0, 1.0]] + add(x, y_wrong) + + # Incorrect dtype + with pytest.raises(ValueError, match=r".*got dtype.*"): + y_wrong = torch.tensor([[1, 1, 1]], dtype=torch.int64) + add(x, y_wrong) + + # Incorrect dtype with lists + with pytest.raises(ValueError, match=r".*got dtype.*"): + y_wrong = [[1, 1, 1]] + add(x, y_wrong) + + +def test_named_dim_numpy(): + @ct.backend.tensor_backend_auto + def add( + x: Float[Tensor, "3"], + y: Float[Tensor, "n 3"], + ) -> Float[Tensor, "n 3"]: + return x + y + + # Fixed x tensor + x = np.array([1.0, 2.0, 3.0], dtype=np.float32) + + # Valid y tensor with shape (1, 3) + y = np.array([[4.0, 5.0, 6.0]], dtype=np.float32) + result = add(x, y) + expected = np.array([[5.0, 7.0, 9.0]], dtype=np.float32) + assert isinstance(result, np.ndarray) + assert np.allclose(result, expected, atol=1e-5) + + # Valid y tensor with shape (2, 3) + y = np.array([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=np.float32) + result = add(x, y) + expected = np.array([[5.0, 7.0, 9.0], [8.0, 10.0, 12.0]], dtype=np.float32) + assert isinstance(result, np.ndarray) + assert np.allclose(result, expected, atol=1e-5) + + # Test for a shape mismatch where y does not conform to "n 3" + with pytest.raises(ValueError, match=r".*got shape.*"): + y_wrong = np.array([4.0, 5.0, 6.0], dtype=np.float32) # Shape (3,) + add(x, y_wrong) + + # List inputs that should be automatically converted and work + y = [[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] + result = add(x, y) + assert isinstance(result, np.ndarray) + assert np.allclose(result, expected, atol=1e-5) + + # Incorrect dtype with lists, expect dtype error + with pytest.raises(ValueError, match=r".*got dtype.*"): + y_wrong = [[4, 5, 6], [7, 8, 9]] # int type elements in list + add(x, y_wrong) + + +@pytest.mark.skipif(not ct.backend.is_torch_available(), reason="Skip torch") +def test_named_dim_torch(): + @ct.backend.tensor_backend_auto + def add( + x: Float[Tensor, "3"], + y: Float[Tensor, "n 3"], + ) -> Float[Tensor, "n 3"]: + return x + y + + # Fixed x tensor for Torch + x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + + # Valid y tensor with shape (1, 3) + y = torch.tensor([[4.0, 5.0, 6.0]], dtype=torch.float32) + result = add(x, y) + expected = torch.tensor([[5.0, 7.0, 9.0]], dtype=torch.float32) + assert isinstance(result, torch.Tensor) + assert torch.allclose(result, expected, atol=1e-5) + + # Valid y tensor with shape (2, 3) + y = torch.tensor([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=torch.float32) + result = add(x, y) + expected = torch.tensor([[5.0, 7.0, 9.0], [8.0, 10.0, 12.0]], dtype=torch.float32) + assert isinstance(result, torch.Tensor) + assert torch.allclose(result, expected, atol=1e-5) + + # Test for a shape mismatch where y does not conform to "n 3" + with pytest.raises(ValueError, match=r".*got shape.*"): + y_wrong = torch.tensor([4.0, 5.0, 6.0], dtype=torch.float32) # Shape (3,) + add(x, y_wrong) + + # List inputs that should be automatically converted and work + y = [[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] + result = add(x, y) + assert isinstance(result, torch.Tensor) + assert torch.allclose(result, expected, atol=1e-5) + + # Incorrect dtype with lists, expect dtype error + with pytest.raises(ValueError, match=r".*got dtype.*"): + y_wrong = [[4, 5, 6], [7, 8, 9]] # int type elements in list + add(x, y_wrong) + + +def test_concat_tensors_with_numpy(): + @ct.backend.tensor_backend_numpy + def concat_tensors_with_numpy( + x: Float[Tensor, "..."], + y: Float[Tensor, "..."], + ): + assert isinstance(x, np.ndarray) + assert isinstance(y, np.ndarray) + return np.concatenate([x, y], axis=0) + + # Test with numpy arrays + x = np.array([1.0, 2.0, 3.0]) + y = np.array([4.0, 5.0, 6.0]) + result = concat_tensors_with_numpy(x, y) + assert isinstance(result, np.ndarray) + assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + # Test with torch tensors + if is_torch_available(): + x = torch.tensor([1.0, 2.0, 3.0]) + y = torch.tensor([4.0, 5.0, 6.0]) + result = concat_tensors_with_numpy(x, y) + assert isinstance(result, np.ndarray) + assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + # Test with lists + x = [1.0, 2.0, 3.0] + y = [4.0, 5.0, 6.0] + result = concat_tensors_with_numpy(x, y) + assert isinstance(result, np.ndarray) + assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + # Numpy and list mixed + x = np.array([1.0, 2.0, 3.0]) + y = [4.0, 5.0, 6.0] + result = concat_tensors_with_numpy(x, y) + assert isinstance(result, np.ndarray) + assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + # Torch and list mixed + if is_torch_available(): + x = torch.tensor([1.0, 2.0, 3.0]) + y = [4.0, 5.0, 6.0] + result = concat_tensors_with_numpy(x, y) + assert isinstance(result, np.ndarray) + assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + +@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") +def test_concat_tensors_with_torch(): + @ct.backend.tensor_backend_torch + def concat_tensors_with_torch( + x: Float[Tensor, "..."], + y: Float[Tensor, "..."], + ): + assert isinstance(x, torch.Tensor) + assert isinstance(y, torch.Tensor) + return torch.cat([x, y], axis=0) + + # Test with numpy arrays + x_np = np.array([1.0, 2.0, 3.0]).astype(np.float32) + y_np = np.array([4.0, 5.0, 6.0]).astype(np.float32) + result_np = concat_tensors_with_torch(x_np, y_np) + assert isinstance(result_np, torch.Tensor) + assert torch.allclose(result_np, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + # Test with torch tensors + x_torch = torch.tensor([1.0, 2.0, 3.0]) + y_torch = torch.tensor([4.0, 5.0, 6.0]) + result_torch = concat_tensors_with_torch(x_torch, y_torch) + assert isinstance(result_torch, torch.Tensor) + assert torch.equal(result_torch, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + # Test with lists + x_list = [1.0, 2.0, 3.0] + y_list = [4.0, 5.0, 6.0] + result_list = concat_tensors_with_torch(x_list, y_list) + assert isinstance(result_list, torch.Tensor) + assert torch.allclose(result_list, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + # Mixed types: numpy array and list + x_mixed = np.array([1.0, 2.0, 3.0]).astype(np.float32) + y_mixed = [4.0, 5.0, 6.0] + result_mixed = concat_tensors_with_torch(x_mixed, y_mixed) + assert isinstance(result_mixed, torch.Tensor) + assert torch.allclose(result_mixed, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) + + # Mixed types: torch tensor and list + x_mixed_torch = torch.tensor([1.0, 2.0, 3.0]) + y_mixed_list = [4.0, 5.0, 6.0] + result_mixed_torch_list = concat_tensors_with_torch(x_mixed_torch, y_mixed_list) + assert isinstance(result_mixed_torch_list, torch.Tensor) + assert torch.allclose( + result_mixed_torch_list, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + ) + + +@ct.backend.tensor_backend_numpy +def sum_xyz( + x: Float[Tensor, "3"], + y: Float[Tensor, "3"] = (2.0, 2.0, 2.0), + z: Float[Tensor, "3"] = (3.0, 3.0, 3.0), +): + assert isinstance(x, np.ndarray) + assert isinstance(y, np.ndarray) + assert isinstance(z, np.ndarray) + return x + y + z + + +def test_kwargs_sum_xyz_with_x(): + x = np.array([1.0, 1.0, 1.0]) + result = sum_xyz(x) + expected_result = np.array([6.0, 6.0, 6.0]) + assert np.allclose(result, expected_result) + + +def test_kwargs_sum_xyz_with_x_y(): + x = np.array([1.0, 1.0, 1.0]) + y = np.array([5.0, 5.0, 5.0]) + result = sum_xyz(x, y=y) + expected_result = np.array([9.0, 9.0, 9.0]) + assert np.allclose(result, expected_result) + + +def test_kwargs_sum_xyz_with_x_y_z(): + x = np.array([1.0, 1.0, 1.0]) + y = np.array([5.0, 5.0, 5.0]) + z = np.array([10.0, 10.0, 10.0]) + result = sum_xyz(x, y=y, z=z) + expected_result = np.array([16.0, 16.0, 16.0]) + assert np.allclose(result, expected_result) + + +@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") +def test_kwargs_sum_xyz_with_x_torch(): + x = torch.tensor([1.0, 1.0, 1.0]) + result = sum_xyz(x) + expected_result = np.array([6.0, 6.0, 6.0]) + assert np.allclose(result, expected_result) + + +@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") +def test_kwargs_sum_xyz_with_x_y_torch(): + x = np.array([1.0, 1.0, 1.0]) + y = torch.tensor([5.0, 5.0, 5.0]) + result = sum_xyz(x, y=y) + expected_result = np.array([9.0, 9.0, 9.0]) + assert np.allclose(result, expected_result) + + +@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") +def test_kwargs_sum_xyz_with_x_y_z_torch(): + x = np.array([1.0, 1.0, 1.0]) + y = torch.tensor([5.0, 5.0, 5.0]) + z = torch.tensor([10.0, 10.0, 10.0]) + result = sum_xyz(x, y=y, z=z) + expected_result = np.array([16.0, 16.0, 16.0]) + assert np.allclose(result, expected_result) + + +def test_disable_tensor_check(): + """ + Test behavior when tensor type checks are disabled. + """ + # Wrong shape + ct.backend.disable_tensor_check() + x = np.array([1.0, 1.0, 1.0], dtype=np.float32) + y = np.array(5.0, dtype=np.float32) + result = sum_xyz(x, y=y) + assert np.allclose(result, np.array([9.0, 9.0, 9.0])) + ct.backend.enable_tensor_check() + with pytest.raises(ValueError, match=r".*got shape.*"): + sum_xyz(x, y=y) + + # Wrong dtype + ct.backend.disable_tensor_check() + x = np.array([1.0, 1.0, 1.0], dtype=np.float32) + y = np.array([5, 5, 5], dtype=np.int32) + result = sum_xyz(x, y=y) + assert np.allclose(result, np.array([9.0, 9.0, 9.0])) + ct.backend.enable_tensor_check() + with pytest.raises(ValueError, match=r".*got dtype.*"): + sum_xyz(x, y=y) + + +def test_union_of_hints(): + + @ct.backend.tensor_backend_auto + def identity( + arr: Union[ + Float[Tensor, "3 4"], + Float[Tensor, "N 3 4"], + Int[Tensor, "N 3 4"], + ] + ): + return arr + + # Supported dtype and shape + x = np.random.rand(3, 4).astype(np.float32) + assert np.allclose(identity(x), x) + x = np.random.rand(5, 3, 4).astype(np.float32) + assert np.allclose(identity(x), x) + x = np.random.randint(0, 10, (5, 3, 4)).astype(np.int32) + assert np.allclose(identity(x), x) + + # Wrong dtype + with pytest.raises(ValueError, match=r".*got dtype.*"): + x = np.random.randint(0, 10, (3, 4)).astype(np.int32) + identity(x) + + # Wrong shape + with pytest.raises(ValueError, match=r".*got shape.*"): + x = np.random.rand(5, 4).astype(np.float32) + identity(x)