Skip to content

Commit

Permalink
Improve the data type checking for 2-D arrays passed to the GMT C API (
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman authored Oct 31, 2024
1 parent 8eb2b4f commit e5ecee9
Showing 1 changed file with 33 additions and 32 deletions.
65 changes: 33 additions & 32 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,8 @@

REGISTRATIONS = ["GMT_GRID_NODE_REG", "GMT_GRID_PIXEL_REG"]

DTYPES = {
# Dictionary for mapping numpy dtypes to GMT data types.
DTYPES_NUMERIC = {
np.int8: "GMT_CHAR",
np.int16: "GMT_SHORT",
np.int32: "GMT_INT",
Expand All @@ -93,10 +94,14 @@
np.uint64: "GMT_ULONG",
np.float32: "GMT_FLOAT",
np.float64: "GMT_DOUBLE",
np.timedelta64: "GMT_LONG",
}
DTYPES_TEXT = {
np.str_: "GMT_TEXT",
np.datetime64: "GMT_DATETIME",
np.timedelta64: "GMT_LONG",
}
DTYPES = DTYPES_NUMERIC | DTYPES_TEXT

# Dictionary for storing the values of GMT constants.
GMT_CONSTANTS = {}

Expand Down Expand Up @@ -879,63 +884,59 @@ def _parse_constant(
integer_value = sum(self[part] for part in parts)
return integer_value

def _check_dtype_and_dim(self, array, ndim):
def _check_dtype_and_dim(self, array: np.ndarray, ndim: int) -> int:
"""
Check that a numpy array has the given number of dimensions and is a valid data
type.
Parameters
----------
array : numpy.ndarray
array
The array to be tested.
ndim : int
ndim
The desired number of array dimensions.
Returns
-------
gmt_type : int
gmt_type
The GMT constant value representing this data type.
Raises
------
GMTInvalidInput
If the array has the wrong number of dimensions or
is an unsupported data type.
If the array has the wrong number of dimensions or is an unsupported data
type.
Examples
--------
>>> import numpy as np
>>> data = np.array([1, 2, 3], dtype="float64")
>>> with Session() as ses:
... gmttype = ses._check_dtype_and_dim(data, ndim=1)
... gmttype == ses["GMT_DOUBLE"]
>>> with Session() as lib:
... gmttype = lib._check_dtype_and_dim(data, ndim=1)
... gmttype == lib["GMT_DOUBLE"]
True
>>> data = np.ones((5, 2), dtype="float32")
>>> with Session() as ses:
... gmttype = ses._check_dtype_and_dim(data, ndim=2)
... gmttype == ses["GMT_FLOAT"]
>>> with Session() as lib:
... gmttype = lib._check_dtype_and_dim(data, ndim=2)
... gmttype == lib["GMT_FLOAT"]
True
"""
# Check that the array has the given number of dimensions
# Check that the array has the given number of dimensions.
if array.ndim != ndim:
raise GMTInvalidInput(
f"Expected a numpy {ndim}-D array, got {array.ndim}-D."
)
msg = f"Expected a numpy {ndim}-D array, got {array.ndim}-D."
raise GMTInvalidInput(msg)

# Check that the array has a valid/known data type
if array.dtype.type not in DTYPES:
try:
if array.dtype.type is np.object_:
# Try to convert unknown object type to np.datetime64
array = array_to_datetime(array)
else:
raise ValueError
except ValueError as e:
raise GMTInvalidInput(
f"Unsupported numpy data type '{array.dtype.type}'."
) from e
return self[DTYPES[array.dtype.type]]
# For 1-D arrays, try to convert unknown object type to np.datetime64.
if ndim == 1 and array.dtype.type is np.object_:
with contextlib.suppress(ValueError):
array = array_to_datetime(array)

# 1-D arrays can be numeric or text, 2-D arrays can only be numeric.
valid_dtypes = DTYPES if ndim == 1 else DTYPES_NUMERIC
if (dtype := array.dtype.type) not in valid_dtypes:
msg = f"Unsupported numpy data type '{dtype}'."
raise GMTInvalidInput(msg)
return self[DTYPES[dtype]]

def put_vector(self, dataset: ctp.c_void_p, column: int, vector: np.ndarray):
r"""
Expand Down

0 comments on commit e5ecee9

Please sign in to comment.