Skip to content

Commit

Permalink
preserve dtypes in DaskExecutor.fit_phase
Browse files Browse the repository at this point in the history
  • Loading branch information
rjzamora committed Jul 26, 2024
1 parent 6a177d8 commit 8a3d7f1
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions merlin/dag/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,6 @@ def transform(
# If so, we should perform column selection at the ddf level.
# Otherwise, Dask will not push the column selection into the
# IO function.

if not nodes:
return ddf[_get_unique(additional_columns)] if additional_columns else ddf

Expand All @@ -389,19 +388,20 @@ def transform(
if col_dtype:
output_dtypes[col_name] = md.dtype(col_dtype).to_numpy

def empty_like(df):
# Construct an empty DataFrame with the same dtypes as df
return df._constructor(
{k: df._constructor_sliced([], dtype=df[k].dtype) for k in df.columns}
)

if isinstance(output_dtypes, dict) and isinstance(ddf._meta, pd.DataFrame):
dtypes = output_dtypes
output_dtypes = type(ddf._meta)({k: [] for k in columns})
output_dtypes = empty_like(ddf._meta[columns])
for col_name, col_dtype in dtypes.items():
output_dtypes[col_name] = output_dtypes[col_name].astype(col_dtype)

elif not output_dtypes:
# TODO: constructing meta like this loses dtype information on the ddf
# and sets it all to 'float64'. We should propagate dtype information along
# with column names in the columngroup graph. This currently only
# happens during intermediate 'fit' transforms, so as long as statoperators
# don't require dtype information on the DDF this doesn't matter all that much
output_dtypes = type(ddf._meta)({k: [] for k in columns})
output_dtypes = empty_like(ddf._meta[columns])

return ensure_optimize_dataframe_graph(
ddf=ddf.map_partitions(
Expand Down

0 comments on commit 8a3d7f1

Please sign in to comment.