Refactor examples #59
Annotations
1 error
/home/runner/work/mlcg/mlcg/mlcg/nn/test_utils.py#L34
ang_prior = HarmonicAngles(angle_dict)
dih_prior = Dihedral(dihedral_dict, n_degs=3)
-torch_assert_equal = functools.partial(torch.testing.assert_close, rtol=0, atol=0)
+torch_assert_equal = functools.partial(
+ torch.testing.assert_close, rtol=0, atol=0
+)
+
@pytest.mark.parametrize("prior_module", [ang_prior, dih_prior])
def test_prior_sparsification(prior_module: torch.nn.Module) -> None:
original_prior = deepcopy(prior_module)
sparsify_prior_module(prior_module)
|