Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added option to rebuild model up to specific readout layer #396

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
17 changes: 13 additions & 4 deletions apax/nn/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def build_descriptor(
):
raise NotImplementedError("use a subclass to facilitate this")

def build_readout(self, head_config, is_feature_fn=False):
def build_readout(self, head_config, is_feature_fn=False, only_use_n_layers=None):
has_ensemble = "ensemble" in head_config.keys() and head_config["ensemble"]
if has_ensemble and head_config["ensemble"]["kind"] == "shallow":
n_shallow_ensemble = head_config["ensemble"]["n_members"]
Expand All @@ -96,8 +96,14 @@ def build_readout(self, head_config, is_feature_fn=False):
else:
raise KeyError("No dtype specified in config")

nn_layers = head_config["nn"]
if only_use_n_layers is not None:
nn_layers = nn_layers[:only_use_n_layers]
if len(nn_layers) == 0:
return None

readout = AtomisticReadout(
units=head_config["nn"],
units=nn_layers,
b_init=head_config["b_init"],
w_init=head_config["w_init"],
use_ntk=head_config["use_ntk"],
Expand Down Expand Up @@ -207,15 +213,18 @@ def build_energy_derivative_model(
)
return model

def build_ll_feature_model(
def build_feature_model(
self,
only_use_n_layers=None,
apply_mask=True,
init_box: np.array = np.array([0.0, 0.0, 0.0]),
inference_disp_fn=None,
):
log.info("Building feature model")
descriptor = self.build_descriptor(apply_mask)
readout = self.build_readout(is_feature_fn=True)
readout = self.build_readout(
self.config, is_feature_fn=True, only_use_n_layers=only_use_n_layers
)

model = FeatureModel(
descriptor,
Expand Down
66 changes: 36 additions & 30 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,42 +37,48 @@ def create_cell(a: float, lattice: str) -> np.ndarray:
return cells[lattice]


def create_example_atoms(num_atoms: int, pbc: bool = False, calc_results=[]):
numbers = np.random.randint(1, 119, size=num_atoms)
cell_const = np.random.uniform(low=10.0, high=12.0)
positions = np.random.uniform(low=0.0, high=cell_const, size=(num_atoms, 3))

additional_data = {}
additional_data["pbc"] = pbc
# lattice = random.choice(["free", "sc", "fcc", "bcc"])
# at the moment we can only work with cubic cells
lattice = "sc"
if pbc:
additional_data["cell"] = create_cell(cell_const, lattice)
else:
additional_data["cell"] = [0.0, 0.0, 0.0]

result_shapes = {
"energy": (np.random.rand() - 5.0) * 10_000,
"forces": np.random.uniform(low=-1.0, high=1.0, size=(num_atoms, 3)),
# "stress": np.random.uniform(low=-1.0, high=1.0, size=(3, 3)),
# "dipole": np.random.randn(3),
# "charge": np.random.randint(-3, 4),
# "ma_tensors": np.random.uniform(low=-1.0, high=1.0, size=(3, 3)),
}

atoms = Atoms(numbers=numbers, positions=positions, **additional_data)
if calc_results:
results = {}
for key in calc_results:
results[key] = result_shapes[key]

atoms.calc = SinglePointCalculator(atoms, **results)
return atoms


@pytest.fixture()
def example_atoms(num_data: int, pbc: bool, calc_results: List[str]) -> Atoms:
def example_atoms_list(num_data: int, pbc: bool, calc_results: List[str]) -> Atoms:
atoms_list = []

for _ in range(num_data):
num_atoms = np.random.randint(10, 15)
numbers = np.random.randint(1, 119, size=num_atoms)
cell_const = np.random.uniform(low=10.0, high=12.0)
positions = np.random.uniform(low=0.0, high=cell_const, size=(num_atoms, 3))
atoms = create_example_atoms(num_atoms, pbc, calc_results)

additional_data = {}
additional_data["pbc"] = pbc
# lattice = random.choice(["free", "sc", "fcc", "bcc"])
# at the moment we can only work with cubic cells
lattice = "sc"
if pbc:
additional_data["cell"] = create_cell(cell_const, lattice)
else:
additional_data["cell"] = [0.0, 0.0, 0.0]

result_shapes = {
"energy": (np.random.rand() - 5.0) * 10_000,
"forces": np.random.uniform(low=-1.0, high=1.0, size=(num_atoms, 3)),
# "stress": np.random.uniform(low=-1.0, high=1.0, size=(3, 3)),
# "dipole": np.random.randn(3),
# "charge": np.random.randint(-3, 4),
# "ma_tensors": np.random.uniform(low=-1.0, high=1.0, size=(3, 3)),
}

atoms = Atoms(numbers=numbers, positions=positions, **additional_data)
if calc_results:
results = {}
for key in calc_results:
results[key] = result_shapes[key]

atoms.calc = SinglePointCalculator(atoms, **results)
atoms_list.append(atoms)

return atoms_list
Expand Down
8 changes: 4 additions & 4 deletions tests/integration_tests/bal/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
([20, False, ["energy", "forces"]],),
)
def test_kernel_selection(
config, features, example_atoms, get_tmp_path, get_sample_input
config, features, example_atoms_list, get_tmp_path, get_sample_input
):
model_config_path = TEST_PATH / config # "config.yaml"

Expand All @@ -54,10 +54,10 @@ def test_kernel_selection(
overwrite=True,
)

num_data = len(example_atoms)
num_data = len(example_atoms_list)
n_train = num_data // 2
train_atoms = example_atoms[:n_train]
pool_atoms = example_atoms[n_train:]
train_atoms = example_atoms_list[:n_train]
pool_atoms = example_atoms_list[n_train:]

base_fm_options = features
selection_method = "max_dist"
Expand Down
32 changes: 16 additions & 16 deletions tests/unit_tests/data/test_input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,22 @@
[10, True, ["energy", "forces"]],
),
)
def test_split_data(example_atoms):
def test_split_data(example_atoms_list):
seed_py_np_tf(1)
train_idxs1, val_idxs1 = split_idxs(example_atoms, 4, 2)
train_idxs2, val_idxs2 = split_idxs(example_atoms, 4, 2)
train_idxs1, val_idxs1 = split_idxs(example_atoms_list, 4, 2)
train_idxs2, val_idxs2 = split_idxs(example_atoms_list, 4, 2)
assert np.all(train_idxs1 != train_idxs2) and np.all(val_idxs1 != val_idxs2)

train_atoms1, val_atoms1 = split_atoms(example_atoms, train_idxs1, val_idxs1)
train_atoms2, val_atoms2 = split_atoms(example_atoms, train_idxs2, val_idxs2)
train_atoms1, val_atoms1 = split_atoms(example_atoms_list, train_idxs1, val_idxs1)
train_atoms2, val_atoms2 = split_atoms(example_atoms_list, train_idxs2, val_idxs2)
assert np.all(train_atoms1[0].get_positions() != train_atoms2[0].get_positions())
assert np.all(val_atoms1[0].get_positions() != val_atoms2[0].get_positions())

seed_py_np_tf(1)
train_idxs2, val_idxs2 = split_idxs(example_atoms, 4, 2)
train_idxs2, val_idxs2 = split_idxs(example_atoms_list, 4, 2)
assert np.all(train_idxs1 == train_idxs2) and np.all(val_idxs1 == val_idxs2)

train_atoms2, val_atoms2 = split_atoms(example_atoms, train_idxs2, val_idxs2)
train_atoms2, val_atoms2 = split_atoms(example_atoms_list, train_idxs2, val_idxs2)
assert np.all(train_atoms1[0].get_positions() == train_atoms2[0].get_positions())
assert np.all(val_atoms1[0].get_positions() == val_atoms2[0].get_positions())

Expand All @@ -120,29 +120,29 @@ def test_split_data(example_atoms):
[5, True, ["energy", "forces"]],
),
)
def test_convert_atoms_to_arrays(example_atoms, pbc):
inputs = atoms_to_inputs(example_atoms)
labels = atoms_to_labels(example_atoms)
def test_convert_atoms_to_arrays(example_atoms_list, pbc):
inputs = atoms_to_inputs(example_atoms_list)
labels = atoms_to_labels(example_atoms_list)

assert "positions" in inputs
assert len(inputs["positions"]) == len(example_atoms)
assert len(inputs["positions"]) == len(example_atoms_list)

assert "numbers" in inputs
assert len(inputs["numbers"]) == len(example_atoms)
assert len(inputs["numbers"]) == len(example_atoms_list)

assert "box" in inputs
assert len(inputs["box"]) == len(example_atoms)
assert len(inputs["box"]) == len(example_atoms_list)
if not pbc:
assert np.all(inputs["box"][0] < 1e-6)

assert "n_atoms" in inputs
assert len(inputs["n_atoms"]) == len(example_atoms)
assert len(inputs["n_atoms"]) == len(example_atoms_list)

assert "energy" in labels
assert len(labels["energy"]) == len(example_atoms)
assert len(labels["energy"]) == len(example_atoms_list)

assert "forces" in labels
assert len(labels["forces"]) == len(example_atoms)
assert len(labels["forces"]) == len(example_atoms_list)


@pytest.mark.parametrize(
Expand Down
37 changes: 37 additions & 0 deletions tests/unit_tests/nn/test_builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import pathlib

import jax

from apax.config.common import parse_config
from apax.utils.data import make_minimal_input

TEST_PATH = pathlib.Path(__file__).parent.resolve()


def test_builder_feature_model():
R, Z, idx, box, offsets = make_minimal_input()

config = parse_config(TEST_PATH / "train.yaml")

Builder = config.model.get_builder()
builder = Builder(config.model.model_dump())

model = builder.build_energy_model()
key = jax.random.PRNGKey(0)
params = model.init(key, R, Z, idx, box, offsets)

model = builder.build_feature_model(only_use_n_layers=0, init_box=box)
out = model.apply(params, R, Z, idx, box, offsets)
assert out.shape == (360,)

model = builder.build_feature_model(only_use_n_layers=1, init_box=box)
out = model.apply(params, R, Z, idx, box, offsets)
assert out.shape == (128,)

model = builder.build_feature_model(only_use_n_layers=2, init_box=box)
out = model.apply(params, R, Z, idx, box, offsets)
assert out.shape == (64,)

model = builder.build_feature_model(only_use_n_layers=3, init_box=box)
out = model.apply(params, R, Z, idx, box, offsets)
assert out.shape == (32,)
16 changes: 16 additions & 0 deletions tests/unit_tests/nn/train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
n_epochs: 1000

data:
directory: test
experiment: feature_model
data_path: example.h5

loss:
- name: energy

model:
name: gmnn
nn:
- 128
- 64
- 32
Loading