Skip to content

Commit

Permalink
Revert some old pieces of DeviceScalar._set_dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
mroeschke committed Dec 18, 2024
1 parent 088fcbe commit 23f7991
Showing 1 changed file with 36 additions and 10 deletions.
46 changes: 36 additions & 10 deletions python/cudf/cudf/_lib/scalar.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@ from libcpp.utility cimport move
import pylibcudf as plc

import cudf
from cudf._lib.types import dtype_from_pylibcudf_column
from cudf.core.dtypes import ListDtype, StructDtype
from cudf._lib.types import PYLIBCUDF_TO_SUPPORTED_NUMPY_TYPES
from cudf._lib.types cimport dtype_from_column_view, underlying_type_t_type_id
from cudf.core.missing import NA, NaT

# We currently need this cimport because some of the implementations here
# access the c_obj of the scalar, and because we need to be able to call
# pylibcudf.Scalar.from_libcudf. Both of those are temporarily acceptable until
# DeviceScalar is phased out entirely from cuDF Cython (at which point
# cudf.Scalar will be directly backed by pylibcudf.Scalar).
from pylibcudf cimport Scalar as plc_Scalar
from pylibcudf.libcudf.scalar.scalar cimport scalar
from pylibcudf cimport Scalar as plc_Scalar, type_id as plc_TypeID
from pylibcudf.libcudf.scalar.scalar cimport list_scalar, scalar, struct_scalar


def _replace_nested(obj, check, replacement):
Expand Down Expand Up @@ -221,19 +223,43 @@ cdef class DeviceScalar:
return s

cdef void _set_dtype(self, dtype=None):
cdef plc_TypeID cdtype_id = self.c_value.type().id()
if dtype is not None:
self._dtype = dtype

plc_scalar = self.c_value
if plc_scalar.type().id() in {
plc.TypeId.DECIMAL32,
plc.TypeId.DECIMAL64,
plc.TypeId.DECIMAL128,
elif cdtype_id in {
plc_TypeID.DECIMAL32,
plc_TypeID.DECIMAL64,
plc_TypeID.DECIMAL128,
}:
raise TypeError(
"Must pass a dtype when constructing from a fixed-point scalar"
)
self._dtype = dtype_from_pylibcudf_column(plc.Column.from_scalar(plc_scalar, 1))
elif cdtype_id == plc_TypeID.STRUCT:
struct_table_view = (<struct_scalar*>self.get_raw_ptr())[0].view()
self._dtype = StructDtype({
str(i): dtype_from_column_view(struct_table_view.column(i))
for i in range(struct_table_view.num_columns())
})
elif cdtype_id == plc_TypeID.LIST:
if (
<list_scalar*>self.get_raw_ptr()
)[0].view().type().id() == plc_TypeID.LIST:
self._dtype = dtype_from_column_view(
(<list_scalar*>self.get_raw_ptr())[0].view()
)
else:
self._dtype = ListDtype(
PYLIBCUDF_TO_SUPPORTED_NUMPY_TYPES[
<underlying_type_t_type_id>(
(<list_scalar*>self.get_raw_ptr())[0]
.view().type().id()
)
]
)
else:
self._dtype = PYLIBCUDF_TO_SUPPORTED_NUMPY_TYPES[
<underlying_type_t_type_id>(cdtype_id)
]


def as_device_scalar(val, dtype=None):
Expand Down

0 comments on commit 23f7991

Please sign in to comment.