diff --git a/merlin/dag/executors.py b/merlin/dag/executors.py index fcb8c9535..ba2550ee1 100644 --- a/merlin/dag/executors.py +++ b/merlin/dag/executors.py @@ -389,19 +389,26 @@ def transform( if col_dtype: output_dtypes[col_name] = md.dtype(col_dtype).to_numpy - if isinstance(output_dtypes, dict) and isinstance(ddf._meta, pd.DataFrame): - dtypes = output_dtypes - output_dtypes = type(ddf._meta)({k: [] for k in columns}) - for col_name, col_dtype in dtypes.items(): - output_dtypes[col_name] = output_dtypes[col_name].astype(col_dtype) + def make_empty(df, cols): + # Construct an empty DataFrame - elif not output_dtypes: # TODO: constructing meta like this loses dtype information on the ddf # and sets it all to 'float64'. We should propagate dtype information along # with column names in the columngroup graph. This currently only # happens during intermediate 'fit' transforms, so as long as statoperators # don't require dtype information on the DDF this doesn't matter all that much - output_dtypes = type(ddf._meta)({k: [] for k in columns}) + return df._constructor( + {col: df._constructor_sliced([], dtype="float64") for col in cols} + ) + + if isinstance(output_dtypes, dict) and isinstance(ddf._meta, pd.DataFrame): + dtypes = output_dtypes + output_dtypes = make_empty(ddf._meta, columns) + for col_name, col_dtype in dtypes.items(): + output_dtypes[col_name] = output_dtypes[col_name].astype(col_dtype) + + elif not output_dtypes: + output_dtypes = make_empty(ddf._meta, columns) return ensure_optimize_dataframe_graph( ddf=ddf.map_partitions(