Skip to content

Commit

Permalink
Add regional forecasting (#3)
Browse files Browse the repository at this point in the history
* add regional forecasting

* add usage for regional forecasting

* fix typo
  • Loading branch information
tung-nd authored Feb 14, 2023
1 parent 8a37186 commit d6dec5f
Show file tree
Hide file tree
Showing 10 changed files with 891 additions and 10 deletions.
199 changes: 199 additions & 0 deletions configs/regional_forecast_climax.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
seed_everything: 42

# ---------------------------- TRAINER -------------------------------------------
trainer:
default_root_dir: ${oc.env:OUTPUT_DIR,/home/t-tungnguyen/ClimaX/exps/regional_forecast_climax}

precision: 16

gpus: null
num_nodes: 1
accelerator: gpu
strategy: ddp

min_epochs: 1
max_epochs: 100
enable_progress_bar: true

sync_batchnorm: True
enable_checkpointing: True
resume_from_checkpoint: null

# debugging
fast_dev_run: false

logger:
class_path: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
init_args:
save_dir: ${trainer.default_root_dir}/logs
name: null
version: null
log_graph: False
default_hp_metric: True
prefix: ""

callbacks:
- class_path: pytorch_lightning.callbacks.LearningRateMonitor
init_args:
logging_interval: "step"

- class_path: pytorch_lightning.callbacks.ModelCheckpoint
init_args:
dirpath: "${trainer.default_root_dir}/checkpoints"
monitor: "val/w_rmse" # name of the logged metric which determines when model is improving
mode: "min" # "max" means higher metric value is better, can be also "min"
save_top_k: 1 # save k best models (determined by above metric)
save_last: True # additionaly always save model from last epoch
verbose: False
filename: "epoch_{epoch:03d}"
auto_insert_metric_name: False

- class_path: pytorch_lightning.callbacks.EarlyStopping
init_args:
monitor: "val/w_rmse" # name of the logged metric which determines when model is improving
mode: "min" # "max" means higher metric value is better, can be also "min"
patience: 5 # how many validation epochs of not improving until training stops
min_delta: 0. # minimum change in the monitored metric needed to qualify as an improvement

- class_path: pytorch_lightning.callbacks.RichModelSummary
init_args:
max_depth: -1

- class_path: pytorch_lightning.callbacks.RichProgressBar

# ---------------------------- MODEL -------------------------------------------
model:
lr: 5e-4
beta_1: 0.9
beta_2: 0.99
weight_decay: 1e-5
warmup_epochs: 10000
max_epochs: 100000
warmup_start_lr: 1e-8
eta_min: 1e-8
pretrained_path: ""

net:
class_path: climax.regional_forecast.arch.RegionalClimaX
init_args:
default_vars: [
"land_sea_mask",
"orography",
"lattitude",
"2m_temperature",
"10m_u_component_of_wind",
"10m_v_component_of_wind",
"geopotential_50",
"geopotential_250",
"geopotential_500",
"geopotential_600",
"geopotential_700",
"geopotential_850",
"geopotential_925",
"u_component_of_wind_50",
"u_component_of_wind_250",
"u_component_of_wind_500",
"u_component_of_wind_600",
"u_component_of_wind_700",
"u_component_of_wind_850",
"u_component_of_wind_925",
"v_component_of_wind_50",
"v_component_of_wind_250",
"v_component_of_wind_500",
"v_component_of_wind_600",
"v_component_of_wind_700",
"v_component_of_wind_850",
"v_component_of_wind_925",
"temperature_50",
"temperature_250",
"temperature_500",
"temperature_600",
"temperature_700",
"temperature_850",
"temperature_925",
"relative_humidity_50",
"relative_humidity_250",
"relative_humidity_500",
"relative_humidity_600",
"relative_humidity_700",
"relative_humidity_850",
"relative_humidity_925",
"specific_humidity_50",
"specific_humidity_250",
"specific_humidity_500",
"specific_humidity_600",
"specific_humidity_700",
"specific_humidity_850",
"specific_humidity_925",
]
img_size: [32, 64]
patch_size: 2
embed_dim: 1024
depth: 8
decoder_depth: 2
num_heads: 16
mlp_ratio: 4
drop_path: 0.1
drop_rate: 0.1

# ---------------------------- DATA -------------------------------------------
data:
root_dir: /datadrive/datasets/5.625deg_equally_np/
variables: [
"land_sea_mask",
"orography",
"lattitude",
"2m_temperature",
"10m_u_component_of_wind",
"10m_v_component_of_wind",
"geopotential_50",
"geopotential_250",
"geopotential_500",
"geopotential_600",
"geopotential_700",
"geopotential_850",
"geopotential_925",
"u_component_of_wind_50",
"u_component_of_wind_250",
"u_component_of_wind_500",
"u_component_of_wind_600",
"u_component_of_wind_700",
"u_component_of_wind_850",
"u_component_of_wind_925",
"v_component_of_wind_50",
"v_component_of_wind_250",
"v_component_of_wind_500",
"v_component_of_wind_600",
"v_component_of_wind_700",
"v_component_of_wind_850",
"v_component_of_wind_925",
"temperature_50",
"temperature_250",
"temperature_500",
"temperature_600",
"temperature_700",
"temperature_850",
"temperature_925",
"relative_humidity_50",
"relative_humidity_250",
"relative_humidity_500",
"relative_humidity_600",
"relative_humidity_700",
"relative_humidity_850",
"relative_humidity_925",
"specific_humidity_50",
"specific_humidity_250",
"specific_humidity_500",
"specific_humidity_600",
"specific_humidity_700",
"specific_humidity_850",
"specific_humidity_925",
]
out_variables: ["geopotential_500", "temperature_850", "2m_temperature", "10m_u_component_of_wind", "10m_v_component_of_wind"]
region: "NorthAmerica"
predict_range: 72
hrs_each_step: 1
buffer_size: 10000
batch_size: 128
num_workers: 1
pin_memory: False
27 changes: 27 additions & 0 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,33 @@ python src/climax/global_forecast/train.py --config configs/global_forecast_clim
```
To train ClimaX from scratch, set `--model.pretrained_path=""`.

## Regional Forecasting

### Data Preparation

We use the same ERA5 data as in global forecasting and extract the regional data on the fly during training. If you have already downloaded and preprocessed the data, you do not have to do it again.

### Training

To finetune ClimaX for regional forecasting, use
```
python src/climax/regional_forecast/train.py --config <path/to/config>
```
For example, to finetune ClimaX on North America using 8 GPUs, use
```bash
python src/climax/regional_forecast/train.py --config configs/regional_forecast_climax.yaml \
--trainer.strategy=ddp --trainer.devices=8 \
--trainer.max_epochs=50 \
--data.root_dir=/mnt/data/5.625deg_npz \
--data.region="NorthAmerica"
--data.predict_range=72 --data.out_variables=['z_500','t_850','t2m'] \
--data.batch_size=16 \
--model.pretrained_path=/mnt/checkpoints/climax_5.625deg.ckpt \
--model.lr=5e-7 --model.beta_1="0.9" --model.beta_2="0.99" \
--model.weight_decay=1e-5
```
To train ClimaX from scratch, set `--model.pretrained_path=""`.

## Visualization

Coming soon
6 changes: 3 additions & 3 deletions src/climax/arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,15 @@ def get_var_emb(self, var_emb, vars):
ids = self.get_var_ids(vars, var_emb.device)
return var_emb[:, ids, :]

def unpatchify(self, x: torch.Tensor):
def unpatchify(self, x: torch.Tensor, h=None, w=None):
"""
x: (B, L, V * patch_size**2)
return imgs: (B, V, H, W)
"""
p = self.patch_size
c = len(self.default_vars)
h = self.img_size[0] // p
w = self.img_size[1] // p
h = self.img_size[0] // p if h is None else h // p
w = self.img_size[1] // p if w is None else w // p
assert h * w == x.shape[1]

x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
Expand Down
12 changes: 5 additions & 7 deletions src/climax/pretrain/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,23 +96,21 @@ def __iter__(self):


class IndividualForecastDataIter(IterableDataset):
def __init__(self, dataset, transforms: torch.nn.Module, output_transforms: torch.nn.Module):
def __init__(self, dataset, transforms: torch.nn.Module, output_transforms: torch.nn.Module, region_info = None):
super().__init__()
self.dataset = dataset
self.transforms = transforms
self.output_transforms = output_transforms
self.region_info = region_info

def __iter__(self):
for (inp, out, lead_times, variables, out_variables) in self.dataset:
assert inp.shape[0] == out.shape[0]
for i in range(inp.shape[0]):
# TODO: should we unsqueeze the first dimension?
if self.transforms is not None:
yield self.transforms(inp[i]), self.output_transforms(out[i]), lead_times[
i
], variables, out_variables
if self.region_info is not None:
yield self.transforms(inp[i]), self.output_transforms(out[i]), lead_times[i], variables, out_variables, self.region_info
else:
yield inp[i], out[i], lead_times[i], variables, out_variables
yield self.transforms(inp[i]), self.output_transforms(out[i]), lead_times[i], variables, out_variables


class ShuffleIterableDataset(IterableDataset):
Expand Down
2 changes: 2 additions & 0 deletions src/climax/regional_forecast/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
94 changes: 94 additions & 0 deletions src/climax/regional_forecast/arch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

import torch
from climax.arch import ClimaX

class RegionalClimaX(ClimaX):
def __init__(self, default_vars, img_size=..., patch_size=2, embed_dim=1024, depth=8, decoder_depth=2, num_heads=16, mlp_ratio=4, drop_path=0.1, drop_rate=0.1):
super().__init__(default_vars, img_size, patch_size, embed_dim, depth, decoder_depth, num_heads, mlp_ratio, drop_path, drop_rate)

def forward_encoder(self, x: torch.Tensor, lead_times: torch.Tensor, variables, region_info):
# x: `[B, V, H, W]` shape.

if isinstance(variables, list):
variables = tuple(variables)

# tokenize each variable separately
embeds = []
var_ids = self.get_var_ids(variables, x.device)
for i in range(len(var_ids)):
id = var_ids[i]
embeds.append(self.token_embeds[id](x[:, i : i + 1]))
x = torch.stack(embeds, dim=1) # B, V, L, D

# add variable embedding
var_embed = self.get_var_emb(self.var_embed, variables)
x = x + var_embed.unsqueeze(2) # B, V, L, D

# get the patch ids corresponding to the region
region_patch_ids = region_info['patch_ids']
x = x[:, :, region_patch_ids, :]

# variable aggregation
x = self.aggregate_variables(x) # B, L, D

# add pos embedding
x = x + self.pos_embed[:, region_patch_ids, :]

# add lead time embedding
lead_time_emb = self.lead_time_embed(lead_times.unsqueeze(-1)) # B, D
lead_time_emb = lead_time_emb.unsqueeze(1)
x = x + lead_time_emb # B, L, D

x = self.pos_drop(x)

# apply Transformer blocks
for blk in self.blocks:
x = blk(x)
x = self.norm(x)

return x

def forward(self, x, y, lead_times, variables, out_variables, metric, lat, region_info):
"""Forward pass through the model.
Args:
x: `[B, Vi, H, W]` shape. Input weather/climate variables
y: `[B, Vo, H, W]` shape. Target weather/climate variables
lead_times: `[B]` shape. Forecasting lead times of each element of the batch.
region_info: Containing the region's information
Returns:
loss (list): Different metrics.
preds (torch.Tensor): `[B, Vo, H, W]` shape. Predicted weather/climate variables.
"""
out_transformers = self.forward_encoder(x, lead_times, variables, region_info) # B, L, D
preds = self.head(out_transformers) # B, L, V*p*p

min_h, max_h = region_info['min_h'], region_info['max_h']
min_w, max_w = region_info['min_w'], region_info['max_w']
preds = self.unpatchify(preds, h = max_h - min_h + 1, w = max_w - min_w + 1)
out_var_ids = self.get_var_ids(tuple(out_variables), preds.device)
preds = preds[:, out_var_ids]

y = y[:, :, min_h:max_h+1, min_w:max_w+1]
lat = lat[min_h:max_h+1]

if metric is None:
loss = None
else:
loss = [m(preds, y, out_variables, lat) for m in metric]

return loss, preds

def evaluate(self, x, y, lead_times, variables, out_variables, transform, metrics, lat, clim, log_postfix, region_info):
_, preds = self.forward(x, y, lead_times, variables, out_variables, metric=None, lat=lat, region_info=region_info)

min_h, max_h = region_info['min_h'], region_info['max_h']
min_w, max_w = region_info['min_w'], region_info['max_w']
y = y[:, :, min_h:max_h+1, min_w:max_w+1]
lat = lat[min_h:max_h+1]
clim = clim[:, min_h:max_h+1, min_w:max_w+1]

return [m(preds, y, transform, out_variables, lat, clim, log_postfix) for m in metrics]
Loading

0 comments on commit d6dec5f

Please sign in to comment.