Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clib.conversion._to_numpy: Add tests for pandas.Series with pandas string dtype #3607

Merged
merged 8 commits into from
Nov 15, 2024
8 changes: 7 additions & 1 deletion pygmt/clib/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Functions to convert data types into ctypes friendly formats.
"""

import contextlib
import ctypes as ctp
import warnings
from collections.abc import Sequence
Expand Down Expand Up @@ -160,7 +161,7 @@ def _to_numpy(data: Any) -> np.ndarray:
dtypes: dict[str, type | str] = {
# For string dtypes.
"large_string": np.str_, # pa.large_string and pa.large_utf8
"string": np.str_, # pa.string and pa.utf8
"string": np.str_, # pa.string, pa.utf8, pd.StringDtype
"string_view": np.str_, # pa.string_view
# For datetime dtypes.
"date32[day][pyarrow]": "datetime64[D]",
Expand All @@ -180,6 +181,11 @@ def _to_numpy(data: Any) -> np.ndarray:
else:
vec_dtype = str(getattr(data, "dtype", getattr(data, "type", "")))
array = np.ascontiguousarray(data, dtype=dtypes.get(vec_dtype))

# Check if a np.object_ array can be converted to np.str_.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is necessary to support pd.Series string like:

x = pd.Series(["abc", "defg", "12345"], dtype=None)
x = pd.Series(["abc", "defg", "12345"], dtype=np.str_)
x = pd.Series(["abc", "defg", "12345"], dtype="U10")

if array.dtype == np.object_:
with contextlib.suppress(TypeError, ValueError):
return np.ascontiguousarray(array, dtype=np.str_)
return array


Expand Down
6 changes: 3 additions & 3 deletions pygmt/clib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -1475,7 +1475,7 @@ def virtualfile_from_vectors(
# 2 columns contains coordinates like longitude, latitude, or datetime string
# types.
for col, array in enumerate(arrays[2:]):
if pd.api.types.is_string_dtype(array.dtype):
if np.issubdtype(array.dtype, np.str_):
Copy link
Member Author

@seisman seisman Nov 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes in this PR were initially added in #684 to support string arrays with np.object_ dtype. It's no longer necessary after dac7e8e because the array has been processed by the _to_numpy function when calling vectors_to_arrays at line 1471.

columns = col + 2
break

Expand Down Expand Up @@ -1506,9 +1506,9 @@ def virtualfile_from_vectors(
strings = string_arrays[0]
elif len(string_arrays) > 1:
strings = np.array(
[" ".join(vals) for vals in zip(*string_arrays, strict=True)]
[" ".join(vals) for vals in zip(*string_arrays, strict=True)],
dtype=np.str_,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Specifying dtype is not necesary here, but I feel it's good to expicitly tell that here we're expecting a np.str_ array.

)
strings = np.asanyarray(a=strings, dtype=np.str_)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this line was added in PR #684 and is no longer needed after dac7e8e

self.put_strings(
dataset, family="GMT_IS_VECTOR|GMT_IS_DUPLICATE", strings=strings
)
Expand Down
26 changes: 26 additions & 0 deletions pygmt/tests/test_clib_to_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import pytest
from packaging.version import Version
from pygmt.clib.conversion import _to_numpy
from pygmt.helpers.testing import skip_if_no

try:
import pyarrow as pa
Expand Down Expand Up @@ -174,6 +175,31 @@ def test_to_numpy_pandas_series_numpy_dtypes_numeric(dtype, expected_dtype):
npt.assert_array_equal(result, series)


@pytest.mark.parametrize(
"dtype",
[
None,
np.str_,
"U10",
"string[python]",
pytest.param("string[pyarrow]", marks=skip_if_no(package="pyarrow")),
pytest.param("string[pyarrow_numpy]", marks=skip_if_no(package="pyarrow")),
],
)
def test_to_numpy_pandas_series_pandas_dtypes_string(dtype):
"""
Test the _to_numpy function with pandas.Series of pandas string types.

In pandas, string arrays can be specified in multiple ways.

Reference: https://pandas.pydata.org/docs/reference/api/pandas.StringDtype.html
"""
array = pd.Series(["abc", "defg", "12345"], dtype=dtype)
result = _to_numpy(array)
_check_result(result, np.str_)
npt.assert_array_equal(result, array)


@pytest.mark.skipif(not _HAS_PYARROW, reason="pyarrow is not installed")
@pytest.mark.parametrize(
("dtype", "expected_dtype"),
Expand Down
Loading