Skip to content

Commit

Permalink
refactor(sql): use a rewrite rule to implement FillNa/DropNa
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed May 6, 2024
1 parent f217127 commit 54c78e4
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 109 deletions.
58 changes: 1 addition & 57 deletions ibis/backends/sql/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import math
import operator
import string
from collections.abc import Mapping
from functools import partial, reduce
from typing import TYPE_CHECKING, Any, Callable, ClassVar

Expand Down Expand Up @@ -35,7 +34,7 @@
from ibis.expr.rewrites import lower_stringslice

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Iterable, Mapping

import ibis.expr.schema as sch
import ibis.expr.types as ir
Expand Down Expand Up @@ -1340,61 +1339,6 @@ def visit_Distinct(self, op, *, parent):
sg.select(STAR, copy=False).distinct(copy=False).from_(parent, copy=False)
)

def visit_DropNa(self, op, *, parent, how, subset):
if subset is None:
subset = [
sg.column(
name, table=parent.alias_or_name, quoted=self.quoted, copy=False
)
for name in op.schema.names
]

if subset:
predicate = reduce(
sg.and_ if how == "any" else sg.or_,
(sg.not_(col.is_(NULL), copy=False) for col in subset),
)
elif how == "all":
predicate = FALSE
else:
predicate = None

if predicate is None:
return parent

try:
return parent.where(predicate, copy=False)
except AttributeError:
return (
sg.select(STAR, copy=False)
.from_(parent, copy=False)
.where(predicate, copy=False)
)

def visit_FillNa(self, op, *, parent, replacements):
if isinstance(replacements, Mapping):
mapping = replacements
else:
mapping = {
name: replacements
for name, dtype in op.schema.items()
if dtype.nullable
}
exprs = {
col: (
self.f.coalesce(
sg.column(col, quoted=self.quoted, copy=False),
sge.convert(alt),
)
if (alt := mapping.get(col)) is not None
else sg.column(col, quoted=self.quoted)
)
for col in op.schema.keys()
}
return sg.select(*self._cleanup_names(exprs), copy=False).from_(
parent, copy=False
)

def visit_CTE(self, op, *, parent):
return sg.table(parent.alias_or_name, quoted=self.quoted)

Expand Down
53 changes: 52 additions & 1 deletion ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import operator
from collections.abc import Mapping
from functools import reduce
from typing import TYPE_CHECKING, Any

Expand All @@ -22,7 +23,7 @@
from ibis.expr.schema import Schema

if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from collections.abc import Sequence

x = var("x")
y = var("y")
Expand Down Expand Up @@ -110,6 +111,54 @@ def sort_to_select(_, **kwargs):
return Select(_.parent, selections=_.values, sort_keys=_.keys)


@replace(p.FillNa)
def fillna_to_select(_, **kwargs):
"""Rewrite FillNa to a Select node."""
if isinstance(_.replacements, Mapping):
mapping = _.replacements
else:
mapping = {
name: _.replacements
for name, type in _.parent.schema.items()
if type.nullable
}

if not mapping:
return _.parent

selections = {}
for name in _.parent.schema.names:
col = ops.Field(_.parent, name)
if (value := mapping.get(name)) is not None:
col = ops.Alias(ops.Coalesce((col, value)), name)
selections[name] = col

return Select(_.parent, selections=selections)


@replace(p.DropNa)
def dropna_to_select(_, **kwargs):
"""Rewrite DropNa to a Select node."""
if _.subset is None:
columns = [ops.Field(_.parent, name) for name in _.parent.schema.names]
else:
columns = _.subset

if columns:
preds = [
reduce(
ops.And if _.how == "any" else ops.Or,
[ops.NotNull(c) for c in columns],
)
]
elif _.how == "all":
preds = [ops.Literal(False, dtype=dt.bool)]
else:
return _.parent

return Select(_.parent, selections=_.values, predicates=tuple(preds))


@replace(p.WindowFunction(p.First | p.Last))
def first_to_firstvalue(_, **kwargs):
"""Convert a First or Last node to a FirstValue or LastValue node."""
Expand Down Expand Up @@ -241,6 +290,8 @@ def sqlize(
| project_to_select
| filter_to_select
| sort_to_select
| fillna_to_select
| dropna_to_select
| first_to_firstvalue,
context=context,
)
Expand Down
51 changes: 0 additions & 51 deletions ibis/expr/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,10 @@

from __future__ import annotations

import functools
from collections import defaultdict
from collections.abc import Mapping

import toolz

import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
from ibis.common.collections import FrozenDict # noqa: TCH001
from ibis.common.deferred import Item, _, deferred, var
Expand Down Expand Up @@ -205,54 +202,6 @@ def replace_parameter(_, params, **kwargs):
return ops.Literal(value=params[_], dtype=_.dtype)


@replace(p.FillNa)
def rewrite_fillna(_):
"""Rewrite FillNa expressions to use more common operations."""
if isinstance(_.replacements, Mapping):
mapping = _.replacements
else:
mapping = {
name: _.replacements
for name, type in _.parent.schema.items()
if type.nullable
}

if not mapping:
return _.parent

selections = []
for name in _.parent.schema.names:
col = ops.Field(_.parent, name)
if (value := mapping.get(name)) is not None:
col = ops.Alias(ops.Coalesce((col, value)), name)
selections.append(col)

return ops.Project(_.parent, selections)


@replace(p.DropNa)
def rewrite_dropna(_):
"""Rewrite DropNa expressions to use more common operations."""
if _.subset is None:
columns = [ops.Field(_.parent, name) for name in _.parent.schema.names]
else:
columns = _.subset

if columns:
preds = [
functools.reduce(
ops.And if _.how == "any" else ops.Or,
[ops.NotNull(c) for c in columns],
)
]
elif _.how == "all":
preds = [ops.Literal(False, dtype=dt.bool)]
else:
return _.parent

return ops.Filter(_.parent, tuple(preds))


@replace(p.StringSlice)
def lower_stringslice(_):
"""Rewrite StringSlice in terms of Substring."""
Expand Down

0 comments on commit 54c78e4

Please sign in to comment.