Skip to content

Commit

Permalink
beignet.func.space
Browse files Browse the repository at this point in the history
  • Loading branch information
0x00b1 committed Apr 23, 2024
1 parent eaca190 commit 1a3ee6e
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 4 deletions.
55 changes: 55 additions & 0 deletions tests/beignet/func/test__space.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import beignet
import torch.testing


def test_space():
PARTICLE_COUNT = 8
spatial_dimension = 2

input = torch.rand([PARTICLE_COUNT, spatial_dimension])

dR = beignet.func.space.map_product(beignet.func.space.pairwise_displacement)(
input, input
)

dR_wrapped = beignet.func.space.periodic_displacement(1.0, dR)

dR_direct = dR
dr_direct = beignet.func.space.distance(dR)
dr_direct = torch.reshape(dr_direct, dr_direct.shape + (1,))

if spatial_dimension == 2:
for i in range(-1, 2):
for j in range(-1, 2):
dR_shifted = dR + torch.tensor([i, j], dtype=input.dtype)

dr_shifted = beignet.func.space.distance(dR_shifted)
dr_shifted = torch.reshape(dr_shifted, dr_shifted.shape + (1,))

dR_direct = torch.where(dr_shifted < dr_direct, dR_shifted, dR_direct)
dr_direct = torch.where(dr_shifted < dr_direct, dr_shifted, dr_direct)
elif spatial_dimension == 3:
for i in range(-1, 2):
for j in range(-1, 2):
for k in range(-1, 2):
dR_shifted = dR + torch.tensor([i, j, k], dtype=input.dtype)

dr_shifted = beignet.func.space.distance(dR_shifted)
dr_shifted = torch.reshape(dr_shifted, dr_shifted.shape + (1,))

dR_direct = torch.where(
dr_shifted < dr_direct,
dR_shifted,
dR_direct,
)

dr_direct = torch.where(
dr_shifted < dr_direct,
dr_shifted,
dr_direct,
)

torch.testing.assert_close(
dR_wrapped,
dR_direct,
)
8 changes: 4 additions & 4 deletions tests/beignet/operators/test__transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def test_transform():
),
)

def f(r: Tensor) -> Tensor:
return torch.sum(r**2)
def f(input: Tensor) -> Tensor:
return torch.sum(input**2)

def g(t: Tensor, r: Tensor) -> Tensor:
return torch.sum(beignet.operators.transform(t, r) ** 2)
def g(transformation: Tensor, input: Tensor) -> Tensor:
return torch.sum(beignet.operators.transform(transformation, input) ** 2)

torch.testing.assert_allclose(
torch.func.grad(f)(beignet.operators.transform(transformation, input)),
Expand Down

0 comments on commit 1a3ee6e

Please sign in to comment.