Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 711558862
  • Loading branch information
ilopezgp authored and The swirl_dynamics Authors committed Jan 3, 2025
1 parent 6405bd8 commit 8688338
Showing 1 changed file with 375 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,375 @@
# Copyright 2024 The swirl_dynamics Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Script to perform downscaling with the STAR-ESDM method.
This script implements the STAR-ESDM method for bias correction and downscaling,
following Hayhoe et al (2024): https://doi.org/10.1029/2023EF004107. This method
is supposed to relax the stationarity assumptions of other empirical statistical
downscaling methods. The methodology is applied to a single domain
discretization (the high-resolution grid), similarly to the STAR-ESDM paper.
The method assumes we have the following information from the low-resolution
dataset:
- A low-resolution dataset to be downscaled,
- a third-order parametric fit of its long-term trend (per pixel),
- its detrended climatology over a "training" period for which we have
high-resolution data. That is, the climatology of the low-resolution data
after the parametric trend has been removed (see
`compute_detrended_climatology.py`),
- its detrended "dynamic" climatology, which denotes the current climatology for
the dates to be downscaled (e.g., the 2090-2100 detrended climatology for
downscaling years in the 2090s), and
- its temporal mean over the "training" period for which we have high-resolution
data.
Regarding the target high-resolution dataset used to downscale the
low-resolution dataset, we assume we have:
- its detrended climatology over the "training" period,
- its temporal mean over the "training" period, used for debiasing.
All these terms are further described in Hayhoe et al (2024). In this
implementation, we assume climatologies have hourly granularity.
The final result consists of the sum of three terms, as sketched in Fig. 1 of
Hayhoe et al (2024):
1. The debiased long-term trend of the low-resolution data. This is the
long-term trend of the low-resolution data, plus the difference between the
mean of the high-resolution data and the mean of the low-resolution data over
the training period, for each location and variable.
2. The "dynamically-adjusted" detrended climatological mean of the
high-resolution data. This is constructed as the climatological mean of the
high-resolution data over the training period, plus the difference between
the dynamic climatology of the low-resolution data and the climatology
of the low-resolution data over the training period, for each location and
variable.
3. The quantile-mapped anomaly of the low-resolution data. The low-resolution
sample is detrended, and its probability is computed according to the CDF of
the detrended dynamic low-resolution climatology. Then, the anomalies
corresponding to this probability according to the detrended high-resolution
climatology, and the detrended low-resolution climatology are computed. The
ratio of these two anomalies is retained. The final anomaly is computed as
the product of the anomaly with respect to the detrended dynamic
low-resolution climatology, and the precomputed ratio. See equations 4 and 5
of Hayhoe et al (2024) for more details.
The method relies on quasi-Gaussian assumptions, and is therefore only
applicable to variables with a quasi-Gaussian distribution. Modifications are
required for variables such as precipitation.
Example usage:
```
FORCING_DATASET=<parent_dir>/canesm5_r1i1p2f1_ssp370_bc
INPUT_TREND=<input_trend>
INPUT_TEMPORAL_MEAN=<input_mean>
INPUT_DETRENDED_CLIM=<input_clim>
INPUT_DETRENDED_DYNAMIC_CLIM=<input_dyn_clim>
TARGET_TEMPORAL_MEAN=<target_mean>
TARGET_DETRENDED_CLIM=<target_clim>
INPUT_DATA=${FORCING_DATASET}/hourly_d01_cubic_interpolated_to_d02_with_prates.zarr
TIME_START=2094
TIME_STOP=2097
OUTPUT_PATH=${FORCING_DATASET}/baselines/staresdm_with_prates_from_canesm5_${TIME_START}_${TIME_STOP}.zarr
python swirl_dynamics/projects/probabilistic_diffusion/downscaling/gcm_wrf/analysis/staresdm.py \
--input_data=${INPUT_DATA} \
--input_detrended_clim=${INPUT_DETRENDED_CLIM} \
--input_detrended_dynamic_clim=${INPUT_DETRENDED_DYNAMIC_CLIM} \
--input_temporal_mean=${INPUT_TEMPORAL_MEAN} \
--input_trend=${INPUT_TREND} \
--target_detrended_clim=${TARGET_DETRENDED_CLIM} \
--target_temporal_mean=${TARGET_TEMPORAL_MEAN} \
--time_start=${TIME_START} \
--time_stop=${TIME_STOP} \
--output_path=${OUTPUT_PATH}
```
"""

import typing

from absl import app
from absl import flags
import apache_beam as beam
import numpy as np
import xarray as xr
import xarray_beam as xbeam


INPUT_DATA = flags.DEFINE_string(
'input_data',
None,
help='Zarr path pointing to the input data to be processed.',
)
INPUT_TREND = flags.DEFINE_string(
'input_trend',
None,
help=(
'Zarr path pointing to the input trend, stored in terms of np.polyfit'
' coefficients.'
),
)
INPUT_TEMPORAL_MEAN = flags.DEFINE_string(
'input_temporal_mean',
None,
help='Zarr path pointing to the input temporal mean.',
)
INPUT_DETRENDED_CLIM = flags.DEFINE_string(
'input_detrended_clim',
None,
help='Zarr path pointing to the input detrended climatology.',
)
INPUT_DETRENDED_DYNAMIC_CLIM = flags.DEFINE_string(
'input_detrended_dynamic_clim',
None,
help='Zarr path pointing to the input detrended dynamic climatology.',
)
TARGET_DETRENDED_CLIM = flags.DEFINE_string(
'target_detrended_clim',
None,
help='Zarr path pointing to the target detrended climatology.',
)
TARGET_TEMPORAL_MEAN = flags.DEFINE_string(
'target_temporal_mean',
None,
help='Zarr path pointing to the target temporal mean.',
)
OUTPUT_PATH = flags.DEFINE_string('output_path', None, help='Output Zarr path')
TIME_DIM = flags.DEFINE_string(
'time_dim', 'time', help='Name for the time dimension to slice data on.'
)
TIME_START = flags.DEFINE_string(
'time_start',
None,
help='ISO 8601 timestamp (inclusive) at which to start evaluation',
)
TIME_STOP = flags.DEFINE_string(
'time_stop',
None,
help='ISO 8601 timestamp (inclusive) at which to stop evaluation',
)
RUNNER = flags.DEFINE_string('runner', None, 'beam.runners.Runner')


def _get_climatology_mean(
climatology: xr.Dataset, variables: list[str], **sel_kwargs
) -> xr.Dataset:
"""Returns the climatological mean of the given variables.
The climatology dataset is assumed to have been produced through
the weatherbench2 compute_climatology.py script,
(https://github.com/google-research/weatherbench2/blob/main/scripts/compute_climatology.py)
and statistics `mean`, and `std`. The convention is that the climatological
means do not have a suffix, and standard deviations have a `_std` suffix.
Args:
climatology: The climatology dataset.
variables: The variables to extract from the climatology.
**sel_kwargs: Additional selection criteria for the variables.
Returns:
The climatological mean of the given variables.
"""
climatology_mean = climatology[variables]
return typing.cast(xr.Dataset, climatology_mean.sel(**sel_kwargs).compute())


def _get_climatology_std(
climatology: xr.Dataset, variables: list[str], **sel_kwargs
) -> xr.Dataset:
"""Returns the climatological standard deviation of the given variables.
The climatology dataset is assumed to have been produced through
the weatherbench2 compute_climatology.py script, and statistics
`mean`, and `std`. The convention is that the climatological means do not
have a suffix, and standard deviations have a `_std` suffix.
Args:
climatology: The climatology dataset.
variables: The variables to extract from the climatology.
**sel_kwargs: Additional selection criteria for the variables.
Returns:
The climatological standard deviation of the given variables.
"""
clim_std_dict = {key + '_std': key for key in variables} # pytype: disable=unsupported-operands
climatology_std = climatology[list(clim_std_dict.keys())].rename(
clim_std_dict
)
return typing.cast(xr.Dataset, climatology_std.sel(**sel_kwargs).compute())


def _staresdm_on_chunks(
source: xr.Dataset,
*,
input_clim: xr.Dataset,
input_dynamic_clim: xr.Dataset,
input_trend: xr.Dataset,
input_temporal_mean: xr.Dataset,
target_clim: xr.Dataset,
target_temporal_mean: xr.Dataset,
) -> xr.Dataset:
"""Process an input data chunk with the STAR-ESDM downscaling method.
All input climatologies are assumed to have hourly granularity.
Args:
source: The source data chunk to be processed with the STAR-ESDM method.
input_clim: The detrended climatology of the source low-resolution data.
input_dynamic_clim: The dynamic detrended climatology of the source
low-resolution data.
input_trend: The trend of the source low-resolution data, stored in terms of
np.polyfit coefficients. For each variable, the trend coefficients are
stored in variables with name <variable>_polyfit_coefficients.
input_temporal_mean: The temporal mean of the source low-resolution data.
target_clim: The detrended climatology of the target high-resolution data.
target_temporal_mean: The temporal mean of the target high-resolution data.
Returns:
The downscaled chunks using the STAR-ESDM method as an xarray Dataset.
"""
variables = [str(key) for key in source.keys()]

# Component 1: Debiased long-term trend.
staresdm = typing.cast(
xr.Dataset,
target_temporal_mean[variables] - input_temporal_mean[variables],
)

# Detrend the source data and add trend value to staresdm.
for source_var in variables:
coeff = input_trend[str(source_var) + '_polyfit_coefficients']
source_trend = xr.polyval(source['time'], coeff)
source[str(source_var)] = source[str(source_var)] - source_trend
staresdm[str(source_var)] = staresdm[str(source_var)] + source_trend

staresdm = staresdm.transpose('time', ...)

# Component 2: Dynamically-adjusted high-resolution climatological mean.
sel = dict(
dayofyear=source['time'].dt.dayofyear,
hour=source['time'].dt.hour,
drop=True,
)

# Static input low-resolution climatology.
input_clim_mean = _get_climatology_mean(input_clim, variables, **sel)
input_clim_std = _get_climatology_std(input_clim, variables, **sel)
# Dynamic input low-resolution climatology.
input_dynamic_clim_mean = _get_climatology_mean(
input_dynamic_clim, variables, **sel
)
input_dynamic_clim_std = _get_climatology_std(
input_dynamic_clim, variables, **sel
)
# Target high-resolution climatology.
target_clim_mean = _get_climatology_mean(target_clim, variables, **sel)
target_clim_std = _get_climatology_std(target_clim, variables, **sel)

staresdm = staresdm + target_clim_mean
staresdm = staresdm + (input_dynamic_clim_mean - input_clim_mean)

# Component 3: Quantile-mapped low-resolution anomaly.

# Standardize with respect to the input dynamic climatology.
source_standard = (source - input_dynamic_clim_mean) / input_dynamic_clim_std

# Construct proxy of high-resolution dynamic anomaly.
source_hr_anom = source_standard * target_clim_std
source_lr_anom = source_standard * input_clim_std
source_lr_dyn_anom = source_standard * input_dynamic_clim_std
source_hr_dyn_anom = source_hr_anom * source_lr_dyn_anom / source_lr_anom

# Add debiased anomaly to the computation
staresdm = staresdm + source_hr_dyn_anom

# TODO: Add wet days correction to precipitation variables.
# Drop hour and dayofyear dimensions.
return staresdm.drop_vars(['hour', 'dayofyear'])


def _impose_data_selection(ds: xr.Dataset) -> xr.Dataset:
if TIME_START.value is None or TIME_STOP.value is None:
return ds
selection = {TIME_DIM.value: slice(TIME_START.value, TIME_STOP.value)}
return ds.sel({k: v for k, v in selection.items() if k in ds.dims})


def main(argv: list[str]) -> None:

source_dataset, source_chunks = xbeam.open_zarr(INPUT_DATA.value)
source_dataset = _impose_data_selection(source_dataset)

input_clim = xr.open_zarr(INPUT_DETRENDED_CLIM.value)
input_dynamic_clim = xr.open_zarr(INPUT_DETRENDED_DYNAMIC_CLIM.value)
input_trend = xr.open_zarr(INPUT_TREND.value)
input_temporal_mean = xr.open_zarr(INPUT_TEMPORAL_MEAN.value)
target_clim = xr.open_zarr(TARGET_DETRENDED_CLIM.value)
target_temporal_mean = xr.open_zarr(TARGET_TEMPORAL_MEAN.value)

source_chunks = {k: source_chunks[k] for k in source_chunks}
in_working_chunks = source_chunks.copy()
in_working_chunks['time'] = 1

output_chunks = source_chunks.copy()
output_chunks['time'] = 1
unassigned_coords = {
dim: np.arange(source_dataset.sizes[dim])
for dim in source_dataset.dims
if dim not in source_dataset.coords
}
template = xbeam.make_template(source_dataset).assign_coords(
**unassigned_coords
)

# Static kwargs for _staresdm_on_chunks.
staresdm_kwargs = dict(
input_clim=input_clim,
input_dynamic_clim=input_dynamic_clim,
input_trend=input_trend,
input_temporal_mean=input_temporal_mean,
target_clim=target_clim,
target_temporal_mean=target_temporal_mean,
)

with beam.Pipeline(runner=RUNNER.value, argv=argv) as root:
_ = (
root
| xbeam.DatasetToChunks(source_dataset, in_working_chunks)
| 'STAR-ESDM'
>> beam.MapTuple(
lambda k, v: (
k,
_staresdm_on_chunks(v, **staresdm_kwargs),
)
)
| xbeam.ConsolidateChunks(output_chunks)
| xbeam.ChunksToZarr(
OUTPUT_PATH.value,
template,
output_chunks,
)
)


if __name__ == '__main__':
app.run(main)

0 comments on commit 8688338

Please sign in to comment.