Skip to content

Commit

Permalink
Fix categorical dtype inference for postal code (#1574)
Browse files Browse the repository at this point in the history
* Fix categorical dtype inference for postal code
  • Loading branch information
sbadithe authored Nov 23, 2022
1 parent 807ea4f commit c9f9401
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/source/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Future Release
* Fixed ``IntegerNullable`` inference by checking values are within valid ``Int64`` bounds (:pr:`1572`)
* Update demo dataset links to point to new endpoint (:pr:`1570`)
* Fix DivisionByZero error in ``type_system.py`` (:pr:`1571`)
* Fix Categorical dtype inference for ``PostalCode`` logical type (:pr:`1574`)
* Fixed issue where forcing a ``Boolean`` logical type on a column of 0.0s and 1.0s caused incorrect transformation (:pr:`1576`)
* Changes
* Unpin dask dependency (:pr:`1561`)
Expand Down
29 changes: 28 additions & 1 deletion woodwork/tests/type_system/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,33 @@ def natural_language(request):
return request.getfixturevalue(request.param)


# Postal Inference Fixtures
@pytest.fixture
def pandas_postal_codes():
return [
pd.Series(10 * ["77002", "55106"]),
pd.Series(10 * ["77002-0000", "55106-0000"]),
pd.Series(10 * ["12345", "12345", "12345-6789", "12345-0000"]),
]


@pytest.fixture
def dask_postal_codes(pandas_postal_codes):
return [pd_to_dask(series) for series in pandas_postal_codes]


@pytest.fixture
def spark_postal_codes(pandas_postal_codes):
return [pd_to_spark(series) for series in pandas_postal_codes]


@pytest.fixture(
params=["pandas_postal_codes", "dask_postal_codes", "spark_postal_codes"],
)
def postal(request):
return request.getfixturevalue(request.param)


# Unknown Inference Fixtures
@pytest.fixture
def pandas_strings():
Expand Down Expand Up @@ -357,7 +384,7 @@ def pyspark_empty_series(pandas_empty_series):


@pytest.fixture(
params=["pandas_empty_series", "dask_empty_series", "pyspark_empty_series"]
params=["pandas_empty_series", "dask_empty_series", "pyspark_empty_series"],
)
def empty_series(request):
return request.getfixturevalue(request.param)
Expand Down
11 changes: 11 additions & 0 deletions woodwork/tests/type_system/test_ltype_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
IntegerNullable,
LogicalType,
NaturalLanguage,
PostalCode,
Timedelta,
Unknown,
)
Expand Down Expand Up @@ -132,6 +133,16 @@ def test_categorical_inference(categories):
assert isinstance(inferred_type, Categorical)


def test_postal_inference(postal):
dtypes = ["category", "string"]
for series in postal:
if _is_spark_series(series):
dtypes = get_spark_dtypes(dtypes)
for dtype in dtypes:
inferred_dtype = ww.type_system.infer_logical_type(series.astype(dtype))
assert isinstance(inferred_dtype, PostalCode)


def test_natural_language_inference(natural_language):
dtypes = ["object", "string"]
if _is_spark_series(natural_language[0]):
Expand Down
5 changes: 4 additions & 1 deletion woodwork/type_sys/inference_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,10 @@ def __call__(self, series: pd.Series) -> bool:
regex = self.get_regex()

# Includes a check for object dtypes
if not pdtypes.is_string_dtype(series.dtype):
if not (
pdtypes.is_categorical_dtype(series.dtype)
or pdtypes.is_string_dtype(series.dtype)
):
return False

try:
Expand Down

0 comments on commit c9f9401

Please sign in to comment.