From 378251ef1ec1be860ed5b18beee361ab95b82d8f Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Mon, 6 May 2024 15:44:22 -0500 Subject: [PATCH] refactor(sql): use a rewrite rule to implement FillNa/DropNa --- ibis/backends/sql/compiler.py | 58 +---------------------------------- ibis/backends/sql/rewrites.py | 53 +++++++++++++++++++++++++++++++- ibis/expr/rewrites.py | 51 ------------------------------ 3 files changed, 53 insertions(+), 109 deletions(-) diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index 4129fb74185c..fb0c5a974c8b 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -6,7 +6,6 @@ import math import operator import string -from collections.abc import Mapping from functools import partial, reduce from typing import TYPE_CHECKING, Any, Callable, ClassVar @@ -35,7 +34,7 @@ from ibis.expr.rewrites import lower_stringslice if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Mapping import ibis.expr.schema as sch import ibis.expr.types as ir @@ -1340,61 +1339,6 @@ def visit_Distinct(self, op, *, parent): sg.select(STAR, copy=False).distinct(copy=False).from_(parent, copy=False) ) - def visit_DropNa(self, op, *, parent, how, subset): - if subset is None: - subset = [ - sg.column( - name, table=parent.alias_or_name, quoted=self.quoted, copy=False - ) - for name in op.schema.names - ] - - if subset: - predicate = reduce( - sg.and_ if how == "any" else sg.or_, - (sg.not_(col.is_(NULL), copy=False) for col in subset), - ) - elif how == "all": - predicate = FALSE - else: - predicate = None - - if predicate is None: - return parent - - try: - return parent.where(predicate, copy=False) - except AttributeError: - return ( - sg.select(STAR, copy=False) - .from_(parent, copy=False) - .where(predicate, copy=False) - ) - - def visit_FillNa(self, op, *, parent, replacements): - if isinstance(replacements, Mapping): - mapping = replacements - else: - mapping = { - name: replacements - for name, dtype in op.schema.items() - if dtype.nullable - } - exprs = { - col: ( - self.f.coalesce( - sg.column(col, quoted=self.quoted, copy=False), - sge.convert(alt), - ) - if (alt := mapping.get(col)) is not None - else sg.column(col, quoted=self.quoted) - ) - for col in op.schema.keys() - } - return sg.select(*self._cleanup_names(exprs), copy=False).from_( - parent, copy=False - ) - def visit_CTE(self, op, *, parent): return sg.table(parent.alias_or_name, quoted=self.quoted) diff --git a/ibis/backends/sql/rewrites.py b/ibis/backends/sql/rewrites.py index 342b428a6cb7..650709a29a86 100644 --- a/ibis/backends/sql/rewrites.py +++ b/ibis/backends/sql/rewrites.py @@ -3,6 +3,7 @@ from __future__ import annotations import operator +from collections.abc import Mapping from functools import reduce from typing import TYPE_CHECKING, Any @@ -22,7 +23,7 @@ from ibis.expr.schema import Schema if TYPE_CHECKING: - from collections.abc import Mapping, Sequence + from collections.abc import Sequence x = var("x") y = var("y") @@ -110,6 +111,54 @@ def sort_to_select(_, **kwargs): return Select(_.parent, selections=_.values, sort_keys=_.keys) +@replace(p.FillNa) +def fillna_to_select(_, **kwargs): + """Rewrite FillNa to a Select node.""" + if isinstance(_.replacements, Mapping): + mapping = _.replacements + else: + mapping = { + name: _.replacements + for name, type in _.parent.schema.items() + if type.nullable + } + + if not mapping: + return _.parent + + selections = {} + for name in _.parent.schema.names: + col = ops.Field(_.parent, name) + if (value := mapping.get(name)) is not None: + col = ops.Alias(ops.Coalesce((col, value)), name) + selections[name] = col + + return Select(_.parent, selections=selections) + + +@replace(p.DropNa) +def dropna_to_select(_, **kwargs): + """Rewrite DropNa to a Select node.""" + if _.subset is None: + columns = [ops.Field(_.parent, name) for name in _.parent.schema.names] + else: + columns = _.subset + + if columns: + preds = [ + reduce( + ops.And if _.how == "any" else ops.Or, + [ops.NotNull(c) for c in columns], + ) + ] + elif _.how == "all": + preds = [ops.Literal(False, dtype=dt.bool)] + else: + return _.parent + + return Select(_.parent, selections=_.values, predicates=tuple(preds)) + + @replace(p.WindowFunction(p.First | p.Last)) def first_to_firstvalue(_, **kwargs): """Convert a First or Last node to a FirstValue or LastValue node.""" @@ -241,6 +290,8 @@ def sqlize( | project_to_select | filter_to_select | sort_to_select + | fillna_to_select + | dropna_to_select | first_to_firstvalue, context=context, ) diff --git a/ibis/expr/rewrites.py b/ibis/expr/rewrites.py index 10646502e677..894f94b6dc2d 100644 --- a/ibis/expr/rewrites.py +++ b/ibis/expr/rewrites.py @@ -2,13 +2,10 @@ from __future__ import annotations -import functools from collections import defaultdict -from collections.abc import Mapping import toolz -import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.common.collections import FrozenDict # noqa: TCH001 from ibis.common.deferred import Item, _, deferred, var @@ -205,54 +202,6 @@ def replace_parameter(_, params, **kwargs): return ops.Literal(value=params[_], dtype=_.dtype) -@replace(p.FillNa) -def rewrite_fillna(_): - """Rewrite FillNa expressions to use more common operations.""" - if isinstance(_.replacements, Mapping): - mapping = _.replacements - else: - mapping = { - name: _.replacements - for name, type in _.parent.schema.items() - if type.nullable - } - - if not mapping: - return _.parent - - selections = [] - for name in _.parent.schema.names: - col = ops.Field(_.parent, name) - if (value := mapping.get(name)) is not None: - col = ops.Alias(ops.Coalesce((col, value)), name) - selections.append(col) - - return ops.Project(_.parent, selections) - - -@replace(p.DropNa) -def rewrite_dropna(_): - """Rewrite DropNa expressions to use more common operations.""" - if _.subset is None: - columns = [ops.Field(_.parent, name) for name in _.parent.schema.names] - else: - columns = _.subset - - if columns: - preds = [ - functools.reduce( - ops.And if _.how == "any" else ops.Or, - [ops.NotNull(c) for c in columns], - ) - ] - elif _.how == "all": - preds = [ops.Literal(False, dtype=dt.bool)] - else: - return _.parent - - return ops.Filter(_.parent, tuple(preds)) - - @replace(p.StringSlice) def lower_stringslice(_, **kwargs): """Rewrite StringSlice in terms of Substring."""