From 153af166fada1f9b1892d95a9264ec43086be927 Mon Sep 17 00:00:00 2001 From: frostedoyster Date: Thu, 8 Feb 2024 06:53:22 +0100 Subject: [PATCH] Remove static functions --- examples/power_spectrum.py | 5 ++++- torch_spex/radial_basis.py | 5 ++++- torch_spex/spherical_expansions.py | 10 ++++++++-- 3 files changed, 16 insertions(+), 4 deletions(-) diff --git a/examples/power_spectrum.py b/examples/power_spectrum.py index c7a185b..2deb87c 100644 --- a/examples/power_spectrum.py +++ b/examples/power_spectrum.py @@ -35,7 +35,10 @@ def forward(self, spex: TensorMap): values=ps_values_ai, samples=spex.block({"lam": 0, "a_i": a_i}).samples, components=[], - properties=Labels.range("property", ps_values_ai.shape[-1]) + properties=Labels( + "property", + torch.range(ps_values_ai.shape[-1], device=ps_values_ai.device).reshape(-1, 1) + ) ) keys.append([a_i]) blocks.append(block) diff --git a/torch_spex/radial_basis.py b/torch_spex/radial_basis.py index 09ca8bc..e3cebf8 100644 --- a/torch_spex/radial_basis.py +++ b/torch_spex/radial_basis.py @@ -56,7 +56,10 @@ def __init__(self, hypers, all_species) -> None: self.is_alchemical = False self.n_pseudo_species = 0 # dummy for torchscript self.combination_matrix = torch.nn.Linear(1, 1) # dummy for torchscript - self.species_neighbor_labels = Labels.empty("dummy") + self.species_neighbor_labels = Labels( + names=["dummy"], + values=torch.empty((0, 1), dtype=torch.int) + ) self.apply_mlp = False if hypers["mlp"]: diff --git a/torch_spex/spherical_expansions.py b/torch_spex/spherical_expansions.py index cafa789..82d7378 100644 --- a/torch_spex/spherical_expansions.py +++ b/torch_spex/spherical_expansions.py @@ -356,7 +356,10 @@ def forward(self, ) ) else: - properties = Labels.range("n", n_max_l) + properties = Labels( + names=["n"], + values = torch.range(n_max_l, device=vector_expansion_l.device).reshape(n_max_l, 1) + ) vector_expansion_blocks.append( TensorBlock( values = vector_expansion_l.reshape(vector_expansion_l.shape[0], 2*l+1, -1), @@ -420,7 +423,10 @@ def get_cartesian_vectors(positions, cells, species, cell_shifts, centers, pairs values = torch.tensor([-1, 0, 1], dtype=torch.int32, device=direction_vectors.device).reshape((-1, 1)) ) ], - properties = Labels.single().to(direction_vectors.device) + properties = Labels( + names=["_"], + values=torch.zeros((1, 1), device=direction_vectors.device) + ) ) return block