Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: flexible numpy/torch backend and tensor type checking #69

Merged
merged 3 commits into from
Jul 9, 2024

Conversation

yxlao
Copy link
Owner

@yxlao yxlao commented Jul 9, 2024

TL;DR

Flexible Backend + Type Checking

We want CamTools to seamlessly work with NumPy/Torch, and enforce dtype and shape checking for tensors. We want all these to work automagically with the help of decorators. This PR provides the initial infrastructure to do that while future PRs will incrementally migrate the existing functions.
@ct.backend.tensor_backend_auto
def func(t: Float[Tensor, "n 3"]):
    """
    Automatically use backend determined by `t`. 
    Also check type and shape of `t`.
    """
    pass

@ct.backend.tensor_backend_numpy
def func(t: Float[Tensor, "n 3"]):
    """
    Attempts to convert `t` to NumPy. 
    Also check type and shape of `t`.
    """
    pass

Overall goals

  1. All functions in CamTools shall be able to handle both NumPy arrays and Torch tensors. This shall be achieved automatically by decorators.
  2. CamTools determines the backend to use by looking at the input arguments. Otherwise, the default backend is NumPy.
  3. Torch shall not be a compulsory dependency of CamTools, yet CamTools can detect if Torch is installed and enable the corresponding features.
  4. Functions in CamTools shall be type-hinted with the expected tensor shape and dtype. The shape and dtypes are automatically checked and enforced by decorators. Compatible list inputs are automatically converted to the native tensor format if the type hint is a tensor.

Key usages

With this PR, the key functions include:

  1. Tensor-like inputs shall be type-annotated with jaxtyping, specifying the shape and dtype. A single tensor type or a union of tensor types can be used. Flexible shape (e.g. "... 3") or context-dependent shape (e.g. "n 3") can be specified.
  2. @ct.backend.tensor_backend_auto: Automatically determines the backend from the input arguments. Compatible list inputs are automatically converted to the native tensor format. This will only handle arguments that are hinted as tensors.
  3. @ct.backend.tensor_backend_numpy and @ct.backend.tensor_backend_torch: Enforces the use of NumPy and Torch backends by converting input tensors to NumPy or Torch, respectively. Compatible list inputs are automatically converted to the native tensor format. This will only handle arguments that are hinted as tensors.
  4. ct.backend.enable_tensor_check() and ct.backend.disable_tensor_check(): Enable or disable tensor type checking (for dtype and shape) globally. By default, the tensor type checking is enabled. The checks will be done if @ct.backend.tensor_backend_xxx decorators are used and the argument is hinted as a tensor.
  5. Inside the functions, you may use ivy to for computation or use native Python operators. The goal is to make the functions compatible with both NumPy and Torch.

Here are some examples.

import camtools as ct
from camtools.backend import Tensor, ivy
from jaxtyping import Float

@ct.backend.tensor_backend_auto
def add(x: Float[Tensor, "2 3"], y: Float[Tensor, "1 3"]) -> Float[Tensor, "2 3"]:
     # Or, you may use ivy:
     # return ivy.add(x, y)
     return x + y

# Numpy inputs: works, returns numpy array
x = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y = np.array([[1.0, 1.0, 1.0]])
result = add(x, y)

# Torch inputs: works, returns torch tensor
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y = torch.tensor([[1.0, 1.0, 1.0]])
result = add(x, y)

# Mixed numpy and list: works, returns numpy array
x = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y = [[1.0, 1.0, 1.0]]
result = add(x, y)

# Mixed torch and list: works, returns torch tensor
x = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y = [[1.0, 1.0, 1.0]]
result = add(x, y)

# Mixed numpy and torch: exception
x = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y = torch.tensor([[1.0, 1.0, 1.0]])
result = add(x, y)  # ValueError: Tensors must be from the same backend

# Numpy inputs with wrong shape: exception
x = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y = np.array([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0]])
result = add(x, y)  # ValueError: Expected shape (1, 3), but got (2, 3)

# Numpy inputs with wrong dtype: exception
x = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y = np.array([[1, 1, 1]]) # int dtype
result = add(x, y)  # ValueError: Expected dtype float32, but got int64

@ct.backend.tensor_backend_numpy
def add_numpy(x: Float[Tensor, "2 3"], y: Float[Tensor, "1 3"]) -> Float[Tensor, "2 3"]:
    return x + y

# Mixed numpy and torch: works, returns numpy array
x = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y = torch.tensor([[1.0, 1.0, 1.0]])
result = add_numpy(x, y)

@ct.backend.tensor_backend_torch
def add_torch(x: Float[Tensor, "2 3"], y: Float[Tensor, "1 3"]) -> Float[Tensor, "2 3"]:
    return x + y

# Mixed numpy and torch: works, returns torch tensor
x = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
y = torch.tensor([[1.0, 1.0, 1.0]])
result = add_torch(x, y)

Other notes

  • Tensor creation APIs are supported for tensor creation using different backends: ct.backend.create_array, ct.backend.create_ones, ct.backend.create_zeros, and ct.backend.create_empty.
  • Union of hints is supported. For example, Union[Float[Tensor, "3 4"], Float[Tensor, "N 3 4"]].
  • You shall import ivy and torch from camtools.backend, as this sets up some internal configurations for the backend.
  • The wrappers are carefully optimized to minimize the overhead of backend conversion and type checking. More details in perf: performance improvements for backend function wrappers #66.

@yxlao yxlao changed the title feat: initial infrastructures for flexible backend support feat: initial infrastructures for flexible numpy/torch backend support Jul 9, 2024
@yxlao yxlao changed the title feat: initial infrastructures for flexible numpy/torch backend support feat: initial infrastructures for flexible numpy/torch backend Jul 9, 2024
@yxlao yxlao changed the title feat: initial infrastructures for flexible numpy/torch backend feat: flexible numpy/torch backend and tensor type checking Jul 9, 2024
@yxlao yxlao merged commit 04beaea into main Jul 9, 2024
9 checks passed
@yxlao yxlao deleted the flex-backend branch July 9, 2024 15:16
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant