Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Typing #311

Closed
wants to merge 17 commits into from
Closed
2 changes: 1 addition & 1 deletion .prettierrc.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
tabWidth = 2
tabWidth = 4
semi = false
singleQuote = true
93 changes: 51 additions & 42 deletions xesmf/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,18 @@

import os
import warnings
from typing import Optional

try:
import esmpy as ESMF
except ImportError:
import ESMF

import numpy as np
import numpy.lib.recfunctions as nprec

huard marked this conversation as resolved.
Show resolved Hide resolved

def warn_f_contiguous(a):
def warn_f_contiguous(a: np.ndarray) -> None:
"""
Give a warning if input array if not Fortran-ordered.

Expand All @@ -37,11 +39,11 @@ def warn_f_contiguous(a):
----------
a : numpy array
"""
if not a.flags['F_CONTIGUOUS']:
warnings.warn('Input array is not F_CONTIGUOUS. ' 'Will affect performance.')
if not a.flags["F_CONTIGUOUS"]:
warnings.warn("Input array is not F_CONTIGUOUS. " "Will affect performance.")


def warn_lat_range(lat):
def warn_lat_range(lat: np.ndarray) -> None:
"""
Give a warning if latitude is outside of [-90, 90]

Expand All @@ -53,12 +55,18 @@ def warn_lat_range(lat):
lat : numpy array
"""
if (lat.max() > 90.0) or (lat.min() < -90.0):
warnings.warn('Latitude is outside of [-90, 90]')
warnings.warn("Latitude is outside of [-90, 90]")


class Grid(ESMF.Grid):
@classmethod
def from_xarray(cls, lon, lat, periodic=False, mask=None):
def from_xarray(
cls,
lon: np.ndarray,
lat: np.ndarray,
periodic: bool = False,
mask: Optional[np.ndarray] = None,
):
"""
Create an ESMF.Grid object, for constructing ESMF.Field and ESMF.Regrid.

Expand Down Expand Up @@ -97,8 +105,8 @@ def from_xarray(cls, lon, lat, periodic=False, mask=None):

# ESMF.Grid can actually take 3D array (lon, lat, radius),
# but regridding only works for 2D array
assert lon.ndim == 2, 'Input grid must be 2D array'
assert lon.shape == lat.shape, 'lon and lat must have same shape'
assert lon.ndim == 2, "Input grid must be 2D array"
assert lon.shape == lat.shape, "lon and lat must have same shape"

staggerloc = ESMF.StaggerLoc.CENTER # actually just integer 0

Expand Down Expand Up @@ -136,8 +144,8 @@ def from_xarray(cls, lon, lat, periodic=False, mask=None):
grid_mask = mask.astype(np.int32)
if not (grid_mask.shape == lon.shape):
raise ValueError(
'mask must have the same shape as the latitude/longitude'
'coordinates, got: mask.shape = %s, lon.shape = %s' % (mask.shape, lon.shape)
"mask must have the same shape as the latitude/longitude"
"coordinates, got: mask.shape = %s, lon.shape = %s" % (mask.shape, lon.shape)
)
grid.add_item(ESMF.GridItem.MASK, staggerloc=ESMF.StaggerLoc.CENTER, from_file=False)
grid.mask[0][:] = grid_mask
Expand All @@ -151,33 +159,33 @@ def get_shape(self, loc=ESMF.StaggerLoc.CENTER):

class LocStream(ESMF.LocStream):
@classmethod
def from_xarray(cls, lon, lat):
def from_xarray(cls, lon: np.ndarray, lat: np.ndarray) -> ESMF.LocStream:
"""
Create an ESMF.LocStream object, for contrusting ESMF.Field and ESMF.Regrid

Parameters
----------
lon, lat : 1D numpy array
Longitute/Latitude of cell centers.
Longitute/Latitude of cell centers.

Returns
-------
locstream : ESMF.LocStream object
"""

if len(lon.shape) > 1:
raise ValueError('lon can only be 1d')
raise ValueError("lon can only be 1d")
if len(lat.shape) > 1:
raise ValueError('lat can only be 1d')
raise ValueError("lat can only be 1d")

assert lon.shape == lat.shape

location_count = len(lon)

locstream = cls(location_count, coord_sys=ESMF.CoordSys.SPH_DEG)

locstream['ESMF:Lon'] = lon.astype(np.dtype('f8'))
locstream['ESMF:Lat'] = lat.astype(np.dtype('f8'))
locstream["ESMF:Lon"] = lon.astype(np.dtype("f8"))
locstream["ESMF:Lat"] = lat.astype(np.dtype("f8"))

return locstream

Expand Down Expand Up @@ -212,12 +220,12 @@ def add_corner(grid, lon_b, lat_b):

warn_lat_range(lat_b)

assert lon_b.ndim == 2, 'Input grid must be 2D array'
assert lon_b.shape == lat_b.shape, 'lon_b and lat_b must have same shape'
assert np.array_equal(lon_b.shape, grid.max_index + 1), 'lon_b should be size (Nx+1, Ny+1)'
assert lon_b.ndim == 2, "Input grid must be 2D array"
assert lon_b.shape == lat_b.shape, "lon_b and lat_b must have same shape"
assert np.array_equal(lon_b.shape, grid.max_index + 1), "lon_b should be size (Nx+1, Ny+1)"
assert (grid.num_peri_dims == 0) and (
grid.periodic_dim is None
), 'Cannot add corner for periodic grid'
), "Cannot add corner for periodic grid"

grid.add_coords(staggerloc=staggerloc)

Expand All @@ -230,7 +238,7 @@ def add_corner(grid, lon_b, lat_b):

class Mesh(ESMF.Mesh):
@classmethod
def from_polygons(cls, polys, element_coords='centroid'):
def from_polygons(cls, polys, element_coords="centroid"):
"""
Create an ESMF.Mesh object from a list of polygons.

Expand All @@ -254,13 +262,13 @@ def from_polygons(cls, polys, element_coords='centroid'):
node_num = sum(len(e.exterior.coords) - 1 for e in polys)
elem_num = len(polys)
# Pre alloc arrays. Special structure for coords makes the code faster.
crd_dt = np.dtype([('x', np.float32), ('y', np.float32)])
crd_dt = np.dtype([("x", np.float32), ("y", np.float32)])
node_coords = np.empty(node_num, dtype=crd_dt)
node_coords[:] = (np.nan, np.nan) # Fill with impossible values
element_types = np.empty(elem_num, dtype=np.uint32)
element_conn = np.empty(node_num, dtype=np.uint32)
# Flag for centroid calculation
calc_centroid = isinstance(element_coords, str) and element_coords == 'centroid'
calc_centroid = isinstance(element_coords, str) and element_coords == "centroid"
if calc_centroid:
element_coords = np.empty(elem_num, dtype=crd_dt)
inode = 0
Expand Down Expand Up @@ -303,7 +311,7 @@ def from_polygons(cls, polys, element_coords='centroid'):
)
except ValueError as err:
raise ValueError(
'ESMF failed to create the Mesh, this usually happen when some polygons are invalid (test with `poly.is_valid`)'
"ESMF failed to create the Mesh, this usually happen when some polygons are invalid (test with `poly.is_valid`)"
) from err
return mesh

Expand Down Expand Up @@ -388,45 +396,45 @@ def esmf_regrid_build(

# use shorter, clearer names for options in ESMF.RegridMethod
method_dict = {
'bilinear': ESMF.RegridMethod.BILINEAR,
'conservative': ESMF.RegridMethod.CONSERVE,
'conservative_normed': ESMF.RegridMethod.CONSERVE,
'patch': ESMF.RegridMethod.PATCH,
'nearest_s2d': ESMF.RegridMethod.NEAREST_STOD,
'nearest_d2s': ESMF.RegridMethod.NEAREST_DTOS,
"bilinear": ESMF.RegridMethod.BILINEAR,
"conservative": ESMF.RegridMethod.CONSERVE,
"conservative_normed": ESMF.RegridMethod.CONSERVE,
"patch": ESMF.RegridMethod.PATCH,
"nearest_s2d": ESMF.RegridMethod.NEAREST_STOD,
"nearest_d2s": ESMF.RegridMethod.NEAREST_DTOS,
}
try:
esmf_regrid_method = method_dict[method]
except Exception:
raise ValueError('method should be chosen from ' '{}'.format(list(method_dict.keys())))
raise ValueError("method should be chosen from " "{}".format(list(method_dict.keys())))

# use shorter, clearer names for options in ESMF.ExtrapMethod
extrap_dict = {
'inverse_dist': ESMF.ExtrapMethod.NEAREST_IDAVG,
'nearest_s2d': ESMF.ExtrapMethod.NEAREST_STOD,
"inverse_dist": ESMF.ExtrapMethod.NEAREST_IDAVG,
"nearest_s2d": ESMF.ExtrapMethod.NEAREST_STOD,
None: None,
}
try:
esmf_extrap_method = extrap_dict[extrap_method]
except KeyError:
raise KeyError(
'`extrap_method` should be chosen from ' '{}'.format(list(extrap_dict.keys()))
"`extrap_method` should be chosen from " "{}".format(list(extrap_dict.keys()))
)

# until ESMPy updates ESMP_FieldRegridStoreFile, extrapolation is not possible
# if files are written on disk
if (extrap_method is not None) & (filename is not None):
raise ValueError('`extrap_method` cannot be used along with `filename`.')
raise ValueError("`extrap_method` cannot be used along with `filename`.")

# conservative regridding needs cell corner information
if method in ['conservative', 'conservative_normed']:
if method in ["conservative", "conservative_normed"]:
if not isinstance(sourcegrid, ESMF.Mesh) and not sourcegrid.has_corners:
raise ValueError(
'source grid has no corner information. ' 'cannot use conservative regridding.'
"source grid has no corner information. " "cannot use conservative regridding."
)
if not isinstance(destgrid, ESMF.Mesh) and not destgrid.has_corners:
raise ValueError(
'destination grid has no corner information. ' 'cannot use conservative regridding.'
"destination grid has no corner information. " "cannot use conservative regridding."
)

# ESMF.Regrid requires Field (Grid+data) as input, not just Grid.
Expand Down Expand Up @@ -454,11 +462,11 @@ def esmf_regrid_build(
if filename is not None:
assert not os.path.exists(
filename
), 'Weight file already exists! Please remove it or use a new name.'
), "Weight file already exists! Please remove it or use a new name."

# re-normalize conservative regridding results
# https://github.com/JiaweiZhuang/xESMF/issues/17
if method == 'conservative_normed':
if method == "conservative_normed":
norm_type = ESMF.NormType.FRACAREA
else:
norm_type = ESMF.NormType.DSTAREA
Expand Down Expand Up @@ -565,14 +573,15 @@ def esmf_regrid_finalize(regrid):

def esmf_locstream(lon, lat):
warnings.warn(
'`esmf_locstream` is being deprecated in favor of `LocStream.from_xarray`',
"`esmf_locstream` is being deprecated in favor of `LocStream.from_xarray`",
DeprecationWarning,
)
return LocStream.from_xarray(lon, lat)


def esmf_grid(lon, lat, periodic=False, mask=None):
warnings.warn(
'`esmf_grid` is being deprecated in favor of `Grid.from_xarray`', DeprecationWarning
"`esmf_grid` is being deprecated in favor of `Grid.from_xarray`",
DeprecationWarning,
)
return Grid.from_xarray(lon, lat)
34 changes: 20 additions & 14 deletions xesmf/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,43 +2,49 @@
Standard test data for regridding benchmark.
"""

from typing import Any

import numpy as np
import numpy.typing as npt
import xarray


def wave_smooth(lon, lat):
r"""
def wave_smooth( # type: ignore
lon: npt.NDArray[np.floating[Any]] | xarray.DataArray,
lat: npt.NDArray[np.floating[Any]] | xarray.DataArray,
) -> npt.NDArray[np.floating[Any]] | xarray.DataArray:
"""
Spherical harmonic with low frequency.

Parameters
----------
lon, lat : 2D numpy array or xarray DataArray
Longitute/Latitude of cell centers
Longitude/Latitude of cell centers

Returns
-------
f : 2D numpy array or xarray DataArray depending on input
2D wave field
f : 2D numpy array or xarray DataArray depending on input2D wave field

Notes
-------
Equation from [1]_ [2]_:

.. math:: Y_2^2 = 2 + \cos^2(\\theta) \cos(2 \phi)
.. math:: Y_2^2 = 2 + cos^2(lat) * cos(2 * lon)

References
----------
.. [1] Jones, P. W. (1999). First-and second-order conservative remapping
schemes for grids in spherical coordinates. Monthly Weather Review,
127(9), 2204-2210.
schemes for grids in spherical coordinates. Monthly Weather Review,
127(9), 2204-2210.

.. [2] Ullrich, P. A., Lauritzen, P. H., & Jablonowski, C. (2009).
Geometrically exact conservative remapping (GECoRe): regular
latitudelongitude and cubed-sphere grids. Monthly Weather Review,
137(6), 1721-1741.
Geometrically exact conservative remapping (GECoRe): regular
latitude-longitude and cubed-sphere grids. Monthly Weather Review,
137(6), 1721-1741.
"""
# degree to radius, make a copy
lat = lat / 180.0 * np.pi
lon = lon / 180.0 * np.pi
lat *= np.pi / 180.0
lon *= np.pi / 180.0

f = 2 + np.cos(lat) ** 2 * np.cos(2 * lon)
f = 2 + pow(np.cos(lat), 2) * np.cos(2 * lon)
return f
Loading
Loading