diff --git a/apax/nn/builder.py b/apax/nn/builder.py index c9cf7c13..3d5e21f0 100644 --- a/apax/nn/builder.py +++ b/apax/nn/builder.py @@ -95,7 +95,7 @@ def build_readout(self, head_config, is_feature_fn=False, only_use_n_layers=None dtype = head_config["dtype"] else: raise KeyError("No dtype specified in config") - + nn_layers = head_config["nn"] if only_use_n_layers is not None: nn_layers = nn_layers[:only_use_n_layers] diff --git a/tests/conftest.py b/tests/conftest.py index 22d6ddb6..79cf0821 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -79,7 +79,7 @@ def example_atoms_list(num_data: int, pbc: bool, calc_results: List[str]) -> Ato for _ in range(num_data): num_atoms = np.random.randint(10, 15) atoms = create_example_atoms(num_atoms, pbc, calc_results) - + atoms_list.append(atoms) return atoms_list diff --git a/tests/unit_tests/nn/test_builder.py b/tests/unit_tests/nn/test_builder.py index 2c15fc86..aaf058fa 100644 --- a/tests/unit_tests/nn/test_builder.py +++ b/tests/unit_tests/nn/test_builder.py @@ -18,7 +18,7 @@ def test_builder_feature_model(): model = builder.build_energy_model() key = jax.random.PRNGKey(0) params = model.init(key, R, Z, idx, box, offsets) - + model = builder.build_feature_model(only_use_n_layers=0, init_box=box) out = model.apply(params, R, Z, idx, box, offsets) assert out.shape == (360,)