Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 720765321
  • Loading branch information
zhong1wan authored and The swirl_dynamics Authors committed Jan 29, 2025
1 parent a13d14d commit 386e03d
Show file tree
Hide file tree
Showing 17 changed files with 3,368 additions and 0 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,207 @@
# Copyright 2024 The swirl_dynamics Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Calculate daily max stats for an ERA5 dataset.
The stats represent smoothed average daily max value within a year range. Also
supports downsampling first before computing the stats.
Example usage:
```
python swirl_dynamics/projects/probabilistic_diffusion/downscaling/era5/beam:compute_dailymax_stats.par -- \
--input_path=/data/era5/selected_variables/0p25deg_hourly_7vars_windspeed_1959-2023.zarr \
--output_path=/data/era5/stats/0p25deg_7vars_windspeed_2hourly_dailymax_1961-2000_copy.zarr \
--downsample=2 \
--start_year=1961 \
--end_year=2000
```
"""

import functools
from typing import TypeVar

from absl import app
from absl import flags
import apache_beam as beam
import numpy as np
import xarray
import xarray_beam as xbeam

Dataset = TypeVar("Dataset", xarray.Dataset, xarray.DataArray)

# Flags.
INPUT_PATH = flags.DEFINE_string("input_path", None, help="Input Zarr path.")
OUTPUT_PATH = flags.DEFINE_string("output_path", None, help="Output Zarr path.")
START_YEAR = flags.DEFINE_string("start_year", None, "Start of time range.")
END_YEAR = flags.DEFINE_string("end_year", None, "End of time range.")
WINDOW_SIZE = flags.DEFINE_integer("window_size", 61, "Window size.")
DOWNSAMPLE = flags.DEFINE_integer("downsample", 1, "Downsample factor.")
RUNNER = flags.DEFINE_string("runner", None, "beam.runners.Runner")

# Constants.
WORKING_CHUNKS = {"latitude": 4, "longitude": 4, "level": 1, "time": -1}
VARIABLES = [
"2m_temperature",
"10m_u_component_of_wind",
"10m_v_component_of_wind",
"specific_humidity",
"mean_sea_level_pressure",
"total_precipitation",
"geopotential",
]


def replace_time_with_doy(ds: xarray.Dataset) -> xarray.Dataset:
"""Replace time coordinate with days of the year."""
return ds.assign_coords({"time": ds.time.dt.dayofyear}).rename(
{"time": "dayofyear"}
)


def create_window_weights(window_size: int) -> xarray.DataArray:
"""Create linearly decaying window weights."""
assert window_size % 2 == 1, "Window size must be odd."
half_window_size = window_size // 2
window_weights = np.concatenate([
np.linspace(0, 1, half_window_size + 1),
np.linspace(1, 0, half_window_size + 1)[1:],
])
window_weights = window_weights / window_weights.mean()
window_weights = xarray.DataArray(window_weights, dims=["window"])
return window_weights


def compute_windowed_stat(
ds: xarray.Dataset, window_weights: xarray.DataArray
) -> xarray.Dataset:
"""Compute rolling-windowed statistics."""
window_size = len(window_weights)
half_window_size = window_size // 2
stacked = xarray.concat(
[
replace_time_with_doy(ds.sel(time=str(year)))
for year in np.unique(ds.time.dt.year)
],
dim="year",
)
# Fill gap day (366) with values from previous day 365
stacked = stacked.fillna(stacked.sel(dayofyear=365))
# Pad edges for perioding window
stacked = stacked.pad(pad_width={"dayofyear": half_window_size}, mode="wrap")
# Weighted rolling mean
stacked = stacked.rolling(dayofyear=window_size, center=True).construct(
"window"
)
rolling_mean = (
stacked.weighted(window_weights)
.mean(dim=("window", "year"))
.isel(dayofyear=slice(half_window_size, -half_window_size))
)
return rolling_mean


def compute_chunk_stats(
key: xbeam.Key,
chunk: xarray.Dataset,
*,
window_size: int,
) -> tuple[xbeam.Key, xarray.Dataset]:
"""Compute stats for a chunk."""
dims = (
["longitude", "latitude", "level"]
if "level" in chunk.dims
else ["longitude", "latitude"]
)
new_key = key.with_offsets(time=None, dayofyear=0)
chunk = chunk.resample(time="D").max()
output_chunk = compute_windowed_stat(
ds=chunk, window_weights=create_window_weights(window_size)
)
output_chunk = output_chunk.transpose("dayofyear", *dims)
return new_key, output_chunk


def main(argv):
input_store = INPUT_PATH.value
source_dataset, source_chunks = xbeam.open_zarr(
input_store, consolidated=True
)

# Enforce year range.
start = np.datetime64(START_YEAR.value, "D") if START_YEAR.value else None
end = np.datetime64(END_YEAR.value, "D") if END_YEAR.value else None
# NOTE: Rechunk does not like using a subset of variables only.
source_dataset = source_dataset.sel(time=slice(start, end))
source_dataset = source_dataset.isel(time=slice(None, None, DOWNSAMPLE.value))

# Create template.
template_ds = source_dataset.isel(time=0, drop=True)
template_ds["latitude"] = np.sort(template_ds["latitude"].to_numpy())
template_ds["longitude"] = np.sort(template_ds["longitude"].to_numpy())

template_ds = template_ds.expand_dims(dayofyear=np.arange(366) + 1, axis=1)
working_time_chunks = {"dayofyear": -1}
output_time_chunks = {"dayofyear": 1}

template_ds = template_ds.transpose(
*working_time_chunks.keys(), "longitude", "latitude", "level"
)
template = xbeam.make_template(template_ds)

# Define intermediate and final chunks.
stats_working_chunks = dict(**WORKING_CHUNKS, **working_time_chunks)
del stats_working_chunks["time"]

output_chunks = {k: v for k, v in source_chunks.items() if k != "time"}
output_chunks.update(output_time_chunks)

output_store = OUTPUT_PATH.value

with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
_ = (
root
| xbeam.DatasetToChunks(source_dataset, source_chunks, num_threads=16)
| xbeam.SplitVariables()
| "Rechunk input"
>> xbeam.Rechunk(
source_dataset.sizes, # pytype: disable=wrong-arg-types
source_chunks,
WORKING_CHUNKS,
itemsize=4,
max_mem=2**30 * 16,
)
| beam.MapTuple(
functools.partial(
compute_chunk_stats, window_size=WINDOW_SIZE.value
)
)
| "Rechunk output"
>> xbeam.Rechunk(
template.sizes, # pytype: disable=wrong-arg-types
stats_working_chunks,
output_chunks,
itemsize=4,
max_mem=2**30 * 16,
)
| xbeam.ChunksToZarr(
output_store, template, output_chunks, num_threads=16
)
)


if __name__ == "__main__":
app.run(main)
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# Copyright 2024 The swirl_dynamics Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Computes the normalized version of a dataset.
Example usage:
Computing the normalized residual between high-resolution target and the
interpolated low-resolution ERA5:
```
python swirl_dynamics/projects/probabilistic_diffusion/downscaling/era5/beam:compute_normalized.par -- \
--raw_path=/data/era5/selected_variables/0p25deg_minus_1p5deg_linear_hourly_7vars_windspeed_1959-2023.zarr \
--stats_path=/data/era5/stats/0p25deg_minus_1p5deg_linear_hourly_7vars_windspeed_1961-2000_dayofyear_hour.zarr \
--output_path=/data/era5/selected_variables/0p25deg_minus_1p5deg_linear_hourly_7vars_windspeed_1959-2023_dayofyear_hour_normalized.zarr \
--stats_type=hourly
```
Computing the normalized input:
```
python swirl_dynamics/projects/probabilistic_diffusion/downscaling/era5/beam:compute_normalized.par -- \
--raw_path=/data/era5/selected_variables/1p5deg_dailymean_7vars_windspeed_1959-2023.zarr \
--stats_path=/data/era5/stats/1p5deg_7vars_windspeed_1961-2000_all.zarr \
--output_path=/data/era5/selected_variables/1p5deg_dailymean_7vars_windspeed_1959-2023_all_normalized.zarr \
--stats_type=all
```
"""

import functools
from typing import TypeVar

from absl import app
from absl import flags
import apache_beam as beam
import xarray
import xarray_beam as xbeam
from zarr.google import gfile_store

Dataset = TypeVar("Dataset", xarray.Dataset, xarray.DataArray)

# Flags.
RAW_DATA_PATH = flags.DEFINE_string(
"raw_path", None, help="Zarr path containing the unnormalized raw data."
)
STATS_PATH = flags.DEFINE_string(
"stats_path", None, help="Zarr path containing the normalization stats."
)
STATS_TYPE = flags.DEFINE_string(
"stats_type",
"hourly",
help="Type of stats. One of [`all`, `daily` or `hourly`].",
)
OUTPUT_PATH = flags.DEFINE_string("output_path", None, help="Output Zarr path.")
RUNNER = flags.DEFINE_string("runner", None, "beam.runners.Runner")


def normalize(
key: xbeam.Key,
raw_chunk: xarray.Dataset,
stats_ds: xarray.Dataset,
stats_type: str,
) -> tuple[xbeam.Key, xarray.Dataset]:
"""Normalizes a chunk with the corresponding mean and std."""
stats_chunk = stats_ds.sel(
longitude=raw_chunk["longitude"],
latitude=raw_chunk["latitude"],
level=raw_chunk["level"],
)
match stats_type:
case "all":
pass
case "daily":
stats_chunk = stats_chunk.sel(
dayofyear=raw_chunk["time"].dt.dayofyear
).drop_vars(["dayofyear"])
case "hourly":
stats_chunk = stats_chunk.sel(
dayofyear=raw_chunk["time"].dt.dayofyear,
hour=raw_chunk["time"].dt.hour,
).drop_vars(["dayofyear", "hour"])
case _:
raise ValueError(
f"Unknown stats type: {stats_type}. Currently supported: [`all`,"
" `daily` or `hourly`]."
)
normalized_chunk = (
raw_chunk - stats_chunk.sel(stats="mean", drop=True)
) / stats_chunk.sel(stats="std", drop=True)
return key, normalized_chunk


def main(argv):
raw_store = gfile_store.GFileStore(RAW_DATA_PATH.value)
raw_ds, raw_chunks = xbeam.open_zarr(raw_store)

stats_store = gfile_store.GFileStore(STATS_PATH.value)
stats_ds = xarray.open_zarr(stats_store, chunks=None)

template = xbeam.make_template(raw_ds)

output_store = OUTPUT_PATH.value
output_store = gfile_store.GFileStore(output_store)

with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
_ = (
root
| xbeam.DatasetToChunks(raw_ds, raw_chunks, num_threads=16)
| beam.MapTuple(
functools.partial(
normalize, stats_ds=stats_ds, stats_type=STATS_TYPE.value
)
)
| xbeam.ChunksToZarr(output_store, template, raw_chunks, num_threads=16)
)


if __name__ == "__main__":
app.run(main)
Loading

0 comments on commit 386e03d

Please sign in to comment.