diff --git a/python/cudf/cudf/core/dataframe.py b/python/cudf/cudf/core/dataframe.py index 1a780cc9e9f..8a3dbe77787 100644 --- a/python/cudf/cudf/core/dataframe.py +++ b/python/cudf/cudf/core/dataframe.py @@ -1390,10 +1390,21 @@ def _get_numeric_data(self): return self[columns] @_cudf_nvtx_annotate - def assign(self, **kwargs): + def assign(self, **kwargs: Union[Callable[[Self], Any], Any]): """ Assign columns to DataFrame from keyword arguments. + Parameters + ---------- + **kwargs: dict mapping string column names to values + The value for each key can either be a literal column (or + something that can be converted to a column), or + a callable of one argument that will be given the + dataframe as an argument and should return the new column + (without modifying the input argument). + Columns are added in-order, so callables can refer to + column names constructed in the assignment. + Examples -------- >>> import cudf @@ -1405,15 +1416,9 @@ def assign(self, **kwargs): 1 1 4 2 2 5 """ - new_df = cudf.DataFrame(index=self.index.copy()) - for name, col in self._data.items(): - if name in kwargs: - new_df[name] = kwargs.pop(name) - else: - new_df._data[name] = col.copy() - + new_df = self.copy(deep=False) for k, v in kwargs.items(): - new_df[k] = v + new_df[k] = v(new_df) if callable(v) else v return new_df @classmethod diff --git a/python/cudf/cudf/tests/test_dataframe.py b/python/cudf/cudf/tests/test_dataframe.py index 6180162ecdd..2f531afdeb7 100644 --- a/python/cudf/cudf/tests/test_dataframe.py +++ b/python/cudf/cudf/tests/test_dataframe.py @@ -1327,6 +1327,25 @@ def test_assign(): np.testing.assert_equal(gdf2.y.to_numpy(), [2, 3, 4]) +@pytest.mark.parametrize( + "mapping", + [ + {"y": 1, "z": lambda df: df["x"] + df["y"]}, + { + "x": lambda df: df["x"] * 2, + "y": lambda df: 2, + "z": lambda df: df["x"] / df["y"], + }, + ], +) +def test_assign_callable(mapping): + df = pd.DataFrame({"x": [1, 2, 3]}) + cdf = cudf.from_pandas(df) + expect = df.assign(**mapping) + actual = cdf.assign(**mapping) + assert_eq(expect, actual) + + @pytest.mark.parametrize("nrows", [1, 8, 100, 1000]) @pytest.mark.parametrize("method", ["murmur3", "md5"]) @pytest.mark.parametrize("seed", [None, 42])