From c9199ec01359fcd97efbf506f6ac73e93fc504de Mon Sep 17 00:00:00 2001 From: brandon-b-miller Date: Tue, 15 Oct 2024 11:35:30 -0700 Subject: [PATCH] address reviews --- .../cudf_polars/dsl/expressions/unary.py | 56 +++++++------------ .../cudf_polars/cudf_polars/utils/dtypes.py | 15 ++++- 2 files changed, 33 insertions(+), 38 deletions(-) diff --git a/python/cudf_polars/cudf_polars/dsl/expressions/unary.py b/python/cudf_polars/cudf_polars/dsl/expressions/unary.py index 8cf28801986..c9e52e5056f 100644 --- a/python/cudf_polars/cudf_polars/dsl/expressions/unary.py +++ b/python/cudf_polars/cudf_polars/dsl/expressions/unary.py @@ -9,7 +9,7 @@ import pyarrow as pa import pylibcudf as plc -from pylibcudf.traits import is_floating_point, is_integral_not_bool +from pylibcudf.traits import is_floating_point from cudf_polars.containers import Column from cudf_polars.dsl.expressions.base import AggInfo, ExecutionContext, Expr @@ -24,10 +24,6 @@ __all__ = ["Cast", "UnaryFunction", "Len"] -def _is_int_or_float(dtype: plc.DataType) -> bool: - return is_integral_not_bool(dtype) or is_floating_point(dtype) - - class Cast(Expr): """Class representing a cast of an expression.""" @@ -38,21 +34,9 @@ class Cast(Expr): def __init__(self, dtype: plc.DataType, value: Expr) -> None: super().__init__(dtype) self.children = (value,) - if ( - self.dtype.id() == plc.TypeId.STRING - or value.dtype.id() == plc.TypeId.STRING - ): - if not ( - (self.dtype.id() == plc.TypeId.STRING and _is_int_or_float(value.dtype)) - or ( - _is_int_or_float(self.dtype) - and value.dtype.id() == plc.TypeId.STRING - ) - ): - raise NotImplementedError("Only string to float cast is supported") - elif not dtypes.can_cast(value.dtype, self.dtype): + if not dtypes.can_cast(value.dtype, self.dtype): raise NotImplementedError( - f"Can't cast {self.dtype.id().name} to {value.dtype.id().name}" + f"Can't cast {value.dtype.id().name} to {self.dtype.id().name}" ) def do_evaluate( @@ -69,26 +53,28 @@ def do_evaluate( self.dtype.id() == plc.TypeId.STRING or column.obj.type().id() == plc.TypeId.STRING ): - if self.dtype.id() == plc.TypeId.STRING: - if is_floating_point(column.obj.type()): - result = plc.strings.convert.convert_floats.from_floats(column.obj) - else: - result = plc.strings.convert.convert_integers.from_integers( - column.obj - ) - else: - if is_floating_point(self.dtype): - result = plc.strings.convert.convert_floats.to_floats( - column.obj, self.dtype - ) - else: - result = plc.strings.convert.convert_integers.to_integers( - column.obj, self.dtype - ) + result = self._handle_string_cast(column) else: result = plc.unary.cast(column.obj, self.dtype) return Column(result).sorted_like(column) + def _handle_string_cast(self, column: Column) -> plc.Column: + if self.dtype.id() == plc.TypeId.STRING: + if is_floating_point(column.obj.type()): + result = plc.strings.convert.convert_floats.from_floats(column.obj) + else: + result = plc.strings.convert.convert_integers.from_integers(column.obj) + else: + if is_floating_point(self.dtype): + result = plc.strings.convert.convert_floats.to_floats( + column.obj, self.dtype + ) + else: + result = plc.strings.convert.convert_integers.to_integers( + column.obj, self.dtype + ) + return result + def collect_agg(self, *, depth: int) -> AggInfo: """Collect information about aggregations in groupbys.""" # TODO: Could do with sort-based groupby and segmented filter diff --git a/python/cudf_polars/cudf_polars/utils/dtypes.py b/python/cudf_polars/cudf_polars/utils/dtypes.py index 4154a404e98..fb7ed0aaf2b 100644 --- a/python/cudf_polars/cudf_polars/utils/dtypes.py +++ b/python/cudf_polars/cudf_polars/utils/dtypes.py @@ -9,6 +9,7 @@ import pyarrow as pa import pylibcudf as plc +from pylibcudf.traits import is_floating_point, is_integral_not_bool from typing_extensions import assert_never import polars as pl @@ -45,6 +46,10 @@ def downcast_arrow_lists(typ: pa.DataType) -> pa.DataType: return typ +def _is_int_or_float(dtype: plc.DataType) -> bool: + return is_integral_not_bool(dtype) or is_floating_point(dtype) + + def can_cast(from_: plc.DataType, to: plc.DataType) -> bool: """ Can we cast (via :func:`~.pylibcudf.unary.cast`) between two datatypes. @@ -61,9 +66,13 @@ def can_cast(from_: plc.DataType, to: plc.DataType) -> bool: True if casting is supported, False otherwise """ return ( - plc.traits.is_fixed_width(to) - and plc.traits.is_fixed_width(from_) - and plc.unary.is_supported_cast(from_, to) + ( + plc.traits.is_fixed_width(to) + and plc.traits.is_fixed_width(from_) + and plc.unary.is_supported_cast(from_, to) + ) + or (from_.id() == plc.TypeId.STRING and _is_int_or_float(to)) + or (_is_int_or_float(from_) and to.id() == plc.TypeId.STRING) )