diff --git a/ibis/backends/bigquery/compiler.py b/ibis/backends/bigquery/compiler.py index a858b5226c66..b32a6476c48f 100644 --- a/ibis/backends/bigquery/compiler.py +++ b/ibis/backends/bigquery/compiler.py @@ -18,10 +18,8 @@ exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_rank, exclude_unsupported_window_frame_from_row_number, - rewrite_sample_as_filter, ) from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit -from ibis.expr.rewrites import rewrite_stringslice _NAME_REGEX = re.compile(r'[^!"$()*,./;?@[\\\]^`{}~\n]+') @@ -31,11 +29,9 @@ class BigQueryCompiler(SQLGlotCompiler): type_mapper = BigQueryType udf_type_mapper = BigQueryUDFType rewrites = ( - rewrite_sample_as_filter, exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_row_number, exclude_unsupported_window_frame_from_rank, - rewrite_stringslice, *SQLGlotCompiler.rewrites, ) diff --git a/ibis/backends/clickhouse/compiler.py b/ibis/backends/clickhouse/compiler.py index 1f9fbee52244..c99635b5fac2 100644 --- a/ibis/backends/clickhouse/compiler.py +++ b/ibis/backends/clickhouse/compiler.py @@ -14,8 +14,6 @@ from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler from ibis.backends.sql.datatypes import ClickHouseType from ibis.backends.sql.dialects import ClickHouse -from ibis.backends.sql.rewrites import rewrite_sample_as_filter -from ibis.expr.rewrites import rewrite_stringslice class ClickHouseCompiler(SQLGlotCompiler): @@ -23,11 +21,6 @@ class ClickHouseCompiler(SQLGlotCompiler): dialect = ClickHouse type_mapper = ClickHouseType - rewrites = ( - rewrite_sample_as_filter, - rewrite_stringslice, - *SQLGlotCompiler.rewrites, - ) UNSUPPORTED_OPS = ( ops.RowID, diff --git a/ibis/backends/datafusion/compiler.py b/ibis/backends/datafusion/compiler.py index 817c17d40949..b4aebd187736 100644 --- a/ibis/backends/datafusion/compiler.py +++ b/ibis/backends/datafusion/compiler.py @@ -14,10 +14,8 @@ from ibis.backends.sql.compiler import FALSE, NULL, STAR, SQLGlotCompiler from ibis.backends.sql.datatypes import DataFusionType from ibis.backends.sql.dialects import DataFusion -from ibis.backends.sql.rewrites import rewrite_sample_as_filter from ibis.common.temporal import IntervalUnit, TimestampUnit from ibis.expr.operations.udf import InputType -from ibis.expr.rewrites import rewrite_stringslice from ibis.formats.pyarrow import PyArrowType @@ -26,11 +24,6 @@ class DataFusionCompiler(SQLGlotCompiler): dialect = DataFusion type_mapper = DataFusionType - rewrites = ( - rewrite_sample_as_filter, - rewrite_stringslice, - *SQLGlotCompiler.rewrites, - ) UNSUPPORTED_OPS = ( ops.ArgMax, diff --git a/ibis/backends/druid/compiler.py b/ibis/backends/druid/compiler.py index a2e05985f22f..ab4eb2020e4b 100644 --- a/ibis/backends/druid/compiler.py +++ b/ibis/backends/druid/compiler.py @@ -9,11 +9,6 @@ from ibis.backends.sql.compiler import NULL, SQLGlotCompiler from ibis.backends.sql.datatypes import DruidType from ibis.backends.sql.dialects import Druid -from ibis.backends.sql.rewrites import ( - rewrite_capitalize, - rewrite_sample_as_filter, -) -from ibis.expr.rewrites import rewrite_stringslice class DruidCompiler(SQLGlotCompiler): @@ -21,15 +16,8 @@ class DruidCompiler(SQLGlotCompiler): dialect = Druid type_mapper = DruidType - rewrites = ( - rewrite_sample_as_filter, - rewrite_stringslice, - *( - rewrite - for rewrite in SQLGlotCompiler.rewrites - if rewrite is not rewrite_capitalize - ), - ) + + LOWERED_OPS = {ops.Capitalize: None} UNSUPPORTED_OPS = ( ops.ApproxMedian, diff --git a/ibis/backends/duckdb/compiler.py b/ibis/backends/duckdb/compiler.py index 2489d293dd46..d17ba9360ba4 100644 --- a/ibis/backends/duckdb/compiler.py +++ b/ibis/backends/duckdb/compiler.py @@ -33,6 +33,11 @@ class DuckDBCompiler(SQLGlotCompiler): dialect = DuckDB type_mapper = DuckDBType + LOWERED_OPS = { + ops.Sample: None, + ops.StringSlice: None, + } + SIMPLE_OPS = { ops.Arbitrary: "any_value", ops.ArrayPosition: "list_indexof", diff --git a/ibis/backends/exasol/compiler.py b/ibis/backends/exasol/compiler.py index 5b9adcf6eee5..de06d99b03cb 100644 --- a/ibis/backends/exasol/compiler.py +++ b/ibis/backends/exasol/compiler.py @@ -14,9 +14,7 @@ exclude_unsupported_window_frame_from_rank, exclude_unsupported_window_frame_from_row_number, rewrite_empty_order_by_window, - rewrite_sample_as_filter, ) -from ibis.expr.rewrites import rewrite_stringslice class ExasolCompiler(SQLGlotCompiler): @@ -25,12 +23,10 @@ class ExasolCompiler(SQLGlotCompiler): dialect = Exasol type_mapper = ExasolType rewrites = ( - rewrite_sample_as_filter, exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_rank, exclude_unsupported_window_frame_from_row_number, rewrite_empty_order_by_window, - rewrite_stringslice, *SQLGlotCompiler.rewrites, ) diff --git a/ibis/backends/flink/compiler.py b/ibis/backends/flink/compiler.py index 9404c7dbd4cf..211ce3c9882a 100644 --- a/ibis/backends/flink/compiler.py +++ b/ibis/backends/flink/compiler.py @@ -15,9 +15,7 @@ exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_rank, exclude_unsupported_window_frame_from_row_number, - rewrite_sample_as_filter, ) -from ibis.expr.rewrites import rewrite_stringslice class FlinkCompiler(SQLGlotCompiler): @@ -25,11 +23,9 @@ class FlinkCompiler(SQLGlotCompiler): dialect = Flink type_mapper = FlinkType rewrites = ( - rewrite_sample_as_filter, exclude_unsupported_window_frame_from_row_number, exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_rank, - rewrite_stringslice, *SQLGlotCompiler.rewrites, ) diff --git a/ibis/backends/impala/compiler.py b/ibis/backends/impala/compiler.py index dbf15f3da615..21fa2bee1d15 100644 --- a/ibis/backends/impala/compiler.py +++ b/ibis/backends/impala/compiler.py @@ -12,9 +12,7 @@ from ibis.backends.sql.dialects import Impala from ibis.backends.sql.rewrites import ( rewrite_empty_order_by_window, - rewrite_sample_as_filter, ) -from ibis.expr.rewrites import rewrite_stringslice class ImpalaCompiler(SQLGlotCompiler): @@ -23,9 +21,7 @@ class ImpalaCompiler(SQLGlotCompiler): dialect = Impala type_mapper = ImpalaType rewrites = ( - rewrite_sample_as_filter, rewrite_empty_order_by_window, - rewrite_stringslice, *SQLGlotCompiler.rewrites, ) diff --git a/ibis/backends/mssql/compiler.py b/ibis/backends/mssql/compiler.py index f164caae10e7..920d1a99b5e1 100644 --- a/ibis/backends/mssql/compiler.py +++ b/ibis/backends/mssql/compiler.py @@ -23,10 +23,8 @@ exclude_unsupported_window_frame_from_row_number, p, replace, - rewrite_sample_as_filter, ) from ibis.common.deferred import var -from ibis.expr.rewrites import rewrite_stringslice y = var("y") start = var("start") @@ -59,11 +57,9 @@ class MSSQLCompiler(SQLGlotCompiler): dialect = MSSQL type_mapper = MSSQLType rewrites = ( - rewrite_sample_as_filter, exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_row_number, rewrite_rows_range_order_by_window, - rewrite_stringslice, *SQLGlotCompiler.rewrites, ) copy_func_args = True diff --git a/ibis/backends/mysql/compiler.py b/ibis/backends/mysql/compiler.py index 64548c5f8ffa..d8bf40d56c02 100644 --- a/ibis/backends/mysql/compiler.py +++ b/ibis/backends/mysql/compiler.py @@ -18,10 +18,9 @@ exclude_unsupported_window_frame_from_rank, exclude_unsupported_window_frame_from_row_number, rewrite_empty_order_by_window, - rewrite_sample_as_filter, ) from ibis.common.patterns import replace -from ibis.expr.rewrites import p, rewrite_stringslice +from ibis.expr.rewrites import p @replace(p.Limit) @@ -50,12 +49,10 @@ class MySQLCompiler(SQLGlotCompiler): type_mapper = MySQLType rewrites = ( rewrite_limit, - rewrite_sample_as_filter, exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_rank, exclude_unsupported_window_frame_from_row_number, rewrite_empty_order_by_window, - rewrite_stringslice, *SQLGlotCompiler.rewrites, ) diff --git a/ibis/backends/oracle/compiler.py b/ibis/backends/oracle/compiler.py index 6ed67e51a064..4129aab09082 100644 --- a/ibis/backends/oracle/compiler.py +++ b/ibis/backends/oracle/compiler.py @@ -15,12 +15,10 @@ LastValue, exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_row_number, - replace_log2, - replace_log10, + lower_log2, + lower_log10, rewrite_empty_order_by_window, - rewrite_sample_as_filter, ) -from ibis.expr.rewrites import rewrite_stringslice @public @@ -33,10 +31,6 @@ class OracleCompiler(SQLGlotCompiler): exclude_unsupported_window_frame_from_row_number, exclude_unsupported_window_frame_from_ops, rewrite_empty_order_by_window, - rewrite_sample_as_filter, - rewrite_stringslice, - replace_log2, - replace_log10, *SQLGlotCompiler.rewrites, ) @@ -49,6 +43,11 @@ class OracleCompiler(SQLGlotCompiler): NEG_INF = sge.Literal.number("-binary_double_infinity") """Backend's negative infinity literal.""" + LOWERED_OPS = { + ops.Log2: lower_log2, + ops.Log10: lower_log10, + } + UNSUPPORTED_OPS = ( ops.ArgMax, ops.ArgMin, diff --git a/ibis/backends/pandas/rewrites.py b/ibis/backends/pandas/rewrites.py index 53f2b6a813e3..87e55026dae5 100644 --- a/ibis/backends/pandas/rewrites.py +++ b/ibis/backends/pandas/rewrites.py @@ -12,7 +12,7 @@ from ibis.common.collections import FrozenDict from ibis.common.patterns import InstanceOf, replace from ibis.common.typing import VarTuple # noqa: TCH001 -from ibis.expr.rewrites import p, replace_parameter, rewrite_stringslice +from ibis.expr.rewrites import lower_stringslice, p, replace_parameter from ibis.expr.schema import Schema from ibis.util import gen_name @@ -354,7 +354,7 @@ def plan(node, backend, params): | rewrite_join | rewrite_limit | replace_parameter - | rewrite_stringslice + | lower_stringslice | bind_unbound_table, context=ctx, ) diff --git a/ibis/backends/polars/__init__.py b/ibis/backends/polars/__init__.py index f2e4ffa24bed..d16a0d0b4f2f 100644 --- a/ibis/backends/polars/__init__.py +++ b/ibis/backends/polars/__init__.py @@ -20,7 +20,7 @@ ) from ibis.backends.polars.compiler import translate from ibis.backends.sql.dialects import Polars -from ibis.expr.rewrites import rewrite_stringslice +from ibis.expr.rewrites import lower_stringslice from ibis.formats.polars import PolarsSchema from ibis.util import gen_name, normalize_filename @@ -406,7 +406,7 @@ def compile( node = expr.as_table().op() node = node.replace( - rewrite_join | replace_parameter | bind_unbound_table | rewrite_stringslice, + rewrite_join | replace_parameter | bind_unbound_table | lower_stringslice, context={"params": params, "backend": self}, ) diff --git a/ibis/backends/postgres/compiler.py b/ibis/backends/postgres/compiler.py index 97c59fd72a5f..7131f061afb2 100644 --- a/ibis/backends/postgres/compiler.py +++ b/ibis/backends/postgres/compiler.py @@ -14,8 +14,6 @@ from ibis.backends.sql.compiler import NULL, STAR, SQLGlotCompiler from ibis.backends.sql.datatypes import PostgresType from ibis.backends.sql.dialects import Postgres -from ibis.backends.sql.rewrites import rewrite_sample_as_filter -from ibis.expr.rewrites import rewrite_stringslice class PostgresUDFNode(ops.Value): @@ -28,11 +26,6 @@ class PostgresCompiler(SQLGlotCompiler): dialect = Postgres type_mapper = PostgresType - rewrites = ( - rewrite_sample_as_filter, - *SQLGlotCompiler.rewrites, - rewrite_stringslice, - ) NAN = sge.Literal.number("'NaN'::double precision") POS_INF = sge.Literal.number("'Inf'::double precision") diff --git a/ibis/backends/pyspark/compiler.py b/ibis/backends/pyspark/compiler.py index 151c13b044e2..27525f169296 100644 --- a/ibis/backends/pyspark/compiler.py +++ b/ibis/backends/pyspark/compiler.py @@ -18,7 +18,6 @@ from ibis.backends.sql.rewrites import FirstValue, LastValue, p from ibis.common.patterns import replace from ibis.config import options -from ibis.expr.rewrites import rewrite_stringslice from ibis.util import gen_name @@ -51,7 +50,7 @@ class PySparkCompiler(SQLGlotCompiler): dialect = PySpark type_mapper = PySparkType - rewrites = (offset_to_filter, *SQLGlotCompiler.rewrites, rewrite_stringslice) + rewrites = (offset_to_filter, *SQLGlotCompiler.rewrites) UNSUPPORTED_OPS = ( ops.RowID, @@ -59,6 +58,10 @@ class PySparkCompiler(SQLGlotCompiler): ops.RandomUUID, ) + LOWERED_OPS = { + ops.Sample: None, + } + SIMPLE_OPS = { ops.ArrayDistinct: "array_distinct", ops.ArrayFlatten: "flatten", diff --git a/ibis/backends/snowflake/compiler.py b/ibis/backends/snowflake/compiler.py index 4be340acf16c..473918aad66d 100644 --- a/ibis/backends/snowflake/compiler.py +++ b/ibis/backends/snowflake/compiler.py @@ -17,11 +17,10 @@ from ibis.backends.sql.rewrites import ( exclude_unsupported_window_frame_from_ops, exclude_unsupported_window_frame_from_row_number, - replace_log2, - replace_log10, + lower_log2, + lower_log10, rewrite_empty_order_by_window, ) -from ibis.expr.rewrites import rewrite_stringslice class SnowflakeFuncGen(FuncGen): @@ -39,12 +38,15 @@ class SnowflakeCompiler(SQLGlotCompiler): exclude_unsupported_window_frame_from_row_number, exclude_unsupported_window_frame_from_ops, rewrite_empty_order_by_window, - rewrite_stringslice, - replace_log2, - replace_log10, *SQLGlotCompiler.rewrites, ) + LOWERED_OPS = { + ops.Log2: lower_log2, + ops.Log10: lower_log10, + ops.Sample: None, + } + UNSUPPORTED_OPS = ( ops.ArrayMap, ops.ArrayFilter, diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index 5e927e5ac983..4129fb74185c 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -4,6 +4,7 @@ import calendar import itertools import math +import operator import string from collections.abc import Mapping from functools import partial, reduce @@ -23,13 +24,15 @@ add_one_to_nth_value_input, add_order_by_to_empty_ranking_window_functions, empty_in_values_right_side, + lower_bucket, + lower_capitalize, + lower_sample, one_to_zero_index, - rewrite_capitalize, sqlize, ) from ibis.config import options from ibis.expr.operations.udf import InputType -from ibis.expr.rewrites import replace_bucket +from ibis.expr.rewrites import lower_stringslice if TYPE_CHECKING: from collections.abc import Iterable @@ -167,23 +170,14 @@ def wrapper(self, op, *, left, right): class SQLGlotCompiler(abc.ABC): __slots__ = "agg", "f", "v" - rewrites: tuple = ( + rewrites: tuple[type[pats.Replace], ...] = ( empty_in_values_right_side, add_order_by_to_empty_ranking_window_functions, one_to_zero_index, add_one_to_nth_value_input, - replace_bucket, - rewrite_capitalize, ) """A sequence of rewrites to apply to the expression tree before compilation.""" - extra_supported_ops: frozenset = frozenset( - (ops.Project, ops.Filter, ops.Sort, ops.WindowFunction) - ) - """A frozenset of ops classes that are supported, but don't have explicit - `visit_*` methods (usually due to being handled by rewrite rules). Used by - `has_operation`""" - no_limit_value: sge.Null | None = None """The value to use to indicate no limit.""" @@ -208,9 +202,29 @@ class SQLGlotCompiler(abc.ABC): ) """Backend's negative infinity literal.""" - UNSUPPORTED_OPS: tuple[type[ops.Node]] = () + EXTRA_SUPPORTED_OPS: tuple[type[ops.Node], ...] = ( + ops.Project, + ops.Filter, + ops.Sort, + ops.WindowFunction, + ) + """A tuple of ops classes that are supported, but don't have explicit + `visit_*` methods (usually due to being handled by rewrite rules). Used by + `has_operation`""" + + UNSUPPORTED_OPS: tuple[type[ops.Node], ...] = () """Tuple of operations the backend doesn't support.""" + LOWERED_OPS: dict[type[ops.Node], pats.Replace | None] = { + ops.Bucket: lower_bucket, + ops.Capitalize: lower_capitalize, + ops.Sample: lower_sample, + ops.StringSlice: lower_stringslice, + } + """A mapping from an operation class to either a rewrite rule for rewriting that + operation to one composed of lower-level operations ("lowering"), or `None` to + remove an existing rewrite rule for that operation added in a base class""" + SIMPLE_OPS = { ops.Abs: "abs", ops.Acos: "acos", @@ -326,6 +340,11 @@ class SQLGlotCompiler(abc.ABC): NEEDS_PARENS = BINARY_INFIX_OPS + (ops.IsNull,) + # Constructed dynamically in `__init_subclass__` from their respective + # UPPERCASE values to handle inheritance, do not modify directly here. + extra_supported_ops: ClassVar[frozenset[type[ops.Node]]] = frozenset() + lowered_ops: ClassVar[dict[type[ops.Node], pats.Replace]] = {} + def __init__(self) -> None: self.agg = AggGen(aggfunc=self._aggregate) self.f = FuncGen(copy=self.__class__.copy_func_args) @@ -359,29 +378,28 @@ def impl(self, _, *, _name: str = target_name, **kw): # TODO: handle geoespatial ops as a separate case? setattr(cls, methodname(op), cls.visit_Undefined) - # override existing base class implementations for op, target_name in cls.SIMPLE_OPS.items(): setattr(cls, methodname(op), make_impl(op, target_name)) - # add simple ops that are not already implemented - for op, target_name in SQLGlotCompiler.SIMPLE_OPS.items(): - name = methodname(op) - if not hasattr(cls, name): - setattr(cls, name, make_impl(op, target_name)) - # raise on any remaining unsupported operations for op in ALL_OPERATIONS: name = methodname(op) if not hasattr(cls, name): setattr(cls, name, cls.visit_Undefined) - # Expand extra_supported_ops with any rewrite rules + # Amend `lowered_ops` and `extra_supported_ops` using their + # respective UPPERCASE classvar values. extra_supported_ops = set(cls.extra_supported_ops) - for rule in cls.rewrites: - if isinstance(rule, pats.Replace) and isinstance( - rule.matcher, pats.InstanceOf - ): - extra_supported_ops.add(rule.matcher.type) + lowered_ops = dict(cls.lowered_ops) + extra_supported_ops.update(cls.EXTRA_SUPPORTED_OPS) + for op_cls, rewrite in cls.LOWERED_OPS.items(): + if rewrite is not None: + lowered_ops[op_cls] = rewrite + extra_supported_ops.add(op_cls) + else: + lowered_ops.pop(op_cls, None) + extra_supported_ops.discard(op_cls) + cls.lowered_ops = lowered_ops cls.extra_supported_ops = frozenset(extra_supported_ops) @property @@ -454,6 +472,8 @@ def translate(self, op, *, params: Mapping[ir.Value, Any]) -> sge.Expression: # substitute parameters immediately to avoid having to define a # ScalarParameter translation rule params = self._prepare_params(params) + if self.lowered_ops: + op = op.replace(reduce(operator.or_, self.lowered_ops.values())) op, ctes = sqlize( op, params=params, @@ -1498,3 +1518,8 @@ def visit_Unsupported(self, op, **_): raise com.UnsupportedOperationError( f"{type(op).__name__!r} operation is not supported in the {self.dialect} backend" ) + + +# `__init_subclass__` is uncalled for subclasses - we manually call it here to +# autogenerate the base class implementations as well. +SQLGlotCompiler.__init_subclass__() diff --git a/ibis/backends/sql/rewrites.py b/ibis/backends/sql/rewrites.py index 579dbcc9c0e2..342b428a6cb7 100644 --- a/ibis/backends/sql/rewrites.py +++ b/ibis/backends/sql/rewrites.py @@ -16,7 +16,7 @@ from ibis.common.collections import FrozenDict # noqa: TCH001 from ibis.common.deferred import var from ibis.common.graph import Graph -from ibis.common.patterns import InstanceOf, Object, Pattern, _, replace +from ibis.common.patterns import InstanceOf, Object, Pattern, replace from ibis.common.typing import VarTuple # noqa: TCH001 from ibis.expr.rewrites import d, p, replace_parameter from ibis.expr.schema import Schema @@ -266,16 +266,10 @@ def wrap(node, _, **kwargs): # supplemental rewrites selectively used on a per-backend basis -"""Replace `log2` and `log10` with `log`.""" -replace_log2 = p.Log2 >> d.Log(_.arg, base=2) -replace_log10 = p.Log10 >> d.Log(_.arg, base=10) - - -"""Add an ORDER BY clause to rank window functions that don't have one.""" - @replace(p.WindowFunction(func=p.NTile(y), order_by=())) def add_order_by_to_empty_ranking_window_functions(_, **kwargs): + """Add an ORDER BY clause to rank window functions that don't have one.""" return _.copy(order_by=(y,)) @@ -303,31 +297,6 @@ def add_one_to_nth_value_input(_, **kwargs): return _.copy(nth=nth) -@replace(p.Capitalize) -def rewrite_capitalize(_, **kwargs): - """Rewrite Capitalize in terms of substring, concat, upper, and lower.""" - first = ops.Uppercase(ops.Substring(_.arg, start=0, length=1)) - # use length instead of length - 1 to avoid backends complaining about - # asking for negative length - # - # there are at most length - 1 characters, so asking for length is fine - rest = ops.Lowercase(ops.Substring(_.arg, start=1, length=ops.StringLength(_.arg))) - return ops.StringConcat((first, rest)) - - -@replace(p.Sample) -def rewrite_sample_as_filter(_, **kwargs): - """Rewrite Sample as `t.filter(random() <= fraction)`. - - Errors as unsupported if a `seed` is specified. - """ - if _.seed is not None: - raise com.UnsupportedOperationError( - "`Table.sample` with a random seed is unsupported" - ) - return ops.Filter(_.parent, (ops.LessEqual(ops.RandomScalar(), _.fraction),)) - - @replace(p.WindowFunction(order_by=())) def rewrite_empty_order_by_window(_, **kwargs): return _.copy(order_by=(ops.NULL,)) @@ -352,3 +321,98 @@ def exclude_unsupported_window_frame_from_rank(_, **kwargs): ) def exclude_unsupported_window_frame_from_ops(_, **kwargs): return _.copy(start=None, end=0, order_by=_.order_by or (ops.NULL,)) + + +# Rewrite rules for lowering a high-level operation into one composed of more +# primitive operations. + + +@replace(p.Log2) +def lower_log2(_, **kwargs): + """Rewrite `log2` as `log`.""" + return ops.Log(_.arg, base=2) + + +@replace(p.Log10) +def lower_log10(_, **kwargs): + """Rewrite `log10` as `log`.""" + return ops.Log(_.arg, base=10) + + +@replace(p.Bucket) +def lower_bucket(_, **kwargs): + """Rewrite `Bucket` as `SearchedCase`.""" + cases = [] + results = [] + + if _.closed == "left": + l_cmp = ops.LessEqual + r_cmp = ops.Less + else: + l_cmp = ops.Less + r_cmp = ops.LessEqual + + user_num_buckets = len(_.buckets) - 1 + + bucket_id = 0 + if _.include_under: + if user_num_buckets > 0: + cmp = ops.Less if _.close_extreme else r_cmp + else: + cmp = ops.LessEqual if _.closed == "right" else ops.Less + cases.append(cmp(_.arg, _.buckets[0])) + results.append(bucket_id) + bucket_id += 1 + + for j, (lower, upper) in enumerate(zip(_.buckets, _.buckets[1:])): + if _.close_extreme and ( + (_.closed == "right" and j == 0) + or (_.closed == "left" and j == (user_num_buckets - 1)) + ): + cases.append( + ops.And(ops.LessEqual(lower, _.arg), ops.LessEqual(_.arg, upper)) + ) + results.append(bucket_id) + else: + cases.append(ops.And(l_cmp(lower, _.arg), r_cmp(_.arg, upper))) + results.append(bucket_id) + bucket_id += 1 + + if _.include_over: + if user_num_buckets > 0: + cmp = ops.Less if _.close_extreme else l_cmp + else: + cmp = ops.Less if _.closed == "right" else ops.LessEqual + + cases.append(cmp(_.buckets[-1], _.arg)) + results.append(bucket_id) + bucket_id += 1 + + return ops.SearchedCase( + cases=tuple(cases), results=tuple(results), default=ops.NULL + ) + + +@replace(p.Capitalize) +def lower_capitalize(_, **kwargs): + """Rewrite Capitalize in terms of substring, concat, upper, and lower.""" + first = ops.Uppercase(ops.Substring(_.arg, start=0, length=1)) + # use length instead of length - 1 to avoid backends complaining about + # asking for negative length + # + # there are at most length - 1 characters, so asking for length is fine + rest = ops.Lowercase(ops.Substring(_.arg, start=1, length=ops.StringLength(_.arg))) + return ops.StringConcat((first, rest)) + + +@replace(p.Sample) +def lower_sample(_, **kwargs): + """Rewrite Sample as `t.filter(random() <= fraction)`. + + Errors as unsupported if a `seed` is specified. + """ + if _.seed is not None: + raise com.UnsupportedOperationError( + "`Table.sample` with a random seed is unsupported" + ) + return ops.Filter(_.parent, (ops.LessEqual(ops.RandomScalar(), _.fraction),)) diff --git a/ibis/backends/sqlite/compiler.py b/ibis/backends/sqlite/compiler.py index d10753f505ab..2e7e6b279c16 100644 --- a/ibis/backends/sqlite/compiler.py +++ b/ibis/backends/sqlite/compiler.py @@ -12,9 +12,7 @@ from ibis.backends.sql.compiler import NULL, SQLGlotCompiler from ibis.backends.sql.datatypes import SQLiteType from ibis.backends.sql.dialects import SQLite -from ibis.backends.sql.rewrites import rewrite_sample_as_filter from ibis.common.temporal import DateUnit, IntervalUnit -from ibis.expr.rewrites import rewrite_stringslice @public @@ -23,11 +21,6 @@ class SQLiteCompiler(SQLGlotCompiler): dialect = SQLite type_mapper = SQLiteType - rewrites = ( - rewrite_sample_as_filter, - rewrite_stringslice, - *SQLGlotCompiler.rewrites, - ) NAN = NULL POS_INF = sge.Literal.number("1e999") diff --git a/ibis/backends/trino/compiler.py b/ibis/backends/trino/compiler.py index dcdcf6d60a84..e53fb73454b5 100644 --- a/ibis/backends/trino/compiler.py +++ b/ibis/backends/trino/compiler.py @@ -14,7 +14,6 @@ from ibis.backends.sql.datatypes import TrinoType from ibis.backends.sql.dialects import Trino from ibis.backends.sql.rewrites import exclude_unsupported_window_frame_from_ops -from ibis.expr.rewrites import rewrite_stringslice class TrinoCompiler(SQLGlotCompiler): @@ -24,7 +23,6 @@ class TrinoCompiler(SQLGlotCompiler): type_mapper = TrinoType rewrites = ( exclude_unsupported_window_frame_from_ops, - rewrite_stringslice, *SQLGlotCompiler.rewrites, ) quoted = True @@ -41,6 +39,10 @@ class TrinoCompiler(SQLGlotCompiler): ops.TimestampBucket, ) + LOWERED_OPS = { + ops.Sample: None, + } + SIMPLE_OPS = { ops.Arbitrary: "any_value", ops.Pi: "pi", diff --git a/ibis/expr/rewrites.py b/ibis/expr/rewrites.py index b53158c59fb2..10646502e677 100644 --- a/ibis/expr/rewrites.py +++ b/ibis/expr/rewrites.py @@ -254,7 +254,7 @@ def rewrite_dropna(_): @replace(p.StringSlice) -def rewrite_stringslice(_, **kwargs): +def lower_stringslice(_, **kwargs): """Rewrite StringSlice in terms of Substring.""" if _.end is None: return ops.Substring(_.arg, start=_.start) @@ -379,59 +379,6 @@ def rewrite_window_input(value, window): return node.replace(window_merge_frames, filter=p.Value, context=context) -@replace(p.Bucket) -def replace_bucket(_): - cases = [] - results = [] - - if _.closed == "left": - l_cmp = ops.LessEqual - r_cmp = ops.Less - else: - l_cmp = ops.Less - r_cmp = ops.LessEqual - - user_num_buckets = len(_.buckets) - 1 - - bucket_id = 0 - if _.include_under: - if user_num_buckets > 0: - cmp = ops.Less if _.close_extreme else r_cmp - else: - cmp = ops.LessEqual if _.closed == "right" else ops.Less - cases.append(cmp(_.arg, _.buckets[0])) - results.append(bucket_id) - bucket_id += 1 - - for j, (lower, upper) in enumerate(zip(_.buckets, _.buckets[1:])): - if _.close_extreme and ( - (_.closed == "right" and j == 0) - or (_.closed == "left" and j == (user_num_buckets - 1)) - ): - cases.append( - ops.And(ops.LessEqual(lower, _.arg), ops.LessEqual(_.arg, upper)) - ) - results.append(bucket_id) - else: - cases.append(ops.And(l_cmp(lower, _.arg), r_cmp(_.arg, upper))) - results.append(bucket_id) - bucket_id += 1 - - if _.include_over: - if user_num_buckets > 0: - cmp = ops.Less if _.close_extreme else l_cmp - else: - cmp = ops.Less if _.closed == "right" else ops.LessEqual - - cases.append(cmp(_.buckets[-1], _.arg)) - results.append(bucket_id) - bucket_id += 1 - - return ops.SearchedCase( - cases=tuple(cases), results=tuple(results), default=ops.NULL - ) - - # TODO(kszucs): schema comparison should be updated to not distinguish between # different column order @replace(p.Project(y @ p.Relation) & Check(_.schema == y.schema))