From 9f3cb6581ea4362ad93f8f8340883e1210a9ed40 Mon Sep 17 00:00:00 2001 From: Matthew Roeschke <10647082+mroeschke@users.noreply.github.com> Date: Fri, 24 Jan 2025 11:10:44 -0800 Subject: [PATCH] Remove cudf._lib.scalar in favor of pylibcudf (#17701) This PR changes `cudf.Scalar.device_scalar` to be a `pylibcudf.Scalar` object instead of a `cudf._lib.scalar.DeviceScalar`. Most of the conversion logic previously in `cudf._lib.scalar.DeviceScalar` now lives in `python/cudf/cudf/core/scalar.py` Some tests that exercised behaviors of `cudf.Scalar.device_scalar` when it was a `cudf._lib.scalar.DeviceScalar` were modified/removed. Authors: - Matthew Roeschke (https://github.com/mroeschke) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: https://github.com/rapidsai/cudf/pull/17701 --- python/cudf/cudf/_lib/CMakeLists.txt | 2 +- python/cudf/cudf/_lib/column.pyx | 18 +- python/cudf/cudf/_lib/scalar.pxd | 22 -- python/cudf/cudf/_lib/scalar.pyx | 243 ------------- python/cudf/cudf/api/types.py | 4 +- python/cudf/cudf/core/_internals/binaryop.py | 4 +- python/cudf/cudf/core/_internals/copying.py | 4 +- python/cudf/cudf/core/column/column.py | 24 +- python/cudf/cudf/core/column/lists.py | 7 + python/cudf/cudf/core/column/struct.py | 2 +- python/cudf/cudf/core/dtypes.py | 42 +++ python/cudf/cudf/core/index.py | 4 +- python/cudf/cudf/core/scalar.py | 348 +++++++++++++------ python/cudf/cudf/core/tools/datetimes.py | 4 +- python/cudf/cudf/tests/test_binops.py | 4 +- python/cudf/cudf/tests/test_list.py | 1 - python/cudf/cudf/tests/test_scalar.py | 52 +-- python/cudf/cudf/tests/test_struct.py | 9 +- 18 files changed, 340 insertions(+), 454 deletions(-) delete mode 100644 python/cudf/cudf/_lib/scalar.pxd delete mode 100644 python/cudf/cudf/_lib/scalar.pyx diff --git a/python/cudf/cudf/_lib/CMakeLists.txt b/python/cudf/cudf/_lib/CMakeLists.txt index ec44a6aa8c5..0ec9350e6ee 100644 --- a/python/cudf/cudf/_lib/CMakeLists.txt +++ b/python/cudf/cudf/_lib/CMakeLists.txt @@ -12,7 +12,7 @@ # the License. # ============================================================================= -set(cython_sources column.pyx scalar.pyx strings_udf.pyx) +set(cython_sources column.pyx strings_udf.pyx) set(linked_libraries cudf::cudf) rapids_cython_create_modules( diff --git a/python/cudf/cudf/_lib/column.pyx b/python/cudf/cudf/_lib/column.pyx index 114991dbe3e..00ecd53e70d 100644 --- a/python/cudf/cudf/_lib/column.pyx +++ b/python/cudf/cudf/_lib/column.pyx @@ -33,7 +33,11 @@ from libcpp.vector cimport vector from rmm.pylibrmm.device_buffer cimport DeviceBuffer -from pylibcudf cimport DataType as plc_DataType, Column as plc_Column +from pylibcudf cimport ( + DataType as plc_DataType, + Column as plc_Column, + Scalar as plc_Scalar, +) cimport pylibcudf.libcudf.copying as cpp_copying cimport pylibcudf.libcudf.types as libcudf_types cimport pylibcudf.libcudf.unary as libcudf_unary @@ -45,8 +49,6 @@ from pylibcudf.libcudf.column.column_view cimport column_view from pylibcudf.libcudf.lists.lists_column_view cimport lists_column_view from pylibcudf.libcudf.scalar.scalar cimport scalar -from cudf._lib.scalar cimport DeviceScalar - cdef get_element(column_view col_view, size_type index): @@ -55,10 +57,8 @@ cdef get_element(column_view col_view, size_type index): c_output = move( cpp_copying.get_element(col_view, index) ) - - return DeviceScalar.from_unique_ptr( - move(c_output), dtype=dtype_from_column_view(col_view) - ) + plc_scalar = plc_Scalar.from_libcudf(move(c_output)) + return pylibcudf.interop.to_arrow(plc_scalar).as_py() def dtype_from_pylibcudf_column(plc_Column col not None): @@ -767,7 +767,7 @@ cdef class Column: base_nbytes = 0 else: chars_size = get_element( - offset_child_column, offset_child_column.size()-1).value + offset_child_column, offset_child_column.size()-1) base_nbytes = chars_size if data_ptr: @@ -908,6 +908,6 @@ cdef class Column: def from_scalar(py_val, size_type size): return Column.from_pylibcudf( pylibcudf.Column.from_scalar( - py_val.device_value.c_value, size + py_val.device_value, size ) ) diff --git a/python/cudf/cudf/_lib/scalar.pxd b/python/cudf/cudf/_lib/scalar.pxd deleted file mode 100644 index a3a8a14e70f..00000000000 --- a/python/cudf/cudf/_lib/scalar.pxd +++ /dev/null @@ -1,22 +0,0 @@ -# Copyright (c) 2020-2024, NVIDIA CORPORATION. - -from libcpp cimport bool -from libcpp.memory cimport unique_ptr - -from pylibcudf.libcudf.scalar.scalar cimport scalar -from rmm.pylibrmm.memory_resource cimport DeviceMemoryResource - - -cdef class DeviceScalar: - cdef public object c_value - - cdef object _dtype - - cdef const scalar* get_raw_ptr(self) except * - - @staticmethod - cdef DeviceScalar from_unique_ptr(unique_ptr[scalar] ptr, dtype=*) - - cdef void _set_dtype(self, dtype=*) - - cpdef bool is_valid(DeviceScalar s) diff --git a/python/cudf/cudf/_lib/scalar.pyx b/python/cudf/cudf/_lib/scalar.pyx deleted file mode 100644 index 65607c91302..00000000000 --- a/python/cudf/cudf/_lib/scalar.pyx +++ /dev/null @@ -1,243 +0,0 @@ -# Copyright (c) 2020-2025, NVIDIA CORPORATION. - -import copy - -import numpy as np -import pandas as pd -import pyarrow as pa - -from libcpp cimport bool -from libcpp.memory cimport unique_ptr -from libcpp.utility cimport move - -import pylibcudf as plc - -import cudf -from cudf.core.dtypes import ListDtype, StructDtype -from cudf.core.missing import NA, NaT -from cudf.utils.dtypes import PYLIBCUDF_TO_SUPPORTED_NUMPY_TYPES - -# We currently need this cimport because some of the implementations here -# access the c_obj of the scalar, and because we need to be able to call -# pylibcudf.Scalar.from_libcudf. Both of those are temporarily acceptable until -# DeviceScalar is phased out entirely from cuDF Cython (at which point -# cudf.Scalar will be directly backed by pylibcudf.Scalar). -from pylibcudf cimport Scalar as plc_Scalar -from pylibcudf.libcudf.scalar.scalar cimport scalar - - -def _replace_nested(obj, check, replacement): - if isinstance(obj, list): - for i, item in enumerate(obj): - if check(item): - obj[i] = replacement - elif isinstance(item, (dict, list)): - _replace_nested(item, check, replacement) - elif isinstance(obj, dict): - for k, v in obj.items(): - if check(v): - obj[k] = replacement - elif isinstance(v, (dict, list)): - _replace_nested(v, check, replacement) - - -def gather_metadata(dtypes): - """Convert a dict of dtypes to a list of ColumnMetadata objects. - - The metadata is constructed recursively so that nested types are - represented as nested ColumnMetadata objects. - - Parameters - ---------- - dtypes : dict - A dict mapping column names to dtypes. - - Returns - ------- - List[ColumnMetadata] - A list of ColumnMetadata objects. - """ - out = [] - for name, dtype in dtypes.items(): - v = plc.interop.ColumnMetadata(name) - if isinstance(dtype, cudf.StructDtype): - v.children_meta = gather_metadata(dtype.fields) - elif isinstance(dtype, cudf.ListDtype): - # Offsets column is unnamed and has no children - v.children_meta.append(plc.interop.ColumnMetadata("")) - v.children_meta.extend( - gather_metadata({"": dtype.element_type}) - ) - out.append(v) - return out - - -cdef class DeviceScalar: - - # TODO: I think this should be removable, except that currently the way - # that from_unique_ptr is implemented is probably dereferencing this in an - # invalid state. See what the best way to fix that is. - def __cinit__(self, *args, **kwargs): - self.c_value = plc.Scalar.__new__(plc.Scalar) - - def __init__(self, value, dtype): - """ - Type representing an *immutable* scalar value on the device - - Parameters - ---------- - value : scalar - An object of scalar type, i.e., one for which - `np.isscalar()` returns `True`. Can also be `None`, - to represent a "null" scalar. In this case, - dtype *must* be provided. - dtype : dtype - A NumPy dtype. - """ - dtype = dtype if dtype.kind != 'U' else cudf.dtype('object') - - if cudf.utils.utils.is_na_like(value): - value = None - else: - # TODO: For now we always deepcopy the input value to avoid - # overwriting the input values when replacing nulls. Since it's - # just host values it's not that expensive, but we could consider - # alternatives. - value = copy.deepcopy(value) - _replace_nested(value, cudf.utils.utils.is_na_like, None) - - if isinstance(dtype, cudf.core.dtypes._BaseDtype): - pa_type = dtype.to_arrow() - elif pd.api.types.is_string_dtype(dtype): - # Have to manually convert object types, which we use internally - # for strings but pyarrow only supports as unicode 'U' - pa_type = pa.string() - else: - pa_type = pa.from_numpy_dtype(dtype) - - if isinstance(pa_type, pa.ListType) and value is None: - # pyarrow doesn't correctly handle None values for list types, so - # we have to create this one manually. - # https://github.com/apache/arrow/issues/40319 - pa_array = pa.array([None], type=pa_type) - else: - pa_array = pa.array([pa.scalar(value, type=pa_type)]) - - pa_table = pa.Table.from_arrays([pa_array], names=[""]) - table = plc.interop.from_arrow(pa_table) - - column = table.columns()[0] - if isinstance(dtype, cudf.core.dtypes.DecimalDtype): - if isinstance(dtype, cudf.core.dtypes.Decimal32Dtype): - column = plc.unary.cast( - column, plc.DataType(plc.TypeId.DECIMAL32, -dtype.scale) - ) - elif isinstance(dtype, cudf.core.dtypes.Decimal64Dtype): - column = plc.unary.cast( - column, plc.DataType(plc.TypeId.DECIMAL64, -dtype.scale) - ) - - self.c_value = plc.copying.get_element(column, 0) - self._dtype = dtype - - def _to_host_scalar(self): - is_datetime = self.dtype.kind == "M" - is_timedelta = self.dtype.kind == "m" - - null_type = NaT if is_datetime or is_timedelta else NA - - metadata = gather_metadata({"": self.dtype})[0] - ps = plc.interop.to_arrow(self.c_value, metadata) - if not ps.is_valid: - return null_type - - # TODO: The special handling of specific types below does not currently - # extend to nested types containing those types (e.g. List[timedelta] - # where the timedelta would overflow). We should eventually account for - # those cases, but that will require more careful consideration of how - # to traverse the contents of the nested data. - if is_datetime or is_timedelta: - time_unit, _ = np.datetime_data(self.dtype) - # Cast to int64 to avoid overflow - ps_cast = ps.cast('int64').as_py() - out_type = np.datetime64 if is_datetime else np.timedelta64 - ret = out_type(ps_cast, time_unit) - elif cudf.api.types.is_numeric_dtype(self.dtype): - ret = ps.type.to_pandas_dtype()(ps.as_py()) - else: - ret = ps.as_py() - - _replace_nested(ret, lambda item: item is None, NA) - return ret - - @property - def dtype(self): - """ - The NumPy dtype corresponding to the data type of the underlying - device scalar. - """ - return self._dtype - - @property - def value(self): - """ - Returns a host copy of the underlying device scalar. - """ - return self._to_host_scalar() - - cdef const scalar* get_raw_ptr(self) except *: - return ( self.c_value).c_obj.get() - - cpdef bool is_valid(self): - """ - Returns if the Scalar is valid or not(i.e., ). - """ - return self.c_value.is_valid() - - def __repr__(self): - if cudf.utils.utils.is_na_like(self.value): - return ( - f"{self.__class__.__name__}" - f"({self.value}, {repr(self.dtype)})" - ) - else: - return f"{self.__class__.__name__}({repr(self.value)})" - - @staticmethod - cdef DeviceScalar from_unique_ptr(unique_ptr[scalar] ptr, dtype=None): - """ - Construct a Scalar object from a unique_ptr. - """ - cdef DeviceScalar s = DeviceScalar.__new__(DeviceScalar) - # Note: This line requires pylibcudf to be cimported - s.c_value = plc_Scalar.from_libcudf(move(ptr)) - s._set_dtype(dtype) - return s - - @staticmethod - def from_pylibcudf(pscalar, dtype=None): - cdef DeviceScalar s = DeviceScalar.__new__(DeviceScalar) - s.c_value = pscalar - s._set_dtype(dtype) - return s - - cdef void _set_dtype(self, dtype=None): - cdtype_id = self.c_value.type().id() - if dtype is not None: - self._dtype = dtype - elif cdtype_id in { - plc.TypeID.DECIMAL32, - plc.TypeID.DECIMAL64, - plc.TypeID.DECIMAL128, - }: - raise TypeError( - "Must pass a dtype when constructing from a fixed-point scalar" - ) - elif cdtype_id == plc.TypeID.STRUCT: - self._dtype = StructDtype.from_arrow( - plc.interop.to_arrow(self.c_value).type - ) - elif cdtype_id == plc.TypeID.LIST: - self._dtype = ListDtype.from_arrow(plc.interop.to_arrow(self.c_value).type) - else: - self._dtype = PYLIBCUDF_TO_SUPPORTED_NUMPY_TYPES[cdtype_id] diff --git a/python/cudf/cudf/api/types.py b/python/cudf/cudf/api/types.py index cad4b1aa72c..35eb25e2a32 100644 --- a/python/cudf/cudf/api/types.py +++ b/python/cudf/cudf/api/types.py @@ -16,6 +16,8 @@ import pyarrow as pa from pandas.api import types as pd_types +import pylibcudf as plc + import cudf from cudf.core._compat import PANDAS_LT_300 from cudf.core.dtypes import ( # noqa: F401 @@ -143,8 +145,8 @@ def is_scalar(val): val, ( cudf.Scalar, - cudf._lib.scalar.DeviceScalar, cudf.core.tools.datetimes.DateOffset, + plc.Scalar, pa.Scalar, ), ) or ( diff --git a/python/cudf/cudf/core/_internals/binaryop.py b/python/cudf/cudf/core/_internals/binaryop.py index a9023f8fd59..4ad873b9825 100644 --- a/python/cudf/cudf/core/_internals/binaryop.py +++ b/python/cudf/cudf/core/_internals/binaryop.py @@ -50,10 +50,10 @@ def binaryop( plc.binaryop.binary_operation( lhs.to_pylibcudf(mode="read") if isinstance(lhs, Column) - else lhs.device_value.c_value, + else lhs.device_value, rhs.to_pylibcudf(mode="read") if isinstance(rhs, Column) - else rhs.device_value.c_value, + else rhs.device_value, plc.binaryop.BinaryOperator[op], dtype_to_pylibcudf_type(dtype), ) diff --git a/python/cudf/cudf/core/_internals/copying.py b/python/cudf/cudf/core/_internals/copying.py index 34c1850cb72..76122f89445 100644 --- a/python/cudf/cudf/core/_internals/copying.py +++ b/python/cudf/cudf/core/_internals/copying.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2024, NVIDIA CORPORATION. +# Copyright (c) 2020-2025, NVIDIA CORPORATION. from __future__ import annotations from typing import TYPE_CHECKING @@ -67,7 +67,7 @@ def scatter( plc_tbl = plc.copying.scatter( plc.Table([col.to_pylibcudf(mode="read") for col in sources]) # type: ignore[union-attr] if isinstance(sources[0], cudf._lib.column.Column) - else [slr.device_value.c_value for slr in sources], # type: ignore[union-attr] + else [slr.device_value for slr in sources], # type: ignore[union-attr] scatter_map.to_pylibcudf(mode="read"), plc.Table([col.to_pylibcudf(mode="read") for col in target_columns]), ) diff --git a/python/cudf/cudf/core/column/column.py b/python/cudf/cudf/core/column/column.py index e6f057f63f5..7c9ed0a911e 100644 --- a/python/cudf/cudf/core/column/column.py +++ b/python/cudf/cudf/core/column/column.py @@ -23,7 +23,6 @@ import rmm import cudf -from cudf import _lib as libcudf from cudf._lib.column import Column from cudf.api.types import ( _is_non_decimal_numeric_dtype, @@ -441,7 +440,7 @@ def _fill( self.to_pylibcudf(mode="read"), begin, end, - slr.device_value.c_value, + slr.device_value, ) ) if is_string_dtype(self.dtype): @@ -461,7 +460,7 @@ def _fill( self.to_pylibcudf(mode="write"), begin, end, - slr.device_value.c_value, + slr.device_value, ) return self @@ -472,7 +471,7 @@ def shift(self, offset: int, fill_value: ScalarLike) -> Self: plc_col = plc.copying.shift( self.to_pylibcudf(mode="read"), offset, - fill_value.device_value.c_value, + fill_value.device_value, ) return type(self).from_pylibcudf(plc_col) # type: ignore[return-value] @@ -588,14 +587,11 @@ def element_indexing(self, index: int): if idx > len(self) - 1 or idx < 0: raise IndexError("single positional indexer is out-of-bounds") with acquire_spill_lock(): - dscalar = libcudf.scalar.DeviceScalar.from_pylibcudf( - plc.copying.get_element( - self.to_pylibcudf(mode="read"), - idx, - ), - dtype=self.dtype, + plc_scalar = plc.copying.get_element( + self.to_pylibcudf(mode="read"), + idx, ) - return dscalar.value + return cudf.Scalar.from_pylibcudf(plc_scalar).value def slice(self, start: int, stop: int, stride: int | None = None) -> Self: stride = 1 if stride is None else stride @@ -742,7 +738,7 @@ def _scatter_by_column( plc_table = plc.copying.boolean_mask_scatter( plc.Table([value.to_pylibcudf(mode="read")]) if isinstance(value, Column) - else [value.device_value.c_value], + else [value.device_value], plc.Table([self.to_pylibcudf(mode="read")]), key.to_pylibcudf(mode="read"), ) @@ -822,7 +818,7 @@ def fillna( else plc.replace.ReplacePolicy.FOLLOWING ) elif is_scalar(fill_value): - plc_replace = cudf.Scalar(fill_value).device_value.c_value + plc_replace = cudf.Scalar(fill_value).device_value else: plc_replace = fill_value.to_pylibcudf(mode="read") plc_column = plc.replace.replace_nulls( @@ -1648,7 +1644,7 @@ def copy_if_else( return type(self).from_pylibcudf( # type: ignore[return-value] plc.copying.copy_if_else( self.to_pylibcudf(mode="read"), - other.device_value.c_value + other.device_value if isinstance(other, cudf.Scalar) else other.to_pylibcudf(mode="read"), boolean_mask.to_pylibcudf(mode="read"), diff --git a/python/cudf/cudf/core/column/lists.py b/python/cudf/cudf/core/column/lists.py index e7e69961db4..2b834a20726 100644 --- a/python/cudf/cudf/core/column/lists.py +++ b/python/cudf/cudf/core/column/lists.py @@ -111,6 +111,13 @@ def memory_usage(self): ) return n + def element_indexing(self, index: int) -> list: + result = super().element_indexing(index) + if isinstance(result, list): + return self.dtype._recursively_replace_fields(result) + else: + return result + def __setitem__(self, key, value): if isinstance(value, list): value = cudf.Scalar(value) diff --git a/python/cudf/cudf/core/column/struct.py b/python/cudf/cudf/core/column/struct.py index 052a68cec98..2e10166295b 100644 --- a/python/cudf/cudf/core/column/struct.py +++ b/python/cudf/cudf/core/column/struct.py @@ -120,7 +120,7 @@ def memory_usage(self) -> int: def element_indexing(self, index: int) -> dict: result = super().element_indexing(index) - return dict(zip(self.dtype.fields, result.values())) + return self.dtype._recursively_replace_fields(result) def __setitem__(self, key, value): if isinstance(value, dict): diff --git a/python/cudf/cudf/core/dtypes.py b/python/cudf/cudf/core/dtypes.py index ce7fb968069..32e695b32e3 100644 --- a/python/cudf/cudf/core/dtypes.py +++ b/python/cudf/cudf/core/dtypes.py @@ -518,6 +518,28 @@ def deserialize(cls, header: dict, frames: list): def itemsize(self): return self.element_type.itemsize + def _recursively_replace_fields(self, result: list) -> list: + """ + Return a new list result but with the keys of dict element by the keys in StructDtype.fields.keys(). + + Intended when result comes from pylibcudf without preserved nested field names. + """ + if isinstance(self.element_type, StructDtype): + return [ + self.element_type._recursively_replace_fields(res) + if isinstance(res, dict) + else res + for res in result + ] + elif isinstance(self.element_type, ListDtype): + return [ + self.element_type._recursively_replace_fields(res) + if isinstance(res, list) + else res + for res in result + ] + return result + class StructDtype(_BaseDtype): """ @@ -677,6 +699,26 @@ def itemsize(self): for field in self._typ ) + def _recursively_replace_fields(self, result: dict) -> dict: + """ + Return a new dict result but with the keys replaced by the keys in self.fields.keys(). + + Intended when result comes from pylibcudf without preserved nested field names. + """ + new_result = {} + for (new_field, field_dtype), result_value in zip( + self.fields.items(), result.values() + ): + if isinstance(field_dtype, StructDtype) and isinstance( + result_value, dict + ): + new_result[new_field] = ( + field_dtype._recursively_replace_fields(result_value) + ) + else: + new_result[new_field] = result_value + return new_result + decimal_dtype_template = textwrap.dedent( """ diff --git a/python/cudf/cudf/core/index.py b/python/cudf/cudf/core/index.py index 1f402c8ec8c..c13d62b39df 100644 --- a/python/cudf/cudf/core/index.py +++ b/python/cudf/cudf/core/index.py @@ -3385,8 +3385,8 @@ def interval_range( bin_edges = libcudf.column.Column.from_pylibcudf( plc.filling.sequence( size=periods + 1, - init=start.device_value.c_value, - step=freq.device_value.c_value, + init=start.device_value, + step=freq.device_value, ) ) return IntervalIndex.from_breaks(bin_edges, closed=closed, name=name) diff --git a/python/cudf/cudf/core/scalar.py b/python/cudf/cudf/core/scalar.py index 6630433c9a3..19b13a8e97d 100644 --- a/python/cudf/cudf/core/scalar.py +++ b/python/cudf/cudf/core/scalar.py @@ -1,19 +1,27 @@ # Copyright (c) 2020-2025, NVIDIA CORPORATION. from __future__ import annotations +import copy import decimal import functools import operator from collections import OrderedDict +from typing import TYPE_CHECKING, Any import numpy as np +import pandas as pd import pyarrow as pa import pylibcudf as plc import cudf from cudf.api.types import is_scalar -from cudf.core.dtypes import ListDtype, StructDtype +from cudf.core.dtypes import ( + Decimal32Dtype, + Decimal64Dtype, + ListDtype, + StructDtype, +) from cudf.core.missing import NA, NaT from cudf.core.mixins import BinaryOperand from cudf.utils.dtypes import ( @@ -21,6 +29,180 @@ to_cudf_compatible_scalar, ) +if TYPE_CHECKING: + from typing_extensions import Self + + from cudf._typing import Dtype, ScalarLike + + +def _preprocess_host_value(value, dtype) -> tuple[ScalarLike, Dtype]: + """ + Preprocess a value and dtype for host-side cudf.Scalar + + Parameters + ---------- + value: Scalarlike + dtype: dtypelike or None + + Returns + ------- + tuple[ScalarLike, Dtype] + """ + valid = not cudf.utils.utils._is_null_host_scalar(value) + + if isinstance(value, list): + if dtype is not None: + raise TypeError("Lists may not be cast to a different dtype") + else: + dtype = ListDtype.from_arrow( + pa.infer_type([value], from_pandas=True) + ) + return value, dtype + elif isinstance(dtype, ListDtype): + if value not in {None, NA}: + raise ValueError(f"Can not coerce {value} to ListDtype") + else: + return NA, dtype + + if isinstance(value, dict): + if dtype is None: + dtype = StructDtype.from_arrow( + pa.infer_type([value], from_pandas=True) + ) + return value, dtype + elif isinstance(dtype, StructDtype): + if value not in {None, NA}: + raise ValueError(f"Can not coerce {value} to StructDType") + else: + return NA, dtype + + if isinstance(dtype, cudf.core.dtypes.DecimalDtype): + value = pa.scalar( + value, type=pa.decimal128(dtype.precision, dtype.scale) + ).as_py() + if isinstance(value, decimal.Decimal) and dtype is None: + dtype = cudf.Decimal128Dtype._from_decimal(value) + + value = to_cudf_compatible_scalar(value, dtype=dtype) + + if dtype is None: + if not valid: + if value is NaT: + value = value.to_numpy() + + if isinstance(value, (np.datetime64, np.timedelta64)): + unit, _ = np.datetime_data(value) + if unit == "generic": + raise TypeError("Cant convert generic NaT to null scalar") + else: + dtype = value.dtype + else: + raise TypeError( + "dtype required when constructing a null scalar" + ) + else: + dtype = value.dtype + + if not isinstance(dtype, cudf.core.dtypes.DecimalDtype): + dtype = cudf.dtype(dtype) + + if not valid: + value = NaT if dtype.kind in "mM" else NA + + return value, dtype + + +def _replace_nested(obj, check, replacement): + if isinstance(obj, list): + for i, item in enumerate(obj): + if check(item): + obj[i] = replacement + elif isinstance(item, (dict, list)): + _replace_nested(item, check, replacement) + elif isinstance(obj, dict): + for k, v in obj.items(): + if check(v): + obj[k] = replacement + elif isinstance(v, (dict, list)): + _replace_nested(v, check, replacement) + + +def _maybe_nested_pa_scalar_to_py(pa_scalar: pa.Scalar) -> Any: + """ + Convert a "nested" pyarrow scalar to a Python object. + + These scalars come from pylibcudf.Scalar where field names can be + duplicate empty strings. + + Parameters + ---------- + pa_scalar: pa.Scalar + + Returns + ------- + Any + Python scalar + """ + if not pa_scalar.is_valid: + return pa_scalar.as_py() + elif pa.types.is_struct(pa_scalar.type): + return { + str(i): _maybe_nested_pa_scalar_to_py(val) + for i, (_, val) in enumerate(pa_scalar.items()) + } + elif pa.types.is_list(pa_scalar.type): + return [_maybe_nested_pa_scalar_to_py(val) for val in pa_scalar] + else: + return pa_scalar.as_py() + + +def _to_plc_scalar(value: ScalarLike, dtype: Dtype) -> plc.Scalar: + """ + Convert a value and dtype to a pylibcudf Scalar for device-side cudf.Scalar + + Parameters + ---------- + value: Scalarlike + dtype: dtypelike + + Returns + ------- + plc.Scalar + """ + if cudf.utils.utils.is_na_like(value): + value = None + else: + # TODO: For now we deepcopy the input value for nested values to avoid + # overwriting the input values when replacing nulls. Since it's + # just host values it's not that expensive, but we could consider + # alternatives. + if isinstance(value, (list, dict)): + value = copy.deepcopy(value) + _replace_nested(value, cudf.utils.utils.is_na_like, None) + + if isinstance(dtype, cudf.core.dtypes._BaseDtype): + pa_type = dtype.to_arrow() + elif pd.api.types.is_string_dtype(dtype): + # Have to manually convert object types, which we use internally + # for strings but pyarrow only supports as unicode 'U' + pa_type = pa.string() + else: + pa_type = pa.from_numpy_dtype(dtype) + + pa_scalar = pa.scalar(value, type=pa_type) + plc_scalar = plc.interop.from_arrow(pa_scalar) + if isinstance(dtype, (Decimal32Dtype, Decimal64Dtype)): + # pyarrow only supports decimal128 + if isinstance(dtype, Decimal32Dtype): + plc_type = plc.DataType(plc.TypeId.DECIMAL32, -dtype.scale) + elif isinstance(dtype, Decimal64Dtype): + plc_type = plc.DataType(plc.TypeId.DECIMAL64, -dtype.scale) + plc_column = plc.unary.cast( + plc.Column.from_scalar(plc_scalar, 1), plc_type + ) + plc_scalar = plc.copying.get_element(plc_column, 0) + return plc_scalar + @functools.lru_cache(maxsize=128) def pa_scalar_to_plc_scalar(pa_scalar: pa.Scalar) -> plc.Scalar: @@ -138,7 +320,7 @@ class Scalar(BinaryOperand, metaclass=CachedScalarInstanceMeta): def __init__(self, value, dtype=None): self._host_value = None self._host_dtype = None - self._device_value = None + self._device_value: plc.Scalar | None = None if isinstance(value, Scalar): if value._is_host_value_current: @@ -147,37 +329,34 @@ def __init__(self, value, dtype=None): else: self._device_value = value._device_value else: - self._host_value, self._host_dtype = self._preprocess_host_value( + self._host_value, self._host_dtype = _preprocess_host_value( value, dtype ) @classmethod - def from_device_scalar(cls, device_scalar): - if not isinstance(device_scalar, cudf._lib.scalar.DeviceScalar): + def from_pylibcudf(cls, scalar: plc.Scalar) -> Self: + if not isinstance(scalar, plc.Scalar): raise TypeError( - "Expected an instance of DeviceScalar, " - f"got {type(device_scalar).__name__}" + "Expected an instance of pylibcudf.Scalar, " + f"got {type(scalar).__name__}" ) obj = object.__new__(cls) obj._host_value = None obj._host_dtype = None - obj._device_value = device_scalar + obj._device_value = scalar return obj @property - def _is_host_value_current(self): + def _is_host_value_current(self) -> bool: return self._host_value is not None @property - def _is_device_value_current(self): + def _is_device_value_current(self) -> bool: return self._device_value is not None @property - def device_value(self): - if self._device_value is None: - self._device_value = cudf._lib.scalar.DeviceScalar( - self._host_value, self._host_dtype - ) + def device_value(self) -> plc.Scalar: + self._sync() return self._device_value @property @@ -186,92 +365,55 @@ def value(self): self._device_value_to_host() return self._host_value - # todo: change to cached property + # TODO: change to @functools.cached_property @property def dtype(self): - if self._is_host_value_current: - if isinstance(self._host_value, str): - return cudf.dtype("object") - else: - return self._host_dtype - else: - return self.device_value.dtype + if self._host_dtype is not None: + return self._host_dtype + if not self._is_host_value_current: + self._device_value_to_host() + _, host_dtype = _preprocess_host_value(self._host_value, None) + self._host_dtype = host_dtype + return self._host_dtype - def is_valid(self): + def is_valid(self) -> bool: if not self._is_host_value_current: self._device_value_to_host() return not cudf.utils.utils._is_null_host_scalar(self._host_value) - def _device_value_to_host(self): - self._host_value = self._device_value._to_host_scalar() - - def _preprocess_host_value(self, value, dtype): - valid = not cudf.utils.utils._is_null_host_scalar(value) - - if isinstance(value, list): - if dtype is not None: - raise TypeError("Lists may not be cast to a different dtype") - else: - dtype = ListDtype.from_arrow( - pa.infer_type([value], from_pandas=True) - ) - return value, dtype - elif isinstance(dtype, ListDtype): - if value not in {None, NA}: - raise ValueError(f"Can not coerce {value} to ListDtype") + def _device_value_to_host(self) -> None: + ps = plc.interop.to_arrow(self._device_value) + is_datetime = pa.types.is_timestamp(ps.type) + is_timedelta = pa.types.is_duration(ps.type) + if not ps.is_valid: + if is_datetime or is_timedelta: + self._host_value = NaT else: - return NA, dtype - - if isinstance(value, dict): - if dtype is None: - dtype = StructDtype.from_arrow( - pa.infer_type([value], from_pandas=True) - ) - return value, dtype - elif isinstance(dtype, StructDtype): - if value not in {None, NA}: - raise ValueError(f"Can not coerce {value} to StructDType") - else: - return NA, dtype - - if isinstance(dtype, cudf.core.dtypes.DecimalDtype): - value = pa.scalar( - value, type=pa.decimal128(dtype.precision, dtype.scale) - ).as_py() - if isinstance(value, decimal.Decimal) and dtype is None: - dtype = cudf.Decimal128Dtype._from_decimal(value) - - value = to_cudf_compatible_scalar(value, dtype=dtype) - - if dtype is None: - if not valid: - if value is NaT: - value = value.to_numpy() - - if isinstance(value, (np.datetime64, np.timedelta64)): - unit, _ = np.datetime_data(value) - if unit == "generic": - raise TypeError( - "Cant convert generic NaT to null scalar" - ) - else: - dtype = value.dtype - else: - raise TypeError( - "dtype required when constructing a null scalar" - ) + self._host_value = NA + else: + # TODO: The special handling of specific types below does not currently + # extend to nested types containing those types (e.g. List[timedelta] + # where the timedelta would overflow). We should eventually account for + # those cases, but that will require more careful consideration of how + # to traverse the contents of the nested data. + if is_datetime or is_timedelta: + time_unit = ps.type.unit + # Cast to int64 to avoid overflow + ps_cast = ps.cast(pa.int64()).as_py() + out_type = np.datetime64 if is_datetime else np.timedelta64 + self._host_value = out_type(ps_cast, time_unit) + elif ( + pa.types.is_integer(ps.type) + or pa.types.is_floating(ps.type) + or pa.types.is_boolean(ps.type) + ): + self._host_value = ps.type.to_pandas_dtype()(ps.as_py()) else: - dtype = value.dtype - - if not isinstance(dtype, cudf.core.dtypes.DecimalDtype): - dtype = cudf.dtype(dtype) + host_value = _maybe_nested_pa_scalar_to_py(ps) + _replace_nested(host_value, lambda item: item is None, NA) + self._host_value = host_value - if not valid: - value = NaT if dtype.kind in "mM" else NA - - return value, dtype - - def _sync(self): + def _sync(self) -> None: """ If the cache is not synched, copy either the device or host value to the host or device respectively. If cache is valid, do nothing @@ -279,27 +421,27 @@ def _sync(self): if self._is_host_value_current and self._is_device_value_current: return elif self._is_host_value_current and not self._is_device_value_current: - self._device_value = cudf._lib.scalar.DeviceScalar( + self._device_value = _to_plc_scalar( self._host_value, self._host_dtype ) elif self._is_device_value_current and not self._is_host_value_current: - self._host_value = self._device_value.value + self._device_value_to_host() self._host_dtype = self._host_value.dtype else: raise ValueError("Invalid cudf.Scalar") - def __index__(self): + def __index__(self) -> int: if self.dtype.kind not in {"u", "i"}: raise TypeError("Only Integer typed scalars may be used in slices") return int(self) - def __int__(self): + def __int__(self) -> int: return int(self.value) - def __float__(self): + def __float__(self) -> float: return float(self.value) - def __bool__(self): + def __bool__(self) -> bool: return bool(self.value) def __round__(self, n): @@ -321,7 +463,7 @@ def __invert__(self): def __neg__(self): return self._scalar_unaop("__neg__") - def __repr__(self): + def __repr__(self) -> str: # str() fixes a numpy bug with NaT # https://github.com/numpy/numpy/issues/17552 return ( @@ -403,13 +545,13 @@ def _unaop_result_type_or_error(self, op): return cudf.dtype("float64") return self.dtype - def _scalar_unaop(self, op): + def _scalar_unaop(self, op) -> None | Self: out_dtype = self._unaop_result_type_or_error(op) if not self.is_valid(): - result = None + return None else: result = self._dispatch_scalar_unaop(op) - return Scalar(result, dtype=out_dtype) + return Scalar(result, dtype=out_dtype) # type: ignore[return-value] def _dispatch_scalar_unaop(self, op): if op == "__floor__": @@ -418,7 +560,7 @@ def _dispatch_scalar_unaop(self, op): return np.ceil(self.value) return getattr(self.value, op)() - def astype(self, dtype): + def astype(self, dtype) -> Self: if self.dtype == dtype: return self - return Scalar(self.value, dtype) + return Scalar(self.value, dtype) # type: ignore[return-value] diff --git a/python/cudf/cudf/core/tools/datetimes.py b/python/cudf/cudf/core/tools/datetimes.py index 8be336021b1..4ca92be2498 100644 --- a/python/cudf/cudf/core/tools/datetimes.py +++ b/python/cudf/cudf/core/tools/datetimes.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2024, NVIDIA CORPORATION. +# Copyright (c) 2019-2025, NVIDIA CORPORATION. from __future__ import annotations import math @@ -998,7 +998,7 @@ def date_range( res = libcudf.column.Column.from_pylibcudf( plc.filling.calendrical_month_sequence( periods, - start.device_value.c_value, + start.device_value, months, ) ) diff --git a/python/cudf/cudf/tests/test_binops.py b/python/cudf/cudf/tests/test_binops.py index 0712a0de635..ef94b3cd176 100644 --- a/python/cudf/cudf/tests/test_binops.py +++ b/python/cudf/cudf/tests/test_binops.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2024, NVIDIA CORPORATION. +# Copyright (c) 2018-2025, NVIDIA CORPORATION. import decimal import operator @@ -3200,7 +3200,7 @@ def set_null_cases(column_l, column_r, case): "lcol,rcol,ans,case", generate_test_null_equals_columnops_data() ) def test_null_equals_columnops(lcol, rcol, ans, case): - assert lcol.equals(rcol).all() == ans + assert lcol.equals(rcol) == ans def test_add_series_to_dataframe(): diff --git a/python/cudf/cudf/tests/test_list.py b/python/cudf/cudf/tests/test_list.py index b1f81edfc54..359660e76a7 100644 --- a/python/cudf/cudf/tests/test_list.py +++ b/python/cudf/cudf/tests/test_list.py @@ -689,7 +689,6 @@ def test_list_getitem(data): def test_list_scalar_host_construction(data): slr = cudf.Scalar(data) assert slr.value == data - assert slr.device_value.value == data @pytest.mark.parametrize( diff --git a/python/cudf/cudf/tests/test_scalar.py b/python/cudf/cudf/tests/test_scalar.py index c14fab4040b..ba2bd040c38 100644 --- a/python/cudf/cudf/tests/test_scalar.py +++ b/python/cudf/cudf/tests/test_scalar.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021-2024, NVIDIA CORPORATION. +# Copyright (c) 2021-2025, NVIDIA CORPORATION. import datetime import re @@ -145,14 +145,11 @@ def test_scalar_host_initialization(value): def test_scalar_device_initialization(value): column = cudf.Series([value], nan_as_null=False)._column with acquire_spill_lock(): - dev_slr = cudf._lib.scalar.DeviceScalar.from_pylibcudf( - plc.copying.get_element( - column.to_pylibcudf(mode="read"), - 0, - ), - dtype=column.dtype, + dev_slr = plc.copying.get_element( + column.to_pylibcudf(mode="read"), + 0, ) - s = cudf.Scalar.from_device_scalar(dev_slr) + s = cudf.Scalar.from_pylibcudf(dev_slr) assert s._is_device_value_current assert not s._is_host_value_current @@ -172,14 +169,11 @@ def test_scalar_device_initialization_decimal(value, decimal_type): dtype = decimal_type._from_decimal(value) column = cudf.Series([str(value)]).astype(dtype)._column with acquire_spill_lock(): - dev_slr = cudf._lib.scalar.DeviceScalar.from_pylibcudf( - plc.copying.get_element( - column.to_pylibcudf(mode="read"), - 0, - ), - dtype=column.dtype, + dev_slr = plc.copying.get_element( + column.to_pylibcudf(mode="read"), + 0, ) - s = cudf.Scalar.from_device_scalar(dev_slr) + s = cudf.Scalar.from_pylibcudf(dev_slr) assert s._is_device_value_current assert not s._is_host_value_current @@ -387,34 +381,6 @@ def test_scalar_invalid_implicit_conversion(cls, dtype): cls(slr) -@pytest.mark.parametrize("value", SCALAR_VALUES + DECIMAL_VALUES) -@pytest.mark.parametrize( - "decimal_type", - [cudf.Decimal32Dtype, cudf.Decimal64Dtype, cudf.Decimal128Dtype], -) -def test_device_scalar_direct_construction(value, decimal_type): - value = cudf.utils.dtypes.to_cudf_compatible_scalar(value) - - dtype = ( - value.dtype - if not isinstance(value, Decimal) - else decimal_type._from_decimal(value) - ) - - s = cudf._lib.scalar.DeviceScalar(value, dtype) - - assert s.value == value or np.isnan(s.value) and np.isnan(value) - if isinstance( - dtype, (cudf.Decimal64Dtype, cudf.Decimal128Dtype, cudf.Decimal32Dtype) - ): - assert s.dtype.precision == dtype.precision - assert s.dtype.scale == dtype.scale - elif dtype.char == "U": - assert s.dtype == "object" - else: - assert s.dtype == dtype - - @pytest.mark.parametrize("value", SCALAR_VALUES + DECIMAL_VALUES) def test_construct_from_scalar(value): value = cudf.utils.dtypes.to_cudf_compatible_scalar(value) diff --git a/python/cudf/cudf/tests/test_struct.py b/python/cudf/cudf/tests/test_struct.py index b85943626a6..e7fca63d980 100644 --- a/python/cudf/cudf/tests/test_struct.py +++ b/python/cudf/cudf/tests/test_struct.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2024, NVIDIA CORPORATION. +# Copyright (c) 2020-2025, NVIDIA CORPORATION. import numpy as np import pandas as pd @@ -6,7 +6,6 @@ import pytest import cudf -from cudf.core.dtypes import StructDtype from cudf.testing import assert_eq from cudf.testing._utils import DATETIME_TYPES, TIMEDELTA_TYPES @@ -161,7 +160,6 @@ def test_struct_setitem(data, item): def test_struct_scalar_host_construction(data): slr = cudf.Scalar(data) assert slr.value == data - assert list(slr.device_value.value.values()) == list(data.values()) @pytest.mark.parametrize( @@ -194,12 +192,11 @@ def test_struct_scalar_host_construction_no_dtype_inference(data, dtype): # is empty. slr = cudf.Scalar(data, dtype=dtype) assert slr.value == data - assert list(slr.device_value.value.values()) == list(data.values()) def test_struct_scalar_null(): - slr = cudf.Scalar(cudf.NA, dtype=StructDtype) - assert slr.device_value.value is cudf.NA + slr = cudf.Scalar(cudf.NA, dtype=cudf.StructDtype) + assert cudf.Scalar.from_pylibcudf(slr.device_value).value is cudf.NA def test_struct_explode():