Skip to content

Commit

Permalink
add a helper to compute model size for deterministic models
Browse files Browse the repository at this point in the history
  • Loading branch information
henryaddison committed Jun 11, 2024
1 parent db29ea9 commit 642e535
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
58 changes: 58 additions & 0 deletions bin/deterministic/model-size
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#!/usr/bin/env python
# calculate the number of parameters in a deterministic model

import logging
import os
from pathlib import Path

from ml_collections import config_dict
from mlde_utils.training.dataset import get_variables
import torch
import typer
import yaml


from ml_downscaling_emulator.deterministic.utils import create_model


logger = logging.getLogger()
logger.setLevel("INFO")

app = typer.Typer()


def load_config(config_path):
logger.info(f"Loading config from {config_path}")
with open(config_path) as f:
config = config_dict.ConfigDict(yaml.unsafe_load(f))

return config


def load_model(config):
num_predictors = len(get_variables(config.data.dataset_name)[0])
if config.data.time_inputs:
num_predictors += 3
model = torch.nn.DataParallel(
create_model(config, num_predictors).to(device=config.device)
)
optimizer = torch.optim.Adam(model.parameters())
state = dict(step=0, epoch=0, optimizer=optimizer, model=model)

return state


@app.command()
def main(
workdir: Path,
):
config_path = os.path.join(workdir, "config.yml")
config = load_config(config_path)
model = load_model(config)["model"]
num_score_model_parameters = sum(p.numel() for p in model.parameters())

typer.echo(f"Model has {num_score_model_parameters} parameters")


if __name__ == "__main__":
app()
1 change: 1 addition & 0 deletions bin/model-size
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#!/usr/bin/env python
# calculate the number of parameters in a model

import os
from pathlib import Path
Expand Down

0 comments on commit 642e535

Please sign in to comment.