Skip to content

Commit

Permalink
fix(sql): properly parenthesize binary ops containing named expressions
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Sep 5, 2024
1 parent d5ca729 commit f8e945b
Show file tree
Hide file tree
Showing 15 changed files with 105 additions and 49 deletions.
8 changes: 1 addition & 7 deletions ibis/backends/sql/compilers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -548,10 +548,7 @@ def cast(self, arg, to: dt.DataType) -> sge.Cast:
def _prepare_params(self, params):
result = {}
for param, value in params.items():
node = param.op()
if isinstance(node, ops.Alias):
node = node.arg
result[node] = value
result[param.op()] = value
return result

def to_sqlglot(
Expand Down Expand Up @@ -689,9 +686,6 @@ def visit_Cast(self, op, *, arg, to):
def visit_ScalarSubquery(self, op, *, rel):
return rel.this.subquery(copy=False)

def visit_Alias(self, op, *, arg, name):
return arg

def visit_Literal(self, op, *, value, dtype):
"""Compile a literal value.
Expand Down
14 changes: 10 additions & 4 deletions ibis/backends/sql/compilers/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,14 @@ def visit_DropColumns(self, op, *, parent, columns_to_drop):
return sg.select(column).from_(parent)

def visit_TableUnnest(
self, op, *, parent, column, offset: str | None, keep_empty: bool
self,
op,
*,
parent,
column,
column_name: str,
offset: str | None,
keep_empty: bool,
):
quoted = self.quoted

Expand All @@ -1029,9 +1036,8 @@ def visit_TableUnnest(

table = sg.to_identifier(parent.alias_or_name, quoted=quoted)

opname = op.column.name
overlaps_with_parent = opname in op.parent.schema
computed_column = column_alias.as_(opname, quoted=quoted)
overlaps_with_parent = column_name in op.parent.schema
computed_column = column_alias.as_(column_name, quoted=quoted)

# replace the existing column if the unnested column hasn't been
# renamed
Expand Down
14 changes: 10 additions & 4 deletions ibis/backends/sql/compilers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,14 @@ def visit_DropColumns(self, op, *, parent, columns_to_drop):
return sg.select(column).from_(parent)

def visit_TableUnnest(
self, op, *, parent, column, offset: str | None, keep_empty: bool
self,
op,
*,
parent,
column,
column_name: str,
offset: str | None,
keep_empty: bool,
):
quoted = self.quoted

Expand All @@ -697,9 +704,8 @@ def visit_TableUnnest(

selcols = []

opname = op.column.name
overlaps_with_parent = opname in op.parent.schema
computed_column = column_alias.as_(opname, quoted=quoted)
overlaps_with_parent = column_name in op.parent.schema
computed_column = column_alias.as_(column_name, quoted=quoted)

if offset is not None:
if overlaps_with_parent:
Expand Down
18 changes: 13 additions & 5 deletions ibis/backends/sql/compilers/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -609,15 +609,21 @@ def visit_DropColumns(self, op, *, parent, columns_to_drop):
return sg.select(column).from_(parent)

def visit_TableUnnest(
self, op, *, parent, column, offset: str | None, keep_empty: bool
self,
op,
*,
parent,
column,
column_name: str,
offset: str | None,
keep_empty: bool,
):
quoted = self.quoted

column_alias = sg.to_identifier(gen_name("table_unnest_column"), quoted=quoted)

opname = op.column.name
overlaps_with_parent = opname in op.parent.schema
computed_column = column_alias.as_(opname, quoted=quoted)
overlaps_with_parent = column_name in op.parent.schema
computed_column = column_alias.as_(column_name, quoted=quoted)

selcols = []

Expand All @@ -627,7 +633,9 @@ def visit_TableUnnest(
# TODO: clean this up once WITH ORDINALITY is supported in DuckDB
# no need for struct_extract once that's upstream
column = self.f.list_zip(column, self.f.range(self.f.len(column)))
extract = self.f.struct_extract(column_alias, 1).as_(opname, quoted=quoted)
extract = self.f.struct_extract(column_alias, 1).as_(
column_name, quoted=quoted
)

if overlaps_with_parent:
replace = sge.Column(this=sge.Star(replace=[extract]), table=table)
Expand Down
14 changes: 10 additions & 4 deletions ibis/backends/sql/compilers/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,18 +718,24 @@ def visit_Hash(self, op, *, arg):
)

def visit_TableUnnest(
self, op, *, parent, column, offset: str | None, keep_empty: bool
self,
op,
*,
parent,
column,
column_name: str,
offset: str | None,
keep_empty: bool,
):
quoted = self.quoted

column_alias = sg.to_identifier(gen_name("table_unnest_column"), quoted=quoted)

parent_alias = parent.alias_or_name

opname = op.column.name
parent_schema = op.parent.schema
overlaps_with_parent = opname in parent_schema
computed_column = column_alias.as_(opname, quoted=quoted)
overlaps_with_parent = column_name in parent_schema
computed_column = column_alias.as_(column_name, quoted=quoted)

selcols = []

Expand Down
14 changes: 10 additions & 4 deletions ibis/backends/sql/compilers/pyspark.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,16 +500,22 @@ def visit_HexDigest(self, op, *, arg, how):
raise NotImplementedError(f"No available hashing function for {how}")

def visit_TableUnnest(
self, op, *, parent, column, offset: str | None, keep_empty: bool
self,
op,
*,
parent,
column,
column_name: str,
offset: str | None,
keep_empty: bool,
):
quoted = self.quoted

column_alias = sg.to_identifier(gen_name("table_unnest_column"), quoted=quoted)

opname = op.column.name
parent_schema = op.parent.schema
overlaps_with_parent = opname in parent_schema
computed_column = column_alias.as_(opname, quoted=quoted)
overlaps_with_parent = column_name in parent_schema
computed_column = column_alias.as_(column_name, quoted=quoted)

parent_alias = parent.alias_or_name

Expand Down
17 changes: 11 additions & 6 deletions ibis/backends/sql/compilers/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,14 @@ def visit_DropColumns(self, op, *, parent, columns_to_drop):
return sg.select(column).from_(parent)

def visit_TableUnnest(
self, op, *, parent, column, offset: str | None, keep_empty: bool
self,
op,
*,
parent,
column,
column_name: str,
offset: str | None,
keep_empty: bool,
):
quoted = self.quoted

Expand All @@ -825,12 +832,10 @@ def visit_TableUnnest(

selcols = []

opcol = op.column
opname = opcol.name
overlaps_with_parent = opname in op.parent.schema
overlaps_with_parent = column_name in op.parent.schema
computed_column = self.cast(
self.f.nullif(column_alias, null_sentinel), opcol.dtype.value_type
).as_(opname, quoted=quoted)
self.f.nullif(column_alias, null_sentinel), op.column.dtype.value_type
).as_(column_name, quoted=quoted)

if overlaps_with_parent:
selcols.append(
Expand Down
14 changes: 10 additions & 4 deletions ibis/backends/sql/compilers/trino.py
Original file line number Diff line number Diff line change
Expand Up @@ -546,16 +546,22 @@ def visit_ToJSONArray(self, op, *, arg):
)

def visit_TableUnnest(
self, op, *, parent, column, offset: str | None, keep_empty: bool
self,
op,
*,
parent,
column,
column_name: str,
offset: str | None,
keep_empty: bool,
):
quoted = self.quoted

column_alias = sg.to_identifier(gen_name("table_unnest_column"), quoted=quoted)

opname = op.column.name
parent_schema = op.parent.schema
overlaps_with_parent = opname in parent_schema
computed_column = column_alias.as_(opname, quoted=quoted)
overlaps_with_parent = column_name in parent_schema
computed_column = column_alias.as_(column_name, quoted=quoted)

parent_alias_or_name = parent.alias_or_name

Expand Down
9 changes: 8 additions & 1 deletion ibis/backends/sql/rewrites.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def fill_null_to_select(_, **kwargs):
for name in _.parent.schema.names:
col = ops.Field(_.parent, name)
if (value := mapping.get(name)) is not None:
col = ops.Alias(ops.Coalesce((col, value)), name)
col = ops.Coalesce((col, value))
selections[name] = col

return Select(_.parent, selections=selections)
Expand Down Expand Up @@ -206,6 +206,12 @@ def first_to_firstvalue(_, **kwargs):
return _.copy(func=klass(_.func.arg))


@replace(p.Alias)
def remove_aliases(_, **kwargs):
"""Remove all remaining aliases, they're not needed for remaining compilation."""
return _.arg


def complexity(node):
"""Assign a complexity score to a node.
Expand Down Expand Up @@ -372,6 +378,7 @@ def sqlize(
context = {"params": params}
result = node.replace(
replace_parameter
| remove_aliases
| project_to_select
| filter_to_select
| sort_to_select
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
},
)
param = ibis.param("timestamp")
f = alltypes.filter((alltypes.timestamp_col < param.name("my_param")))
f = alltypes.filter((alltypes.timestamp_col < param))
agg = f.aggregate([f.float_col.sum().name("foo")], by=[f.string_col])

result = agg.foo.count()
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
SELECT
(
"t0"."a" + "t0"."b"
) * "t0"."c" AS "x"
FROM "t" AS "t0"
2 changes: 1 addition & 1 deletion ibis/backends/tests/sql/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def test_subquery_where_location(snapshot):
],
name="alltypes",
)
param = ibis.param("timestamp").name("my_param")
param = ibis.param("timestamp")
expr = (
t[["float_col", "timestamp_col", "int_col", "string_col"]][
lambda t: t.timestamp_col < param
Expand Down
6 changes: 6 additions & 0 deletions ibis/backends/tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ def test_binop_parens(snapshot, opname, dtype, associative):
snapshot.assert_match(combined, "out.sql")


def test_binop_with_alias_still_parenthesized(snapshot):
t = ibis.table({"a": "int", "b": "int", "c": "int"}, name="t")
sql = to_sql(((t.a + t.b).name("d") * t.c).name("x"))
snapshot.assert_match(sql, "out.sql")


@pytest.mark.parametrize(
"expr_fn",
[
Expand Down
11 changes: 4 additions & 7 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,7 @@ class TableUnnest(Relation):

parent: Relation
column: Value[dt.Array]
column_name: str
offset: typing.Union[str, None]
keep_empty: bool

Expand All @@ -507,15 +508,11 @@ def values(self):

@attribute
def schema(self):
column = self.column
offset = self.offset

base = self.parent.schema.fields.copy()
base[self.column_name] = self.column.dtype.value_type

base[column.name] = column.dtype.value_type

if offset is not None:
base[offset] = dt.int64
if self.offset is not None:
base[self.offset] = dt.int64

return Schema(base)

Expand Down
6 changes: 5 additions & 1 deletion ibis/expr/types/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4890,7 +4890,11 @@ def unnest(
"""
(column,) = self.bind(column)
return ops.TableUnnest(
parent=self, column=column, offset=offset, keep_empty=keep_empty
parent=self,
column=column,
column_name=column.get_name(),
offset=offset,
keep_empty=keep_empty,
).to_expr()


Expand Down

0 comments on commit f8e945b

Please sign in to comment.