Skip to content

Commit

Permalink
beignet.func.space
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Apr 22, 2024
1 parent 0eb76eb commit f3f894b
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 20 deletions.
20 changes: 0 additions & 20 deletions tests/beignet/operators/test__inverse_transform.py
Original file line number Diff line number Diff line change
@@ -1,20 +0,0 @@
import beignet.operators
import torch.testing


def test_transform():
input = torch.randn([256, 3], dtype=torch.float32)

transformation = torch.randn([3, 3], dtype=torch.float32)

torch.testing.assert_allclose(
torch.einsum(
"ij,kj->ki",
transformation,
input,
),
beignet.operators.transform(
transformation,
input,
),
)
32 changes: 32 additions & 0 deletions tests/beignet/operators/test__transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import beignet.operators
import torch.func
import torch.testing


def test_transform():
input = torch.randn([256, 3], dtype=torch.float32)

transformation = torch.randn([3, 3], dtype=torch.float32)

torch.testing.assert_allclose(
torch.einsum(
"ij,kj->ki",
transformation,
input,
),
beignet.operators.transform(
transformation,
input,
),
)

def f(R):
return torch.sum(input**2)

def g(T, R):
return torch.sum(beignet.operators.transform(transformation, input) ** 2)

torch.testing.assert_allclose(
torch.func.grad(f)(beignet.operators.transform(transformation, input)),
torch.func.grad(g, 1)(transformation, input),
)

0 comments on commit f3f894b

Please sign in to comment.