From 6945c4f8b9a0f8497b1f9f662a2015bdc4992048 Mon Sep 17 00:00:00 2001 From: GALI PREM SAGAR Date: Thu, 7 Sep 2023 12:04:23 -0500 Subject: [PATCH] Raise `MixedTypeError` when a column of mixed-dtype is being constructed (#14050) Fixes #14038 This PR introduces changes that raise an error when a column of `object` dtype is being constructed when the data is not string or bools. Authors: - GALI PREM SAGAR (https://github.com/galipremsagar) Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: https://github.com/rapidsai/cudf/pull/14050 --- python/cudf/cudf/core/column/column.py | 19 ++++++++++++++----- python/cudf/cudf/tests/test_index.py | 3 ++- python/cudf/cudf/tests/test_parquet.py | 4 ++-- python/cudf/cudf/tests/test_series.py | 6 +++++- 4 files changed, 23 insertions(+), 9 deletions(-) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index a8735a1dd8d..b4ad6765207 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -2062,10 +2062,15 @@ def as_column( ) else: pyarrow_array = pa.array(arbitrary, from_pandas=nan_as_null) - if arbitrary.dtype == cudf.dtype("object") and isinstance( - pyarrow_array, (pa.DurationArray, pa.TimestampArray) + if ( + arbitrary.dtype == cudf.dtype("object") + and cudf.dtype(pyarrow_array.type.to_pandas_dtype()) + != cudf.dtype(arbitrary.dtype) + and not is_bool_dtype( + cudf.dtype(pyarrow_array.type.to_pandas_dtype()) + ) ): - raise TypeError("Cannot create column with mixed types") + raise MixedTypeError("Cannot create column with mixed types") if isinstance(pyarrow_array.type, pa.Decimal128Type): pyarrow_type = cudf.Decimal128Dtype.from_arrow( pyarrow_array.type @@ -2436,8 +2441,12 @@ def as_column( if ( isinstance(arbitrary, pd.Index) and arbitrary.dtype == cudf.dtype("object") - and isinstance( - pyarrow_array, (pa.DurationArray, pa.TimestampArray) + and ( + cudf.dtype(pyarrow_array.type.to_pandas_dtype()) + != cudf.dtype(arbitrary.dtype) + and not is_bool_dtype( + cudf.dtype(pyarrow_array.type.to_pandas_dtype()) + ) ) ): raise MixedTypeError( diff --git a/python/cudf/cudf/tests/test_index.py b/python/cudf/cudf/tests/test_index.py index 5730ecc4ae7..819527ac312 100644 --- a/python/cudf/cudf/tests/test_index.py +++ b/python/cudf/cudf/tests/test_index.py @@ -2676,10 +2676,11 @@ def test_scalar_getitem(self, index_values, i): 12, 20, ], + [1, 2, 3, 4], ], ) def test_index_mixed_dtype_error(data): - pi = pd.Index(data) + pi = pd.Index(data, dtype="object") with pytest.raises(TypeError): cudf.Index(pi) diff --git a/python/cudf/cudf/tests/test_parquet.py b/python/cudf/cudf/tests/test_parquet.py index 66c4a253423..b892cc62ac4 100644 --- a/python/cudf/cudf/tests/test_parquet.py +++ b/python/cudf/cudf/tests/test_parquet.py @@ -2374,11 +2374,11 @@ def test_parquet_writer_list_statistics(tmpdir): for i, col in enumerate(pd_slice): stats = pq_file.metadata.row_group(rg).column(i).statistics - actual_min = cudf.Series(pd_slice[col].explode().explode()).min() + actual_min = pd_slice[col].explode().explode().dropna().min() stats_min = stats.min assert normalized_equals(actual_min, stats_min) - actual_max = cudf.Series(pd_slice[col].explode().explode()).max() + actual_max = pd_slice[col].explode().explode().dropna().max() stats_max = stats.max assert normalized_equals(actual_max, stats_max) diff --git a/python/cudf/cudf/tests/test_series.py b/python/cudf/cudf/tests/test_series.py index 51c6bb1634d..783d7d31d7f 100644 --- a/python/cudf/cudf/tests/test_series.py +++ b/python/cudf/cudf/tests/test_series.py @@ -2187,11 +2187,15 @@ def test_series_init_error(): ) -@pytest.mark.parametrize("dtype", ["datetime64[ns]", "timedelta64[ns]"]) +@pytest.mark.parametrize( + "dtype", ["datetime64[ns]", "timedelta64[ns]", "object", "str"] +) def test_series_mixed_dtype_error(dtype): ps = pd.concat([pd.Series([1, 2, 3], dtype=dtype), pd.Series([10, 11])]) with pytest.raises(TypeError): cudf.Series(ps) + with pytest.raises(TypeError): + cudf.Series(ps.array) @pytest.mark.parametrize("data", [[True, False, None], [10, 200, 300]])