Skip to content

Commit

Permalink
Cleanup + indices
Browse files Browse the repository at this point in the history
  • Loading branch information
ipcamit committed Jun 5, 2024
1 parent 289a2a9 commit 1f4e6eb
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 62 deletions.
22 changes: 15 additions & 7 deletions kliff/models/kim.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,16 +752,24 @@ def get_model_from_manifest(model_manifest: dict, param_manifest: dict = None):

# ensure model is installed
if model_type.lower() == "kim":
is_model_installed = install_kim_model(model_name, model_collection)
if not is_model_installed:
logger.error(
f"Mode: {model_name} neither installed nor available in the KIM API collections. Please check the model name and try again."
)
raise KIMModelError(f"Model {model_name} not found.")
is_model_installed = is_kim_model_installed(model_name)
if is_model_installed:
logger.info(f"Model {model_name} is already installed, continuing ...")
else:
logger.info(
f"Model {model_name} is present in {model_collection} collection."
f"Model {model_name} not installed on system, attempting to installing ..."
)
was_install_success = install_kim_model(model_name, model_collection)
if not was_install_success:
logger.error(
f"Model {model_name} not found in the KIM API collections. Please check the model name and try again."
)
raise KIMModelError(f"Model {model_name} not found.")
else:
logger.info(
f"Model {model_name} installed in {model_collection} collection."
)

elif model_type.lower() == "tar":
archive_content = tarfile.open(model_path + "/" + model_name)
model = archive_content.getnames()[0]
Expand Down
79 changes: 29 additions & 50 deletions kliff/trainer/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,13 @@ def __init__(self, training_manifest: dict, model=None):
"name": "kliff_workspace",
"seed": 12345,
"resume": False,
"walltime": "2:00:00:00",
}

# dataset variables
self.dataset_manifest: dict = {
"type": "kliff",
"path": "./",
"save": False,
"shuffle": False,
"keys": {"energy": "energy", "forces": "forces"},
"dynamic_loading": False,
"colabfit_dataset": {
Expand Down Expand Up @@ -144,13 +142,12 @@ def __init__(self, training_manifest: dict, model=None):
"run_hash": None,
"start_time": None,
"end_time": None,
"best_loss": None,
"best_loss": np.inf,
"best_model": None,
"loss": None,
"epoch": 0,
"step": 0,
"device": "cpu",
"expected_end_time": None,
"warned_once": False,
"dataset_hash": None,
"data_dir": None,
Expand Down Expand Up @@ -188,19 +185,6 @@ def parse_manifest(self, manifest: dict):
else:
self.workspace |= workspace_block

if isinstance(self.workspace["walltime"], int):
expected_end_time = datetime.now() + timedelta(
seconds=self.workspace["walltime"]
)
else:
expected_end_time = datetime.now() + timedelta(
days=int(self.workspace["walltime"].split(":")[0]),
hours=int(self.workspace["walltime"].split(":")[1]),
minutes=int(self.workspace["walltime"].split(":")[2]),
seconds=int(self.workspace["walltime"].split(":")[3]),
)
self.current["expected_end_time"] = expected_end_time

# Dataset manifest #################################################
dataset_manifest: Union[None, dict] = manifest.get("dataset", None)
if dataset_manifest is None:
Expand Down Expand Up @@ -283,12 +267,13 @@ def config_to_dict(self):
"""
Convert the configuration to a dictionary.
"""
config = {}
config |= self.workspace
config |= self.dataset_manifest
config |= self.model_manifest
config |= self.transform_manifest
config |= self.training_manifest
config = {
"workspace": self.workspace,
"dataset": self.dataset_manifest,
"model": self.model_manifest,
"transforms": self.transform_manifest,
"training": self.training_manifest,
}
return config

@classmethod
Expand Down Expand Up @@ -335,7 +320,7 @@ def initialize(self):
# Step 4.5 - Set up the dataset transforms
self.setup_dataset_transforms()
# Step 5 - Set up the test and train datasets, based on the provided indices
self.setup_test_train_datasets()
self.setup_dataset_split()
logger.info(f"Train and validation datasets set up.")
# Step 6 - Set up the model
self.setup_model()
Expand Down Expand Up @@ -527,7 +512,7 @@ def setup_optimizer(self):
"""
raise TrainerError("setup_optimizer not implemented.")

def setup_test_train_datasets(self):
def setup_dataset_split(self):
"""
Simple test train split for now, will have more options like stratification
in the future.
Expand Down Expand Up @@ -565,41 +550,35 @@ def setup_test_train_datasets(self):

# check if indices are provided
train_indices = self.dataset_sample_manifest.get("train_indices")
if train_indices is None:
train_indices = np.arange(train_size, dtype=int)
elif isinstance(train_indices, str):
val_indices = None
if isinstance(train_indices, str):
train_indices = np.genfromtxt(train_indices, dtype=int)
if val_size > 0:
val_indices = np.genfromtxt(
self.dataset_sample_manifest.get("val_indices"), dtype=int
)
else:
TrainerError("train_indices should be a numpy array or a path to a file.")

val_indices = self.dataset_sample_manifest.get("val_indices")
if val_indices is None:
val_indices = np.arange(train_size, train_size + val_size, dtype=int)
elif isinstance(val_indices, str):
val_indices = np.genfromtxt(val_indices, dtype=int)
else:
TrainerError("val_indices should be a numpy array or a path to a file.")

if self.dataset_manifest.get("shuffle", False):
# instead of shuffling the main dataset, validation/train indices are shuffled
# this gives better control over future active learning scenarios
np.random.shuffle(train_indices)
np.random.shuffle(val_indices)
TrainerError(f"Could not load indices from {train_indices}.")

train_dataset = self.dataset[train_indices]

if val_size > 0:
val_dataset = self.dataset[val_indices]
else:
val_dataset = None
if train_indices is None:
indices = np.random.permutation(train_size + val_size)
train_indices = indices[:train_size]
if val_size > 0:
val_indices = indices[-val_size:]

self.dataset_sample_manifest["train_size"] = train_size
self.dataset_sample_manifest["val_size"] = val_size
self.dataset_sample_manifest["train_indices"] = train_indices
self.dataset_sample_manifest["val_indices"] = val_indices

train_dataset = self.dataset[train_indices]
self.train_dataset = train_dataset
self.val_dataset = val_dataset

if val_size > 0:
val_dataset = self.dataset[val_indices]
self.val_dataset = val_dataset
else:
self.val_dataset = None

# save the indices if generated
if isinstance(train_indices, str):
Expand Down
4 changes: 4 additions & 0 deletions kliff/trainer/kim_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,10 @@ def _wrapper_func(x):
def _get_loss_fn(self):
if self.loss_manifest["function"].lower() == "mse":
return MSE_residuals
else:
raise TrainerError(
f"Loss function {self.loss_manifest['function']} not supported."
)

def save_kim_model(self):
if self.export_manifest["model_type"].lower() == "kim":
Expand Down
20 changes: 15 additions & 5 deletions kliff/trainer/torch_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,20 @@ def checkpoint(self):
},
f"{self.current['run_dir']}/checkpoint_{self.current['step']}.pkl",
)
# self.model.save(f"{self.current['run_dir']}/model_{self.current['step']}.pt")

# save best and last model
if self.current["loss"]["val"] < self.current["best_loss"]:
self.current["best_loss"] = self.current["loss"]["val"]
torch.save(
self.model.state_dict(),
f"{self.current['run_dir']}/best_model.pth",
)

torch.save(
self.model.state_dict(),
f"{self.current['run_dir']}/last_model.pth",
)

with open(f"{self.current['run_dir']}/log.txt", "a") as f:
f.write(
f"Step: {self.current['step']}, Train Loss: {self.current['loss']['train']}, Val Loss: {self.current['loss']['val']}\n"
Expand All @@ -84,10 +97,7 @@ def get_optimizer(self):
# TODO: Scheduler and ema

def _get_loss_function(self):
if (
self.loss_manifest["function"].lower() == "mseloss"
or self.loss_manifest["function"].lower() == "mse"
):
if self.loss_manifest["function"].lower() == "mse":
return torch.nn.MSELoss()
else:
raise TrainerError(
Expand Down

0 comments on commit 1f4e6eb

Please sign in to comment.