Skip to content

Commit

Permalink
Generalize sst_trends (#99)
Browse files Browse the repository at this point in the history
* sst_trends now uses config file and fixed typo in plot common

* Refactored process_glorys function to be compatible with resampling, adding logging

* Added type hints to process_glorys to clarify potential return values

* Regrid mom to oisst, change plotting and metric variables to reflect new regridder

* Read rename variables from config file so that target grid contains corner points

* Removed do_regrid option from process_oisst

* Simplified get_3d_trends, passed it to process_glorys to avoid simplify it's return values

---------

Co-authored-by: Utheri Wagura <[email protected]>
Co-authored-by: Utheri Wagura <[email protected]>
Co-authored-by: Utheri Wagura <[email protected]>
Co-authored-by: Utheri Wagura <[email protected]>
Co-authored-by: Utheri Wagura <[email protected]>
Co-authored-by: Yi-Cheng Teng - NOAA GFDL <[email protected]>
  • Loading branch information
7 people authored Oct 28, 2024
1 parent 9675e2b commit 87b5c21
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 64 deletions.
9 changes: 9 additions & 0 deletions diagnostics/physics/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,13 @@ levels_step: 2
# Colorbar for sst difference plots
bias_min: -2
bias_max: 2.1
bias_min_trends: -1.5
bias_max_trends: 1.51
bias_step: 0.25

ticks: [-2, -1, 0, 1, 2]


# SST Trends Settings
start_year: "2005"
end_year: "2019"
43 changes: 36 additions & 7 deletions diagnostics/physics/plot_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,13 +176,14 @@ def load_config(config_path: str):
logger.error(f"Error loading configuration from {config_path}: {e}")
raise

def process_oisst(config, target_grid, model_ave):
def process_oisst(config, target_grid, model_ave, start=1993, end = 2020, resamp_freq = None):
"""Open and regrid OISST dataset, return relevant vars from dataset."""
try:
oisst = (
xarray.open_mfdataset([config['oisst'] + f'sst.month.mean.{y}.nc' for y in range(1993, 2020)])
xarray.open_mfdataset([config['oisst'] + f'sst.month.mean.{y}.nc' for y in range(start, end)])
.sst
.sel(lat=slice(config['lat']['south'], config['lat']['north']), lon=slice(config['lon']['west'], config['lon']['east']))
.load()
)
except Exception as e:
logger.error(f"Error processing OISST data: {e}")
Expand All @@ -193,20 +194,33 @@ def process_oisst(config, target_grid, model_ave):

oisst_lonc, oisst_latc = corners(oisst.lon, oisst.lat)
oisst_lonc -= 360

mom_to_oisst = xesmf.Regridder(
target_grid,
{'lat': oisst.lat, 'lon': oisst.lon, 'lat_b': oisst_latc, 'lon_b': oisst_lonc},
method='conservative_normed',
unmapped_to_nan=True
)

oisst_ave = oisst.mean('time').load()
# If a resample frequency is provided, use it to resample the oisst data over time before taking the average
if resamp_freq:
oisst = oisst.resample( time = resamp_freq )

oisst_ave = oisst.mean('time')

mom_rg = mom_to_oisst(model_ave)
logger.info("OISST data processed successfully.")
return mom_rg, oisst_ave, oisst_lonc, oisst_latc

def process_glorys(config, target_grid, var):
""" Open and regrid glorys data, return regridded glorys data """
def process_glorys(config, target_grid, var, sel_time = None, resamp_freq = None, preprocess_regrid = None):
"""
Open and regrid glorys data, return regridded glorys data
If a function is passed to the preprocess_regrid option, it will be called on the
data before it is passed to the regridder but after the regridder
is created and the average is calculated
NOTE: if preprocess_regrid returns numpy array, the return value of glorys_ave will
be a numpy array, not an xarray dataarray as is the default
"""
glorys = xarray.open_dataset( config['glorys'] ).squeeze(drop=True) #.rename({'longitude': 'lon', 'latitude': 'lat'})
if var in glorys:
glorys = glorys[var]
Expand All @@ -225,15 +239,30 @@ def process_glorys(config, target_grid, var):
logger.info("Glorys data is using longitude/latitude")
except:
logger.error("Name of longitude and latitude variables is unknown")
raise Exception("Error: Lat/Latitude, Lon/Longitdue not found in glorys data")
raise Exception("Error: Lat/Latitude, Lon/Longitude not found in glorys data")

# If a time slice is provided use it to select a portion of the glorys data
if sel_time:
glorys = glorys.sel( time = sel_time )

# If a resample frequency is provided, use it to resample the glorys data over time before taking the average
if resamp_freq:
glorys = glorys.resample(time = resamp_freq)

glorys_ave = glorys.mean('time').load()

glorys_to_mom = xesmf.Regridder(glorys_ave, target_grid, method='bilinear', unmapped_to_nan=True)
glorys_rg = glorys_to_mom(glorys_ave)

# If a preprocessing function is provided, call it before doing any regridding
# glorys_ave may not remain a xarray dataset after this step
if preprocess_regrid:
glorys_ave = preprocess_regrid(glorys_ave)

glorys_rg = glorys_to_mom(glorys_ave)
logger.info("Glorys data processed successfully.")
return glorys_rg, glorys_ave, glorys_lonc, glorys_latc


def get_end_of_climatology_period(clima_file):
"""
Determine the time period covered by the last climatology file. This function is needed
Expand Down
149 changes: 92 additions & 57 deletions diagnostics/physics/sst_trends.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""
Compare the model 2005-019 sea surface temperature trends from OISST and GLORYS.
How to use:
python sst_trends.py /archive/acr/fre/NWA/2023_04/NWA12_COBALT_2023_04_kpo4-coastatten-physics/gfdl.ncrc5-intel22-prod
python sst_trends.py -p /archive/acr/fre/NWA/2023_04/NWA12_COBALT_2023_04_kpo4-coastatten-physics/gfdl.ncrc5-intel22-prod -c config.yaml
"""
import cartopy.crs as ccrs
from cartopy.mpl.geoaxes import GeoAxes
Expand All @@ -10,116 +10,147 @@
from mpl_toolkits.axes_grid1 import AxesGrid
import numpy as np
import xarray
import xesmf
import logging

from plot_common import autoextend_colorbar, corners, get_map_norm, annotate_skill, open_var, save_figure
from plot_common import( autoextend_colorbar, corners, get_map_norm,
annotate_skill, open_var, save_figure, load_config,
process_glorys, process_oisst)

PC = ccrs.PlateCarree()
# Configure logging for sst_eval
logger = logging.getLogger(__name__)
logging.basicConfig(filename="sst_trends.log", format='%(asctime)s %(levelname)s:%(name)s: %(message)s',level=logging.INFO)


def get_3d_trends(x, y):
x = np.array(x)
def get_3d_trends(y):
x = np.array( y['time.year'] )
y2 = np.array(y).reshape((len(x), -1))
coefs = np.polyfit(x, y2, 1)
trends = coefs[0, :].reshape(y.shape[1:])
trends = coefs[0, :].reshape(y.shape[1:]) * 10 # -> C / decade

return trends


def plot_sst_trends(pp_root, label):
def plot_sst_trends(pp_root, label, config):
model = (
open_var(pp_root, 'ocean_monthly', 'tos')
.sel(time=slice('2005', '2019'))
.sel(time=slice(config['start_year'], config['end_year']))
.resample(time='1AS')
.mean('time')
.load()
)
model_grid = xarray.open_dataset('../data/geography/ocean_static.nc')
logger.info("MODEL: %s",model)
model_grid = xarray.open_dataset( config['model_grid'])
logger.info("MODEL_GRID: %s",model_grid)

# Verify that xh/yh are set as coordinates, then make sure model coordinates match grid data
model_grid = model_grid.assign_coords( {'xh':model_grid.xh, 'yh':model_grid.yh } )
model = xarray.align(model_grid, model, join='override', exclude='time')[1]
logger.info("Successfully modified coordinates of model grid, and aligned model coordinates to grid coordinates")

model_trend = get_3d_trends(model['time.year'], model) * 10 # -> C / decade
model_trend = get_3d_trends(model)
# Convert to Data Array, since xskillscore expects dataarrays to calculate skill metrics
model_trend = xarray.DataArray(model_trend, dims=['yh', 'xh'], coords={'yh': model.yh, 'xh': model.xh})
logger.info("MODEL_TREND: %s", model_trend)

oisst = (
xarray.open_mfdataset([f'/work/acr/oisstv2/sst.month.mean.{y}.nc' for y in range(2005, 2020)])
.sst
.sel(lat=slice(0, 60), lon=slice(360-100, 360-30))
.resample(time='1AS')
.mean('time')
.load()
)
oisst_trend = get_3d_trends(oisst['time.year'], oisst) * 10 # -> C / decade
target_grid = model_grid[ config['rename_map'].keys() ].rename( config['rename_map'] )

glorys = (
xarray.open_dataset('/work/acr/mom6/diagnostics/glorys/glorys_sfc.nc')
['thetao']
.sel(time=slice('2005', '2019'))
.resample(time='1AS')
.mean('time')
)
glorys_trend = get_3d_trends(glorys['time.year'], glorys) * 10 # -> C / decade
# Process OISST and get trend
mom_rg, oisst, oisst_lonc, oisst_latc = process_oisst(config, target_grid, model_trend, start = int(config['start_year']),
end = int(config['end_year'])+1, resamp_freq = '1AS')
logger.info("OISST: %s", oisst )
oisst_trend = get_3d_trends(oisst)
oisst_trend = xarray.DataArray(oisst_trend, dims=['lat','lon'], coords={'lat':oisst.lat,'lon':oisst.lon} )
logger.info("OISST_TREND: %s",oisst_trend)

oisst_lonc, oisst_latc = corners(oisst.lon, oisst.lat)
oisst_lonc -= 360
oisst_to_mom = xesmf.Regridder({'lat': oisst.lat, 'lon': oisst.lon}, model_grid[['geolon', 'geolat']].rename({'geolon': 'lon', 'geolat': 'lat'}), method='bilinear')
oisst_delta = mom_rg - oisst_trend
logger.info("MOM_RG: %s",mom_rg)
logger.info("OISST_DELTA: %s",oisst_delta)

glorys_lonc, glorys_latc = corners(glorys.lon, glorys.lat)
glorys_to_mom = xesmf.Regridder(glorys, model_grid[['geolon', 'geolat']].rename({'geolon': 'lon', 'geolat': 'lat'}), method='bilinear')
# Process Glorys and get trend
# NOTE: Glorys_ave is glorys_trends because we call get_3d_trends on it.
glorys_rg, glorys_trend, glorys_lonc, glorys_latc = process_glorys(config, target_grid, 'thetao',
sel_time = slice(config['start_year'], config['end_year']),
resamp_freq = '1AS', preprocess_regrid = get_3d_trends)
logger.info("GLORYS_TREND: %s",glorys_trend)

glorys_rg = glorys_to_mom(glorys_trend)
glorys_rg = xarray.DataArray(glorys_rg, dims=['yh', 'xh'], coords={'yh': model.yh, 'xh': model.xh})
glorys_delta = model_trend - glorys_rg
logger.info("GLORYS_RG: %s",glorys_rg)
logger.info("GLORYS_DELTA: %s",glorys_delta)

oisst_rg = oisst_to_mom(oisst_trend)
oisst_rg = xarray.DataArray(oisst_rg, dims=['yh', 'xh'], coords={'yh': model.yh, 'xh': model.xh})
oisst_delta = model_trend - oisst_rg
# Set projection of each grid in the plot
# For now, sst_eval.py will only support a projection for the arctic and a projection for all other domains
if config['projection_grid'] == 'NorthPolarStereo':
p = ccrs.NorthPolarStereo()
else:
p = ccrs.PlateCarree()

fig = plt.figure(figsize=(10, 14))
grid = AxesGrid(fig, 111,
nrows_ncols=(2, 3),
axes_class = (GeoAxes, dict(projection=PC)),
axes_class = (GeoAxes, dict(projection=p)),
axes_pad=0.3,
cbar_location='bottom',
cbar_mode='edge',
cbar_pad=0.2,
cbar_size='15%',
label_mode=''
label_mode='keep'
)
logger.info("Successfully created grid")

cmap, norm = get_map_norm('cet_CET_D1', np.arange(-2, 2.1, .25), no_offset=True)
cmap, norm = get_map_norm('cet_CET_D1', np.arange(config['bias_min'], config['bias_max'], config['bias_step']), no_offset=True)
common = dict(cmap=cmap, norm=norm)

bias_cmap, bias_norm = get_map_norm('RdBu_r', np.arange(-1.5, 1.51, .25), no_offset=True)
bias_cmap, bias_norm = get_map_norm('RdBu_r', np.arange(config['bias_min_trends'], config['bias_max_trends'], config['bias_step']), no_offset=True)
bias_common = dict(cmap=bias_cmap, norm=bias_norm)

p0 = grid[0].pcolormesh(model_grid.geolon_c, model_grid.geolat_c, model_trend, **common)
# Set projection of input data files so that data is correctly tranformed when plotting
# For now, sst_eval.py will only support a projection for the arctic and a projection for all other domains
if config['projection_data'] == 'NorthPolarStereo':
proj = ccrs.NorthPolarStereo()
else:
proj = ccrs.PlateCarree()

# MODEL
p0 = grid[0].pcolormesh(model_grid.geolon_c, model_grid.geolat_c, model_trend, transform = proj, **common)
grid[0].set_title('(a) Model')
cbar0 = autoextend_colorbar(grid.cbar_axes[0], p0)
cbar0.ax.set_xlabel('SST trend (°C / decade)')
cbar0.set_ticks([-2, -1, 0, 1, 2])
cbar0.set_ticklabels([-2, -1, 0, 1, 2])
cbar0.set_ticks( config['ticks'] )
cbar0.set_ticklabels( config['ticks'] )
logger.info("Successfully plotted model data")

p1 = grid[1].pcolormesh(oisst_lonc, oisst_latc, oisst_trend, **common)
# OISST
p1 = grid[1].pcolormesh(oisst_lonc, oisst_latc, oisst_trend, transform = proj, **common)
grid[1].set_title('(b) OISST')
logger.info("Successfully plotted oisst")

grid[2].pcolormesh(model_grid.geolon_c, model_grid.geolat_c, oisst_delta, **bias_common)
# MODEL - OISST
grid[2].pcolormesh(oisst_lonc, oisst_latc, oisst_delta, transform = proj, **bias_common)
grid[2].set_title('(c) Model - OISST')
annotate_skill(model_trend, oisst_rg, grid[2], weights=model_grid.areacello)
# NOTE: Oisst dims are [lat,lon], so dim argument is needed. Must use mom_rg though, since oisst also contains
# an extra time dimension that changes output of xskillscore functions and leads to error when annotating plot
annotate_skill(mom_rg, oisst_trend, grid[2], dim= list(mom_rg.dims), x0=config['text_x'], y0=config['text_y'], xint=config['text_xint'], plot_lat=config['plot_lat'])
logger.info("Successfully plotted difference between model and oisst")

grid[4].pcolormesh(glorys_lonc, glorys_latc, glorys_trend, **common)
# GLORYS
grid[4].pcolormesh(glorys_lonc, glorys_latc, glorys_trend, transform = proj, **common)
grid[4].set_title('(d) GLORYS12')
cbar1 = autoextend_colorbar(grid.cbar_axes[1], p1)
cbar1.ax.set_xlabel('SST trend (°C / decade)')
cbar1.set_ticks([-2, -1, 0, 1, 2])
cbar1.set_ticklabels([-2, -1, 0, 1, 2])
cbar1.set_ticks( config['ticks'] )
cbar1.set_ticklabels( config['ticks'] )
logger.info("Successfully plotted glorys")

p2 = grid[5].pcolormesh(model_grid.geolon_c, model_grid.geolat_c, glorys_delta, **bias_common)
# MODEL - GLORYS
p2 = grid[5].pcolormesh(model_grid.geolon_c, model_grid.geolat_c, glorys_delta, transform = proj, **bias_common)
grid[5].set_title('(e) Model - GLORYS12')
cbar2 = autoextend_colorbar(grid.cbar_axes[2], p2)
cbar2.ax.set_xlabel('SST trend difference (°C / decade)')
annotate_skill(model_trend, glorys_rg, grid[5], weights=model_grid.areacello)
annotate_skill(model_trend, glorys_rg, grid[5], weights=model_grid.areacello, x0=config['text_x'], y0=config['text_y'], xint=config['text_xint'], plot_lat=config['plot_lat'])
logger.info("Successfully plotted difference between glorys and model")

for i, ax in enumerate(grid):
ax.set_xlim(-99, -35)
ax.set_ylim(4, 59)
ax.set_extent([ config['x']['min'], config['x']['max'], config['y']['min'], config['y']['max'] ], crs=proj)
ax.set_xticks([])
ax.set_yticks([])
ax.set_xticklabels([])
Expand All @@ -128,14 +159,18 @@ def plot_sst_trends(pp_root, label):
ax.set_facecolor('#bbbbbb')
for s in ax.spines.values():
s.set_visible(False)
logger.info("Successfully set extent of each axis")

save_figure('sst_trends', label=label)
logger.info("Successfully saved figure")


if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('pp_root', help='Path to postprocessed data (up to but not including /pp/)')
parser.add_argument('-p','--pp_root', help='Path to postprocessed data (up to but not including /pp/)', required = True)
parser.add_argument('-c','--config', help='Path to yaml config file', required = True)
parser.add_argument('-l', '--label', help='Label to add to figure file names', type=str, default='')
args = parser.parse_args()
plot_sst_trends(args.pp_root, args.label)
config = load_config(args.config)
plot_sst_trends(args.pp_root, args.label, config)

0 comments on commit 87b5c21

Please sign in to comment.