Skip to content

Commit

Permalink
address reviews
Browse files Browse the repository at this point in the history
  • Loading branch information
brandon-b-miller committed Oct 15, 2024
1 parent 0258478 commit c9199ec
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 38 deletions.
56 changes: 21 additions & 35 deletions python/cudf_polars/cudf_polars/dsl/expressions/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import pyarrow as pa
import pylibcudf as plc
from pylibcudf.traits import is_floating_point, is_integral_not_bool
from pylibcudf.traits import is_floating_point

from cudf_polars.containers import Column
from cudf_polars.dsl.expressions.base import AggInfo, ExecutionContext, Expr
Expand All @@ -24,10 +24,6 @@
__all__ = ["Cast", "UnaryFunction", "Len"]


def _is_int_or_float(dtype: plc.DataType) -> bool:
return is_integral_not_bool(dtype) or is_floating_point(dtype)


class Cast(Expr):
"""Class representing a cast of an expression."""

Expand All @@ -38,21 +34,9 @@ class Cast(Expr):
def __init__(self, dtype: plc.DataType, value: Expr) -> None:
super().__init__(dtype)
self.children = (value,)
if (
self.dtype.id() == plc.TypeId.STRING
or value.dtype.id() == plc.TypeId.STRING
):
if not (
(self.dtype.id() == plc.TypeId.STRING and _is_int_or_float(value.dtype))
or (
_is_int_or_float(self.dtype)
and value.dtype.id() == plc.TypeId.STRING
)
):
raise NotImplementedError("Only string to float cast is supported")
elif not dtypes.can_cast(value.dtype, self.dtype):
if not dtypes.can_cast(value.dtype, self.dtype):
raise NotImplementedError(
f"Can't cast {self.dtype.id().name} to {value.dtype.id().name}"
f"Can't cast {value.dtype.id().name} to {self.dtype.id().name}"
)

def do_evaluate(
Expand All @@ -69,26 +53,28 @@ def do_evaluate(
self.dtype.id() == plc.TypeId.STRING
or column.obj.type().id() == plc.TypeId.STRING
):
if self.dtype.id() == plc.TypeId.STRING:
if is_floating_point(column.obj.type()):
result = plc.strings.convert.convert_floats.from_floats(column.obj)
else:
result = plc.strings.convert.convert_integers.from_integers(
column.obj
)
else:
if is_floating_point(self.dtype):
result = plc.strings.convert.convert_floats.to_floats(
column.obj, self.dtype
)
else:
result = plc.strings.convert.convert_integers.to_integers(
column.obj, self.dtype
)
result = self._handle_string_cast(column)
else:
result = plc.unary.cast(column.obj, self.dtype)
return Column(result).sorted_like(column)

def _handle_string_cast(self, column: Column) -> plc.Column:
if self.dtype.id() == plc.TypeId.STRING:
if is_floating_point(column.obj.type()):
result = plc.strings.convert.convert_floats.from_floats(column.obj)
else:
result = plc.strings.convert.convert_integers.from_integers(column.obj)
else:
if is_floating_point(self.dtype):
result = plc.strings.convert.convert_floats.to_floats(
column.obj, self.dtype
)
else:
result = plc.strings.convert.convert_integers.to_integers(
column.obj, self.dtype
)
return result

def collect_agg(self, *, depth: int) -> AggInfo:
"""Collect information about aggregations in groupbys."""
# TODO: Could do with sort-based groupby and segmented filter
Expand Down
15 changes: 12 additions & 3 deletions python/cudf_polars/cudf_polars/utils/dtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pyarrow as pa
import pylibcudf as plc
from pylibcudf.traits import is_floating_point, is_integral_not_bool
from typing_extensions import assert_never

import polars as pl
Expand Down Expand Up @@ -45,6 +46,10 @@ def downcast_arrow_lists(typ: pa.DataType) -> pa.DataType:
return typ


def _is_int_or_float(dtype: plc.DataType) -> bool:
return is_integral_not_bool(dtype) or is_floating_point(dtype)


def can_cast(from_: plc.DataType, to: plc.DataType) -> bool:
"""
Can we cast (via :func:`~.pylibcudf.unary.cast`) between two datatypes.
Expand All @@ -61,9 +66,13 @@ def can_cast(from_: plc.DataType, to: plc.DataType) -> bool:
True if casting is supported, False otherwise
"""
return (
plc.traits.is_fixed_width(to)
and plc.traits.is_fixed_width(from_)
and plc.unary.is_supported_cast(from_, to)
(
plc.traits.is_fixed_width(to)
and plc.traits.is_fixed_width(from_)
and plc.unary.is_supported_cast(from_, to)
)
or (from_.id() == plc.TypeId.STRING and _is_int_or_float(to))
or (_is_int_or_float(from_) and to.id() == plc.TypeId.STRING)
)


Expand Down

0 comments on commit c9199ec

Please sign in to comment.