Skip to content

Commit

Permalink
Split Gulf Steam plot from ssh_eval and rewrite ssh_eval to use confi…
Browse files Browse the repository at this point in the history
…g file (#101)

* Separated gulfstream from ssh, added comments to plotting section of ssh_eval.py

* Modified ssh_eval.py to use config file, added option to pass projection to add_ticks in plot_common

* Update docstrings at top of script

* Changed projection var from proj to p for consistency with other scripts, Added noted about lat/lon consistency

* Made names of ssh_eval plots consistent with other scripts, mv gulf_stream plot ot NWA12

* fixed small bugs

* Changed output dir in gulfstream plot to figures dir, updated note about model_grid coordinates

---------

Co-authored-by: Utheri Wagura <[email protected]>
Co-authored-by: Utheri Wagura <[email protected]>
Co-authored-by: Utheri Wagura <[email protected]>
Co-authored-by: Utheri Wagura <[email protected]>
Co-authored-by: yuchengt900 <[email protected]>
Co-authored-by: Yi-Cheng Teng - NOAA GFDL <[email protected]>
  • Loading branch information
7 people authored Oct 28, 2024
1 parent 87b5c21 commit c313f29
Show file tree
Hide file tree
Showing 4 changed files with 243 additions and 146 deletions.
157 changes: 157 additions & 0 deletions diagnostics/physics/NWA12/plot_gulf_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
"""
Plot of the Gulf Stream position and index,
Uses whatever model data can be found within the directory pp_root,
and does not try to match the model and observed time periods.
How to use:
python plot_gulf_stream.py -p /archive/acr/fre/NWA/2023_04/NWA12_COBALT_2023_04_kpo4-coastatten-physics/gfdl.ncrc5-intel22-prod
"""
import xarray
import xesmf
import pandas as pd
import numpy as np
import cartopy.feature as feature
import cartopy.crs as ccrs
from cartopy.mpl.geoaxes import GeoAxes
import matplotlib.gridspec as gridspec
from matplotlib.lines import Line2D
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import AxesGrid

# Need to append physics dir to path to access plot common
import sys
sys.path.append("..")
from plot_common import open_var, add_ticks, save_figure

def compute_gs(ssh, data_grid=None):
lons = np.arange(360-72, 360-51.9, 1)
lats = np.arange(36, 42, 0.1)
target_grid = {'lat': lats, 'lon': lons}

if data_grid is None:
data_grid = {'lat': ssh.lat, 'lon': ssh.lon}

ssh_to_grid = xesmf.Regridder(
data_grid,
target_grid,
method='bilinear'
)

# Interpolate the SSH data onto the index grid.
regridded = ssh_to_grid(ssh)

# Find anomalies relative to the calendar month mean SSH over the full model run.
anom = regridded.groupby('time.month') - regridded.groupby('time.month').mean('time')

# For each longitude point, the Gulf Stream is located at the latitude with the maximum SSH anomaly variance.
stdev = anom.std('time')
amax = stdev.argmax('lat').compute()
gs_points = stdev.lat.isel(lat=amax).compute()

# The index is the mean latitude of the Gulf Stream, divided by the standard deviation of the mean latitude of the Gulf Stream.
index = ((anom.isel(lat=amax).mean('lon')) / anom.isel(lat=amax).mean('lon').std('time')).compute()

# Move times to the beginning of the month to match observations.
monthly_index = index.to_pandas().resample('1MS').first()
return monthly_index, gs_points

def plot_gulf_stream(pp_root, label):

# Load Natural Earth Shapefiles
_LAND_50M = feature.NaturalEarthFeature(
'physical', 'land', '50m',
edgecolor='face',
facecolor='#999999'
)

# Get model grid
model_grid = xarray.open_dataset( '../../data/geography/ocean_static.nc' )

# Get model thetao data TODO: maki this comment better
model_thetao = open_var(pp_root, 'ocean_monthly_z', 'thetao')

if '01_l' in model_thetao.coords:
model_thetao = model_thetao.rename({'01_l': 'z_l'})

model_t200 = model_thetao.interp(z_l=200).mean('time')

# Ideally would use SSH, but some diag_tables only saved zos
try:
model_ssh = open_var(pp_root, 'ocean_monthly', 'ssh')
except:
print('Using zos')
model_ssh = open_var(pp_root, 'ocean_monthly', 'zos')

model_ssh_index, model_ssh_points = compute_gs(
model_ssh,
data_grid=model_grid[['geolon', 'geolat']].rename({'geolon': 'lon', 'geolat': 'lat'})
)

# Get Glorys data
glorys_t200 = xarray.open_dataarray('../../data/diagnostics/glorys_T200.nc')

# Get satellite points
#satellite_ssh_index, satellite_ssh_points = compute_gs(satellite['adt'])
#satellite_ssh_points.to_netcdf('../data/obs/satellite_ssh_points.nc')
#satellite_ssh_index.to_pickle('../data/obs/satellite_ssh_index.pkl')
#read pre-calculate satellite_ssh_index and points
satellite_ssh_points = xarray.open_dataset('../../data/obs/satellite_ssh_points.nc')
satellite_ssh_index = pd.read_pickle('../../data/obs/satellite_ssh_index.pkl')
satellite_rolled = satellite_ssh_index.rolling(25, center=True, min_periods=25).mean().dropna()

#satellite = xarray.open_mfdataset([f'/net2/acr/altimetry/SEALEVEL_GLO_PHY_L4_MY_008_047/adt_{y}_{m:02d}.nc' for y in range(1993, 2020) for m in range(1, 13)])
#satellite = satellite.rename({'longitude': 'lon', 'latitude': 'lat'})
#satellite = satellite.resample(time='1MS').mean('time')

# Get rolling averages and correlations
model_rolled = model_ssh_index.rolling(25, center=True, min_periods=25).mean().dropna()
corr = pd.concat((model_ssh_index, satellite_ssh_index), axis=1).corr().iloc[0, 1]
corr_rolled = pd.concat((model_rolled, satellite_rolled), axis=1).corr().iloc[0, 1]

# Plot of Gulf Stream position and index based on SSH,
# plus position based on T200.
fig = plt.figure(figsize=(10, 6), tight_layout=True)
gs = gridspec.GridSpec(2, 2, hspace=.25)

# Set projection of input data files so that data is correctly tranformed when plotting
proj = ccrs.PlateCarree()

ax = fig.add_subplot(gs[0, 0], projection = proj)
ax.add_feature(_LAND_50M)
ax.contour(model_grid.geolon, model_grid.geolat, model_t200, levels=[15], colors='r')
ax.contour(glorys_t200.longitude, glorys_t200.latitude, glorys_t200, levels=[15], colors='k')
add_ticks(ax, xlabelinterval=5)
ax.set_extent([-82, -50, 25, 41])
ax.set_title('(a) Gulf Stream position based on T200')
custom_lines = [Line2D([0], [0], color=c, lw=2) for c in ['r', 'k']]
ax.legend(custom_lines, ['Model', 'GLORYS12'], loc='lower right', frameon=False)

ax = fig.add_subplot(gs[0, 1], projection = proj)
ax.add_feature(_LAND_50M)
ax.plot(model_ssh_points.lon-360, model_ssh_points, c='r')
ax.plot(satellite_ssh_points.lon-360, satellite_ssh_points['__xarray_dataarray_variable__'], c='k')
add_ticks(ax, xlabelinterval=5)
ax.set_extent([-82, -50, 25, 41])
ax.set_title('(b) Gulf Stream position based on SSH variance')
ax.legend(custom_lines, ['Model', 'Altimetry'], loc='lower right', frameon=False)

ax = fig.add_subplot(gs[1, :])
model_ssh_index.plot(ax=ax, c='#ffbbbb', label='Model')
satellite_ssh_index.plot(ax=ax, c='#bbbbbb', label=f'Altimetry (r={corr:2.2f})')
model_rolled.plot(ax=ax, c='r', label='Model rolling mean')
satellite_rolled.plot(ax=ax, c='k', label=f'Altimetry rolling mean (r={corr_rolled:2.2f})')
ax.set_title('(c) Gulf Stream index based on SSH variance')
ax.set_xlabel('')
ax.set_ylim(-3, 3)
ax.set_ylabel('Index (positive north)')
ax.legend(ncol=4, loc='lower right', frameon=False, fontsize=8)

# default to saving figures in current dir instead of dedicated figures dir
save_figure('gulfstream_eval', label=label, pdf=True, output_dir='../figures/')

if __name__ == '__main__':
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('-p','--pp_root', help='Path to postprocessed data (up to but not including /pp/)', required = True)
parser.add_argument('-l', '--label', help='Label to add to figure file names', type=str, default='')
args = parser.parse_args()
plot_gulf_stream(args.pp_root, args.label)
8 changes: 6 additions & 2 deletions diagnostics/physics/config.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
figures_dir: 'figures/'
glorys: '/work/acr/mom6/diagnostics/glorys/glorys_sfc.nc'
glorys_zos: '/work/acr/glorys/GLOBAL_MULTIYEAR_PHY_001_030/monthly/glorys_monthly_z_fine_*.nc'
model_grid: '../data/geography/ocean_static.nc'

# Variables to rename
Expand Down Expand Up @@ -68,10 +69,13 @@ bias_max: 2.1
bias_min_trends: -1.5
bias_max_trends: 1.51
bias_step: 0.25

ticks: [-2, -1, 0, 1, 2]


# SST Trends Settings
start_year: "2005"
end_year: "2019"

# Colorbar for ssh plots
ssh_levels_min: -1.1
ssh_levels_max: .8
ssh_levels_step: .1
9 changes: 5 additions & 4 deletions diagnostics/physics/plot_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def get_map_norm(cmap, levels, no_offset=True):
norm = BoundaryNorm(levels, ncolors=nlev, clip=False)
return cmap, norm

def annotate_skill(model, obs, ax, dim=['yh', 'xh'], x0=-98.5, y0=54, yint=4, xint=4, weights=None, cols=1, proj = ccrs.PlateCarree(), plot_lat=False,**kwargs):
def annotate_skill(model, obs, ax, dim=['yh', 'xh'], x0=-98.5, y0=54, yint=4, xint=4, weights=None, cols=1, proj = ccrs.PlateCarree(), plot_lat=False, **kwargs):
"""
Annotate an axis with model vs obs skill metrics
"""
Expand All @@ -65,6 +65,7 @@ def annotate_skill(model, obs, ax, dim=['yh', 'xh'], x0=-98.5, y0=54, yint=4, xi
medae = xskillscore.median_absolute_error(model, obs, dim=dim, skipna=True)

ax.text(x0, y0, f'Bias: {float(bias):2.2f}', transform=proj, **kwargs)

# Set plot_lat=True in order to plot skill along a line of latitude. Otherwise, plot along longitude
if plot_lat:
ax.text(x0-xint, y0, f'RMSE: {float(rmse):2.2f}', transform=proj, **kwargs)
Expand Down Expand Up @@ -113,20 +114,20 @@ def autoextend_colorbar(ax, plot, plot_array=None, **kwargs):
extend = 'neither'
return ax.colorbar(plot, extend=extend, **kwargs)

def add_ticks(ax, xticks=np.arange(-100, -31, 1), yticks=np.arange(2, 61, 1), xlabelinterval=2, ylabelinterval=2, fontsize=10, **kwargs):
def add_ticks(ax, xticks=np.arange(-100, -31, 1), yticks=np.arange(2, 61, 1), xlabelinterval=2, ylabelinterval=2, fontsize=10, projection = ccrs.PlateCarree(), **kwargs):
"""
Add lat and lon ticks and labels to a plot axis.
By default, tick at 1 degree intervals for x and y, and label every other tick.
Additional kwargs are passed to LongitudeFormatter and LatitudeFormatter.
"""
ax.yaxis.tick_right()
ax.set_xticks(xticks, crs=ccrs.PlateCarree())
ax.set_xticks(xticks, crs = projection)
if xlabelinterval == 0:
plt.setp(ax.get_xticklabels(), visible=False)
else:
plt.setp([l for i, l in enumerate(ax.get_xticklabels()) if i % xlabelinterval != 0], visible=False, fontsize=fontsize)
plt.setp([l for i, l in enumerate(ax.get_xticklabels()) if i % xlabelinterval == 0], fontsize=fontsize)
ax.set_yticks(yticks, crs=ccrs.PlateCarree())
ax.set_yticks(yticks, crs = projection)
if ylabelinterval == 0:
plt.setp(ax.get_yticklabels(), visible=False)
else:
Expand Down
Loading

0 comments on commit c313f29

Please sign in to comment.