From 9b3700b889ffb6b2991f82d03223cd197ebfbf93 Mon Sep 17 00:00:00 2001 From: Yixing Lao Date: Thu, 19 Dec 2024 13:37:57 +0800 Subject: [PATCH] remove auto backend switcher --- camtools/__init__.py | 1 - camtools/backend.py | 588 ------------------------------------------- camtools/metric.py | 2 +- camtools/util.py | 32 +++ pyproject.toml | 1 - test/test_backend.py | 538 --------------------------------------- 6 files changed, 33 insertions(+), 1129 deletions(-) delete mode 100644 test/test_backend.py diff --git a/camtools/__init__.py b/camtools/__init__.py index 2b0c3e6d..13f2e922 100644 --- a/camtools/__init__.py +++ b/camtools/__init__.py @@ -1,5 +1,4 @@ from . import artifact -from . import backend from . import camera from . import colmap from . import colormap diff --git a/camtools/backend.py b/camtools/backend.py index c19f3eab..e69de29b 100644 --- a/camtools/backend.py +++ b/camtools/backend.py @@ -1,588 +0,0 @@ -import inspect -import typing -import warnings -from functools import lru_cache, wraps -from typing import Any, Tuple, Union - -from jaxtyping import AbstractArray -import numpy as np -from typing import List, Tuple, Dict, Literal - - -@lru_cache(maxsize=1) -def _safely_import_torch(): - """ - Open3D has an issue where it must be imported before torch. If Open3D is - installed, this function will import Open3D before torch. Otherwise, it - will return simply import and return torch. - - Use this function to import torch within camtools to handle the Open3D - import order issue. That is, within camtools, we shall avoid `import torch`, - and instead use `from camtools.backend import torch`. As torch is an - optional dependency for camtools, this function will return None if torch - is not available. - - Returns: - module: The torch module if available, otherwise None. - """ - try: - __import__("open3d") - except ImportError: - pass - - try: - _torch = __import__("torch") - return _torch - except ImportError: - return None - - -torch = _safely_import_torch() - - -@lru_cache(maxsize=1) -def is_torch_available(): - return _safely_import_torch() is not None - - -@lru_cache(maxsize=1) -def _safely_import_ivy(): - """ - This function sets up the warnings filter to suppress the deprecation - before importing ivy. This is a temporary workaround to suppress the - deprecation warning from numpy 2.0. - - Within camtools, we shall avoid `import ivy`, and instead use - `from camtools.backend import ivy`. - """ - warnings.filterwarnings( - "ignore", - message=".*numpy.core.numeric is deprecated.*", - category=DeprecationWarning, - module="ivy", - ) - warnings.filterwarnings( - "ignore", - message=".*Compositional function.*array_mode is set to False.*", - category=UserWarning, - module="ivy", - ) - ivy = __import__("ivy") - ivy.set_array_mode(False) - return ivy - - -ivy = _safely_import_ivy() - - -@lru_cache(maxsize=64) -def _dtype_to_str(dtype): - """ - Convert numpy or torch dtype to string - - - "bool" - - "bool_" - - "uint4" - - "uint8" - - "uint16" - - "uint32" - - "uint64" - - "int4" - - "int8" - - "int16" - - "int32" - - "int64" - - "bfloat16" - - "float16" - - "float32" - - "float64" - - "complex64" - - "complex128" - """ - if isinstance(dtype, np.dtype): - return dtype.name - - if is_torch_available(): - if isinstance(dtype, torch.dtype): - return str(dtype).split(".")[1] - - return ValueError(f"Unknown dtype {dtype}.") - - -@lru_cache(maxsize=1024) -def _shape_from_dim_str(dim_str: str) -> Tuple[Union[int, None, str], ...]: - shape = [] - elements = dim_str.split() - for elem in elements: - if elem == "...": - shape.append("...") - elif elem.isdigit(): - shape.append(int(elem)) - else: - shape.append(None) - return tuple(shape) - - -@lru_cache(maxsize=1024) -def _is_shape_compatible( - arg_shape: Tuple[Union[int, None, str], ...], - gt_shape: Tuple[Union[int, None, str], ...], -) -> bool: - if "..." in gt_shape: - pre_ellipsis = None - post_ellipsis = None - - for i, dim in enumerate(gt_shape): - if dim == "...": - pre_ellipsis = i - post_ellipsis = len(gt_shape) - i - 1 - break - - if pre_ellipsis is None or gt_shape.count("...") > 1: - raise ValueError( - "Only one ellipsis is supported in the shape hint for now." - ) - - if len(arg_shape) < len(gt_shape) - 1: - return False - - for i in range(pre_ellipsis): - if arg_shape[i] != gt_shape[i] and gt_shape[i] is not None: - return False - - for i in range(1, post_ellipsis + 1): - if arg_shape[-i] != gt_shape[-i] and gt_shape[-i] is not None: - return False - - return True - else: - if len(arg_shape) != len(gt_shape): - return False - - for arg_dim, gt_dim in zip(arg_shape, gt_shape): - if arg_dim != gt_dim and gt_dim is not None: - return False - - return True - - -@lru_cache(maxsize=1024) -def _assert_tensor_hint( - hint: Union[AbstractArray, Union[AbstractArray]], - arg_shape: Tuple[int, ...], - arg_dtype: Any, - arg_name: str, -): - """ - Args: - hint: A type hint for a tensor, must be javtyping.AbstractArray or - a Union of javtyping.AbstractArray. - arg: An argument to check, typically a tensor. - arg_name: The name of the argument, for error messages. - """ - if typing.get_origin(hint) is Union: - for sub_hint in typing.get_args(hint): - gt_shape = _shape_from_dim_str(sub_hint.dim_str) - gt_dtypes = sub_hint.dtypes - if ( - _is_shape_compatible(arg_shape, gt_shape) - and _dtype_to_str(arg_dtype) in gt_dtypes - ): - return - raise ValueError( - f"{arg_name} must be of shape and dtype " - f"compatible with any of {hint}, " - f"but got shape {arg_shape} and got dtype {_dtype_to_str(arg_dtype)}." - ) - else: - # Check shapes. - gt_shape = _shape_from_dim_str(hint.dim_str) - if not _is_shape_compatible(arg_shape, gt_shape): - raise ValueError( - f"{arg_name} must be of shape {gt_shape}, " - f"but got shape {arg_shape}." - ) - # Check dtype. - gt_dtypes = hint.dtypes - if _dtype_to_str(arg_dtype) not in gt_dtypes: - raise ValueError( - f"{arg_name} must be of dtype {gt_dtypes}, " - f"but got dtype {_dtype_to_str(arg_dtype)}." - ) - - -@lru_cache(maxsize=1) -def _get_valid_array_types(): - if is_torch_available(): - valid_array_types = (np.ndarray, torch.Tensor) - else: - valid_array_types = (np.ndarray,) - return valid_array_types - - -# Global variable to keep track of the tensor type check status -_tensor_check_enabled = True - - -def enable_tensor_check(): - """ - Enable the tensor type check globally. This function activates type checking - for tensors, which is useful for ensuring that tensor operations are - performed correctly, especially during debugging and development. - """ - global _tensor_check_enabled - _tensor_check_enabled = True - - -def disable_tensor_check(): - """ - Disable the tensor type check globally. This function deactivates type checking - for tensors, which can be useful for performance optimizations or when - type checks are known to be unnecessary or problematic. - """ - global _tensor_check_enabled - _tensor_check_enabled = False - - -def is_tensor_check_enabled(): - """ - Returns True if the tensor dtype and shape check is enabled, and False - otherwise. This will be used when @tensor_backend_auto, - @tensor_backend_numpy, or @tensor_backend_torch is called. - """ - return _tensor_check_enabled - - -class Tensor: - """ - An abstract tensor type for type hinting only. - Typically np.ndarray or torch.Tensor is supported. - """ - - -def _is_pure_tensor_hint(hint): - """ - Returns True if the type hint is a tensor hint or a Union of purely - tensor hints, and False otherwise. - """ - if typing.get_origin(hint) is Union: - is_tensor_hints = [ - inspect.isclass(h) and issubclass(h, AbstractArray) - for h in typing.get_args(hint) - ] - return all(is_tensor_hints) - else: - return inspect.isclass(hint) and issubclass(hint, AbstractArray) - - -def tensor_backend_auto(func, force_backend: Literal["numpy", "torch"] = None): - """ - Automatic backend selection based on the backend of type-annotated input - tensors, and run tensor type and shape checks if is_tensor_check_enabled(). - If there are no tensors, or if the tensors do not have the necessary type - annotations, the default backend is used. The function targets specifically - AbstractArray annotations to determine tensor treatment and - backend usage. - - Raises: - ValueError: If the tensor's shape and dtype do not match the type hints. - - Detailed behaviors: - 1. Only processes input arguments that are explicitly typed as - AbstractArray. Arguments without this type hint or with - different annotations maintain their default behavior without backend - modification. - 2. Supports handling of numpy.ndarray, torch.Tensor, and Python lists that - should be converted to tensors based on their type hints. - 3. If the type hint is AbstractArray and the argument is a list, - the list will be converted to a tensor using the native array - functionality of the active backend. - 4. Ensures all tensor arguments must be from the same backend to avoid - conflicts. - 5. Uses the default backend if no tensors are present or if the tensors do - not require specific backend handling based on their annotations. - 6. If force_backend is specified, the inferred backend from arguments - and type hints will be ignored, and the specified backend will be used - instead. Don't confuse this with the default backend as this takes - higher precedence. - """ - - # Pre-compute the function signature and type hints - # This is called per function declaration and not per function call - sig = inspect.signature(func) - arg_names = [ - param.name - for param in sig.parameters.values() - if param.kind - in ( - inspect.Parameter.POSITIONAL_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY, - ) - ] - arg_name_to_hint = { - name: hint - for name, hint in typing.get_type_hints(func).items() - if name in arg_names - } - tensor_names = [ - name for name, hint in arg_name_to_hint.items() if _is_pure_tensor_hint(hint) - ] - - def _convert_tensor_to_backend(arg, backend): - """ - Convert the tensor to the specified backend. It shall already be checked - that the arg is a tensor-like object, the tensor is type-annotated, and - the backend is valid. - """ - if backend == "numpy": - if isinstance(arg, np.ndarray): - return arg - elif is_torch_available() and isinstance(arg, torch.Tensor): - return arg.detach().cpu().numpy() - elif isinstance(arg, (list, tuple)): - return np.array(arg) - else: - raise ValueError( - f"Unsupported type {type(arg)} for conversion to numpy." - ) - elif backend == "torch": - if not is_torch_available(): - raise ValueError("Torch is not available.") - elif isinstance(arg, torch.Tensor): - return arg - elif isinstance(arg, np.ndarray): - return torch.from_numpy(arg) - elif isinstance(arg, (list, tuple)): - return torch.tensor(arg) - else: - raise ValueError( - f"Unsupported type {type(arg)} for conversion to torch." - ) - else: - raise ValueError(f"Unsupported backend {backend}.") - - @wraps(func) - def wrapper(*args, **kwargs): - # Bind args and kwargs - # This is faster than sig.bind() but less flexible - arg_name_to_arg = dict(zip(arg_names, args)) - arg_name_to_arg.update(kwargs) - - # Fill in missing arguments with their default values - for arg_name, param in sig.parameters.items(): - if arg_name not in arg_name_to_arg and param.default is not param.empty: - arg_name_to_arg[arg_name] = param.default - - # Determine backend - if force_backend is None: - # Recursively collect np.ndarray and torch.Tensor objects - # Other types including lists are ignored - if is_torch_available(): - tensor_types = (np.ndarray, torch.Tensor) - else: - tensor_types = (np.ndarray,) - tensors = [ - arg_name_to_arg[tensor_name] - for tensor_name in tensor_names - if isinstance(arg_name_to_arg[tensor_name], tensor_types) - ] - - # Determine the backend based on tensor types present - if not tensors: - backend = "numpy" - else: - tensor_types_used = {type(t) for t in tensors} - if tensor_types_used == {np.ndarray}: - backend = "numpy" - elif is_torch_available() and tensor_types_used == {torch.Tensor}: - backend = "torch" - else: - raise ValueError( - f"All tensors must be from the same backend, " - f"but got {tensor_types_used}." - ) - elif force_backend in ("numpy", "torch"): - backend = force_backend - else: - raise ValueError(f"Unsupported forced backend {force_backend}.") - - # Convert tensors to the appropriate backend - for tensor_name in tensor_names: - arg_name_to_arg[tensor_name] = _convert_tensor_to_backend( - arg_name_to_arg[tensor_name], backend - ) - - # Check tensor dtype and shape if enabled - if is_tensor_check_enabled(): - for tensor_name in tensor_names: - hint = arg_name_to_hint[tensor_name] - tensor_arg = arg_name_to_arg[tensor_name] - if isinstance(tensor_arg, _get_valid_array_types()): - _assert_tensor_hint( - hint=hint, - arg_shape=tensor_arg.shape, - arg_dtype=tensor_arg.dtype, - arg_name=tensor_name, - ) - - # Call the original function with updated arguments - result = func(**arg_name_to_arg) - - return result - - return wrapper - - -def tensor_backend_numpy(func): - """ - Run this function by first converting its input tensors to numpy arrays. - Only jaxtyping-annotated tensors will be processed. This wrapper shall be - used if the internal implementation is numpy-only or if we expect to return - numpy arrays. - - Behavior: - 1. Only converts arguments that are annotated explicitly with a jaxtyping - tensor type. If the type hint is a container of tensors, the conversion - will not be performed. - 2. Supports conversion of lists into numpy arrays if they are intended to be - tensors, according to the function's type annotations. - 3. The conversion is applied to top-level arguments and does not recursively - convert tensors within nested custom types (e.g., custom classes - containing tensors). - 4. This decorator is particularly useful for functions requiring consistent - tensor handling specifically with numpy, ensuring compatibility and - simplifying operations that depend on numpy's functionality. - - Note: - - The decorator inspects type annotations and applies conversions where - specified. - - Lists of tensors or tensors within lists annotated as tensors - will be converted to numpy arrays if not already in that format. - - This function simply wraps the tensor_auto_backend function with the - force_backend argument set to "numpy". - """ - # Wrap the original function with tensor_auto_backend enforcing numpy. - wrapped_func = tensor_backend_auto(func, force_backend="numpy") - - @wraps(func) - def wrapper(*args, **kwargs): - return wrapped_func(*args, **kwargs) - - return wrapper - - -def tensor_backend_torch(func): - """ - Run this function by first converting its input tensors to torch tensors. - Only jaxtyping-annotated tensors will be processed. This wrapper shall be - used if the internal implementation is torch-only or if we expect to return - torch tensors. - - Behavior: - 1. Only converts arguments that are annotated explicitly with a jaxtyping - tensor type. If the type hint is a container of tensors, the conversion - will not be performed. - 2. Supports conversion of lists into torch tensors if they are intended to be - tensors, according to the function's type annotations. - 3. The conversion is applied to top-level arguments and does not recursively - convert tensors within nested custom types (e.g., custom classes - containing tensors). - 4. This decorator is particularly useful for functions requiring consistent - tensor handling specifically with torch, ensuring compatibility and - simplifying operations that depend on torch's functionality. - - Note: - - The decorator inspects type annotations and applies conversions where - specified. - - Lists of tensors or tensors within lists annotated as tensors - will be converted to torch tensors if not already in that format. - - This function simply wraps the tensor_auto_backend function with the - force_backend argument set to "torch". - """ - # Wrap the original function with tensor_auto_backend enforcing torch. - wrapped_func = tensor_backend_auto(func, force_backend="torch") - - @wraps(func) - def wrapper(*args, **kwargs): - return wrapped_func(*args, **kwargs) - - return wrapper - - -def create_array( - arr: List, - dtype: Any, - backend: Literal["numpy", "torch"], -) -> Tensor: - """ - Call np.array() or torch.tensor() depending on the backend. - """ - if backend == "numpy": - return np.array(arr, dtype=dtype) - elif backend == "torch": - if not is_torch_available(): - raise ValueError("Torch is not available.") - return torch.tensor(arr, dtype=dtype) - - -def create_ones( - shape: Tuple[int, ...], - dtype: Any, - backend: Literal["numpy", "torch"], -) -> Tensor: - """ - Call np.ones() or torch.ones() depending on the backend. - """ - if backend == "numpy": - return np.ones(shape, dtype=dtype) - elif backend == "torch": - if not is_torch_available(): - raise ValueError("Torch is not available.") - return torch.ones(shape, dtype=dtype) - - -def create_zeros( - shape: Tuple[int, ...], - dtype: Any, - backend: Literal["numpy", "torch"], -) -> Tensor: - """ - Call np.zeros() or torch.zeros() depending on the backend. - """ - if backend == "numpy": - return np.zeros(shape, dtype=dtype) - elif backend == "torch": - if not is_torch_available(): - raise ValueError("Torch is not available.") - return torch.zeros(shape, dtype=dtype) - - -def create_empty( - shape: Tuple[int, ...], - dtype: Any, - backend: Literal["numpy", "torch"], -) -> Tensor: - """ - Call np.empty() or torch.empty() depending on the backend. - """ - if backend == "numpy": - return np.empty(shape, dtype=dtype) - elif backend == "torch": - if not is_torch_available(): - raise ValueError("Torch is not available.") - return torch.empty(shape, dtype=dtype) - - -def get_tensor_backend(arr: Tensor) -> Literal["numpy", "torch"]: - """ - Get the backend of a tensor. - """ - if isinstance(arr, np.ndarray): - return "numpy" - elif is_torch_available() and isinstance(arr, torch.Tensor): - return "torch" - else: - raise ValueError(f"Unsupported tensor type {type(arr)}.") diff --git a/camtools/metric.py b/camtools/metric.py index 7cc8badb..6028a376 100644 --- a/camtools/metric.py +++ b/camtools/metric.py @@ -9,7 +9,7 @@ from . import image from . import io from . import sanity -from .backend import torch +from .util import _safe_torch as torch def image_psnr( diff --git a/camtools/util.py b/camtools/util.py index 638a1022..a7fc13f7 100644 --- a/camtools/util.py +++ b/camtools/util.py @@ -1,6 +1,7 @@ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor, as_completed from typing import Any, Callable, Iterable +from functools import lru_cache from tqdm import tqdm @@ -117,3 +118,34 @@ def query_yes_no(question, default=None): return response_to_bool[choice] else: print('Please respond with "yes" or "no" (or "y" or "n").') + + +@lru_cache(maxsize=1) +def _safely_import_torch(): + """ + Open3D has an issue where it must be imported before torch. If Open3D is + installed, this function will import Open3D before torch. Otherwise, it + will return simply import and return torch. + + Use this function to import torch within camtools to handle the Open3D + import order issue. That is, within camtools, we shall avoid `import torch`, + and instead use `from camtools.backend import torch`. As torch is an + optional dependency for camtools, this function will return None if torch + is not available. + + Returns: + module: The torch module if available, otherwise None. + """ + try: + __import__("open3d") + except ImportError: + pass + + try: + _torch = __import__("torch") + return _torch + except ImportError: + return None + + +_safe_torch = _safely_import_torch() diff --git a/pyproject.toml b/pyproject.toml index 5aeb8fa1..1a73503e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ dependencies = [ "matplotlib>=3.3.4", "scikit-image>=0.16.2", "tqdm>=4.60.0", - "ivy>=0.0.8.0", "jaxtyping>=0.2.12", ] description = "CamTools: Camera Tools for Computer Vision." diff --git a/test/test_backend.py b/test/test_backend.py deleted file mode 100644 index 07dd34f5..00000000 --- a/test/test_backend.py +++ /dev/null @@ -1,538 +0,0 @@ -import numpy as np -import pytest -from jaxtyping import Float, Int - -import camtools as ct -from camtools.backend import Tensor, ivy, is_torch_available, torch -import warnings -from typing import Union - - -@pytest.fixture(autouse=True) -def ignore_ivy_warnings(): - warnings.filterwarnings( - "ignore", - message=".*Compositional function.*array_mode is set to False.*", - category=UserWarning, - ) - yield - - -@ct.backend.tensor_backend_auto -def concat(x: Float[Tensor, "..."], y: Float[Tensor, "..."]): - return ivy.concat([x, y], axis=0) - - -def test_concat_numpy(): - """ - Test the default backend when no tensors are provided. - """ - x = np.array([1.0, 2.0, 3.0]) - y = np.array([4.0, 5.0, 6.0]) - result = concat(x, y) - assert isinstance(result, np.ndarray) - assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - -@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") -def test_concat_torch(): - """ - Test the default backend when no tensors are provided. - """ - x = torch.tensor([1.0, 2.0, 3.0]) - y = torch.tensor([4.0, 5.0, 6.0]) - result = concat(x, y) - assert isinstance(result, torch.Tensor) - assert torch.equal(result, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - -def test_concat_list_to_numpy(): - """ - Test the default backend when no tensors are provided. - """ - x = [1.0, 2.0, 3.0] - y = [4.0, 5.0, 6.0] - result = concat(x, y) - assert isinstance(result, np.ndarray) - assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - -def test_concat_mix_list_and_numpy(): - """ - Test handling of mixed list and tensor types. - """ - x = [1.0, 2.0, 3.0] - y = np.array([4.0, 5.0, 6.0]) - result = concat(x, y) - assert isinstance(result, np.ndarray) - assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - -@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") -def test_concat_mix_list_and_torch(): - """ - Test handling of mixed list and tensor types. - """ - x = [1.0, 2.0, 3.0] - y = torch.tensor([4.0, 5.0, 6.0]) - result = concat(x, y) - assert isinstance(result, torch.Tensor) - assert torch.equal(result, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - -@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") -def test_concat_mix_numpy_and_torch(): - """ - Test error handling with mixed tensor types across arguments. - """ - x = np.array([1.0, 2.0, 3.0]) - y = torch.tensor([4.0, 5.0, 6.0]) - with pytest.raises(ValueError, match=r".*must be from the same backend.*"): - concat(x, y) - - -def test_concat_list_of_numpy(): - """ - Test handling of containers holding tensors from different backends. - """ - - x = [np.array(1.0), np.array(2.0), np.array(3.0)] - y = np.array([4.0, 5.0, 6.0]) - result = concat(x, y) - assert isinstance(result, np.ndarray) - assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - -@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") -def test_concat_list_of_torch(): - """ - Test handling of containers holding tensors from different backends. - """ - x = [torch.tensor(1.0), torch.tensor(2.0), torch.tensor(3.0)] - y = torch.tensor([4.0, 5.0, 6.0]) - result = concat(x, y) - - assert isinstance(result, torch.Tensor) - assert torch.equal(result, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - -@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") -def test_concat_list_of_numpy_and_torch(): - """ - Test handling with mixed tensor types across containers. - - In this case as lists are not type-checked, we both x and y will be - converted to default backend's arrays internally. That is, - x <- np.array(x) and y <- np.array(y) are both valid operation. In this - case, even though y contains tensors from both numpy and torch, as - np.asarray(y) is valid, the function should work. - - However, this can be very slow. As creating a torch tensor from a list of - np.ndarray is very slow and likewise for creating np.ndarray from a list of - torch tensors. Therefore, you shall avoid doing this in practice. - """ - x = [np.array(1.0), np.array(2.0), np.array(3.0)] - y = [torch.tensor(4.0), torch.tensor(5.0), torch.tensor(6.0)] - result = concat(x, y) - assert isinstance(result, np.ndarray) - assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - -def test_creation(): - @ct.backend.tensor_backend_auto - def creation(): - zeros = ivy.zeros([2, 3]) - return zeros - - # Default backend is numpy - tensor = creation() - assert isinstance(tensor, np.ndarray) - assert tensor.shape == (2, 3) - assert tensor.dtype == np.float32 - - -def test_type_hint_arguments_numpy(): - @ct.backend.tensor_backend_auto - def add( - x: Float[Tensor, "2 3"], - y: Float[Tensor, "1 3"], - ) -> Float[Tensor, "2 3"]: - return x + y - - # Default backend is numpy - x = np.array([[1, 2, 3], [4, 5, 6]], dtype=np.float32) - y = np.array([[1, 1, 1]], dtype=np.float32) - result = add(x, y) - expected = np.array([[2, 3, 4], [5, 6, 7]], dtype=np.float32) - assert isinstance(result, np.ndarray) - assert np.allclose(result, expected, atol=1e-5) - - # List can be converted to numpy automatically - x = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] - result = add(x, y) - assert isinstance(result, np.ndarray) - assert np.allclose(result, expected, atol=1e-5) - - # Incorrect shapes - with pytest.raises(ValueError, match=r".*got shape.*"): - y_wrong = np.array([[1, 1, 1, 1]], dtype=np.float32) - add(x, y_wrong) - - # Incorrect shape with lists - with pytest.raises(ValueError, match=r".*got shape.*"): - y_wrong = [[1.0, 1.0, 1.0, 1.0]] - add(x, y_wrong) - - # Incorrect dtype - with pytest.raises(ValueError, match=r".*got dtype.*"): - y_wrong = np.array([[1, 1, 1]], dtype=np.int64) - add(x, y_wrong) - - # Incorrect dtype with lists - with pytest.raises(ValueError, match=r".*got dtype.*"): - y_wrong = [[1, 1, 1]] - add(x, y_wrong) - - -@pytest.mark.skipif(not ct.backend.is_torch_available(), reason="Skip torch") -def test_type_hint_arguments_torch(): - @ct.backend.tensor_backend_auto - def add( - x: Float[Tensor, "2 3"], - y: Float[Tensor, "1 3"], - ) -> Float[Tensor, "2 3"]: - return x + y - - x = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=torch.float32) - y = torch.tensor([[1, 1, 1]], dtype=torch.float32) - result = add(x, y) - expected = torch.tensor([[2, 3, 4], [5, 6, 7]], dtype=torch.float32) - assert isinstance(result, torch.Tensor) - assert torch.allclose(result, expected, atol=1e-5) - - # List can be converted to torch automatically - x = [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]] - result = add(x, y) - assert isinstance(result, torch.Tensor) - assert torch.allclose(result, expected, atol=1e-5) - - # Incorrect shapes - with pytest.raises(ValueError, match=r".*got shape.*"): - y_wrong = torch.tensor([[1, 1, 1, 1]], dtype=torch.float32) - add(x, y_wrong) - - # Incorrect shape with lists - with pytest.raises(ValueError, match=r".*got shape.*"): - y_wrong = [[1.0, 1.0, 1.0, 1.0]] - add(x, y_wrong) - - # Incorrect dtype - with pytest.raises(ValueError, match=r".*got dtype.*"): - y_wrong = torch.tensor([[1, 1, 1]], dtype=torch.int64) - add(x, y_wrong) - - # Incorrect dtype with lists - with pytest.raises(ValueError, match=r".*got dtype.*"): - y_wrong = [[1, 1, 1]] - add(x, y_wrong) - - -def test_named_dim_numpy(): - @ct.backend.tensor_backend_auto - def add( - x: Float[Tensor, "3"], - y: Float[Tensor, "n 3"], - ) -> Float[Tensor, "n 3"]: - return x + y - - # Fixed x tensor - x = np.array([1.0, 2.0, 3.0], dtype=np.float32) - - # Valid y tensor with shape (1, 3) - y = np.array([[4.0, 5.0, 6.0]], dtype=np.float32) - result = add(x, y) - expected = np.array([[5.0, 7.0, 9.0]], dtype=np.float32) - assert isinstance(result, np.ndarray) - assert np.allclose(result, expected, atol=1e-5) - - # Valid y tensor with shape (2, 3) - y = np.array([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=np.float32) - result = add(x, y) - expected = np.array([[5.0, 7.0, 9.0], [8.0, 10.0, 12.0]], dtype=np.float32) - assert isinstance(result, np.ndarray) - assert np.allclose(result, expected, atol=1e-5) - - # Test for a shape mismatch where y does not conform to "n 3" - with pytest.raises(ValueError, match=r".*got shape.*"): - y_wrong = np.array([4.0, 5.0, 6.0], dtype=np.float32) # Shape (3,) - add(x, y_wrong) - - # List inputs that should be automatically converted and work - y = [[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] - result = add(x, y) - assert isinstance(result, np.ndarray) - assert np.allclose(result, expected, atol=1e-5) - - # Incorrect dtype with lists, expect dtype error - with pytest.raises(ValueError, match=r".*got dtype.*"): - y_wrong = [[4, 5, 6], [7, 8, 9]] # int type elements in list - add(x, y_wrong) - - -@pytest.mark.skipif(not ct.backend.is_torch_available(), reason="Skip torch") -def test_named_dim_torch(): - @ct.backend.tensor_backend_auto - def add( - x: Float[Tensor, "3"], - y: Float[Tensor, "n 3"], - ) -> Float[Tensor, "n 3"]: - return x + y - - # Fixed x tensor for Torch - x = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) - - # Valid y tensor with shape (1, 3) - y = torch.tensor([[4.0, 5.0, 6.0]], dtype=torch.float32) - result = add(x, y) - expected = torch.tensor([[5.0, 7.0, 9.0]], dtype=torch.float32) - assert isinstance(result, torch.Tensor) - assert torch.allclose(result, expected, atol=1e-5) - - # Valid y tensor with shape (2, 3) - y = torch.tensor([[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]], dtype=torch.float32) - result = add(x, y) - expected = torch.tensor([[5.0, 7.0, 9.0], [8.0, 10.0, 12.0]], dtype=torch.float32) - assert isinstance(result, torch.Tensor) - assert torch.allclose(result, expected, atol=1e-5) - - # Test for a shape mismatch where y does not conform to "n 3" - with pytest.raises(ValueError, match=r".*got shape.*"): - y_wrong = torch.tensor([4.0, 5.0, 6.0], dtype=torch.float32) # Shape (3,) - add(x, y_wrong) - - # List inputs that should be automatically converted and work - y = [[4.0, 5.0, 6.0], [7.0, 8.0, 9.0]] - result = add(x, y) - assert isinstance(result, torch.Tensor) - assert torch.allclose(result, expected, atol=1e-5) - - # Incorrect dtype with lists, expect dtype error - with pytest.raises(ValueError, match=r".*got dtype.*"): - y_wrong = [[4, 5, 6], [7, 8, 9]] # int type elements in list - add(x, y_wrong) - - -def test_concat_tensors_with_numpy(): - @ct.backend.tensor_backend_numpy - def concat_tensors_with_numpy( - x: Float[Tensor, "..."], - y: Float[Tensor, "..."], - ): - assert isinstance(x, np.ndarray) - assert isinstance(y, np.ndarray) - return np.concatenate([x, y], axis=0) - - # Test with numpy arrays - x = np.array([1.0, 2.0, 3.0]) - y = np.array([4.0, 5.0, 6.0]) - result = concat_tensors_with_numpy(x, y) - assert isinstance(result, np.ndarray) - assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - # Test with torch tensors - if is_torch_available(): - x = torch.tensor([1.0, 2.0, 3.0]) - y = torch.tensor([4.0, 5.0, 6.0]) - result = concat_tensors_with_numpy(x, y) - assert isinstance(result, np.ndarray) - assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - # Test with lists - x = [1.0, 2.0, 3.0] - y = [4.0, 5.0, 6.0] - result = concat_tensors_with_numpy(x, y) - assert isinstance(result, np.ndarray) - assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - # Numpy and list mixed - x = np.array([1.0, 2.0, 3.0]) - y = [4.0, 5.0, 6.0] - result = concat_tensors_with_numpy(x, y) - assert isinstance(result, np.ndarray) - assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - # Torch and list mixed - if is_torch_available(): - x = torch.tensor([1.0, 2.0, 3.0]) - y = [4.0, 5.0, 6.0] - result = concat_tensors_with_numpy(x, y) - assert isinstance(result, np.ndarray) - assert np.array_equal(result, np.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - -@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") -def test_concat_tensors_with_torch(): - @ct.backend.tensor_backend_torch - def concat_tensors_with_torch( - x: Float[Tensor, "..."], - y: Float[Tensor, "..."], - ): - assert isinstance(x, torch.Tensor) - assert isinstance(y, torch.Tensor) - return torch.cat([x, y], axis=0) - - # Test with numpy arrays - x_np = np.array([1.0, 2.0, 3.0]).astype(np.float32) - y_np = np.array([4.0, 5.0, 6.0]).astype(np.float32) - result_np = concat_tensors_with_torch(x_np, y_np) - assert isinstance(result_np, torch.Tensor) - assert torch.allclose(result_np, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - # Test with torch tensors - x_torch = torch.tensor([1.0, 2.0, 3.0]) - y_torch = torch.tensor([4.0, 5.0, 6.0]) - result_torch = concat_tensors_with_torch(x_torch, y_torch) - assert isinstance(result_torch, torch.Tensor) - assert torch.equal(result_torch, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - # Test with lists - x_list = [1.0, 2.0, 3.0] - y_list = [4.0, 5.0, 6.0] - result_list = concat_tensors_with_torch(x_list, y_list) - assert isinstance(result_list, torch.Tensor) - assert torch.allclose(result_list, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - # Mixed types: numpy array and list - x_mixed = np.array([1.0, 2.0, 3.0]).astype(np.float32) - y_mixed = [4.0, 5.0, 6.0] - result_mixed = concat_tensors_with_torch(x_mixed, y_mixed) - assert isinstance(result_mixed, torch.Tensor) - assert torch.allclose(result_mixed, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0])) - - # Mixed types: torch tensor and list - x_mixed_torch = torch.tensor([1.0, 2.0, 3.0]) - y_mixed_list = [4.0, 5.0, 6.0] - result_mixed_torch_list = concat_tensors_with_torch(x_mixed_torch, y_mixed_list) - assert isinstance(result_mixed_torch_list, torch.Tensor) - assert torch.allclose( - result_mixed_torch_list, torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) - ) - - -@ct.backend.tensor_backend_numpy -def sum_xyz( - x: Float[Tensor, "3"], - y: Float[Tensor, "3"] = (2.0, 2.0, 2.0), - z: Float[Tensor, "3"] = (3.0, 3.0, 3.0), -): - assert isinstance(x, np.ndarray) - assert isinstance(y, np.ndarray) - assert isinstance(z, np.ndarray) - return x + y + z - - -def test_kwargs_sum_xyz_with_x(): - x = np.array([1.0, 1.0, 1.0]) - result = sum_xyz(x) - expected_result = np.array([6.0, 6.0, 6.0]) - assert np.allclose(result, expected_result) - - -def test_kwargs_sum_xyz_with_x_y(): - x = np.array([1.0, 1.0, 1.0]) - y = np.array([5.0, 5.0, 5.0]) - result = sum_xyz(x, y=y) - expected_result = np.array([9.0, 9.0, 9.0]) - assert np.allclose(result, expected_result) - - -def test_kwargs_sum_xyz_with_x_y_z(): - x = np.array([1.0, 1.0, 1.0]) - y = np.array([5.0, 5.0, 5.0]) - z = np.array([10.0, 10.0, 10.0]) - result = sum_xyz(x, y=y, z=z) - expected_result = np.array([16.0, 16.0, 16.0]) - assert np.allclose(result, expected_result) - - -@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") -def test_kwargs_sum_xyz_with_x_torch(): - x = torch.tensor([1.0, 1.0, 1.0]) - result = sum_xyz(x) - expected_result = np.array([6.0, 6.0, 6.0]) - assert np.allclose(result, expected_result) - - -@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") -def test_kwargs_sum_xyz_with_x_y_torch(): - x = np.array([1.0, 1.0, 1.0]) - y = torch.tensor([5.0, 5.0, 5.0]) - result = sum_xyz(x, y=y) - expected_result = np.array([9.0, 9.0, 9.0]) - assert np.allclose(result, expected_result) - - -@pytest.mark.skipif(not is_torch_available(), reason="Torch is not available") -def test_kwargs_sum_xyz_with_x_y_z_torch(): - x = np.array([1.0, 1.0, 1.0]) - y = torch.tensor([5.0, 5.0, 5.0]) - z = torch.tensor([10.0, 10.0, 10.0]) - result = sum_xyz(x, y=y, z=z) - expected_result = np.array([16.0, 16.0, 16.0]) - assert np.allclose(result, expected_result) - - -def test_disable_tensor_check(): - """ - Test behavior when tensor type checks are disabled. - """ - # Wrong shape - ct.backend.disable_tensor_check() - x = np.array([1.0, 1.0, 1.0], dtype=np.float32) - y = np.array(5.0, dtype=np.float32) - result = sum_xyz(x, y=y) - assert np.allclose(result, np.array([9.0, 9.0, 9.0])) - ct.backend.enable_tensor_check() - with pytest.raises(ValueError, match=r".*got shape.*"): - sum_xyz(x, y=y) - - # Wrong dtype - ct.backend.disable_tensor_check() - x = np.array([1.0, 1.0, 1.0], dtype=np.float32) - y = np.array([5, 5, 5], dtype=np.int32) - result = sum_xyz(x, y=y) - assert np.allclose(result, np.array([9.0, 9.0, 9.0])) - ct.backend.enable_tensor_check() - with pytest.raises(ValueError, match=r".*got dtype.*"): - sum_xyz(x, y=y) - - -def test_union_of_hints(): - - @ct.backend.tensor_backend_auto - def identity( - arr: Union[ - Float[Tensor, "3 4"], - Float[Tensor, "n 3 4"], - Int[Tensor, "n 3 4"], - ] - ): - return arr - - # Supported dtype and shape - x = np.random.rand(3, 4).astype(np.float32) - assert np.allclose(identity(x), x) - x = np.random.rand(5, 3, 4).astype(np.float32) - assert np.allclose(identity(x), x) - x = np.random.randint(0, 10, (5, 3, 4)).astype(np.int32) - assert np.allclose(identity(x), x) - - # Wrong dtype - with pytest.raises(ValueError, match=r".*got dtype.*"): - x = np.random.randint(0, 10, (3, 4)).astype(np.int32) - identity(x) - - # Wrong shape - with pytest.raises(ValueError, match=r".*got shape.*"): - x = np.random.rand(5, 4).astype(np.float32) - identity(x)