From 84a786d236912839a4be4fb05b6e2e8097bb01a3 Mon Sep 17 00:00:00 2001 From: Phillip Cloud <417981+cpcloud@users.noreply.github.com> Date: Mon, 5 Aug 2024 11:33:57 -0400 Subject: [PATCH] refactor(sql): make compilers usable with a base install (#9766) --- ibis/__init__.py | 3 +- ibis/backends/__init__.py | 13 -- ibis/backends/bigquery/__init__.py | 157 ++----------- .../test_cross_project_query/out.sql | 6 - .../test_multiple_project_queries/out.sql | 5 - .../out.sql | 5 - .../bigquery/tests/system/test_client.py | 15 +- .../bigquery/tests/unit/udf/test_core.py | 5 +- .../bigquery/tests/unit/udf/test_find.py | 2 +- ibis/backends/clickhouse/__init__.py | 8 +- ibis/backends/datafusion/__init__.py | 9 +- ibis/backends/druid/__init__.py | 4 +- ibis/backends/duckdb/__init__.py | 51 +---- ibis/backends/exasol/__init__.py | 6 +- ibis/backends/flink/__init__.py | 22 +- ibis/backends/impala/__init__.py | 4 +- ibis/backends/mssql/__init__.py | 21 +- ibis/backends/mysql/__init__.py | 6 +- ibis/backends/oracle/__init__.py | 10 +- ibis/backends/postgres/__init__.py | 93 +------- ibis/backends/pyspark/__init__.py | 4 +- ibis/backends/risingwave/__init__.py | 6 +- ibis/backends/snowflake/__init__.py | 117 +--------- ibis/backends/sql/__init__.py | 82 +++---- ibis/backends/sql/compilers/base.py | 46 ++++ .../{bigquery.py => bigquery/__init__.py} | 211 ++++++++++++++++++ .../compilers}/bigquery/udf/__init__.py | 0 .../{ => sql/compilers}/bigquery/udf/core.py | 4 +- .../{ => sql/compilers}/bigquery/udf/find.py | 0 .../compilers}/bigquery/udf/rewrite.py | 0 ibis/backends/sql/compilers/clickhouse.py | 3 + ibis/backends/sql/compilers/datafusion.py | 3 + ibis/backends/sql/compilers/druid.py | 3 + ibis/backends/sql/compilers/duckdb.py | 37 +++ ibis/backends/sql/compilers/exasol.py | 3 + ibis/backends/sql/compilers/flink.py | 3 + ibis/backends/sql/compilers/impala.py | 3 + ibis/backends/sql/compilers/mssql.py | 46 +++- ibis/backends/sql/compilers/mysql.py | 3 + ibis/backends/sql/compilers/oracle.py | 3 + ibis/backends/sql/compilers/postgres.py | 77 +++++++ ibis/backends/sql/compilers/pyspark.py | 3 + ibis/backends/sql/compilers/risingwave.py | 3 + ibis/backends/sql/compilers/snowflake.py | 97 ++++++++ ibis/backends/sql/compilers/sqlite.py | 3 + ibis/backends/sql/compilers/trino.py | 3 + ibis/backends/sqlite/__init__.py | 13 +- ibis/backends/tests/test_generic.py | 34 +-- ibis/backends/tests/test_sql.py | 57 +++-- ibis/backends/trino/__init__.py | 6 +- ibis/expr/schema.py | 4 + ibis/expr/sql.py | 26 ++- ibis/tests/expr/mocks.py | 5 - ibis/tests/expr/test_sql_builtins.py | 9 + 54 files changed, 733 insertions(+), 629 deletions(-) delete mode 100644 ibis/backends/bigquery/tests/system/snapshots/test_client/test_cross_project_query/out.sql delete mode 100644 ibis/backends/bigquery/tests/system/snapshots/test_client/test_multiple_project_queries/out.sql delete mode 100644 ibis/backends/bigquery/tests/system/snapshots/test_client/test_multiple_project_queries_database_api/out.sql rename ibis/backends/sql/compilers/{bigquery.py => bigquery/__init__.py} (81%) rename ibis/backends/{ => sql/compilers}/bigquery/udf/__init__.py (100%) rename ibis/backends/{ => sql/compilers}/bigquery/udf/core.py (99%) rename ibis/backends/{ => sql/compilers}/bigquery/udf/find.py (100%) rename ibis/backends/{ => sql/compilers}/bigquery/udf/rewrite.py (100%) diff --git a/ibis/__init__.py b/ibis/__init__.py index 43d010c239db..8085ae6f529f 100644 --- a/ibis/__init__.py +++ b/ibis/__init__.py @@ -100,7 +100,6 @@ def load_backend(name: str) -> BaseBackend: # - compile # - has_operation # - _from_url - # - _to_sqlglot # # We also copy over the docstring from `do_connect` to the proxy `connect` # method, since that's where all the backend-specific kwargs are currently @@ -121,7 +120,7 @@ def connect(*args, **kwargs): proxy.has_operation = backend.has_operation proxy.name = name proxy._from_url = backend._from_url - proxy._to_sqlglot = backend._to_sqlglot + # Add any additional methods that should be exposed at the top level for attr in getattr(backend, "_top_level_methods", ()): setattr(proxy, attr, getattr(backend, attr)) diff --git a/ibis/backends/__init__.py b/ibis/backends/__init__.py index b3c6a6e30107..566e2501d32e 100644 --- a/ibis/backends/__init__.py +++ b/ibis/backends/__init__.py @@ -1032,14 +1032,9 @@ def _register_in_memory_table(self, op: ops.InMemoryTable): def _run_pre_execute_hooks(self, expr: ir.Expr) -> None: """Backend-specific hooks to run before an expression is executed.""" - self._define_udf_translation_rules(expr) self._register_udfs(expr) self._register_in_memory_tables(expr) - def _define_udf_translation_rules(self, expr: ir.Expr): - if self.supports_python_udfs: - raise NotImplementedError(self.name) - def compile( self, expr: ir.Expr, @@ -1048,14 +1043,6 @@ def compile( """Compile an expression.""" return self.compiler.to_sql(expr, params=params) - def _to_sqlglot(self, expr: ir.Expr, **kwargs) -> sg.exp.Expression: - """Convert an Ibis expression to a sqlglot expression. - - Called by `ibis.to_sql`; gives the backend an opportunity to generate - nicer SQL for human consumption. - """ - raise NotImplementedError(f"Backend '{self.name}' backend doesn't support SQL") - def execute(self, expr: ir.Expr) -> Any: """Execute an expression.""" diff --git a/ibis/backends/bigquery/__init__.py b/ibis/backends/bigquery/__init__.py index 5c94651eb389..a1bef8f57f2f 100644 --- a/ibis/backends/bigquery/__init__.py +++ b/ibis/backends/bigquery/__init__.py @@ -19,6 +19,7 @@ from pydata_google_auth import cache import ibis +import ibis.backends.sql.compilers as sc import ibis.common.exceptions as com import ibis.expr.operations as ops import ibis.expr.schema as sch @@ -32,9 +33,7 @@ schema_from_bigquery_table, ) from ibis.backends.bigquery.datatypes import BigQuerySchema -from ibis.backends.bigquery.udf.core import PythonToJavaScriptTranslator from ibis.backends.sql import SQLBackend -from ibis.backends.sql.compilers import BigQueryCompiler from ibis.backends.sql.datatypes import BigQueryType if TYPE_CHECKING: @@ -150,7 +149,7 @@ def _force_quote_table(table: sge.Table) -> sge.Table: class Backend(SQLBackend, CanCreateDatabase, CanCreateSchema): name = "bigquery" - compiler = BigQueryCompiler() + compiler = sc.bigquery.compiler supports_in_memory_tables = True supports_python_udfs = False @@ -652,68 +651,6 @@ def _get_schema_using_query(self, query: str) -> sch.Schema: ) return BigQuerySchema.to_ibis(job.schema) - def _to_sqlglot( - self, - expr: ir.Expr, - limit: str | None = None, - params: Mapping[ir.Expr, Any] | None = None, - **kwargs, - ) -> Any: - """Compile an Ibis expression. - - Parameters - ---------- - expr - Ibis expression - limit - For expressions yielding result sets; retrieve at most this number - of values/rows. Overrides any limit already set on the expression. - params - Named unbound parameters - kwargs - Keyword arguments passed to the compiler - - Returns - ------- - Any - The output of compilation. The type of this value depends on the - backend. - - """ - self._define_udf_translation_rules(expr) - sql = super()._to_sqlglot(expr, limit=limit, params=params, **kwargs) - - table_expr = expr.as_table() - geocols = [ - name for name, typ in table_expr.schema().items() if typ.is_geospatial() - ] - - query = sql.transform( - _qualify_memtable, - dataset=getattr(self._session_dataset, "dataset_id", None), - project=getattr(self._session_dataset, "project", None), - ).transform(_remove_null_ordering_from_unsupported_window) - - if not geocols: - return query - - # if there are any geospatial columns, we have to convert them to WKB, - # so interactive mode knows how to display them - # - # by default bigquery returns data to python as WKT, and there's really - # no point in supporting both if we don't need to. - compiler = self.compiler - quoted = compiler.quoted - f = compiler.f - return sg.select( - sge.Star( - replace=[ - f.st_asbinary(sg.column(col, quoted=quoted)).as_(col, quoted=quoted) - for col in geocols - ] - ) - ).from_(query.subquery()) - def raw_sql(self, query: str, params=None, page_size: int | None = None): query_parameters = [ bigquery_param( @@ -747,19 +684,25 @@ def current_database(self) -> str | None: return self.dataset def compile( - self, expr: ir.Expr, limit: str | None = None, params=None, **kwargs: Any + self, + expr: ir.Expr, + limit: str | None = None, + params=None, + pretty: bool = True, + **kwargs: Any, ): """Compile an Ibis expression to a SQL string.""" - query = self._to_sqlglot(expr, limit=limit, params=params, **kwargs) - udf_sources = [] - for udf_node in expr.op().find(ops.ScalarUDF): - compile_func = getattr( - self, f"_compile_{udf_node.__input_type__.name.lower()}_udf" - ) - if sql := compile_func(udf_node): - udf_sources.append(sql.sql(self.name, pretty=True)) - - sql = ";\n".join([*udf_sources, query.sql(dialect=self.name, pretty=True)]) + session_dataset = self._session_dataset + query = self.compiler.to_sqlglot( + expr, + limit=limit, + params=params, + session_dataset_id=getattr(session_dataset, "dataset_id", None), + session_project=getattr(session_dataset, "project", None), + **kwargs, + ) + queries = query if isinstance(query, list) else [query] + sql = ";\n".join(query.sql(self.dialect, pretty=pretty) for query in queries) self._log(sql) return sql @@ -1202,68 +1145,6 @@ def _clean_up_cached_table(self, name): force=True, ) - def _get_udf_source(self, udf_node: ops.ScalarUDF): - name = type(udf_node).__name__ - type_mapper = self.compiler.udf_type_mapper - - body = PythonToJavaScriptTranslator(udf_node.__func__).compile() - config = udf_node.__config__ - libraries = config.get("libraries", []) - - signature = [ - sge.ColumnDef( - this=sg.to_identifier(name, quoted=self.compiler.quoted), - kind=type_mapper.from_ibis(param.annotation.pattern.dtype), - ) - for name, param in udf_node.__signature__.parameters.items() - ] - - lines = ['"""'] - - if config.get("strict", True): - lines.append('"use strict";') - - lines += [ - body, - "", - f"return {udf_node.__func_name__}({', '.join(udf_node.argnames)});", - '"""', - ] - - func = sge.Create( - kind="FUNCTION", - this=sge.UserDefinedFunction( - this=sg.to_identifier(name), expressions=signature, wrapped=True - ), - # not exactly what I had in mind, but it works - # - # quoting is too simplistic to handle multiline strings - expression=sge.Var(this="\n".join(lines)), - exists=False, - properties=sge.Properties( - expressions=[ - sge.TemporaryProperty(), - sge.ReturnsProperty(this=type_mapper.from_ibis(udf_node.dtype)), - sge.StabilityProperty( - this="IMMUTABLE" if config.get("determinism") else "VOLATILE" - ), - sge.LanguageProperty(this=sg.to_identifier("js")), - ] - + [ - sge.Property( - this=sg.to_identifier("library"), - value=self.compiler.f.array(*libraries), - ) - ] - * bool(libraries) - ), - ) - - return func - - def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> None: - return self._get_udf_source(udf_node) - def _register_udfs(self, expr: ir.Expr) -> None: """No op because UDFs made with CREATE TEMPORARY FUNCTION must be followed by a query.""" diff --git a/ibis/backends/bigquery/tests/system/snapshots/test_client/test_cross_project_query/out.sql b/ibis/backends/bigquery/tests/system/snapshots/test_client/test_cross_project_query/out.sql deleted file mode 100644 index 819adf5b9db3..000000000000 --- a/ibis/backends/bigquery/tests/system/snapshots/test_client/test_cross_project_query/out.sql +++ /dev/null @@ -1,6 +0,0 @@ -SELECT - `t0`.`title`, - `t0`.`tags` -FROM `bigquery-public-data`.`stackoverflow`.`posts_questions` AS `t0` -WHERE - strpos(`t0`.`tags`, 'ibis') > 0 \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/system/snapshots/test_client/test_multiple_project_queries/out.sql b/ibis/backends/bigquery/tests/system/snapshots/test_client/test_multiple_project_queries/out.sql deleted file mode 100644 index fbae8c0def18..000000000000 --- a/ibis/backends/bigquery/tests/system/snapshots/test_client/test_multiple_project_queries/out.sql +++ /dev/null @@ -1,5 +0,0 @@ -SELECT - `t2`.`title` -FROM `bigquery-public-data`.`stackoverflow`.`posts_questions` AS `t2` -INNER JOIN `nyc-tlc`.`yellow`.`trips` AS `t3` - ON `t2`.`tags` = `t3`.`rate_code` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/system/snapshots/test_client/test_multiple_project_queries_database_api/out.sql b/ibis/backends/bigquery/tests/system/snapshots/test_client/test_multiple_project_queries_database_api/out.sql deleted file mode 100644 index f9d06ecd8b53..000000000000 --- a/ibis/backends/bigquery/tests/system/snapshots/test_client/test_multiple_project_queries_database_api/out.sql +++ /dev/null @@ -1,5 +0,0 @@ -SELECT - t0.`title` -FROM `bigquery-public-data`.stackoverflow.posts_questions AS t0 -INNER JOIN `nyc-tlc`.yellow.trips AS t1 - ON t0.`tags` = t1.`rate_code` \ No newline at end of file diff --git a/ibis/backends/bigquery/tests/system/test_client.py b/ibis/backends/bigquery/tests/system/test_client.py index bc5d7eddca33..8fb248a4d3fa 100644 --- a/ibis/backends/bigquery/tests/system/test_client.py +++ b/ibis/backends/bigquery/tests/system/test_client.py @@ -199,11 +199,9 @@ def test_parted_column(con, kind): assert t.columns == [expected_column, "string_col", "int_col"] -def test_cross_project_query(public, snapshot): +def test_cross_project_query(public): table = public.table("posts_questions") expr = table[table.tags.contains("ibis")][["title", "tags"]] - result = expr.compile() - snapshot.assert_match(result, "out.sql") n = 5 df = expr.limit(n).execute() assert len(df) == n @@ -226,17 +224,6 @@ def test_exists_table_different_project(con): assert "foobar" not in con.list_tables(database=dataset) -def test_multiple_project_queries(con, snapshot): - so = con.table( - "posts_questions", - database=("bigquery-public-data", "stackoverflow"), - ) - trips = con.table("trips", database="nyc-tlc.yellow") - join = so.join(trips, so.tags == trips.rate_code)[[so.title]] - result = join.compile() - snapshot.assert_match(result, "out.sql") - - def test_multiple_project_queries_execute(con): posts_questions = con.table( "posts_questions", database="bigquery-public-data.stackoverflow" diff --git a/ibis/backends/bigquery/tests/unit/udf/test_core.py b/ibis/backends/bigquery/tests/unit/udf/test_core.py index 6486a595bbb4..1209ec4fd22b 100644 --- a/ibis/backends/bigquery/tests/unit/udf/test_core.py +++ b/ibis/backends/bigquery/tests/unit/udf/test_core.py @@ -6,7 +6,10 @@ import pytest -from ibis.backends.bigquery.udf.core import PythonToJavaScriptTranslator, SymbolTable +from ibis.backends.sql.compilers.bigquery.udf.core import ( + PythonToJavaScriptTranslator, + SymbolTable, +) def test_symbol_table(): diff --git a/ibis/backends/bigquery/tests/unit/udf/test_find.py b/ibis/backends/bigquery/tests/unit/udf/test_find.py index db77d1a9bd93..4435474e2cb0 100644 --- a/ibis/backends/bigquery/tests/unit/udf/test_find.py +++ b/ibis/backends/bigquery/tests/unit/udf/test_find.py @@ -2,7 +2,7 @@ import ast -from ibis.backends.bigquery.udf.find import find_names +from ibis.backends.sql.compilers.bigquery.udf.find import find_names from ibis.util import is_iterable diff --git a/ibis/backends/clickhouse/__init__.py b/ibis/backends/clickhouse/__init__.py index ec7ce922359d..d93a522c19c9 100644 --- a/ibis/backends/clickhouse/__init__.py +++ b/ibis/backends/clickhouse/__init__.py @@ -17,6 +17,7 @@ from clickhouse_connect.driver.external import ExternalData import ibis +import ibis.backends.sql.compilers as sc import ibis.common.exceptions as com import ibis.config import ibis.expr.operations as ops @@ -26,7 +27,6 @@ from ibis.backends import BaseBackend, CanCreateDatabase from ibis.backends.clickhouse.converter import ClickHousePandasData from ibis.backends.sql import SQLBackend -from ibis.backends.sql.compilers import ClickHouseCompiler from ibis.backends.sql.compilers.base import C if TYPE_CHECKING: @@ -44,7 +44,7 @@ def _to_memtable(v): class Backend(SQLBackend, CanCreateDatabase): name = "clickhouse" - compiler = ClickHouseCompiler() + compiler = sc.clickhouse.compiler # ClickHouse itself does, but the client driver does not supports_temporary_tables = False @@ -732,7 +732,7 @@ def create_table( expression = None if obj is not None: - expression = self._to_sqlglot(obj) + expression = self.compiler.to_sqlglot(obj) external_tables.update(self._collect_in_memory_tables(obj)) code = sge.Create( @@ -759,7 +759,7 @@ def create_view( database: str | None = None, overwrite: bool = False, ) -> ir.Table: - expression = self._to_sqlglot(obj) + expression = self.compiler.to_sqlglot(obj) src = sge.Create( this=sg.table(name, db=database), kind="VIEW", diff --git a/ibis/backends/datafusion/__init__.py b/ibis/backends/datafusion/__init__.py index f88956337f36..e53e8112c1da 100644 --- a/ibis/backends/datafusion/__init__.py +++ b/ibis/backends/datafusion/__init__.py @@ -15,6 +15,7 @@ import sqlglot.expressions as sge import ibis +import ibis.backends.sql.compilers as sc import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops @@ -23,7 +24,6 @@ from ibis import util from ibis.backends import CanCreateCatalog, CanCreateDatabase, CanCreateSchema, NoUrl from ibis.backends.sql import SQLBackend -from ibis.backends.sql.compilers import DataFusionCompiler from ibis.backends.sql.compilers.base import C from ibis.common.dispatch import lazy_singledispatch from ibis.expr.operations.udf import InputType @@ -68,7 +68,7 @@ class Backend(SQLBackend, CanCreateCatalog, CanCreateDatabase, CanCreateSchema, name = "datafusion" supports_in_memory_tables = True supports_arrays = True - compiler = DataFusionCompiler() + compiler = sc.datafusion.compiler @property def version(self): @@ -629,16 +629,17 @@ def create_table( # If it's a memtable, it will get registered in the pre-execute hooks self._run_pre_execute_hooks(table) + compiler = self.compiler relname = "_" query = sg.select( *( - self.compiler.cast( + compiler.cast( sg.column(col, table=relname, quoted=quoted), dtype ).as_(col, quoted=quoted) for col, dtype in table.schema().items() ) ).from_( - self._to_sqlglot(table).subquery( + compiler.to_sqlglot(table).subquery( sg.to_identifier(relname, quoted=quoted) ) ) diff --git a/ibis/backends/druid/__init__.py b/ibis/backends/druid/__init__.py index 1b794279e105..571b471ead6c 100644 --- a/ibis/backends/druid/__init__.py +++ b/ibis/backends/druid/__init__.py @@ -10,11 +10,11 @@ import pydruid.db import sqlglot as sg +import ibis.backends.sql.compilers as sc import ibis.expr.datatypes as dt import ibis.expr.schema as sch from ibis import util from ibis.backends.sql import SQLBackend -from ibis.backends.sql.compilers import DruidCompiler from ibis.backends.sql.compilers.base import STAR from ibis.backends.sql.datatypes import DruidType @@ -31,7 +31,7 @@ class Backend(SQLBackend): name = "druid" - compiler = DruidCompiler() + compiler = sc.druid.compiler supports_create_or_replace = False supports_in_memory_tables = True diff --git a/ibis/backends/duckdb/__init__.py b/ibis/backends/duckdb/__init__.py index 0a1f05212eed..fc36fd4b1e1a 100644 --- a/ibis/backends/duckdb/__init__.py +++ b/ibis/backends/duckdb/__init__.py @@ -18,6 +18,7 @@ import sqlglot.expressions as sge import ibis +import ibis.backends.sql.compilers as sc import ibis.common.exceptions as exc import ibis.expr.operations as ops import ibis.expr.schema as sch @@ -26,7 +27,6 @@ from ibis.backends import CanCreateDatabase, CanCreateSchema, UrlFromPath from ibis.backends.duckdb.converter import DuckDBPandasData from ibis.backends.sql import SQLBackend -from ibis.backends.sql.compilers import DuckDBCompiler from ibis.backends.sql.compilers.base import STAR, C from ibis.common.dispatch import lazy_singledispatch from ibis.expr.operations.udf import InputType @@ -68,10 +68,7 @@ def __repr__(self): class Backend(SQLBackend, CanCreateDatabase, CanCreateSchema, UrlFromPath): name = "duckdb" - compiler = DuckDBCompiler() - - def _define_udf_translation_rules(self, expr): - """No-op: UDF translation rules are defined in the compiler.""" + compiler = sc.duckdb.compiler @property def settings(self) -> _Settings: @@ -95,34 +92,6 @@ def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any: query = query.sql(dialect=self.name) return self.con.execute(query, **kwargs) - def _to_sqlglot( - self, expr: ir.Expr, limit: str | None = None, params=None, **_: Any - ): - sql = super()._to_sqlglot(expr, limit=limit, params=params) - - table_expr = expr.as_table() - geocols = [ - name for name, typ in table_expr.schema().items() if typ.is_geospatial() - ] - - if not geocols: - return sql - else: - self._load_extensions(["spatial"]) - - compiler = self.compiler - quoted = compiler.quoted - return sg.select( - sge.Star( - replace=[ - compiler.f.st_aswkb(sg.column(col, quoted=quoted)).as_( - col, quoted=quoted - ) - for col in geocols - ] - ) - ).from_(sql.subquery()) - def create_table( self, name: str, @@ -195,7 +164,7 @@ def create_table( self._run_pre_execute_hooks(table) - query = self._to_sqlglot(table) + query = self.compiler.to_sqlglot(table) else: query = None @@ -1325,6 +1294,8 @@ def _to_duckdb_relation( self._run_pre_execute_hooks(expr) table_expr = expr.as_table() sql = self.compile(table_expr, limit=limit, params=params) + if table_expr.schema().geospatial: + self._load_extensions(["spatial"]) return self.con.sql(sql) def to_pyarrow_batches( @@ -1569,17 +1540,17 @@ def _register_udfs(self, expr: ir.Expr) -> None: con = self.con for udf_node in expr.op().find(ops.ScalarUDF): - compile_func = getattr( - self, f"_compile_{udf_node.__input_type__.name.lower()}_udf" + register_func = getattr( + self, f"_register_{udf_node.__input_type__.name.lower()}_udf" ) with contextlib.suppress(duckdb.InvalidInputException): con.remove_function(udf_node.__class__.__name__) - registration_func = compile_func(udf_node) + registration_func = register_func(udf_node) if registration_func is not None: registration_func(con) - def _compile_udf(self, udf_node: ops.ScalarUDF): + def _register_udf(self, udf_node: ops.ScalarUDF): func = udf_node.__func__ name = type(udf_node).__name__ type_mapper = self.compiler.type_mapper @@ -1600,8 +1571,8 @@ def register_udf(con): return register_udf - _compile_python_udf = _compile_udf - _compile_pyarrow_udf = _compile_udf + _register_python_udf = _register_udf + _register_pyarrow_udf = _register_udf def _get_temp_view_definition(self, name: str, definition: str) -> str: return sge.Create( diff --git a/ibis/backends/exasol/__init__.py b/ibis/backends/exasol/__init__.py index 89ce6dec9fd2..759dfc940d1d 100644 --- a/ibis/backends/exasol/__init__.py +++ b/ibis/backends/exasol/__init__.py @@ -12,6 +12,7 @@ import sqlglot.expressions as sge import ibis +import ibis.backends.sql.compilers as sc import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops @@ -20,7 +21,6 @@ from ibis import util from ibis.backends import CanCreateDatabase, CanCreateSchema from ibis.backends.sql import SQLBackend -from ibis.backends.sql.compilers import ExasolCompiler from ibis.backends.sql.compilers.base import STAR, C if TYPE_CHECKING: @@ -39,7 +39,7 @@ class Backend(SQLBackend, CanCreateDatabase, CanCreateSchema): name = "exasol" - compiler = ExasolCompiler() + compiler = sc.exasol.compiler supports_temporary_tables = False supports_create_or_replace = False supports_in_memory_tables = False @@ -360,7 +360,7 @@ def create_table( self._run_pre_execute_hooks(table) - query = self._to_sqlglot(table) + query = self.compiler.to_sqlglot(table) else: query = None diff --git a/ibis/backends/flink/__init__.py b/ibis/backends/flink/__init__.py index 8445e5250161..f959ca1e5f0c 100644 --- a/ibis/backends/flink/__init__.py +++ b/ibis/backends/flink/__init__.py @@ -7,6 +7,7 @@ import sqlglot.expressions as sge import ibis +import ibis.backends.sql.compilers as sc import ibis.common.exceptions as exc import ibis.expr.operations as ops import ibis.expr.schema as sch @@ -23,7 +24,6 @@ RenameTable, ) from ibis.backends.sql import SQLBackend -from ibis.backends.sql.compilers import FlinkCompiler from ibis.backends.tests.errors import Py4JJavaError from ibis.expr.operations.udf import InputType from ibis.util import gen_name @@ -44,7 +44,7 @@ class Backend(SQLBackend, CanCreateDatabase, NoUrl): name = "flink" - compiler = FlinkCompiler() + compiler = sc.flink.compiler supports_temporary_tables = True supports_python_udfs = True @@ -321,26 +321,27 @@ def version(self) -> str: def _register_udfs(self, expr: ir.Expr) -> None: for udf_node in expr.op().find(ops.ScalarUDF): register_func = getattr( - self, f"_compile_{udf_node.__input_type__.name.lower()}_udf" + self, f"_register_{udf_node.__input_type__.name.lower()}_udf" ) register_func(udf_node) def _register_udf(self, udf_node: ops.ScalarUDF): - import pyflink.table.udf + from pyflink.table.udf import udf from ibis.backends.flink.datatypes import FlinkType name = type(udf_node).__name__ self._table_env.drop_temporary_function(name) - udf = pyflink.table.udf.udf( + + func = udf( udf_node.__func__, result_type=FlinkType.from_ibis(udf_node.dtype), func_type=_INPUT_TYPE_TO_FUNC_TYPE[udf_node.__input_type__], ) - self._table_env.create_temporary_function(name, udf) + self._table_env.create_temporary_function(name, func) - _compile_pandas_udf = _register_udf - _compile_python_udf = _register_udf + _register_pandas_udf = _register_udf + _register_python_udf = _register_udf def compile( self, @@ -354,11 +355,6 @@ def compile( expr, params=params, pretty=pretty ) # Discard `limit` and other kwargs. - def _to_sqlglot( - self, expr: ir.Expr, params: Mapping[ir.Expr, Any] | None = None, **_: Any - ) -> str: - return super()._to_sqlglot(expr, params=params) - def execute(self, expr: ir.Expr, **kwargs: Any) -> Any: """Execute an expression.""" self._register_udfs(expr) diff --git a/ibis/backends/impala/__init__.py b/ibis/backends/impala/__init__.py index 32bd8439e645..e6d8a3e4fc18 100644 --- a/ibis/backends/impala/__init__.py +++ b/ibis/backends/impala/__init__.py @@ -13,6 +13,7 @@ import sqlglot.expressions as sge from impala.error import Error as ImpylaError +import ibis.backends.sql.compilers as sc import ibis.common.exceptions as com import ibis.config import ibis.expr.schema as sch @@ -38,7 +39,6 @@ wrap_udf, ) from ibis.backends.sql import SQLBackend -from ibis.backends.sql.compilers import ImpalaCompiler if TYPE_CHECKING: from collections.abc import Mapping @@ -64,7 +64,7 @@ class Backend(SQLBackend): name = "impala" - compiler = ImpalaCompiler() + compiler = sc.impala.compiler supports_in_memory_tables = True diff --git a/ibis/backends/mssql/__init__.py b/ibis/backends/mssql/__init__.py index a4e711b93312..a1bb4f0b0f09 100644 --- a/ibis/backends/mssql/__init__.py +++ b/ibis/backends/mssql/__init__.py @@ -14,6 +14,7 @@ import sqlglot.expressions as sge import ibis +import ibis.backends.sql.compilers as sc import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops @@ -22,7 +23,6 @@ from ibis import util from ibis.backends import CanCreateCatalog, CanCreateDatabase, CanCreateSchema, NoUrl from ibis.backends.sql import SQLBackend -from ibis.backends.sql.compilers import MSSQLCompiler from ibis.backends.sql.compilers.base import STAR, C if TYPE_CHECKING: @@ -75,7 +75,7 @@ def datetimeoffset_to_datetime(value): class Backend(SQLBackend, CanCreateCatalog, CanCreateDatabase, CanCreateSchema, NoUrl): name = "mssql" - compiler = MSSQLCompiler() + compiler = sc.mssql.compiler supports_create_or_replace = False @property @@ -597,7 +597,7 @@ def create_table( self._run_pre_execute_hooks(table) - query = self._to_sqlglot(table) + query = self.compiler.to_sqlglot(table) else: query = None @@ -719,21 +719,6 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: if not df.empty: cur.executemany(insert_stmt, data) - def _to_sqlglot( - self, expr: ir.Expr, *, limit: str | None = None, params=None, **_: Any - ): - """Compile an Ibis expression to a sqlglot object.""" - table_expr = expr.as_table() - conversions = { - name: ibis.ifelse(table_expr[name], 1, 0).cast("boolean") - for name, typ in table_expr.schema().items() - if typ.is_boolean() - } - - if conversions: - table_expr = table_expr.mutate(**conversions) - return super()._to_sqlglot(table_expr, limit=limit, params=params) - def _cursor_batches( self, expr: ir.Expr, diff --git a/ibis/backends/mysql/__init__.py b/ibis/backends/mysql/__init__.py index 9d1e3926376b..67dcb683a105 100644 --- a/ibis/backends/mysql/__init__.py +++ b/ibis/backends/mysql/__init__.py @@ -16,6 +16,7 @@ import sqlglot.expressions as sge import ibis +import ibis.backends.sql.compilers as sc import ibis.common.exceptions as com import ibis.expr.operations as ops import ibis.expr.schema as sch @@ -24,7 +25,6 @@ from ibis.backends import CanCreateDatabase from ibis.backends.mysql.datatypes import _type_from_cursor_info from ibis.backends.sql import SQLBackend -from ibis.backends.sql.compilers import MySQLCompiler from ibis.backends.sql.compilers.base import STAR, TRUE, C if TYPE_CHECKING: @@ -38,7 +38,7 @@ class Backend(SQLBackend, CanCreateDatabase): name = "mysql" - compiler = MySQLCompiler() + compiler = sc.mysql.compiler supports_create_or_replace = False def _from_url(self, url: ParseResult, **kwargs): @@ -412,7 +412,7 @@ def create_table( self._run_pre_execute_hooks(table) - query = self._to_sqlglot(table) + query = self.compiler.to_sqlglot(table) else: query = None diff --git a/ibis/backends/oracle/__init__.py b/ibis/backends/oracle/__init__.py index 2339859f0000..a7831e53f0bb 100644 --- a/ibis/backends/oracle/__init__.py +++ b/ibis/backends/oracle/__init__.py @@ -17,6 +17,7 @@ import sqlglot.expressions as sge import ibis +import ibis.backends.sql.compilers as sc import ibis.common.exceptions as exc import ibis.expr.datatypes as dt import ibis.expr.operations as ops @@ -24,9 +25,8 @@ import ibis.expr.types as ir from ibis import util from ibis.backends import CanListDatabase, CanListSchema -from ibis.backends.sql import STAR, SQLBackend -from ibis.backends.sql.compilers import OracleCompiler -from ibis.backends.sql.compilers.base import C +from ibis.backends.sql import SQLBackend +from ibis.backends.sql.compilers.base import STAR, C if TYPE_CHECKING: from urllib.parse import ParseResult @@ -79,7 +79,7 @@ def metadata_row_to_type( class Backend(SQLBackend, CanListDatabase, CanListSchema): name = "oracle" - compiler = OracleCompiler() + compiler = sc.oracle.compiler @cached_property def version(self): @@ -420,7 +420,7 @@ def create_table( self._run_pre_execute_hooks(table) - query = self._to_sqlglot(table) + query = self.compiler.to_sqlglot(table) else: query = None diff --git a/ibis/backends/postgres/__init__.py b/ibis/backends/postgres/__init__.py index 801514d320c7..12ec7342d01b 100644 --- a/ibis/backends/postgres/__init__.py +++ b/ibis/backends/postgres/__init__.py @@ -4,9 +4,6 @@ import contextlib import inspect -import textwrap -from functools import partial -from itertools import takewhile from operator import itemgetter from typing import TYPE_CHECKING, Any from urllib.parse import unquote_plus @@ -18,6 +15,7 @@ from pandas.api.types import is_float_dtype import ibis +import ibis.backends.sql.compilers as sc import ibis.common.exceptions as com import ibis.common.exceptions as exc import ibis.expr.datatypes as dt @@ -27,9 +25,7 @@ from ibis import util from ibis.backends import CanCreateDatabase, CanCreateSchema, CanListCatalog from ibis.backends.sql import SQLBackend -from ibis.backends.sql.compilers import PostgresCompiler from ibis.backends.sql.compilers.base import TRUE, C, ColGen, F -from ibis.common.exceptions import InvalidDecoratorError if TYPE_CHECKING: from collections.abc import Callable @@ -41,15 +37,9 @@ import pyarrow as pa -def _verify_source_line(func_name: str, line: str): - if line.startswith("@"): - raise InvalidDecoratorError(func_name, line) - return line - - class Backend(SQLBackend, CanListCatalog, CanCreateDatabase, CanCreateSchema): name = "postgres" - compiler = PostgresCompiler() + compiler = sc.postgres.compiler supports_python_udfs = True def _from_url(self, url: ParseResult, **kwargs): @@ -509,69 +499,6 @@ def fake_func(*args, **kwargs): ... op = ops.udf.scalar.builtin(fake_func, database=database) return op - def _get_udf_source(self, udf_node: ops.ScalarUDF): - config = udf_node.__config__ - func = udf_node.__func__ - func_name = func.__name__ - - lines, _ = inspect.getsourcelines(func) - iter_lines = iter(lines) - - function_premable_lines = list( - takewhile(lambda line: not line.lstrip().startswith("def "), iter_lines) - ) - - if len(function_premable_lines) > 1: - raise InvalidDecoratorError( - name=func_name, lines="".join(function_premable_lines) - ) - - source = textwrap.dedent( - "".join(map(partial(_verify_source_line, func_name), iter_lines)) - ).strip() - - type_mapper = self.compiler.type_mapper - argnames = udf_node.argnames - return dict( - name=type(udf_node).__name__, - ident=self.compiler.__sql_name__(udf_node), - signature=", ".join( - f"{argname} {type_mapper.to_string(arg.dtype)}" - for argname, arg in zip(argnames, udf_node.args) - ), - return_type=type_mapper.to_string(udf_node.dtype), - language=config.get("language", "plpython3u"), - source=source, - args=", ".join(argnames), - ) - - def _define_udf_translation_rules(self, expr: ir.Expr) -> None: - """No-op, these are defined in the compiler.""" - - def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> str: - return """\ -CREATE OR REPLACE FUNCTION {ident}({signature}) -RETURNS {return_type} -LANGUAGE {language} -AS $$ -{source} -return {name}({args}) -$$""".format(**self._get_udf_source(udf_node)) - - def _register_udfs(self, expr: ir.Expr) -> None: - udf_sources = [] - for udf_node in expr.op().find(ops.ScalarUDF): - compile_func = getattr( - self, f"_compile_{udf_node.__input_type__.name.lower()}_udf" - ) - if sql := compile_func(udf_node): - udf_sources.append(sql) - if udf_sources: - # define every udf in one execution to avoid the overhead of - # database round trips per udf - with self._safe_raw_sql(";\n".join(udf_sources)): - pass - def get_schema( self, name: str, @@ -744,7 +671,7 @@ def create_table( self._run_pre_execute_hooks(table) - query = self._to_sqlglot(table) + query = self.compiler.to_sqlglot(table) else: query = None @@ -847,17 +774,3 @@ def raw_sql(self, query: str | sg.Expression, **kwargs: Any) -> Any: else: con.commit() return cursor - - def _to_sqlglot( - self, expr: ir.Expr, limit: str | None = None, params=None, **kwargs: Any - ): - table_expr = expr.as_table() - conversions = { - name: table_expr[name].as_ewkb() - for name, typ in table_expr.schema().items() - if typ.is_geospatial() - } - - if conversions: - table_expr = table_expr.mutate(**conversions) - return super()._to_sqlglot(table_expr, limit=limit, params=params, **kwargs) diff --git a/ibis/backends/pyspark/__init__.py b/ibis/backends/pyspark/__init__.py index 1226d1110f74..2c171394c320 100644 --- a/ibis/backends/pyspark/__init__.py +++ b/ibis/backends/pyspark/__init__.py @@ -14,6 +14,7 @@ from pyspark.sql import SparkSession from pyspark.sql.types import BooleanType, DoubleType, LongType, StringType +import ibis.backends.sql.compilers as sc import ibis.common.exceptions as com import ibis.config import ibis.expr.operations as ops @@ -24,7 +25,6 @@ from ibis.backends.pyspark.converter import PySparkPandasData from ibis.backends.pyspark.datatypes import PySparkSchema, PySparkType from ibis.backends.sql import SQLBackend -from ibis.backends.sql.compilers import PySparkCompiler from ibis.expr.operations.udf import InputType from ibis.legacy.udf.vectorized import _coerce_to_series from ibis.util import deprecated @@ -104,7 +104,7 @@ def _interval_to_string(interval): class Backend(SQLBackend, CanListCatalog, CanCreateDatabase): name = "pyspark" - compiler = PySparkCompiler() + compiler = sc.pyspark.compiler class Options(ibis.config.Config): """PySpark options. diff --git a/ibis/backends/risingwave/__init__.py b/ibis/backends/risingwave/__init__.py index d43ca2b51d16..2270a67dc998 100644 --- a/ibis/backends/risingwave/__init__.py +++ b/ibis/backends/risingwave/__init__.py @@ -11,12 +11,12 @@ from psycopg2 import extras import ibis +import ibis.backends.sql.compilers as sc import ibis.common.exceptions as com import ibis.expr.operations as ops import ibis.expr.types as ir from ibis import util from ibis.backends.postgres import Backend as PostgresBackend -from ibis.backends.sql.compilers import RisingWaveCompiler from ibis.util import experimental if TYPE_CHECKING: @@ -45,7 +45,7 @@ def format_properties(props): class Backend(PostgresBackend): name = "risingwave" - compiler = RisingWaveCompiler() + compiler = sc.risingwave.compiler supports_python_udfs = False def do_connect( @@ -202,7 +202,7 @@ def create_table( self._run_pre_execute_hooks(table) - query = self._to_sqlglot(table) + query = self.compiler.to_sqlglot(table) else: query = None diff --git a/ibis/backends/snowflake/__init__.py b/ibis/backends/snowflake/__init__.py index 64d57028100f..32350aa9de27 100644 --- a/ibis/backends/snowflake/__init__.py +++ b/ibis/backends/snowflake/__init__.py @@ -3,15 +3,11 @@ import contextlib import functools import glob -import inspect import itertools import json import os -import platform import shutil -import sys import tempfile -import textwrap import warnings from operator import itemgetter from pathlib import Path @@ -25,6 +21,7 @@ import sqlglot.expressions as sge import ibis +import ibis.backends.sql.compilers as sc import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops @@ -34,9 +31,7 @@ from ibis.backends import CanCreateCatalog, CanCreateDatabase, CanCreateSchema from ibis.backends.snowflake.converter import SnowflakePandasData from ibis.backends.sql import SQLBackend -from ibis.backends.sql.compilers import SnowflakeCompiler from ibis.backends.sql.compilers.base import STAR -from ibis.backends.sql.datatypes import SnowflakeType if TYPE_CHECKING: from collections.abc import Iterable, Iterator, Mapping @@ -145,10 +140,9 @@ class Backend(SQLBackend, CanCreateCatalog, CanCreateDatabase, CanCreateSchema): name = "snowflake" - compiler = SnowflakeCompiler() + compiler = sc.snowflake.compiler supports_python_udfs = True - _latest_udf_python_version = (3, 10) _top_level_methods = ("from_connection", "from_snowpark") def __init__(self, *args, _from_snowpark: bool = False, **kwargs) -> None: @@ -458,107 +452,6 @@ def reconnect(self) -> None: ) super().reconnect() - def _get_udf_source(self, udf_node: ops.ScalarUDF): - name = type(udf_node).__name__ - signature = ", ".join( - f"{name} {self.compiler.type_mapper.to_string(arg.dtype)}" - for name, arg in zip(udf_node.argnames, udf_node.args) - ) - return_type = SnowflakeType.to_string(udf_node.dtype) - lines, _ = inspect.getsourcelines(udf_node.__func__) - source = textwrap.dedent( - "".join( - itertools.dropwhile( - lambda line: not line.lstrip().startswith("def "), lines - ) - ) - ).strip() - - config = udf_node.__config__ - - preamble_lines = [*self._UDF_PREAMBLE_LINES] - - if imports := config.get("imports"): - preamble_lines.append(f"IMPORTS = ({', '.join(map(repr, imports))})") - - packages = "({})".format( - ", ".join(map(repr, ("pandas", *config.get("packages", ())))) - ) - preamble_lines.append(f"PACKAGES = {packages}") - - return dict( - source=source, - name=name, - func_name=udf_node.__func_name__, - preamble="\n".join(preamble_lines).format( - name=name, - signature=signature, - return_type=return_type, - comment=f"Generated by ibis {ibis.__version__} using Python {platform.python_version()}", - version=".".join( - map(str, min(sys.version_info[:2], self._latest_udf_python_version)) - ), - ), - ) - - _UDF_PREAMBLE_LINES = ( - "CREATE OR REPLACE TEMPORARY FUNCTION {name}({signature})", - "RETURNS {return_type}", - "LANGUAGE PYTHON", - "IMMUTABLE", - "RUNTIME_VERSION = '{version}'", - "COMMENT = '{comment}'", - ) - - def _define_udf_translation_rules(self, expr): - """No-op, these are defined in the compiler.""" - - def _register_udfs(self, expr: ir.Expr) -> None: - udf_sources = [] - for udf_node in expr.op().find(ops.ScalarUDF): - compile_func = getattr( - self, f"_compile_{udf_node.__input_type__.name.lower()}_udf" - ) - if sql := compile_func(udf_node): - udf_sources.append(sql) - if udf_sources: - # define every udf in one execution to avoid the overhead of db - # round trips per udf - with self._safe_raw_sql(";\n".join(udf_sources)): - pass - - def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> str: - return """\ -{preamble} -HANDLER = '{func_name}' -AS $$ -from __future__ import annotations - -from typing import * - -{source} -$$""".format(**self._get_udf_source(udf_node)) - - def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str: - template = """\ -{preamble} -HANDLER = 'wrapper' -AS $$ -from __future__ import annotations - -from typing import * - -import _snowflake -import pandas as pd - -{source} - -@_snowflake.vectorized(input=pd.DataFrame) -def wrapper(df): - return {func_name}(*(col for _, col in df.items())) -$$""" - return template.format(**self._get_udf_source(udf_node)) - def to_pyarrow( self, expr: ir.Expr, @@ -594,10 +487,10 @@ def to_pandas_batches( *, params: Mapping[ir.Scalar, Any] | None = None, limit: int | str | None = None, - **kwargs: Any, + chunk_size: int = 1_000_000, ) -> Iterator[pd.DataFrame | pd.Series | Any]: self._run_pre_execute_hooks(expr) - sql = self.compile(expr, limit=limit, params=params, **kwargs) + sql = self.compile(expr, limit=limit, params=params) target_schema = expr.as_table().schema() converter = functools.partial( SnowflakePandasData.convert_table, schema=target_schema @@ -950,7 +843,7 @@ def create_table( self._run_pre_execute_hooks(table) - query = self._to_sqlglot(table) + query = self.compiler.to_sqlglot(table) else: query = None diff --git a/ibis/backends/sql/__init__.py b/ibis/backends/sql/__init__.py index 7bb47cbf282a..99d9555456bb 100644 --- a/ibis/backends/sql/__init__.py +++ b/ibis/backends/sql/__init__.py @@ -14,7 +14,6 @@ import ibis.expr.types as ir from ibis import util from ibis.backends import BaseBackend -from ibis.backends.sql.compilers.base import STAR if TYPE_CHECKING: from collections.abc import Iterable, Mapping @@ -143,39 +142,15 @@ def table( namespace=ops.Namespace(catalog=catalog, database=database), ).to_expr() - def _to_sqlglot( - self, expr: ir.Expr, *, limit: str | None = None, params=None, **_: Any - ): - """Compile an Ibis expression to a sqlglot object.""" - table_expr = expr.as_table() - - if limit == "default": - limit = ibis.options.sql.default_limit - if limit is not None: - table_expr = table_expr.limit(limit) - - if params is None: - params = {} - - sql = self.compiler.translate(table_expr.op(), params=params) - assert not isinstance(sql, sge.Subquery) - - if isinstance(sql, sge.Table): - sql = sg.select(STAR, copy=False).from_(sql, copy=False) - - assert not isinstance(sql, sge.Subquery) - return sql - def compile( self, expr: ir.Expr, limit: str | None = None, - params=None, + params: Mapping[ir.Expr, Any] | None = None, pretty: bool = False, - **kwargs: Any, ): """Compile an Ibis expression to a SQL string.""" - query = self._to_sqlglot(expr, limit=limit, params=params, **kwargs) + query = self.compiler.to_sqlglot(expr, limit=limit, params=params) sql = query.sql(dialect=self.dialect, pretty=pretty, copy=False) self._log(sql) return sql @@ -220,7 +195,7 @@ def _get_sql_string_view_schema(self, name, table, query) -> sch.Schema: compiler = self.compiler dialect = compiler.dialect - cte = self._to_sqlglot(table) + cte = compiler.to_sqlglot(table) parsed = sg.parse_one(query, read=dialect) parsed.args["with"] = cte.args.pop("with", []) parsed = parsed.with_( @@ -230,6 +205,21 @@ def _get_sql_string_view_schema(self, name, table, query) -> sch.Schema: sql = parsed.sql(dialect) return self._get_schema_using_query(sql) + def _register_udfs(self, expr: ir.Expr) -> None: + udf_sources = [] + compiler = self.compiler + for udf_node in expr.op().find(ops.ScalarUDF): + compile_func = getattr( + compiler, f"_compile_{udf_node.__input_type__.name.lower()}_udf" + ) + if sql := compile_func(udf_node): + udf_sources.append(sql) + if udf_sources: + # define every udf in one execution to avoid the overhead of db + # round trips per udf + with self._safe_raw_sql(";\n".join(udf_sources)): + pass + def create_view( self, name: str, @@ -568,24 +558,6 @@ def disconnect(self): # _most_ sqlglot backends self.con.close() - def _compile_builtin_udf(self, udf_node: ops.ScalarUDF | ops.AggUDF) -> None: - """Compile a built-in UDF. No-op by default.""" - - def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> None: - raise NotImplementedError( - f"Python UDFs are not supported in the {self.name} backend" - ) - - def _compile_pyarrow_udf(self, udf_node: ops.ScalarUDF) -> None: - raise NotImplementedError( - f"PyArrow UDFs are not supported in the {self.name} backend" - ) - - def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str: - raise NotImplementedError( - f"pandas UDFs are not supported in the {self.name} backend" - ) - def _to_catalog_db_tuple(self, table_loc: sge.Table): if (sg_cat := table_loc.args["catalog"]) is not None: sg_cat.args["quoted"] = False @@ -643,3 +615,21 @@ def _to_sqlglot_table(self, database): ) return database + + def _register_builtin_udf(self, udf_node: ops.ScalarUDF) -> None: + """No-op.""" + + def _register_python_udf(self, udf_node: ops.ScalarUDF) -> str: + raise NotImplementedError( + f"Python UDFs are not supported in the {self.dialect} backend" + ) + + def _register_pyarrow_udf(self, udf_node: ops.ScalarUDF) -> str: + raise NotImplementedError( + f"PyArrow UDFs are not supported in the {self.dialect} backend" + ) + + def _register_pandas_udf(self, udf_node: ops.ScalarUDF) -> str: + raise NotImplementedError( + f"pandas UDFs are not supported in the {self.dialect} backend" + ) diff --git a/ibis/backends/sql/compilers/base.py b/ibis/backends/sql/compilers/base.py index e66aeecc8245..6e201b52fa71 100644 --- a/ibis/backends/sql/compilers/base.py +++ b/ibis/backends/sql/compilers/base.py @@ -496,6 +496,24 @@ def dialect(self) -> str: def type_mapper(self) -> type[SqlglotType]: """The type mapper for the backend.""" + def _compile_builtin_udf(self, udf_node: ops.ScalarUDF) -> None: # noqa: B027 + """No-op.""" + + def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> None: + raise NotImplementedError( + f"Python UDFs are not supported in the {self.dialect} backend" + ) + + def _compile_pyarrow_udf(self, udf_node: ops.ScalarUDF) -> None: + raise NotImplementedError( + f"PyArrow UDFs are not supported in the {self.dialect} backend" + ) + + def _compile_pandas_udf(self, udf_node: ops.ScalarUDF) -> str: + raise NotImplementedError( + f"pandas UDFs are not supported in the {self.dialect} backend" + ) + # Concrete API def if_(self, condition, true, false: sge.Expression | None = None) -> sge.If: @@ -517,6 +535,34 @@ def _prepare_params(self, params): result[node] = value return result + def to_sqlglot( + self, + expr: ir.Expr, + *, + limit: str | None = None, + params: Mapping[ir.Expr, Any] | None = None, + ): + import ibis + + table_expr = expr.as_table() + + if limit == "default": + limit = ibis.options.sql.default_limit + if limit is not None: + table_expr = table_expr.limit(limit) + + if params is None: + params = {} + + sql = self.translate(table_expr.op(), params=params) + assert not isinstance(sql, sge.Subquery) + + if isinstance(sql, sge.Table): + sql = sg.select(STAR, copy=False).from_(sql, copy=False) + + assert not isinstance(sql, sge.Subquery) + return sql + def translate(self, op, *, params: Mapping[ir.Value, Any]) -> sge.Expression: """Translate an ibis operation to a sqlglot expression. diff --git a/ibis/backends/sql/compilers/bigquery.py b/ibis/backends/sql/compilers/bigquery/__init__.py similarity index 81% rename from ibis/backends/sql/compilers/bigquery.py rename to ibis/backends/sql/compilers/bigquery/__init__.py index 013600796c10..e0c225d45323 100644 --- a/ibis/backends/sql/compilers/bigquery.py +++ b/ibis/backends/sql/compilers/bigquery/__init__.py @@ -3,6 +3,7 @@ from __future__ import annotations import re +from typing import TYPE_CHECKING, Any import sqlglot as sg import sqlglot.expressions as sge @@ -13,6 +14,7 @@ import ibis.expr.operations as ops from ibis import util from ibis.backends.sql.compilers.base import NULL, STAR, AggGen, SQLGlotCompiler +from ibis.backends.sql.compilers.bigquery.udf.core import PythonToJavaScriptTranslator from ibis.backends.sql.datatypes import BigQueryType, BigQueryUDFType from ibis.backends.sql.rewrites import ( exclude_unsupported_window_frame_from_ops, @@ -21,9 +23,81 @@ ) from ibis.common.temporal import DateUnit, IntervalUnit, TimestampUnit, TimeUnit +if TYPE_CHECKING: + from collections.abc import Mapping + + import ibis.expr.types as ir + _NAME_REGEX = re.compile(r'[^!"$()*,./;?@[\\\]^`{}~\n]+') +_MEMTABLE_PATTERN = re.compile( + r"^_?ibis_(?:[A-Za-z_][A-Za-z_0-9]*)_memtable_[a-z0-9]{26}$" +) + + +def _qualify_memtable( + node: sge.Expression, *, dataset: str | None, project: str | None +) -> sge.Expression: + """Add a BigQuery dataset and project to memtable references.""" + if isinstance(node, sge.Table) and _MEMTABLE_PATTERN.match(node.name) is not None: + node.args["db"] = dataset + node.args["catalog"] = project + # make sure to quote table location + node = _force_quote_table(node) + return node + + +def _remove_null_ordering_from_unsupported_window( + node: sge.Expression, +) -> sge.Expression: + """Remove null ordering in window frame clauses not supported by BigQuery. + + BigQuery has only partial support for NULL FIRST/LAST in RANGE windows so + we remove it from any window frame clause that doesn't support it. + + Here's the support matrix: + + ✅ sum(x) over (order by y desc nulls last) + 🚫 sum(x) over (order by y asc nulls last) + ✅ sum(x) over (order by y asc nulls first) + 🚫 sum(x) over (order by y desc nulls first) + """ + if isinstance(node, sge.Window): + order = node.args.get("order") + if order is not None: + for key in order.args["expressions"]: + kargs = key.args + if kargs.get("desc") is True and kargs.get("nulls_first", False): + kargs["nulls_first"] = False + elif kargs.get("desc") is False and not kargs.setdefault( + "nulls_first", True + ): + kargs["nulls_first"] = True + return node + + +def _force_quote_table(table: sge.Table) -> sge.Table: + """Force quote all the parts of a bigquery path. + + The BigQuery identifier quoting semantics are bonkers + https://cloud.google.com/bigquery/docs/reference/standard-sql/lexical#identifiers + + my-table is OK, but not mydataset.my-table + + mytable-287 is OK, but not mytable-287a + + Just quote everything. + """ + for key in ("this", "db", "catalog"): + if (val := table.args[key]) is not None: + if isinstance(val, sg.exp.Identifier) and not val.quoted: + val.args["quoted"] = True + else: + table.args[key] = sg.to_identifier(val, quoted=True) + return table + + class BigQueryCompiler(SQLGlotCompiler): dialect = BigQuery type_mapper = BigQueryType @@ -117,6 +191,140 @@ class BigQueryCompiler(SQLGlotCompiler): ops.ExtractHost: "net.host", } + def to_sqlglot( + self, + expr: ir.Expr, + *, + limit: str | None = None, + params: Mapping[ir.Expr, Any] | None = None, + session_dataset_id: str | None = None, + session_project: str | None = None, + ) -> Any: + """Compile an Ibis expression. + + Parameters + ---------- + expr + Ibis expression + limit + For expressions yielding result sets; retrieve at most this number + of values/rows. Overrides any limit already set on the expression. + params + Named unbound parameters + session_dataset_id + Optional dataset ID to qualify memtable references. + session_project + Optional project ID to qualify memtable references. + + Returns + ------- + Any + The output of compilation. The type of this value depends on the + backend. + + """ + sql = super().to_sqlglot(expr, limit=limit, params=params) + + table_expr = expr.as_table() + geocols = table_expr.schema().geospatial + + result = sql.transform( + _qualify_memtable, + dataset=session_dataset_id, + project=session_project, + ).transform(_remove_null_ordering_from_unsupported_window) + + if geocols: + # if there are any geospatial columns, we have to convert them to WKB, + # so interactive mode knows how to display them + # + # by default bigquery returns data to python as WKT, and there's really + # no point in supporting both if we don't need to. + quoted = self.quoted + result = sg.select( + sge.Star( + replace=[ + self.f.st_asbinary(sg.column(col, quoted=quoted)).as_( + col, quoted=quoted + ) + for col in geocols + ] + ) + ).from_(result.subquery()) + + sources = [] + + for udf_node in table_expr.op().find(ops.ScalarUDF): + compile_func = getattr( + self, f"_compile_{udf_node.__input_type__.name.lower()}_udf" + ) + if sql := compile_func(udf_node): + sources.append(sql) + + if not sources: + return result + + sources.append(result) + return sources + + def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> sge.Create: + name = type(udf_node).__name__ + type_mapper = self.udf_type_mapper + + body = PythonToJavaScriptTranslator(udf_node.__func__).compile() + config = udf_node.__config__ + libraries = config.get("libraries", []) + + signature = [ + sge.ColumnDef( + this=sg.to_identifier(name, quoted=self.quoted), + kind=type_mapper.from_ibis(param.annotation.pattern.dtype), + ) + for name, param in udf_node.__signature__.parameters.items() + ] + + lines = ['"""'] + + if config.get("strict", True): + lines.append('"use strict";') + + lines += [ + body, + "", + f"return {udf_node.__func_name__}({', '.join(udf_node.argnames)});", + '"""', + ] + + func = sge.Create( + kind="FUNCTION", + this=sge.UserDefinedFunction( + this=sg.to_identifier(name), expressions=signature, wrapped=True + ), + # not exactly what I had in mind, but it works + # + # quoting is too simplistic to handle multiline strings + expression=sge.Var(this="\n".join(lines)), + exists=False, + properties=sge.Properties( + expressions=[ + sge.TemporaryProperty(), + sge.ReturnsProperty(this=type_mapper.from_ibis(udf_node.dtype)), + sge.StabilityProperty( + this="IMMUTABLE" if config.get("determinism") else "VOLATILE" + ), + sge.LanguageProperty(this=sg.to_identifier("js")), + ] + + [ + sge.Property( + this=sg.to_identifier("library"), value=self.f.array(*libraries) + ) + ] + * bool(libraries) + ), + ) + + return func + @staticmethod def _minimize_spec(start, end, spec): if ( @@ -817,3 +1025,6 @@ def visit_ArrayAny(self, op, *, arg): def visit_ArrayAll(self, op, *, arg): return self._array_reduction(arg=arg, reduction="logical_and") + + +compiler = BigQueryCompiler() diff --git a/ibis/backends/bigquery/udf/__init__.py b/ibis/backends/sql/compilers/bigquery/udf/__init__.py similarity index 100% rename from ibis/backends/bigquery/udf/__init__.py rename to ibis/backends/sql/compilers/bigquery/udf/__init__.py diff --git a/ibis/backends/bigquery/udf/core.py b/ibis/backends/sql/compilers/bigquery/udf/core.py similarity index 99% rename from ibis/backends/bigquery/udf/core.py rename to ibis/backends/sql/compilers/bigquery/udf/core.py index b75a49118d25..9cffb420840a 100644 --- a/ibis/backends/bigquery/udf/core.py +++ b/ibis/backends/sql/compilers/bigquery/udf/core.py @@ -10,8 +10,8 @@ from collections import ChainMap from typing import TYPE_CHECKING -from ibis.backends.bigquery.udf.find import find_names -from ibis.backends.bigquery.udf.rewrite import rewrite +from ibis.backends.sql.compilers.bigquery.udf.find import find_names +from ibis.backends.sql.compilers.bigquery.udf.rewrite import rewrite if TYPE_CHECKING: from collections.abc import Callable diff --git a/ibis/backends/bigquery/udf/find.py b/ibis/backends/sql/compilers/bigquery/udf/find.py similarity index 100% rename from ibis/backends/bigquery/udf/find.py rename to ibis/backends/sql/compilers/bigquery/udf/find.py diff --git a/ibis/backends/bigquery/udf/rewrite.py b/ibis/backends/sql/compilers/bigquery/udf/rewrite.py similarity index 100% rename from ibis/backends/bigquery/udf/rewrite.py rename to ibis/backends/sql/compilers/bigquery/udf/rewrite.py diff --git a/ibis/backends/sql/compilers/clickhouse.py b/ibis/backends/sql/compilers/clickhouse.py index 66cc9a421e58..ec94f7f4ebc0 100644 --- a/ibis/backends/sql/compilers/clickhouse.py +++ b/ibis/backends/sql/compilers/clickhouse.py @@ -795,3 +795,6 @@ def visit_MapContains(self, op, *, arg, key): return self.if_( sg.or_(arg.is_(NULL), key.is_(NULL)), NULL, self.f.mapContains(arg, key) ) + + +compiler = ClickHouseCompiler() diff --git a/ibis/backends/sql/compilers/datafusion.py b/ibis/backends/sql/compilers/datafusion.py index e5aa459fcdcb..e3c4ba65478d 100644 --- a/ibis/backends/sql/compilers/datafusion.py +++ b/ibis/backends/sql/compilers/datafusion.py @@ -497,3 +497,6 @@ def visit_GroupConcat(self, op, *, arg, sep, where, order_by): return super().visit_GroupConcat( op, arg=arg, sep=sep, where=where, order_by=order_by ) + + +compiler = DataFusionCompiler() diff --git a/ibis/backends/sql/compilers/druid.py b/ibis/backends/sql/compilers/druid.py index cf876aacc766..d1571363e14e 100644 --- a/ibis/backends/sql/compilers/druid.py +++ b/ibis/backends/sql/compilers/druid.py @@ -199,3 +199,6 @@ def visit_TimestampFromYMDHMS( "Z", ) ) + + +compiler = DruidCompiler() diff --git a/ibis/backends/sql/compilers/duckdb.py b/ibis/backends/sql/compilers/duckdb.py index e47e1bafe8ce..1724885ec0bb 100644 --- a/ibis/backends/sql/compilers/duckdb.py +++ b/ibis/backends/sql/compilers/duckdb.py @@ -2,6 +2,7 @@ import math from functools import partial, reduce +from typing import TYPE_CHECKING, Any import sqlglot as sg import sqlglot.expressions as sge @@ -16,6 +17,12 @@ from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect from ibis.util import gen_name +if TYPE_CHECKING: + from collections.abc import Mapping + + import ibis.expr.types as ir + + _INTERVAL_SUFFIXES = { "ms": "milliseconds", "us": "microseconds", @@ -98,6 +105,33 @@ class DuckDBCompiler(SQLGlotCompiler): ops.GeoY: "st_y", } + def to_sqlglot( + self, + expr: ir.Expr, + *, + limit: str | None = None, + params: Mapping[ir.Expr, Any] | None = None, + ): + sql = super().to_sqlglot(expr, limit=limit, params=params) + + table_expr = expr.as_table() + geocols = table_expr.schema().geospatial + + if not geocols: + return sql + + quoted = self.quoted + return sg.select( + sge.Star( + replace=[ + self.f.st_aswkb(sg.column(col, quoted=quoted)).as_( + col, quoted=quoted + ) + for col in geocols + ] + ) + ).from_(sql.subquery()) + def visit_StructColumn(self, op, *, names, values): return sge.Struct.from_arg_list( [ @@ -614,3 +648,6 @@ def visit_TableUnnest( .from_(parent) .join(unnest, join_type="CROSS" if not keep_empty else "LEFT") ) + + +compiler = DuckDBCompiler() diff --git a/ibis/backends/sql/compilers/exasol.py b/ibis/backends/sql/compilers/exasol.py index d6d8fbb4e279..87f5aaa543d3 100644 --- a/ibis/backends/sql/compilers/exasol.py +++ b/ibis/backends/sql/compilers/exasol.py @@ -250,3 +250,6 @@ def visit_BitwiseOr(self, op, *, left, right): def visit_BitwiseXor(self, op, *, left, right): return self.cast(self.f.bit_xor(left, right), op.dtype) + + +compiler = ExasolCompiler() diff --git a/ibis/backends/sql/compilers/flink.py b/ibis/backends/sql/compilers/flink.py index 462e496428af..4e4fb1586415 100644 --- a/ibis/backends/sql/compilers/flink.py +++ b/ibis/backends/sql/compilers/flink.py @@ -563,3 +563,6 @@ def visit_MapMerge(self, op: ops.MapMerge, *, left, right): def visit_StructColumn(self, op, *, names, values): return self.cast(sge.Struct(expressions=list(values)), op.dtype) + + +compiler = FlinkCompiler() diff --git a/ibis/backends/sql/compilers/impala.py b/ibis/backends/sql/compilers/impala.py index 6288865d2ec6..f73a38751d08 100644 --- a/ibis/backends/sql/compilers/impala.py +++ b/ibis/backends/sql/compilers/impala.py @@ -320,3 +320,6 @@ def visit_Sign(self, op, *, arg): if not dtype.is_float32(): return self.cast(sign, dtype) return sign + + +compiler = ImpalaCompiler() diff --git a/ibis/backends/sql/compilers/mssql.py b/ibis/backends/sql/compilers/mssql.py index 31f67b52aeb5..0c6fe9a567ac 100644 --- a/ibis/backends/sql/compilers/mssql.py +++ b/ibis/backends/sql/compilers/mssql.py @@ -1,6 +1,7 @@ from __future__ import annotations import calendar +from typing import TYPE_CHECKING, Any import sqlglot as sg import sqlglot.expressions as sge @@ -26,6 +27,11 @@ ) from ibis.common.deferred import var +if TYPE_CHECKING: + from collections.abc import Mapping + + import ibis.expr.operations as ir + y = var("y") start = var("start") end = var("end") @@ -133,17 +139,9 @@ class MSSQLCompiler(SQLGlotCompiler): ops.Max: "max", } - @property - def NAN(self): - return self.f.double("NaN") - - @property - def POS_INF(self): - return self.f.double("Infinity") - - @property - def NEG_INF(self): - return self.f.double("-Infinity") + NAN = sg.func("double", sge.convert("NaN")) + POS_INF = sg.func("double", sge.convert("Infinity")) + NEG_INF = sg.func("double", sge.convert("-Infinity")) @staticmethod def _generate_groups(groups): @@ -160,7 +158,28 @@ def _minimize_spec(start, end, spec): return None return spec - def visit_RandomUUID(self, op, **kwargs): + def to_sqlglot( + self, + expr: ir.Expr, + *, + limit: str | None = None, + params: Mapping[ir.Expr, Any] | None = None, + ): + """Compile an Ibis expression to a sqlglot object.""" + import ibis + + table_expr = expr.as_table() + conversions = { + name: ibis.ifelse(table_expr[name], 1, 0).cast(dt.boolean) + for name, typ in table_expr.schema().items() + if typ.is_boolean() + } + + if conversions: + table_expr = table_expr.mutate(**conversions) + return super().to_sqlglot(table_expr, limit=limit, params=params) + + def visit_RandomUUID(self, op, **_): return self.f.newid() def visit_StringLength(self, op, *, arg): @@ -480,3 +499,6 @@ def visit_Select(self, op, *, parent, selections, predicates, sort_keys): result = result.order_by(*sort_keys, copy=False) return result + + +compiler = MSSQLCompiler() diff --git a/ibis/backends/sql/compilers/mysql.py b/ibis/backends/sql/compilers/mysql.py index c9278910a891..5244e2642b52 100644 --- a/ibis/backends/sql/compilers/mysql.py +++ b/ibis/backends/sql/compilers/mysql.py @@ -377,3 +377,6 @@ def visit_UnwrapJSONBoolean(self, op, *, arg): self.if_(arg.eq(sge.convert("true")), 1, 0), NULL, ) + + +compiler = MySQLCompiler() diff --git a/ibis/backends/sql/compilers/oracle.py b/ibis/backends/sql/compilers/oracle.py index a9d1d033627a..ee98dda5e842 100644 --- a/ibis/backends/sql/compilers/oracle.py +++ b/ibis/backends/sql/compilers/oracle.py @@ -459,3 +459,6 @@ def visit_GroupConcat(self, op, *, arg, where, sep, order_by): out = sge.WithinGroup(this=out, expression=sge.Order(expressions=order_by)) return out + + +compiler = OracleCompiler() diff --git a/ibis/backends/sql/compilers/postgres.py b/ibis/backends/sql/compilers/postgres.py index 7386e33e7ed8..9d47044aa835 100644 --- a/ibis/backends/sql/compilers/postgres.py +++ b/ibis/backends/sql/compilers/postgres.py @@ -1,7 +1,11 @@ from __future__ import annotations +import inspect import string +import textwrap from functools import partial, reduce +from itertools import takewhile +from typing import TYPE_CHECKING, Any import sqlglot as sg import sqlglot.expressions as sge @@ -14,8 +18,20 @@ from ibis.backends.sql.datatypes import PostgresType from ibis.backends.sql.dialects import Postgres from ibis.backends.sql.rewrites import exclude_nulls_from_array_collect +from ibis.common.exceptions import InvalidDecoratorError from ibis.util import gen_name +if TYPE_CHECKING: + from collections.abc import Mapping + + import ibis.expr.types as ir + + +def _verify_source_line(func_name: str, line: str): + if line.startswith("@"): + raise InvalidDecoratorError(func_name, line) + return line + class PostgresUDFNode(ops.Value): shape = rlz.shape_like("args") @@ -99,6 +115,64 @@ class PostgresCompiler(SQLGlotCompiler): ops.TimeFromHMS: "make_time", } + def to_sqlglot( + self, + expr: ir.Expr, + *, + limit: str | None = None, + params: Mapping[ir.Expr, Any] | None = None, + ): + table_expr = expr.as_table() + geocols = table_expr.schema().geospatial + conversions = {name: table_expr[name].as_ewkb() for name in geocols} + + if conversions: + table_expr = table_expr.mutate(**conversions) + return super().to_sqlglot(table_expr, limit=limit, params=params) + + def _compile_python_udf(self, udf_node: ops.ScalarUDF): + config = udf_node.__config__ + func = udf_node.__func__ + func_name = func.__name__ + + lines, _ = inspect.getsourcelines(func) + iter_lines = iter(lines) + + function_premable_lines = list( + takewhile(lambda line: not line.lstrip().startswith("def "), iter_lines) + ) + + if len(function_premable_lines) > 1: + raise InvalidDecoratorError( + name=func_name, lines="".join(function_premable_lines) + ) + + source = textwrap.dedent( + "".join(map(partial(_verify_source_line, func_name), iter_lines)) + ).strip() + + type_mapper = self.type_mapper + argnames = udf_node.argnames + return """\ + CREATE OR REPLACE FUNCTION {ident}({signature}) + RETURNS {return_type} + LANGUAGE {language} + AS $$ + {source} + return {name}({args}) + $$""".format( + name=type(udf_node).__name__, + ident=self.__sql_name__(udf_node), + signature=", ".join( + f"{argname} {type_mapper.to_string(arg.dtype)}" + for argname, arg in zip(argnames, udf_node.args) + ), + return_type=type_mapper.to_string(udf_node.dtype), + language=config.get("language", "plpython3u"), + source=source, + args=", ".join(argnames), + ) + def visit_RandomUUID(self, op, **kwargs): return self.f.gen_random_uuid() @@ -699,3 +773,6 @@ def visit_ArrayAny(self, op, *, arg): def visit_ArrayAll(self, op, *, arg): return self._array_reduction(arg=arg, reduction="bool_and") + + +compiler = PostgresCompiler() diff --git a/ibis/backends/sql/compilers/pyspark.py b/ibis/backends/sql/compilers/pyspark.py index 1ac1ad6553d9..ee6060e4bfcc 100644 --- a/ibis/backends/sql/compilers/pyspark.py +++ b/ibis/backends/sql/compilers/pyspark.py @@ -634,3 +634,6 @@ def visit_ArraySum(self, op, *, arg): def visit_ArrayMean(self, op, *, arg): return self._array_reduction(dtype=op.dtype, arg=arg, output=operator.truediv) + + +compiler = PySparkCompiler() diff --git a/ibis/backends/sql/compilers/risingwave.py b/ibis/backends/sql/compilers/risingwave.py index 8d1e86d1ce5f..35f741c17499 100644 --- a/ibis/backends/sql/compilers/risingwave.py +++ b/ibis/backends/sql/compilers/risingwave.py @@ -95,3 +95,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype): elif dtype.is_json(): return sge.convert(str(value)) return None + + +compiler = RisingWaveCompiler() diff --git a/ibis/backends/sql/compilers/snowflake.py b/ibis/backends/sql/compilers/snowflake.py index 23d73f384824..c6ab7451ddbd 100644 --- a/ibis/backends/sql/compilers/snowflake.py +++ b/ibis/backends/sql/compilers/snowflake.py @@ -1,6 +1,10 @@ from __future__ import annotations +import inspect import itertools +import platform +import sys +import textwrap from functools import partial import sqlglot as sg @@ -36,6 +40,8 @@ class SnowflakeFuncGen(FuncGen): class SnowflakeCompiler(SQLGlotCompiler): __slots__ = () + latest_udf_python_version = (3, 11) + dialect = Snowflake type_mapper = SnowflakeType no_limit_value = NULL @@ -95,6 +101,94 @@ def __init__(self): super().__init__() self.f = SnowflakeFuncGen() + _UDF_TEMPLATES = { + ops.udf.InputType.PYTHON: """\ +{preamble} +HANDLER = '{func_name}' +AS $$ +from __future__ import annotations + +from typing import * + +{source} +$$""", + ops.udf.InputType.PANDAS: """\ +{preamble} +HANDLER = 'wrapper' +AS $$ +from __future__ import annotations + +from typing import * + +import _snowflake +import pandas as pd + +{source} + +@_snowflake.vectorized(input=pd.DataFrame) +def wrapper(df): + return {func_name}(*(col for _, col in df.items())) +$$""", + } + + _UDF_PREAMBLE_LINES = ( + "CREATE OR REPLACE TEMPORARY FUNCTION {name}({signature})", + "RETURNS {return_type}", + "LANGUAGE PYTHON", + "IMMUTABLE", + "RUNTIME_VERSION = '{version}'", + "COMMENT = '{comment}'", + ) + + def _compile_udf(self, udf_node: ops.ScalarUDF): + import ibis + + name = type(udf_node).__name__ + signature = ", ".join( + f"{name} {self.type_mapper.to_string(arg.dtype)}" + for name, arg in zip(udf_node.argnames, udf_node.args) + ) + return_type = SnowflakeType.to_string(udf_node.dtype) + lines, _ = inspect.getsourcelines(udf_node.__func__) + source = textwrap.dedent( + "".join( + itertools.dropwhile( + lambda line: not line.lstrip().startswith("def "), lines + ) + ) + ).strip() + + config = udf_node.__config__ + + preamble_lines = [*self._UDF_PREAMBLE_LINES] + + if imports := config.get("imports"): + preamble_lines.append(f"IMPORTS = ({', '.join(map(repr, imports))})") + + packages = "({})".format( + ", ".join(map(repr, ("pandas", *config.get("packages", ())))) + ) + preamble_lines.append(f"PACKAGES = {packages}") + + template = self._UDF_TEMPLATES[udf_node.__input_type__] + return template.format( + source=source, + name=name, + func_name=udf_node.__func_name__, + preamble="\n".join(preamble_lines).format( + name=name, + signature=signature, + return_type=return_type, + comment=f"Generated by ibis {ibis.__version__} using Python {platform.python_version()}", + version=".".join( + map(str, min(sys.version_info[:2], self.latest_udf_python_version)) + ), + ), + ) + + _compile_pandas_udf = _compile_udf + _compile_python_udf = _compile_udf + @staticmethod def _minimize_spec(start, end, spec): if ( @@ -774,3 +868,6 @@ def visit_ArraySum(self, op, *, arg): def visit_ArrayMean(self, op, *, arg): return self.cast(self.f.udf.array_avg(arg), op.dtype) + + +compiler = SnowflakeCompiler() diff --git a/ibis/backends/sql/compilers/sqlite.py b/ibis/backends/sql/compilers/sqlite.py index f8f21219a731..88f86c5bb979 100644 --- a/ibis/backends/sql/compilers/sqlite.py +++ b/ibis/backends/sql/compilers/sqlite.py @@ -480,3 +480,6 @@ def visit_NonNullLiteral(self, op, *, value, dtype): ): raise com.UnsupportedBackendType(f"Unsupported type: {dtype!r}") return super().visit_NonNullLiteral(op, value=value, dtype=dtype) + + +compiler = SQLiteCompiler() diff --git a/ibis/backends/sql/compilers/trino.py b/ibis/backends/sql/compilers/trino.py index 8ae7eb2eaa2f..90ee114c893c 100644 --- a/ibis/backends/sql/compilers/trino.py +++ b/ibis/backends/sql/compilers/trino.py @@ -652,3 +652,6 @@ def visit_ArraySum(self, op, *, arg): def visit_ArrayMean(self, op, *, arg): return self.visit_ArraySumAgg(op, arg=arg, output=operator.truediv) + + +compiler = TrinoCompiler() diff --git a/ibis/backends/sqlite/__init__.py b/ibis/backends/sqlite/__init__.py index 09980edd52f2..e770cecd72be 100644 --- a/ibis/backends/sqlite/__init__.py +++ b/ibis/backends/sqlite/__init__.py @@ -9,6 +9,7 @@ import sqlglot.expressions as sge import ibis +import ibis.backends.sql.compilers as sc import ibis.common.exceptions as com import ibis.expr.datatypes as dt import ibis.expr.operations as ops @@ -17,7 +18,6 @@ from ibis import util from ibis.backends import UrlFromPath from ibis.backends.sql import SQLBackend -from ibis.backends.sql.compilers import SQLiteCompiler from ibis.backends.sql.compilers.base import C, F from ibis.backends.sqlite.converter import SQLitePandasData from ibis.backends.sqlite.udf import ignore_nulls, register_all @@ -45,7 +45,7 @@ def _quote(name: str) -> str: class Backend(SQLBackend, UrlFromPath): name = "sqlite" - compiler = SQLiteCompiler() + compiler = sc.sqlite.compiler supports_python_udfs = True @property @@ -365,9 +365,6 @@ def _register_in_memory_table(self, op: ops.InMemoryTable) -> None: cur.execute(create_stmt) cur.executemany(insert_stmt, data) - def _define_udf_translation_rules(self, expr): - """No-op, these are defined in the compiler.""" - def _register_udfs(self, expr: ir.Expr) -> None: import ibis.expr.operations as ops @@ -375,13 +372,13 @@ def _register_udfs(self, expr: ir.Expr) -> None: for udf_node in expr.op().find(ops.ScalarUDF): compile_func = getattr( - self, f"_compile_{udf_node.__input_type__.name.lower()}_udf" + self, f"_register_{udf_node.__input_type__.name.lower()}_udf" ) registration_func = compile_func(udf_node) if registration_func is not None: registration_func(con) - def _compile_python_udf(self, udf_node: ops.ScalarUDF) -> None: + def _register_python_udf(self, udf_node: ops.ScalarUDF) -> None: name = type(udf_node).__name__ nargs = len(udf_node.__signature__.parameters) func = udf_node.__func__ @@ -480,7 +477,7 @@ def create_table( self._run_pre_execute_hooks(obj) - insert_query = self._to_sqlglot(obj) + insert_query = self.compiler.to_sqlglot(obj) else: insert_query = None diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 8026ea3e9606..9be4dd25a6c0 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -1384,22 +1384,6 @@ def test_memtable_column_naming_mismatch(con, monkeypatch, df, columns): ibis.memtable(df, columns=columns) -@pytest.mark.notimpl( - ["dask", "pandas", "polars"], raises=NotImplementedError, reason="not a SQL backend" -) -def test_many_subqueries(con, snapshot): - def query(t, group_cols): - t2 = t.mutate(key=ibis.row_number().over(ibis.window(order_by=group_cols))) - return t2.inner_join(t2[["key"]], "key") - - t = ibis.table(dict(street="str"), name="data") - - t2 = query(t, group_cols=["street"]) - t3 = query(t2, group_cols=["street"]) - - snapshot.assert_match(str(ibis.to_sql(t3, dialect=con.name)), "out.sql") - - @pytest.mark.notimpl(["oracle", "exasol"], raises=com.OperationNotDefinedError) @pytest.mark.notimpl(["druid"], raises=AssertionError) @pytest.mark.notyet( @@ -2289,18 +2273,11 @@ def test_sample_with_seed(backend): backend.assert_frame_equal(df1, df2) -@pytest.mark.notimpl( - ["dask", "pandas", "polars"], raises=NotImplementedError, reason="not a SQL backend" -) def test_simple_memtable_construct(con): t = ibis.memtable({"a": [1, 2]}) expr = t.a expected = [1.0, 2.0] assert sorted(con.to_pandas(expr).tolist()) == expected - # we can't generically check for specific sql, even with a snapshot, - # because memtables have a unique name per table per process, so smoke test - # it - assert str(ibis.to_sql(expr, dialect=con.name)).startswith("SELECT") def test_select_mutate_with_dict(backend): @@ -2490,3 +2467,14 @@ def test_value_counts_on_tables(backend, df): ) expected = expected.sort_values(expected.columns.tolist()).reset_index(drop=True) backend.assert_frame_equal(result, expected, check_dtype=False) + + +def test_union_generates_predictable_aliases(con): + t = ibis.memtable( + data=[{"island": "Torgerson", "body_mass_g": 3750, "sex": "male"}] + ) + sub1 = t.inner_join(t.view(), "island").mutate(island_right=lambda t: t.island) + sub2 = t.inner_join(t.view(), "sex").mutate(sex_right=lambda t: t.sex) + expr = ibis.union(sub1, sub2) + df = con.execute(expr) + assert len(df) == 2 diff --git a/ibis/backends/tests/test_sql.py b/ibis/backends/tests/test_sql.py index fcc8e50d8f9c..4a299f8a83c1 100644 --- a/ibis/backends/tests/test_sql.py +++ b/ibis/backends/tests/test_sql.py @@ -42,17 +42,12 @@ ), ], ) -@pytest.mark.never( - ["pandas", "dask"], - raises=(exc.IbisError, NotImplementedError, ValueError), - reason="Not a SQL backend", -) -@pytest.mark.notimpl(["polars"], reason="Not clear how to extract SQL from the backend") +@pytest.mark.never(["pandas", "dask", "polars"], reason="not SQL", raises=ValueError) def test_literal(backend, expr): assert "432" in ibis.to_sql(expr, dialect=backend.name()) -@pytest.mark.never(["pandas", "dask", "polars"], reason="not SQL") +@pytest.mark.never(["pandas", "dask", "polars"], reason="not SQL", raises=ValueError) def test_group_by_has_index(backend, snapshot): countries = ibis.table( dict(continent="string", population="int64"), name="countries" @@ -75,7 +70,7 @@ def test_group_by_has_index(backend, snapshot): snapshot.assert_match(sql, "out.sql") -@pytest.mark.never(["pandas", "dask", "polars"], reason="not SQL") +@pytest.mark.never(["pandas", "dask", "polars"], reason="not SQL", raises=ValueError) def test_cte_refs_in_topo_order(backend, snapshot): mr0 = ibis.table(schema=ibis.schema(dict(key="int")), name="leaf") @@ -88,7 +83,7 @@ def test_cte_refs_in_topo_order(backend, snapshot): snapshot.assert_match(sql, "out.sql") -@pytest.mark.never(["pandas", "dask", "polars"], reason="not SQL") +@pytest.mark.never(["pandas", "dask", "polars"], reason="not SQL", raises=ValueError) def test_isin_bug(con, snapshot): t = ibis.table(dict(x="int"), name="t") good = t[t.x > 2].x @@ -96,11 +91,7 @@ def test_isin_bug(con, snapshot): snapshot.assert_match(str(ibis.to_sql(expr, dialect=con.name)), "out.sql") -@pytest.mark.never( - ["pandas", "dask", "polars"], - reason="not SQL", - raises=NotImplementedError, -) +@pytest.mark.never(["pandas", "dask", "polars"], reason="not SQL", raises=ValueError) @pytest.mark.notyet( ["exasol", "oracle", "flink"], reason="no unnest support", @@ -165,22 +156,7 @@ def test_union_aliasing(backend_name, snapshot): snapshot.assert_match(str(ibis.to_sql(result, dialect=backend_name)), "out.sql") -def test_union_generates_predictable_aliases(con): - t = ibis.memtable( - data=[{"island": "Torgerson", "body_mass_g": 3750, "sex": "male"}] - ) - sub1 = t.inner_join(t.view(), "island").mutate(island_right=lambda t: t.island) - sub2 = t.inner_join(t.view(), "sex").mutate(sex_right=lambda t: t.sex) - expr = ibis.union(sub1, sub2) - df = con.execute(expr) - assert len(df) == 2 - - -@pytest.mark.never( - ["pandas", "dask", "polars"], - reason="not SQL", - raises=NotImplementedError, -) +@pytest.mark.never(["pandas", "dask", "polars"], reason="not SQL", raises=ValueError) @pytest.mark.parametrize( "value", [ @@ -204,11 +180,28 @@ def test_selects_with_impure_operations_not_merged(con, snapshot, value): snapshot.assert_match(sql, "out.sql") -@pytest.mark.notyet(["polars"], reason="no sql generation") -@pytest.mark.never(["pandas", "dask"], reason="no sql generation") +@pytest.mark.never( + ["pandas", "dask", "polars"], reason="not SQL", raises=NotImplementedError +) def test_to_sql_default_backend(con, snapshot, monkeypatch): monkeypatch.setattr(ibis.options, "default_backend", con) t = ibis.memtable({"b": [1, 2]}, name="mytable") expr = t.select("b").count() snapshot.assert_match(ibis.to_sql(expr), "to_sql.sql") + + +@pytest.mark.notimpl( + ["dask", "pandas", "polars"], raises=ValueError, reason="not a SQL backend" +) +def test_many_subqueries(backend_name, snapshot): + def query(t, group_cols): + t2 = t.mutate(key=ibis.row_number().over(ibis.window(order_by=group_cols))) + return t2.inner_join(t2[["key"]], "key") + + t = ibis.table(dict(street="str"), name="data") + + t2 = query(t, group_cols=["street"]) + t3 = query(t2, group_cols=["street"]) + + snapshot.assert_match(str(ibis.to_sql(t3, dialect=backend_name)), "out.sql") diff --git a/ibis/backends/trino/__init__.py b/ibis/backends/trino/__init__.py index 0bbd32045d1c..f7af4c493649 100644 --- a/ibis/backends/trino/__init__.py +++ b/ibis/backends/trino/__init__.py @@ -14,13 +14,13 @@ import trino import ibis +import ibis.backends.sql.compilers as sc import ibis.common.exceptions as com import ibis.expr.schema as sch import ibis.expr.types as ir from ibis import util from ibis.backends import CanCreateDatabase, CanCreateSchema, CanListCatalog from ibis.backends.sql import SQLBackend -from ibis.backends.sql.compilers import TrinoCompiler from ibis.backends.sql.compilers.base import C if TYPE_CHECKING: @@ -36,7 +36,7 @@ class Backend(SQLBackend, CanListCatalog, CanCreateDatabase, CanCreateSchema): name = "trino" - compiler = TrinoCompiler() + compiler = sc.trino.compiler supports_create_or_replace = False supports_temporary_tables = False @@ -490,7 +490,7 @@ def create_table( ) for name, typ in (schema or table.schema()).items() ) - ).from_(self._to_sqlglot(table).subquery()) + ).from_(self.compiler.to_sqlglot(table).subquery()) else: select = None diff --git a/ibis/expr/schema.py b/ibis/expr/schema.py index 3de925afbdae..54097476e06c 100644 --- a/ibis/expr/schema.py +++ b/ibis/expr/schema.py @@ -63,6 +63,10 @@ def names(self): def types(self): return tuple(self.values()) + @attribute + def geospatial(self) -> tuple[str, ...]: + return tuple(name for name, typ in self.fields.items() if typ.is_geospatial()) + @attribute def _name_locs(self) -> dict[str, int]: return {v: i for i, v in enumerate(self.names)} diff --git a/ibis/expr/sql.py b/ibis/expr/sql.py index 2e9c163cd6e9..f8af6c9960e8 100644 --- a/ibis/expr/sql.py +++ b/ibis/expr/sql.py @@ -362,26 +362,28 @@ def to_sql( Formatted SQL string """ + import ibis.backends.sql.compilers as sc + # try to infer from a non-str expression or if not possible fallback to # the default pretty dialect for expressions if dialect is None: try: - backend = expr._find_backend(use_default=True) + compiler_provider = expr._find_backend(use_default=True) except com.IbisError: # default to duckdb for SQL compilation because it supports the # widest array of ibis features for SQL backends - backend = ibis.duckdb - dialect = ibis.options.sql.default_dialect - else: - dialect = backend.dialect + compiler_provider = sc.duckdb else: try: - backend = getattr(ibis, dialect) - except AttributeError: - raise ValueError(f"Unknown dialect {dialect}") - else: - dialect = getattr(backend, "dialect", dialect) + compiler_provider = getattr(sc, dialect) + except AttributeError as e: + raise ValueError(f"Unknown dialect {dialect}") from e + + if (compiler := getattr(compiler_provider, "compiler", None)) is None: + raise NotImplementedError(f"{compiler_provider} is not a SQL backend") - sg_expr = backend._to_sqlglot(expr.unbind(), **kwargs) - sql = sg_expr.sql(dialect=dialect, pretty=pretty) + out = compiler.to_sqlglot(expr.unbind(), **kwargs) + queries = out if isinstance(out, list) else [out] + dialect = compiler.dialect + sql = ";\n".join(query.sql(dialect=dialect, pretty=pretty) for query in queries) return SQLString(sql) diff --git a/ibis/tests/expr/mocks.py b/ibis/tests/expr/mocks.py index fb4b79b35b78..54638f706d90 100644 --- a/ibis/tests/expr/mocks.py +++ b/ibis/tests/expr/mocks.py @@ -53,11 +53,6 @@ def list_tables(self): def list_databases(self): return ["mockdb"] - def _to_sqlglot(self, expr, **kwargs): - import ibis - - return ibis.duckdb._to_sqlglot(expr, **kwargs) - def fetch_from_cursor(self, cursor, schema): pass diff --git a/ibis/tests/expr/test_sql_builtins.py b/ibis/tests/expr/test_sql_builtins.py index 9a81085dba89..578f5fc83f88 100644 --- a/ibis/tests/expr/test_sql_builtins.py +++ b/ibis/tests/expr/test_sql_builtins.py @@ -16,6 +16,7 @@ import pytest import ibis +import ibis.backends.sql.compilers as sc import ibis.expr.operations as ops import ibis.expr.types as ir from ibis import _ @@ -223,3 +224,11 @@ def test_no_arguments_errors(function): SignatureValidationError, match=".+ has failed due to the following errors:" ): function() + + +@pytest.mark.parametrize( + "name", [name.lower().removesuffix("compiler") for name in sc.__all__] +) +def test_compile_without_dependencies(name): + table = ibis.table({"a": "int64"}, name="t") + assert isinstance(ibis.to_sql(table, dialect=name), str)