diff --git a/docs/source/release_notes.rst b/docs/source/release_notes.rst index de411e8ce..04eb05119 100644 --- a/docs/source/release_notes.rst +++ b/docs/source/release_notes.rst @@ -14,6 +14,7 @@ Future Release * Fixed ``IntegerNullable`` inference by checking values are within valid ``Int64`` bounds (:pr:`1572`) * Update demo dataset links to point to new endpoint (:pr:`1570`) * Fix DivisionByZero error in ``type_system.py`` (:pr:`1571`) + * Fix Categorical dtype inference for ``PostalCode`` logical type (:pr:`1574`) * Fixed issue where forcing a ``Boolean`` logical type on a column of 0.0s and 1.0s caused incorrect transformation (:pr:`1576`) * Changes * Unpin dask dependency (:pr:`1561`) diff --git a/woodwork/tests/type_system/conftest.py b/woodwork/tests/type_system/conftest.py index 2b48563f8..2059aad34 100644 --- a/woodwork/tests/type_system/conftest.py +++ b/woodwork/tests/type_system/conftest.py @@ -282,6 +282,33 @@ def natural_language(request): return request.getfixturevalue(request.param) +# Postal Inference Fixtures +@pytest.fixture +def pandas_postal_codes(): + return [ + pd.Series(10 * ["77002", "55106"]), + pd.Series(10 * ["77002-0000", "55106-0000"]), + pd.Series(10 * ["12345", "12345", "12345-6789", "12345-0000"]), + ] + + +@pytest.fixture +def dask_postal_codes(pandas_postal_codes): + return [pd_to_dask(series) for series in pandas_postal_codes] + + +@pytest.fixture +def spark_postal_codes(pandas_postal_codes): + return [pd_to_spark(series) for series in pandas_postal_codes] + + +@pytest.fixture( + params=["pandas_postal_codes", "dask_postal_codes", "spark_postal_codes"], +) +def postal(request): + return request.getfixturevalue(request.param) + + # Unknown Inference Fixtures @pytest.fixture def pandas_strings(): @@ -357,7 +384,7 @@ def pyspark_empty_series(pandas_empty_series): @pytest.fixture( - params=["pandas_empty_series", "dask_empty_series", "pyspark_empty_series"] + params=["pandas_empty_series", "dask_empty_series", "pyspark_empty_series"], ) def empty_series(request): return request.getfixturevalue(request.param) diff --git a/woodwork/tests/type_system/test_ltype_inference.py b/woodwork/tests/type_system/test_ltype_inference.py index 9a3d41af6..b5c7b461a 100644 --- a/woodwork/tests/type_system/test_ltype_inference.py +++ b/woodwork/tests/type_system/test_ltype_inference.py @@ -13,6 +13,7 @@ IntegerNullable, LogicalType, NaturalLanguage, + PostalCode, Timedelta, Unknown, ) @@ -132,6 +133,16 @@ def test_categorical_inference(categories): assert isinstance(inferred_type, Categorical) +def test_postal_inference(postal): + dtypes = ["category", "string"] + for series in postal: + if _is_spark_series(series): + dtypes = get_spark_dtypes(dtypes) + for dtype in dtypes: + inferred_dtype = ww.type_system.infer_logical_type(series.astype(dtype)) + assert isinstance(inferred_dtype, PostalCode) + + def test_natural_language_inference(natural_language): dtypes = ["object", "string"] if _is_spark_series(natural_language[0]): diff --git a/woodwork/type_sys/inference_functions.py b/woodwork/type_sys/inference_functions.py index f6d1699d5..af543be63 100644 --- a/woodwork/type_sys/inference_functions.py +++ b/woodwork/type_sys/inference_functions.py @@ -156,7 +156,10 @@ def __call__(self, series: pd.Series) -> bool: regex = self.get_regex() # Includes a check for object dtypes - if not pdtypes.is_string_dtype(series.dtype): + if not ( + pdtypes.is_categorical_dtype(series.dtype) + or pdtypes.is_string_dtype(series.dtype) + ): return False try: