diff --git a/python/cudf/cudf/core/frame.py b/python/cudf/cudf/core/frame.py index e1b2f7d674d..123f13f8733 100644 --- a/python/cudf/cudf/core/frame.py +++ b/python/cudf/cudf/core/frame.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. from __future__ import annotations @@ -99,9 +99,14 @@ def _has_nulls(self): @_cudf_nvtx_annotate def serialize(self): + # TODO: See if self._data can be serialized outright header = { "type-serialized": pickle.dumps(type(self)), "column_names": pickle.dumps(tuple(self._data.names)), + "column_rangeindex": pickle.dumps(self._data.rangeindex), + "column_multiindex": pickle.dumps(self._data.multiindex), + "column_label_dtype": pickle.dumps(self._data.label_dtype), + "column_level_names": pickle.dumps(self._data._level_names), } header["columns"], frames = serialize_columns(self._columns) return header, frames @@ -112,7 +117,20 @@ def deserialize(cls, header, frames): cls_deserialize = pickle.loads(header["type-serialized"]) column_names = pickle.loads(header["column_names"]) columns = deserialize_columns(header["columns"], frames) - return cls_deserialize._from_data(dict(zip(column_names, columns))) + kwargs = {} + for metadata in [ + "rangeindex", + "multiindex", + "label_dtype", + "level_names", + ]: + key = f"column_{metadata}" + if key in header: + kwargs[metadata] = pickle.loads(header[key]) + col_accessor = ColumnAccessor( + data=dict(zip(column_names, columns)), **kwargs + ) + return cls_deserialize._from_data(col_accessor) @classmethod @_cudf_nvtx_annotate diff --git a/python/cudf/cudf/tests/test_serialize.py b/python/cudf/cudf/tests/test_serialize.py index e2788e4f03b..cac170cce55 100644 --- a/python/cudf/cudf/tests/test_serialize.py +++ b/python/cudf/cudf/tests/test_serialize.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2023, NVIDIA CORPORATION. +# Copyright (c) 2018-2024, NVIDIA CORPORATION. import pickle @@ -8,7 +8,6 @@ import pytest import cudf -from cudf.core._compat import PANDAS_GE_200 from cudf.testing import _utils as utils from cudf.testing._utils import assert_eq @@ -302,12 +301,9 @@ def test_serialize_string(): "frames", [ (cudf.Series([], dtype="str"), pd.Series([], dtype="str")), - pytest.param( - (cudf.DataFrame([]), pd.DataFrame([])), - marks=pytest.mark.xfail( - not PANDAS_GE_200, reason=".column returns Index[object]" - ), - ), + (cudf.DataFrame(), pd.DataFrame()), + (cudf.DataFrame([]), pd.DataFrame([])), + (cudf.DataFrame({}), pd.DataFrame({})), (cudf.DataFrame([1]).head(0), pd.DataFrame([1]).head(0)), (cudf.DataFrame({"a": []}), pd.DataFrame({"a": []})), ( @@ -401,3 +397,19 @@ def test_serialize_sliced_string(): recreated = cudf.Series.deserialize(*sliced.serialize()) assert_eq(recreated.to_pandas(nullable=True), pd_series) + + +@pytest.mark.parametrize( + "columns", + [ + cudf.RangeIndex(2), + cudf.Index([1, 2], dtype="int8"), + cudf.MultiIndex( + levels=[["a", "b"], [1, 2]], codes=[[0, 1], [0, 1]], names=["a", 0] + ), + ], +) +def test_serialize_column_types_preserved(columns): + expected = cudf.DataFrame([[10, 11]], columns=columns) + result = cudf.DataFrame.deserialize(*expected.serialize()) + assert_eq(result, expected)