Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transfer Learning fix #200

Merged
merged 8 commits into from
Nov 17, 2023
11 changes: 3 additions & 8 deletions apax/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from apax.model import ModelBuilder
from apax.optimizer import get_opt
from apax.train.callbacks import initialize_callbacks
from apax.train.checkpoints import create_params, create_train_state, load_params
from apax.train.checkpoints import create_params, create_train_state
from apax.train.loss import Loss, LossCollection
from apax.train.metrics import initialize_metrics
from apax.train.trainer import fit
from apax.transfer_learning import param_transfer
from apax.transfer_learning import transfer_parameters
from apax.utils.random import seed_py_np_tf

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -101,12 +101,7 @@ def run(user_config, log_level="error"):
base_checkpoint = config.checkpoints.base_model_checkpoint
do_transfer_learning = base_checkpoint is not None
if do_transfer_learning:
source_params = load_params(base_checkpoint)
log.info("Transferring parameters from %s", base_checkpoint)
params = param_transfer(
source_params, state.params, config.checkpoints.reset_layers
)
state.replace(params=params)
state = transfer_parameters(state, config.checkpoints)

fit(
state,
Expand Down
7 changes: 5 additions & 2 deletions apax/transfer_learning/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from apax.transfer_learning.parameter_transfer import param_transfer
from apax.transfer_learning.parameter_transfer import (
black_list_param_transfer,
transfer_parameters,
)

__all__ = ["param_transfer"]
__all__ = ["transfer_parameters", "black_list_param_transfer"]
16 changes: 14 additions & 2 deletions apax/transfer_learning/parameter_transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from flax.core.frozen_dict import freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict

from apax.train.checkpoints import load_params

log = logging.getLogger(__name__)


def param_transfer(source_params, target_params, param_black_list):
def black_list_param_transfer(source_params, target_params, param_black_list):
source_params = unfreeze(source_params)
target_params = unfreeze(target_params)

Expand All @@ -15,8 +17,18 @@ def param_transfer(source_params, target_params, param_black_list):
for p, v in flat_source.items():
if p[-2] not in param_black_list:
flat_target[p] = v
log.info("Transferring parameter: %s", p)
log.info("Transferring parameter: %s", p[-2])

transfered_target = unflatten_dict(flat_target)
transfered_target = freeze(transfered_target)
return transfered_target


def transfer_parameters(state, ckpt_config):
source_params = load_params(ckpt_config.base_model_checkpoint)
log.info("Transferring parameters from %s", ckpt_config.base_model_checkpoint)
params = black_list_param_transfer(
source_params, state.params, ckpt_config.reset_layers
)
state = state.replace(params=params)
return state
47 changes: 42 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
import pytest
import yaml
from ase import Atoms
from ase.calculators.emt import EMT
from ase.calculators.singlepoint import SinglePointCalculator

from apax.config.train_config import Config
from apax.model.builder import ModelBuilder
from apax.train.run import run
from apax.utils.random import seed_py_np_tf


Expand Down Expand Up @@ -77,23 +79,49 @@ def example_atoms(num_data: int, pbc: bool, calc_results: List[str]) -> Atoms:
return atoms_list


@pytest.fixture()
def example_dataset(num_data: int) -> Atoms:
atoms_list = []

p2 = np.random.uniform(low=1.0, high=1.5, size=(num_data,))
for i in range(num_data):
positions = np.array([[0.0, 0.0, 0.0], [0.0, 0.0, p2[i]]])

additional_data = {}
additional_data["cell"] = [0.0, 0.0, 0.0]

atoms = Atoms("H2", positions=positions, **additional_data)
atoms.calc = EMT()
atoms.get_potential_energy()
atoms.get_forces()

atoms_list.append(atoms)

return atoms_list


@pytest.fixture()
def get_tmp_path(tmp_path_factory):
test_path = tmp_path_factory.mktemp("apax_tests")
return test_path


@pytest.fixture(scope="session")
def get_md22_stachyose(get_tmp_path):
def tmp_data_path(tmp_path_factory):
test_path = tmp_path_factory.mktemp("data")
return test_path


@pytest.fixture(scope="session")
def get_md22_stachyose(tmp_data_path):
url = "http://www.quantum-machine.org/gdml/repo/static/md22_stachyose.zip"
data_path = get_tmp_path / "data"
file_path = data_path / "md22_stachyose.zip"
file_path = tmp_data_path / "md22_stachyose.zip"

os.makedirs(data_path, exist_ok=True)
os.makedirs(tmp_data_path, exist_ok=True)
urllib.request.urlretrieve(url, file_path)

with zipfile.ZipFile(file_path, "r") as zip_ref:
zip_ref.extractall(data_path)
zip_ref.extractall(tmp_data_path)
PythonFZ marked this conversation as resolved.
Show resolved Hide resolved

file_path = modify_xyz_file(
file_path.with_suffix(".xyz"), target_string="Energy", replacement_string="energy"
Expand Down Expand Up @@ -150,3 +178,12 @@ def load_and_dump_config(config_path, dump_path):
os.makedirs(model_config.data.model_version_path, exist_ok=True)
model_config.dump_config(model_config.data.model_version_path)
return model_config


def load_config_and_run_training(config_path, updated_config):
with open(config_path.as_posix(), "r") as stream:
config_dict = yaml.safe_load(stream)

for key, new_value in updated_config.items():
config_dict[key].update(new_value)
run(config_dict)
Empty file.
29 changes: 29 additions & 0 deletions tests/integration_tests/transfer_learning/config_base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
n_epochs: 20
seed: 1

data:
n_train: 16
n_valid: 10
batch_size: 4
valid_batch_size: 10

model:
nn: [32,32]
n_basis: 5
n_radial: 3
calc_stress: false
b_init: normal

metrics:
- name: energy
reductions: [mae]

loss:
- name: energy
- name: forces

optimizer:
emb_lr: 0.001
nn_lr: 0.001
scale_lr: 0.0001
shift_lr: 0.001
32 changes: 32 additions & 0 deletions tests/integration_tests/transfer_learning/config_ft.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
n_epochs: 2
seed: 2

data:
n_train: 16
n_valid: 10
batch_size: 4
valid_batch_size: 10

model:
nn: [32,32]
n_basis: 5
n_radial: 3
calc_stress: false
b_init: zeros

metrics:
- name: energy
reductions: [mae]

loss:
- name: energy
- name: forces

optimizer:
emb_lr: 0.0001
nn_lr: 0.0001
scale_lr: 0.00001
shift_lr: 0.001

checkpoints:
base_model_checkpoint: null
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import pathlib
import uuid

import jax
import numpy as np
import pytest
from ase.io import write

from apax.train.checkpoints import restore_parameters
from tests.conftest import load_config_and_run_training

TEST_PATH = pathlib.Path(__file__).parent.resolve()


def l2_param_diff(p1, p2):
p1, _ = jax.tree_flatten(p1)
p2, _ = jax.tree_flatten(p2)
diff = 0.0
for i in range(len(p1)):
diff += np.sum((p1[i] - p2[i]) ** 2)
return diff


@pytest.mark.parametrize("num_data", (30,))
def test_transfer_learning(get_tmp_path, example_dataset):
config_path = TEST_PATH / "config_base.yaml"
config_ft_path = TEST_PATH / "config_ft.yaml"
working_dir = get_tmp_path / str(uuid.uuid4())
data_path = get_tmp_path / "ds.extxyz"

write(data_path, example_dataset)

data_config_mods = {
"data": {
"directory": working_dir.as_posix(),
"experiment": "base",
"data_path": data_path.as_posix(),
},
}
load_config_and_run_training(config_path, data_config_mods)

data_config_mods = {
"data": {
"directory": working_dir.as_posix(),
"experiment": "fine_tune",
"data_path": data_path.as_posix(),
},
"checkpoints": {"base_model_checkpoint": (working_dir / "base").as_posix()},
}
load_config_and_run_training(config_ft_path, data_config_mods)

data_config_mods = {
"data": {
"directory": working_dir.as_posix(),
"experiment": "fine_tune_no_pre_training",
"data_path": data_path.as_posix(),
},
}
load_config_and_run_training(config_ft_path, data_config_mods)

# Compare parameters
_, base_params = restore_parameters(working_dir / "base")
_, ft_params = restore_parameters(working_dir / "fine_tune")
_, ft_no_pre_params = restore_parameters(working_dir / "fine_tune_no_pre_training")

diff_base_ft = l2_param_diff(base_params, ft_params)
diff_base_no_pre = l2_param_diff(base_params, ft_no_pre_params)

assert diff_base_ft < diff_base_no_pre
24 changes: 7 additions & 17 deletions tests/regression_tests/test_apax_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,12 @@

import numpy as np
import pytest
import yaml

from apax.train.run import run
from tests.conftest import load_config_and_run_training

TEST_PATH = pathlib.Path(__file__).parent.resolve()


def load_config_and_run_training(config_path, **config_kwargs):
with open(config_path.as_posix(), "r") as stream:
config_dict = yaml.safe_load(stream)

for pydentic_model_key, config_mods in config_kwargs.items():
for h_param_key, value in config_mods.items():
config_dict[pydentic_model_key][h_param_key] = value

run(config_dict)


def load_csv(filename):
data = np.loadtxt(filename, delimiter=",", skiprows=1) # Skip the header row

Expand All @@ -39,12 +27,14 @@ def test_regression_model_training(get_md22_stachyose, get_tmp_path):
file_path = get_md22_stachyose

data_config_mods = {
"directory": working_dir.as_posix(),
"data_path": file_path.as_posix(),
"energy_unit": "kcal/mol",
"data": {
"directory": working_dir.as_posix(),
"data_path": file_path.as_posix(),
"energy_unit": "kcal/mol",
}
}

load_config_and_run_training(config_path, data=data_config_mods)
load_config_and_run_training(config_path, data_config_mods)

current_metrics = load_csv(working_dir / "test/log.csv")

Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/transfer_learning/test_parameter_transfer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from apax.transfer_learning import param_transfer
from apax.transfer_learning import black_list_param_transfer


def test_param_transfer():
Expand All @@ -15,7 +15,7 @@ def test_param_transfer():
}
}
reinitialize_layers = ["basis"]
transfered_target = param_transfer(source, target, reinitialize_layers)
transfered_target = black_list_param_transfer(source, target, reinitialize_layers)

assert transfered_target["params"]["dense"]["w"] == source["params"]["dense"]["w"]
assert transfered_target["params"]["dense"]["b"] == source["params"]["dense"]["b"]
Expand Down