Skip to content

Commit

Permalink
enabled passing atoms_list to run
Browse files Browse the repository at this point in the history
  • Loading branch information
Tetracarbonylnickel committed Oct 24, 2023
1 parent f2f9e1c commit 99faf18
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
2 changes: 1 addition & 1 deletion apax/data/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def compute(inputs, labels, shift_options) -> np.ndarray:
class IsolatedAtomEnergyShift:
name = "isolated_atom_energy_shift"
parameters = ["E0s"]
dtypes = [dict[int,float]]
dtypes = [dict[int, float]]

@staticmethod
def compute(inputs, labels, shift_options):
Expand Down
24 changes: 19 additions & 5 deletions apax/train/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,15 @@ class RawDataset:
additional_labels: Optional[dict] = None


def load_data_files(data_config, model_version_path):
def load_data_files(
data_config, model_version_path, train_atoms_list=None, val_atoms_list=None
):
log.info("Running Input Pipeline")
if data_config.data_path is not None:
if train_atoms_list is not None and val_atoms_list is not None:
train_label_dict = None
val_label_dict = None

elif data_config.data_path is not None:
log.info(f"Read data file {data_config.data_path}")
atoms_list, label_dict = load_data(data_config.data_path)

Expand All @@ -63,7 +69,7 @@ def load_data_files(data_config, model_version_path):
train_atoms_list, train_label_dict = load_data(data_config.train_data_path)
val_atoms_list, val_label_dict = load_data(data_config.val_data_path)
else:
raise ValueError("input data path/paths not defined")
raise ValueError("eighter define input data path/paths or atoms_lists")

train_raw_ds = RawDataset(
atoms_list=train_atoms_list, additional_labels=train_label_dict
Expand Down Expand Up @@ -194,7 +200,13 @@ def setup_logging(log_file, log_level):
logging.basicConfig(filename=log_file, level=log_levels[log_level])


def run(user_config, log_file="train.log", log_level="error"):
def run(
user_config,
log_file="train.log",
log_level="error",
train_atoms_list=None,
val_atoms_list=None,
):
setup_logging(log_file, log_level)
log.info("Loading user config")
config = parse_config(user_config)
Expand All @@ -215,7 +227,9 @@ def run(user_config, log_file="train.log", log_level="error"):
loss_fn = initialize_loss_fn(config.loss)
Metrics = initialize_metrics(config.metrics)

train_raw_ds, val_raw_ds = load_data_files(config.data, model_version_path)
train_raw_ds, val_raw_ds = load_data_files(
config.data, model_version_path, train_atoms_list, val_atoms_list
)
train_ds, ds_stats = initialize_dataset(config, train_raw_ds)
val_ds = initialize_dataset(config, val_raw_ds, calc_stats=False)

Expand Down

0 comments on commit 99faf18

Please sign in to comment.