Skip to content

Commit

Permalink
device type
Browse files Browse the repository at this point in the history
  • Loading branch information
Henry Isaacson committed Jul 30, 2024
1 parent 74f8f82 commit 81277a5
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 9 deletions.
3 changes: 2 additions & 1 deletion src/beignet/func/_interact.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
safe_index,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class _ParameterTreeKind(Enum):
BOND = 0
Expand Down Expand Up @@ -120,7 +121,7 @@ def compute_force(R, *args, **kwargs):
R = R.requires_grad_(True)
energy = energy_fn(R, *args, **kwargs)
force = -torch.autograd.grad(
energy, R, grad_outputs=torch.ones_like(energy), create_graph=True
energy, R, grad_outputs=torch.ones_like(energy).to(device=device), create_graph=True
)[0]

return force
Expand Down
12 changes: 6 additions & 6 deletions src/beignet/func/_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1060,7 +1060,7 @@ def _particles_per_cell(

particle_hash = torch.sum(particle_index * hash_multipliers, dim=1)

filling = _segment_sum(torch.ones_like(particle_hash), particle_hash, n)
filling = _segment_sum(torch.ones_like(particle_hash).to(device=device), particle_hash, n)

return filling

Expand Down Expand Up @@ -1233,7 +1233,7 @@ def fn(
)

exceeded_maximum_size = exceeded_maximum_size | (
torch.max(_segment_sum(torch.ones_like(hashes), hashes, unit_count))
torch.max(_segment_sum(torch.ones_like(hashes).to(device=device), hashes, unit_count))
> buffer_size
)

Expand Down Expand Up @@ -1434,7 +1434,7 @@ def prune_dense_neighbor_list(
mask = (displacements < squared_cutoff) & (indexes < positions.shape[0])

output_indexes = positions.shape[0] * torch.ones(
indexes.shape, dtype=torch.int32
indexes.shape, dtype=torch.int32, device=device
)

cumsum = torch.cumsum(mask, dim=1)
Expand Down Expand Up @@ -1471,10 +1471,10 @@ def prune_sparse_neighbor_list(
if neighbor_list_format is _NeighborListFormat.ORDERED_SPARSE:
mask = mask & (receiver_idx < sender_idx)

out_idx = position.shape[0] * torch.ones(receiver_idx.shape, dtype=torch.int32)
out_idx = position.shape[0] * torch.ones(receiver_idx.shape, dtype=torch.int32, device=device)

cumsum = torch.cumsum(torch.flatten(mask), dim=0)
index = torch.where(mask, cumsum - 1, len(receiver_idx) - 1)
cumsum = torch.cumsum(torch.flatten(mask), dim=0).to(device=device)
index = torch.where(mask, cumsum - 1, len(receiver_idx) - 1).to(device=device)

index = index.to(torch.int64)
sender_idx = sender_idx.to(torch.int32)
Expand Down
6 changes: 4 additions & 2 deletions src/beignet/func/_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def setup_fn(

momentums = torch.zeros(size, dtype=kinetic_energies.dtype)

masses = torch.ones(size, dtype=torch.float32)
masses = torch.ones(size, dtype=torch.float32, device=device)

masses = temperature * oscillation**2.0 * masses

Expand Down Expand Up @@ -519,7 +519,7 @@ def update_mass_fn(
positions,
) = dataclasses.astuple(state)

masses = torch.ones(size, dtype=torch.float32)
masses = torch.ones(size, dtype=torch.float32, device=device)

masses = temperature * oscillations**2 * masses

Expand Down Expand Up @@ -608,6 +608,8 @@ def setup_fn(
masses: Tensor | None = None,
**kwargs,
):
positions = positions.to(device=device)

if not masses:
masses = 1.0

Expand Down

0 comments on commit 81277a5

Please sign in to comment.