diff --git a/bin/deterministic/model-size b/bin/deterministic/model-size index 0d9a7df25..ab955065d 100755 --- a/bin/deterministic/model-size +++ b/bin/deterministic/model-size @@ -13,6 +13,7 @@ import yaml from ml_downscaling_emulator.deterministic.utils import create_model +from ml_downscaling_emulator.utils import model_size, param_count logger = logging.getLogger() @@ -49,10 +50,14 @@ def main( config_path = os.path.join(workdir, "config.yml") config = load_config(config_path) model = load_model(config)["model"] - num_score_model_parameters = sum(p.numel() for p in model.parameters()) + num_score_model_parameters = param_count(model) typer.echo(f"Model has {num_score_model_parameters} parameters") + size_all_mb = model_size(model) + + typer.echo("model size: {:.3f}MB".format(size_all_mb)) + if __name__ == "__main__": app() diff --git a/bin/model-size b/bin/model-size index d38f6a090..1acb7efdb 100755 --- a/bin/model-size +++ b/bin/model-size @@ -27,6 +27,8 @@ from ml_downscaling_emulator.score_sde_pytorch.models import ( # noqa: F401 normalization, # noqa: F401 ) # noqa: F401 +from ml_downscaling_emulator.utils import model_size, param_count + logger = logging.getLogger() logger.setLevel("INFO") @@ -58,12 +60,16 @@ def main( config_path = os.path.join(workdir, "config.yml") config = load_config(config_path) score_model, location_params = load_model(config) - num_score_model_parameters = sum(p.numel() for p in score_model.parameters()) - num_location_parameters = sum(p.numel() for p in location_params.parameters()) + num_score_model_parameters = param_count(score_model) + num_location_parameters = param_count(location_params) typer.echo(f"Score model has {num_score_model_parameters} parameters") typer.echo(f"Location parameters have {num_location_parameters} parameters") + size_all_mb = sum(model_size(model) for model in [score_model, location_params]) + + typer.echo("model size: {:.3f}MB".format(size_all_mb)) + if __name__ == "__main__": app() diff --git a/src/ml_downscaling_emulator/utils.py b/src/ml_downscaling_emulator/utils.py new file mode 100644 index 000000000..a15ab7a54 --- /dev/null +++ b/src/ml_downscaling_emulator/utils.py @@ -0,0 +1,18 @@ +"""Helper methods""" + + +def param_count(model): + """Count the number of parameters in a model.""" + return sum(p.numel() for p in model.parameters()) + + +def model_size(model): + """Compute size in memory of model in MB.""" + param_size = sum( + param.nelement() * param.element_size() for param in model.parameters() + ) + buffer_size = sum( + buffer.nelement() * buffer.element_size() for buffer in model.buffers() + ) + + return (param_size + buffer_size) / 1024**2 diff --git a/tests/deterministic/smoke-test b/tests/deterministic/smoke-test index d0fdc140e..2f549acf3 100755 --- a/tests/deterministic/smoke-test +++ b/tests/deterministic/smoke-test @@ -4,8 +4,8 @@ set -euo pipefail SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) -cpm_dataset="bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic" -gcm_dataset="bham_60km-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic" +cpm_dataset="bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr-historic" +gcm_dataset="bham64_gcm-4x_12em_psl-sphum4th-temp4th-vort4th_pr-historic" workdir="output/test/unet/test-run" config_path="src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py" diff --git a/tests/smoke-test b/tests/smoke-test index 0c5ae3792..1895b9ea5 100755 --- a/tests/smoke-test +++ b/tests/smoke-test @@ -14,7 +14,7 @@ train_batch_size=2 epoch=2 rm -rf ${workdir} -WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=${epoch} --config.model.loc_spec_channels=${loc_spec_channels} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True --config.model.num_scales=10 +WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=${epoch} --config.model.loc_spec_channels=${loc_spec_channels} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True --config.model.num_scales=10 --config.data.target_transforms.tmean150cm=stan num_samples=2 eval_batch_size=32