Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove duplicate code in notebooks, analysis/ #111

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "056637ab52c5467c8542cd4c64a5f00c",
"model_id": "cfb7c47b46084063a7033a776b365d9b",
"version_major": 2,
"version_minor": 0
},
Expand Down
109 changes: 58 additions & 51 deletions src/gz21_ocean_momentum/analysis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

from gz21_ocean_momentum.analysis.analysis import TimeSeriesForPoint
from gz21_ocean_momentum.data.pangeo_catalog import get_patch, get_whole_data
import gz21_ocean_momentum.lib.data as datalib
from gz21_ocean_momentum.common.bounding_box import BoundingBox

import numpy as np
Expand Down Expand Up @@ -344,7 +344,9 @@ class GlobalPlotter:
continental data + showing a band near coastlines."""

def __init__(self, margin: int = 10, cbar: bool = True, ice: bool = True):
self.mask = self._get_global_u_mask()
# TODO: remove global catalog URL! pass in as arg
self.catalog = intake.open_catalog(CATALOG_URL)
self.mask = self._get_global_u_mask(catalog)
self.margin = margin
self.cbar = cbar
self.ticks = dict(x=None, y=None)
Expand Down Expand Up @@ -373,7 +375,7 @@ def margin(self):
@margin.setter
def margin(self, margin):
self._margin = margin
self.borders = self._get_continent_borders(self.mask, self.margin)
self.borders = _get_continent_borders(self.mask, self.margin)

@property
def x_ticks(self):
Expand All @@ -393,6 +395,7 @@ def y_ticks(self, value):

def plot(
self,
lon_dim: str, lat_dim: str,
u: xr.DataArray = None,
projection_cls=PlateCarree,
lon: float = -100.0,
Expand Down Expand Up @@ -428,12 +431,12 @@ def plot(
projection = projection_cls(lon)
if ax is None:
ax = plt.axes(projection=projection)
mesh_x, mesh_y = np.meshgrid(u["longitude"], u["latitude"])
mesh_x, mesh_y = np.meshgrid(u[lon_dim], u["latitude"])
if u is not None:
extra = self.mask.isel(longitude=slice(0, 10))
extra["longitude"] = extra["longitude"] + 360
mask = xr.concat((self.mask, extra), dim="longitude")
mask = mask.interp({k: u.coords[k] for k in ("longitude", "latitude")})
extra = self.mask.isel(lon_dim=slice(0, 10))
extra[lon_dim] = extra[lon_dim] + 360
mask = xr.concat((self.mask, extra), dim=lon_dim)
mask = mask.interp({k: u.coords[k] for k in (lon_dim, "latitude")})
u = u * mask
im = ax.pcolormesh(
mesh_x,
Expand All @@ -451,11 +454,11 @@ def plot(
ax.coastlines()
# "Gray-out" near continental locations
if self.margin > 0:
extra = self.borders.isel(longitude=slice(0, 10))
extra["longitude"] = extra["longitude"] + 360
borders = xr.concat((self.borders, extra), dim="longitude")
extra = self.borders.isel(lon_dim=slice(0, 10))
extra[lon_dim] = extra[lon_dim] + 360
borders = xr.concat((self.borders, extra), dim=lon_dim)
borders = borders.interp(
{k: u.coords[k] for k in ("longitude", "latitude")}
{k: u.coords[k] for k in (lon_dim, "latitude")}
)
borders_cmap = colors.ListedColormap(
[
Expand All @@ -473,11 +476,11 @@ def plot(
)
# Add locations of ice
if self.ice:
ice = self._get_ice_border()
ice = self._get_ice_border(self.catalog)
ice = xr.where(ice, 1.0, 0.0)
ice = ice.interp({k: u.coords[k] for k in ("longitude", "latitude")})
ice = ice.interp({k: u.coords[k] for k in (lon_dim, "latitude")})
ice = xr.where(ice != 0, 1.0, 0.0)
ice = abs(ice.diff(dim="longitude")) + abs(ice.diff(dim="latitude"))
ice = abs(ice.diff(dim=lon_dim)) + abs(ice.diff(dim="latitude"))
ice = xr.where(ice != 0.0, 1, np.nan)
ice_cmap = colors.ListedColormap(
[
Expand All @@ -500,7 +503,10 @@ def plot(
return ax

@staticmethod
def _get_global_u_mask(factor: int = 4, base_mask: xr.DataArray = None):
def _get_global_u_mask(
catalog,
lon_dim: str, lat_dim: str,
factor: int, base_mask: xr.DataArray = None) -> xr.DataArray:
"""
Return the global mask of the low-resolution surface velocities for
plots. While the coarse-grained velocities might be defined on
Expand All @@ -511,8 +517,8 @@ def _get_global_u_mask(factor: int = 4, base_mask: xr.DataArray = None):

Parameters
----------
factor : int, optional
Coarse-graining factor. The default is 4.
factor : int
Coarse-graining factor.

base_mask: xr.DataArray, optional
# TODO
Expand All @@ -526,58 +532,59 @@ def _get_global_u_mask(factor: int = 4, base_mask: xr.DataArray = None):
if base_mask is not None:
mask = base_mask
else:
_, grid_info = get_whole_data(CATALOG_URL, 0)
grid_info = datalib.retrieve_cm2_6_grid(catalog)
mask = grid_info["wet"]
mask = mask.coarsen(dict(xt_ocean=factor, yt_ocean=factor))
mask_ = mask.max()
mask_ = mask_.where(mask_ > 0.1)
mask_ = mask_.rename(dict(xt_ocean="longitude", yt_ocean="latitude"))
mask_ = mask_.rename(dict(xt_ocean=lon_dim, yt_ocean="latitude"))
return mask_.compute()

@staticmethod
def _get_ice_border():
def _get_ice_border(catalog):
"""Return an xarray.DataArray that indicates the locations of ice
in the oceans."""
temperature, _ = get_patch(CATALOG_URL, 1, None, 0, "surface_temp")
velocities = datalib.retrieve_cm2_6_velocities(catalog, co2_increase=False)
temperature = velocities[["surface_temp"]]
temperature = temperature.isel(time=slice(0, 1))
temperature = temperature.rename(
dict(xt_ocean="longitude", yt_ocean="latitude")
)
# maybe superfluous! we already drop other vars & slice above
temperature = temperature["surface_temp"].isel(time=0)
ice = xr.where(temperature <= 0.0, True, False)
return ice

@staticmethod
def _get_continent_borders(base_mask: xr.DataArray, margin: int):
"""
Returns a boolean xarray DataArray corresponding to a mask of the
continents' coasts, which we do not process.
Hence margin should be set according to the model.

Parameters
----------
mask : xr.DataArray
Mask taking value 1 where coarse velocities are defined and used
as input and nan elsewhere.
margin : int
Margin imposed by the model used, i.e. number of points lost on
one side of a square.
def _get_continent_borders(base_mask: xr.DataArray, margin: int):
"""
Returns a boolean xarray DataArray corresponding to a mask of the
continents' coasts, which we do not process.
Hence margin should be set according to the model.

Returns
-------
mask : xr.DataArray
Boolean DataArray taking value True for continents.
Parameters
----------
mask : xr.DataArray
Mask taking value 1 where coarse velocities are defined and used
as input and nan elsewhere.
margin : int
Margin imposed by the model used, i.e. number of points lost on
one side of a square.

"""
assert margin >= 0, "The margin parameter should be a non-negative" " integer"
assert base_mask.ndim <= 2, "Velocity array should have two" " dims"
# Small trick using the guassian filter function
mask = xr.apply_ufunc(
lambda x: gaussian_filter(x, 1.0, truncate=margin), base_mask
)
mask = np.logical_and(np.isnan(mask), ~np.isnan(base_mask))
mask = mask.where(mask)
return mask.compute()
Returns
-------
mask : xr.DataArray
Boolean DataArray taking value True for continents.

"""
assert margin >= 0, "The margin parameter should be a non-negative integer"
assert base_mask.ndim <= 2, "Velocity array should have two dims"
# Small trick using the guassian filter function
mask = xr.apply_ufunc(
lambda x: gaussian_filter(x, 1.0, truncate=margin), base_mask
)
mask = np.logical_and(np.isnan(mask), ~np.isnan(base_mask))
mask = mask.where(mask)
return mask.compute()

def apply_complete_mask(array, pred, uv_plotter):
mask = uv_plotter.borders
Expand Down
11 changes: 0 additions & 11 deletions src/gz21_ocean_momentum/data/pangeo_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,3 @@ def get_whole_data(url, c02_level):
"""
data, grid = get_patch(url, None, None, c02_level, "usurf", "vsurf")
return data, grid


if __name__ == "__main__":
import os

os.environ[
"GOOGLE_APPLICATION_CREDENTIALS"
] = "~/.config/gcloud/application_default_credentials.json"
CATALOG_URL = "https://raw.githubusercontent.com/pangeo-data/pangeo-datastore\
/master/intake-catalogs/master.yaml"
retrieved_data = get_whole_data(CATALOG_URL, 0)