Skip to content

Commit

Permalink
Remove static functions
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Feb 8, 2024
1 parent b4c166b commit 153af16
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 4 deletions.
5 changes: 4 additions & 1 deletion examples/power_spectrum.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,10 @@ def forward(self, spex: TensorMap):
values=ps_values_ai,
samples=spex.block({"lam": 0, "a_i": a_i}).samples,
components=[],
properties=Labels.range("property", ps_values_ai.shape[-1])
properties=Labels(
"property",
torch.range(ps_values_ai.shape[-1], device=ps_values_ai.device).reshape(-1, 1)
)
)
keys.append([a_i])
blocks.append(block)
Expand Down
5 changes: 4 additions & 1 deletion torch_spex/radial_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,10 @@ def __init__(self, hypers, all_species) -> None:
self.is_alchemical = False
self.n_pseudo_species = 0 # dummy for torchscript
self.combination_matrix = torch.nn.Linear(1, 1) # dummy for torchscript
self.species_neighbor_labels = Labels.empty("dummy")
self.species_neighbor_labels = Labels(
names=["dummy"],
values=torch.empty((0, 1), dtype=torch.int)
)

self.apply_mlp = False
if hypers["mlp"]:
Expand Down
10 changes: 8 additions & 2 deletions torch_spex/spherical_expansions.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,10 @@ def forward(self,
)
)
else:
properties = Labels.range("n", n_max_l)
properties = Labels(
names=["n"],
values = torch.range(n_max_l, device=vector_expansion_l.device).reshape(n_max_l, 1)
)
vector_expansion_blocks.append(
TensorBlock(
values = vector_expansion_l.reshape(vector_expansion_l.shape[0], 2*l+1, -1),
Expand Down Expand Up @@ -420,7 +423,10 @@ def get_cartesian_vectors(positions, cells, species, cell_shifts, centers, pairs
values = torch.tensor([-1, 0, 1], dtype=torch.int32, device=direction_vectors.device).reshape((-1, 1))
)
],
properties = Labels.single().to(direction_vectors.device)
properties = Labels(
names=["_"],
values=torch.zeros((1, 1), device=direction_vectors.device)
)
)

return block

0 comments on commit 153af16

Please sign in to comment.