Skip to content

Commit

Permalink
beignet.func.space
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Apr 22, 2024
1 parent f3f894b commit eaca190
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions tests/beignet/operators/test__transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import beignet.operators
import torch.func
import torch.testing
from torch import Tensor


def test_transform():
Expand All @@ -20,11 +21,11 @@ def test_transform():
),
)

def f(R):
return torch.sum(input**2)
def f(r: Tensor) -> Tensor:
return torch.sum(r**2)

def g(T, R):
return torch.sum(beignet.operators.transform(transformation, input) ** 2)
def g(t: Tensor, r: Tensor) -> Tensor:
return torch.sum(beignet.operators.transform(t, r) ** 2)

torch.testing.assert_allclose(
torch.func.grad(f)(beignet.operators.transform(transformation, input)),
Expand Down

0 comments on commit eaca190

Please sign in to comment.