Skip to content

Commit

Permalink
Add tests for string arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
seisman committed Nov 3, 2024
1 parent bf8c9a5 commit e4807e2
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
5 changes: 4 additions & 1 deletion pygmt/clib/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,9 @@ def _to_ndarray(array: Any) -> np.ndarray:
"""
# A dictionary mapping unsupported dtypes to the expected numpy dtype.
dtypes: dict[str, type] = {
# "string" for "string[python]", "string[pyarrow]", "string[pyarrow_numpy]", and
# pa.string()
"string": np.str_,
"date32[day][pyarrow]": np.datetime64,
"date64[ms][pyarrow]": np.datetime64,
}
Expand All @@ -184,7 +187,7 @@ def _to_ndarray(array: Any) -> np.ndarray:
if hasattr(array, "isna") and array.isna().any():
array = array.astype(np.float64)

vec_dtype = str(getattr(array, "dtype", ""))
vec_dtype = str(getattr(array, "dtype", getattr(array, "type", "")))
array = np.ascontiguousarray(array, dtype=dtypes.get(vec_dtype))
return array

Expand Down
42 changes: 42 additions & 0 deletions pygmt/tests/test_clib_to_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,17 @@ def test_to_ndarray_numpy_ndarray_numpy_numeric(dtype):
npt.assert_array_equal(result, array)


@pytest.mark.parametrize("dtype", [None, np.str_])
def test_to_ndarray_numpy_ndarray_numpy_string(dtype):
"""
Test the _to_ndarray function with 1-D NumPy arrays of strings.
"""
array = np.array(["a", "b", "c"], dtype=dtype)
result = _to_ndarray(array)
_check_result(result)
npt.assert_array_equal(result, array)


@pytest.mark.parametrize(
"dtype",
[
Expand Down Expand Up @@ -146,6 +157,26 @@ def test_to_ndarray_pandas_series_numeric_with_na(dtype):
npt.assert_array_equal(result, np.array([1, np.nan, 3], dtype=np.float64))


@pytest.mark.parametrize(
"dtype",
[
# None,
# np.str_,
"string[python]",
pytest.param("string[pyarrow]", marks=skip_if_no(package="pyarrow")),
pytest.param("string[pyarrow_numpy]", marks=skip_if_no(package="pyarrow")),
],
)
def test_to_ndarray_pandas_series_string(dtype):
"""
Test the _to_ndarray function with pandas Series with string dtype.
"""
series = pd.Series(["a", "bcd", "12345"], dtype=dtype)
result = _to_ndarray(series)
_check_result(result)
npt.assert_array_equal(result, series)


@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
@pytest.mark.parametrize(
"dtype",
Expand Down Expand Up @@ -184,3 +215,14 @@ def test_to_ndarray_pyarrow_array_float16():
result = _to_ndarray(array)
_check_result(result)
npt.assert_array_equal(result, array)


@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
def test_to_ndarray_pyarrow_array_string():
"""
Test the _to_ndarray function with pyarrow string array.
"""
array = pa.array(["a", "bcd", "12345"], type=pa.string())
result = _to_ndarray(array)
_check_result(result)
npt.assert_array_equal(result, array)

0 comments on commit e4807e2

Please sign in to comment.