Skip to content

Commit

Permalink
Handle schema inference in Dataset with empty list col (#319)
Browse files Browse the repository at this point in the history
* Add test for schema inference with empty list

* Only use `list_val_dtype` if col is a list
  • Loading branch information
oliverholworthy authored May 12, 2023
1 parent 5acb1b7 commit e1eaf26
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 12 deletions.
15 changes: 10 additions & 5 deletions merlin/core/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import enum
import functools
import itertools
from typing import Callable, Union
from typing import Callable, Optional, Union

import dask.dataframe as dd
import numpy as np
Expand Down Expand Up @@ -311,7 +311,7 @@ def series_has_nulls(s):
return s.has_nulls


def list_val_dtype(ser: SeriesLike) -> np.dtype:
def list_val_dtype(ser: SeriesLike) -> Optional[np.dtype]:
"""
Return the dtype of the leaves from a list or nested list
Expand All @@ -322,16 +322,21 @@ def list_val_dtype(ser: SeriesLike) -> np.dtype:
Returns
-------
np.dtype
The dtype of the innermost elements
Optional[np.dtype]
The dtype of the innermost elements if we find one
"""
if is_list_dtype(ser):
if cudf is not None and isinstance(ser, cudf.Series):
if is_list_dtype(ser):
ser = ser.list.leaves
return ser.dtype
elif isinstance(ser, pd.Series):
return pd.core.dtypes.cast.infer_dtype_from(next(iter(pd.core.common.flatten(ser))))[0]
try:
return pd.core.dtypes.cast.infer_dtype_from(
next(iter(pd.core.common.flatten(ser)))
)[0]
except StopIteration:
return None
if isinstance(ser, np.ndarray):
return ser.dtype
# adds detection when in merlin column
Expand Down
13 changes: 6 additions & 7 deletions merlin/io/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,13 +1212,12 @@ def sample_dtypes(self, n=1, annotate_lists=False):

if annotate_lists:
_real_meta = self._real_meta[n]
annotated = {
col: {
"dtype": list_val_dtype(_real_meta[col]) or _real_meta[col].dtype,
"is_list": is_list_dtype(_real_meta[col]),
}
for col in _real_meta.columns
}
annotated = {}
for col in _real_meta.columns:
is_list = is_list_dtype(_real_meta[col])
dtype = list_val_dtype(_real_meta[col]) if is_list else _real_meta[col].dtype
annotated[col] = {"dtype": dtype, "is_list": is_list}

return annotated

return self._real_meta[n].dtypes
Expand Down
6 changes: 6 additions & 0 deletions tests/unit/io/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,9 @@ def test_false_with_cudf_and_gpu(self):
def test_false_missing_cudf_or_gpu(self):
with pytest.raises(RuntimeError):
Dataset(make_df({"a": [1, 2, 3]}), cpu=False)


def test_infer_list_dtype_unknown():
df = pd.DataFrame({"col": [[], []]})
dataset = Dataset(df, cpu=True)
assert dataset.schema["col"].dtype.element_type.value == "unknown"

0 comments on commit e1eaf26

Please sign in to comment.