Skip to content

Commit

Permalink
add memory size of model to model-size scritps
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Jul 18, 2024
1 parent 22d0ad2 commit 77e81dd
Show file tree
Hide file tree
Showing 5 changed files with 35 additions and 6 deletions.
7 changes: 6 additions & 1 deletion bin/deterministic/model-size
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import yaml


from ml_downscaling_emulator.deterministic.utils import create_model
from ml_downscaling_emulator.utils import model_size, param_count


logger = logging.getLogger()
Expand Down Expand Up @@ -49,10 +50,14 @@ def main(
config_path = os.path.join(workdir, "config.yml")
config = load_config(config_path)
model = load_model(config)["model"]
num_score_model_parameters = sum(p.numel() for p in model.parameters())
num_score_model_parameters = param_count(model)

typer.echo(f"Model has {num_score_model_parameters} parameters")

size_all_mb = model_size(model)

typer.echo("model size: {:.3f}MB".format(size_all_mb))


if __name__ == "__main__":
app()
10 changes: 8 additions & 2 deletions bin/model-size
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ from ml_downscaling_emulator.score_sde_pytorch.models import ( # noqa: F401
normalization, # noqa: F401
) # noqa: F401

from ml_downscaling_emulator.utils import model_size, param_count

logger = logging.getLogger()
logger.setLevel("INFO")

Expand Down Expand Up @@ -58,12 +60,16 @@ def main(
config_path = os.path.join(workdir, "config.yml")
config = load_config(config_path)
score_model, location_params = load_model(config)
num_score_model_parameters = sum(p.numel() for p in score_model.parameters())
num_location_parameters = sum(p.numel() for p in location_params.parameters())
num_score_model_parameters = param_count(score_model)
num_location_parameters = param_count(location_params)

typer.echo(f"Score model has {num_score_model_parameters} parameters")
typer.echo(f"Location parameters have {num_location_parameters} parameters")

size_all_mb = sum(model_size(model) for model in [score_model, location_params])

typer.echo("model size: {:.3f}MB".format(size_all_mb))


if __name__ == "__main__":
app()
18 changes: 18 additions & 0 deletions src/ml_downscaling_emulator/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""Helper methods"""


def param_count(model):
"""Count the number of parameters in a model."""
return sum(p.numel() for p in model.parameters())


def model_size(model):
"""Compute size in memory of model in MB."""
param_size = sum(
param.nelement() * param.element_size() for param in model.parameters()
)
buffer_size = sum(
buffer.nelement() * buffer.element_size() for buffer in model.buffers()
)

return (param_size + buffer_size) / 1024**2
4 changes: 2 additions & 2 deletions tests/deterministic/smoke-test
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@ set -euo pipefail

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )

cpm_dataset="bham_gcmx-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic"
gcm_dataset="bham_60km-4x_1em_psl-sphum4th-temp4th-vort4th_eqvt_random-season-historic"
cpm_dataset="bham64_ccpm-4x_12em_psl-sphum4th-temp4th-vort4th_pr-historic"
gcm_dataset="bham64_gcm-4x_12em_psl-sphum4th-temp4th-vort4th_pr-historic"
workdir="output/test/unet/test-run"

config_path="src/ml_downscaling_emulator/deterministic/configs/ukcp_local_12em_pr_unet.py"
Expand Down
2 changes: 1 addition & 1 deletion tests/smoke-test
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ train_batch_size=2
epoch=2

rm -rf ${workdir}
WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=${epoch} --config.model.loc_spec_channels=${loc_spec_channels} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True --config.model.num_scales=10
WANDB_EXPERIMENT_NAME="test" python bin/main.py --workdir ${workdir} --config ${config_path} --mode train --config.data.dataset_name=${dataset} --config.training.snapshot_freq=5 --config.training.eval_freq=100 --config.training.log_freq=50 --config.training.n_epochs=${epoch} --config.model.loc_spec_channels=${loc_spec_channels} --config.training.batch_size=${train_batch_size} --config.data.input_transform_key=stan --config.data.target_transform_key=sqrturrecen --config.data.time_inputs=True --config.model.num_scales=10 --config.data.target_transforms.tmean150cm=stan

num_samples=2
eval_batch_size=32
Expand Down

0 comments on commit 77e81dd

Please sign in to comment.