Skip to content

Commit

Permalink
remove score_sde_pytorch namespace
Browse files Browse the repository at this point in the history
no reason anymore to have it in this namespace now that parallel deterministic namespace has been removed
  • Loading branch information
henryaddison committed Oct 22, 2024
1 parent f3cf227 commit f76547f
Show file tree
Hide file tree
Showing 64 changed files with 105 additions and 85 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ repos:
hooks:
- id: black
language_version: python3.9
exclude: ^src/ml_downscaling_emulator/score_sde_pytorch/
exclude: ^src/ml_downscaling_emulator/(run_lib.py|sde_lib.py|likelihood.py|sampling.py|losses.py|models|op|configs)
- repo: https://github.com/pycqa/flake8
rev: '6.0.0' # pick a git hash / tag to point to
hooks:
- id: flake8
exclude: ^src/ml_downscaling_emulator/score_sde_pytorch/
exclude: ^src/ml_downscaling_emulator/(run_lib.py|sde_lib.py|likelihood.py|sampling.py|losses.py|models|op|configs)
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Recommended to run with a sample of the dataset.
Train models through `bin/main.py`, e.g. to train the model used in the paper use

```sh
python bin/main.py --config src/ml_downscaling_emulator/score_sde_pytorch/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py --workdir ${DERIVED_DATA}/path/to/models/paper-12em --mode train
python bin/main.py --config src/ml_downscaling_emulator/configs/subvpsde/ukcp_local_pr_12em_cncsnpp_continuous.py --workdir ${DERIVED_DATA}/path/to/models/paper-12em --mode train
```

```sh
Expand Down
2 changes: 1 addition & 1 deletion bin/bp/queue-training
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def train_cmd(sde, workdir, config, config_overrides=list):
train_basecmd = ["python", f"bin/main.py"]

train_opts = {
"--config": f"src/ml_downscaling_emulator/score_sde_pytorch/configs/{sde}/{config}.py",
"--config": f"src/ml_downscaling_emulator/configs/{sde}/{config}.py",
"--workdir": workdir,
"--mode": "train",
}
Expand Down
2 changes: 1 addition & 1 deletion bin/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

"""Training"""

import ml_downscaling_emulator.score_sde_pytorch.run_lib as run_lib
import ml_downscaling_emulator.run_lib as run_lib
from absl import app
from absl import flags
from ml_collections.config_flags import config_flags
Expand Down
14 changes: 7 additions & 7 deletions bin/model-size
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,20 @@ import typer
import logging
import yaml

from ml_downscaling_emulator.score_sde_pytorch.models.location_params import (
from ml_downscaling_emulator.models.location_params import (
LocationParams,
)

from ml_downscaling_emulator.score_sde_pytorch.models import utils as mutils
from ml_downscaling_emulator.models import utils as mutils

from ml_downscaling_emulator.score_sde_pytorch.models import cncsnpp # noqa: F401
from ml_downscaling_emulator.score_sde_pytorch.models import cunet # noqa: F401
from ml_downscaling_emulator.models import cncsnpp # noqa: F401
from ml_downscaling_emulator.models import cunet # noqa: F401

from ml_downscaling_emulator.score_sde_pytorch.models import ( # noqa: F401
from ml_downscaling_emulator.models import ( # noqa: F401
layerspp, # noqa: F401
) # noqa: F401
from ml_downscaling_emulator.score_sde_pytorch.models import layers # noqa: F401
from ml_downscaling_emulator.score_sde_pytorch.models import ( # noqa: F401
from ml_downscaling_emulator.models import layers # noqa: F401
from ml_downscaling_emulator.models import ( # noqa: F401
normalization, # noqa: F401
) # noqa: F401

Expand Down
28 changes: 14 additions & 14 deletions bin/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,33 +22,33 @@
from mlde_utils import samples_path, DEFAULT_ENSEMBLE_MEMBER
from mlde_utils.training.dataset import get_variables

from ml_downscaling_emulator.score_sde_pytorch.losses import get_optimizer
from ml_downscaling_emulator.score_sde_pytorch.models.ema import (
from ml_downscaling_emulator.losses import get_optimizer
from ml_downscaling_emulator.models.ema import (
ExponentialMovingAverage,
)
from ml_downscaling_emulator.score_sde_pytorch.models.location_params import (
from ml_downscaling_emulator.models.location_params import (
LocationParams,
)

from ml_downscaling_emulator.score_sde_pytorch.utils import restore_checkpoint
from ml_downscaling_emulator.utils import restore_checkpoint

import ml_downscaling_emulator.score_sde_pytorch.models as models # noqa: F401
from ml_downscaling_emulator.score_sde_pytorch.models import utils as mutils
import ml_downscaling_emulator.models as models # noqa: F401
from ml_downscaling_emulator.models import utils as mutils

from ml_downscaling_emulator.score_sde_pytorch.models import cncsnpp # noqa: F401
from ml_downscaling_emulator.score_sde_pytorch.models import cunet # noqa: F401
from ml_downscaling_emulator.score_sde_pytorch.models import det_cunet # noqa: F401
from ml_downscaling_emulator.models import cncsnpp # noqa: F401
from ml_downscaling_emulator.models import cunet # noqa: F401
from ml_downscaling_emulator.models import det_cunet # noqa: F401

from ml_downscaling_emulator.score_sde_pytorch.models import ( # noqa: F401
from ml_downscaling_emulator.models import ( # noqa: F401
layerspp, # noqa: F401
) # noqa: F401
from ml_downscaling_emulator.score_sde_pytorch.models import layers # noqa: F401
from ml_downscaling_emulator.score_sde_pytorch.models import ( # noqa: F401
from ml_downscaling_emulator.models import layers # noqa: F401
from ml_downscaling_emulator.models import ( # noqa: F401
normalization, # noqa: F401
) # noqa: F401
import ml_downscaling_emulator.score_sde_pytorch.sampling as sampling
import ml_downscaling_emulator.sampling as sampling

from ml_downscaling_emulator.score_sde_pytorch.sde_lib import (
from ml_downscaling_emulator.sde_lib import (
VESDE,
VPSDE,
subVPSDE,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ dynamic = ["dependencies"]
dependencies = { file = ["requirements.txt"] }

[tool.black]
extend-exclude = '^/src/ml_downscaling_emulator/score_sde_pytorch/'
extend-exclude = '^/src/ml_downscaling_emulator/(run_lib.py|sde_lib.py|likelihood.py|sampling.py|losses.py|models|op|configs)'
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ml_collections
import torch

from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_1em_configs import get_default_configs as get_base_configs
from ml_downscaling_emulator.configs.default_ukcp_local_pr_1em_configs import get_default_configs as get_base_configs


def get_default_configs():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# Lint as: python3
"""Training NCSN++ on precip data in a deterministic fashion."""

from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs
from ml_downscaling_emulator.configs.deterministic.default_configs import get_default_configs

def get_config():
config = get_default_configs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
but training it in a deterministic fashion.
"""

from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs
from ml_downscaling_emulator.configs.deterministic.default_configs import get_default_configs

def get_config():
config = get_default_configs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
but training it in a deterministic fashion.
"""

from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs
from ml_downscaling_emulator.configs.deterministic.default_configs import get_default_configs

def get_config():
config = get_default_configs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# Lint as: python3
"""Training NCSN++ on precip data in a deterministic fashion."""

from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs
from ml_downscaling_emulator.configs.deterministic.default_configs import get_default_configs

def get_config():
config = get_default_configs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# Lint as: python3
"""Debug config for training in a deterministic fashion."""

from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.default_configs import get_default_configs
from ml_downscaling_emulator.configs.deterministic.default_configs import get_default_configs

def get_config():
config = get_default_configs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
but training it in a deterministic fashion.
"""

from ml_downscaling_emulator.score_sde_pytorch.configs.deterministic.ukcp_local_pr_12em_tuned_plain_unet import get_config as get_default_configs
from ml_downscaling_emulator.configs.deterministic.ukcp_local_pr_12em_tuned_plain_unet import get_config as get_default_configs

def get_config():
config = get_default_configs()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Lint as: python3
"""Training NCSN++ on precip data with sub-VP SDE."""
from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_12em_configs import get_default_configs
from ml_downscaling_emulator.configs.default_ukcp_local_pr_12em_configs import get_default_configs


def get_config():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Lint as: python3
"""Training conditional U-Net on precip data with sub-VP SDE.
DEBUGGING ONLY"""
from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_1em_configs import get_default_configs
from ml_downscaling_emulator.configs.default_ukcp_local_pr_1em_configs import get_default_configs


def get_config():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Lint as: python3
"""Training NCSN++ on precip data with sub-VP SDE."""
from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_12em_configs import get_default_configs
from ml_downscaling_emulator.configs.default_ukcp_local_pr_12em_configs import get_default_configs


def get_config():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Lint as: python3
"""Training NCSN++ on precip data with sub-VP SDE."""
from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_1em_configs import get_default_configs
from ml_downscaling_emulator.configs.default_ukcp_local_pr_1em_configs import get_default_configs


def get_config():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Lint as: python3
"""Training NCSN++ on precip data with sub-VP SDE."""
from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_1em_configs import get_default_configs
from ml_downscaling_emulator.configs.default_ukcp_local_pr_1em_configs import get_default_configs


def get_config():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# Lint as: python3
"""Training conditional U-Net on precip data with sub-VP SDE.
DEBUGGING ONLY"""
from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_1em_configs import get_default_configs
from ml_downscaling_emulator.configs.default_ukcp_local_pr_1em_configs import get_default_configs


def get_config():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Lint as: python3
"""Training NCSN++ on precip data with sub-VP SDE."""
from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_12em_configs import get_default_configs
from ml_downscaling_emulator.configs.default_ukcp_local_pr_12em_configs import get_default_configs


def get_config():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Lint as: python3
"""Training NCSN++ on precip data with sub-VP SDE."""
from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_12em_configs import get_default_configs
from ml_downscaling_emulator.configs.default_ukcp_local_pr_12em_configs import get_default_configs


def get_config():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Lint as: python3
"""Training NCSN++ on precip data with VE SDE."""
from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_1em_configs import get_default_configs
from ml_downscaling_emulator.configs.default_ukcp_local_pr_1em_configs import get_default_configs


def get_config():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# Lint as: python3
"""Training UNet on XArray with VE SDE."""
from ml_downscaling_emulator.score_sde_pytorch.configs.default_ukcp_local_pr_1em_configs import get_default_configs
from ml_downscaling_emulator.configs.default_ukcp_local_pr_1em_configs import get_default_configs


def get_config():
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
Empty file.
36 changes: 0 additions & 36 deletions src/ml_downscaling_emulator/score_sde_pytorch/utils.py

This file was deleted.

File renamed without changes.
58 changes: 57 additions & 1 deletion src/ml_downscaling_emulator/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,60 @@
"""Helper methods"""
# coding=utf-8
# Copyright 2020 The Google Research Authors.
# Modifications copyright 2024 Henry Addison
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Significant modifications to the original work have been made by Henry Addison
# to allow for location-specific parameters and iterating by epoch using PyTorch
# DataLoaders and helpers for determining a model size.

import torch
import os
import logging


def restore_checkpoint(ckpt_dir, state, device):
if not os.path.exists(ckpt_dir):
os.makedirs(os.path.dirname(ckpt_dir), exist_ok=True)
logging.warning(
f"No checkpoint found at {ckpt_dir}. " f"Returned the same state as input"
)
return state, False
else:
loaded_state = torch.load(ckpt_dir, map_location=device)
state["optimizer"].load_state_dict(loaded_state["optimizer"])
state["model"].load_state_dict(loaded_state["model"], strict=False)
state["ema"].load_state_dict(loaded_state["ema"])
state["location_params"].load_state_dict(loaded_state["location_params"])
state["step"] = loaded_state["step"]
state["epoch"] = loaded_state["epoch"]
logging.info(
f"Checkpoint found at {ckpt_dir}. "
f"Returned the state from {state['epoch']}/{state['step']}"
)
return state, True


def save_checkpoint(ckpt_dir, state):
saved_state = {
"optimizer": state["optimizer"].state_dict(),
"model": state["model"].state_dict(),
"ema": state["ema"].state_dict(),
"step": state["step"],
"epoch": state["epoch"],
"location_params": state["location_params"].state_dict(),
}
torch.save(saved_state, ckpt_dir)


def param_count(model):
Expand Down
2 changes: 1 addition & 1 deletion tests/smoke-tests/test-det-debug-cunet
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ set -euo pipefail
config_name="ukcp_local_pr_debug"

workdir="output/test/deterministic/${config_name}/test-run"
config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/${config_name}.py"
config_path="src/ml_downscaling_emulator/configs/deterministic/${config_name}.py"

loc_spec_channels=0

Expand Down
2 changes: 1 addition & 1 deletion tests/smoke-tests/test-det-det_cunet
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ set -euo pipefail
config_name="ukcp_local_pr_plain_unet_debug"

workdir="output/test/deterministic/${config_name}/test-run"
config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/deterministic/${config_name}.py"
config_path="src/ml_downscaling_emulator/configs/deterministic/${config_name}.py"

loc_spec_channels=2

Expand Down
2 changes: 1 addition & 1 deletion tests/smoke-tests/test-subvpsde-debug-cunet
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ config_name="ukcp_local_mv_debug"
dataset="debug-sample-mv"

workdir="output/test/${sde}/${config_name}/test-run"
config_path="src/ml_downscaling_emulator/score_sde_pytorch/configs/${sde}/${config_name}.py"
config_path="src/ml_downscaling_emulator/configs/${sde}/${config_name}.py"

loc_spec_channels=2
train_batch_size=2
Expand Down
Loading

0 comments on commit f76547f

Please sign in to comment.