From 642e53510e0f21f61722a03f168c66ed4e1c9d1e Mon Sep 17 00:00:00 2001 From: Henry Addison Date: Tue, 11 Jun 2024 14:15:28 +0100 Subject: [PATCH] add a helper to compute model size for deterministic models --- bin/deterministic/model-size | 58 ++++++++++++++++++++++++++++++++++++ bin/model-size | 1 + 2 files changed, 59 insertions(+) create mode 100755 bin/deterministic/model-size diff --git a/bin/deterministic/model-size b/bin/deterministic/model-size new file mode 100755 index 000000000..0d9a7df25 --- /dev/null +++ b/bin/deterministic/model-size @@ -0,0 +1,58 @@ +#!/usr/bin/env python +# calculate the number of parameters in a deterministic model + +import logging +import os +from pathlib import Path + +from ml_collections import config_dict +from mlde_utils.training.dataset import get_variables +import torch +import typer +import yaml + + +from ml_downscaling_emulator.deterministic.utils import create_model + + +logger = logging.getLogger() +logger.setLevel("INFO") + +app = typer.Typer() + + +def load_config(config_path): + logger.info(f"Loading config from {config_path}") + with open(config_path) as f: + config = config_dict.ConfigDict(yaml.unsafe_load(f)) + + return config + + +def load_model(config): + num_predictors = len(get_variables(config.data.dataset_name)[0]) + if config.data.time_inputs: + num_predictors += 3 + model = torch.nn.DataParallel( + create_model(config, num_predictors).to(device=config.device) + ) + optimizer = torch.optim.Adam(model.parameters()) + state = dict(step=0, epoch=0, optimizer=optimizer, model=model) + + return state + + +@app.command() +def main( + workdir: Path, +): + config_path = os.path.join(workdir, "config.yml") + config = load_config(config_path) + model = load_model(config)["model"] + num_score_model_parameters = sum(p.numel() for p in model.parameters()) + + typer.echo(f"Model has {num_score_model_parameters} parameters") + + +if __name__ == "__main__": + app() diff --git a/bin/model-size b/bin/model-size index fee6586d3..d38f6a090 100755 --- a/bin/model-size +++ b/bin/model-size @@ -1,4 +1,5 @@ #!/usr/bin/env python +# calculate the number of parameters in a model import os from pathlib import Path