diff --git a/ibis/backends/sql/__init__.py b/ibis/backends/sql/__init__.py index 67b256afec728..c8a37941ff054 100644 --- a/ibis/backends/sql/__init__.py +++ b/ibis/backends/sql/__init__.py @@ -35,8 +35,11 @@ def dialect(self) -> sg.Dialect: @classmethod def has_operation(cls, operation: type[ops.Value]) -> bool: compiler = cls.compiler + if operation in compiler.extra_supported_ops: + return True method = getattr(compiler, f"visit_{operation.__name__}", None) - return method is not None and method not in ( + return method not in ( + None, compiler.visit_Undefined, compiler.visit_Unsupported, ) diff --git a/ibis/backends/sql/compiler.py b/ibis/backends/sql/compiler.py index 8bb46ce21c895..a1642e3e50fa7 100644 --- a/ibis/backends/sql/compiler.py +++ b/ibis/backends/sql/compiler.py @@ -14,6 +14,7 @@ from public import public import ibis.common.exceptions as com +import ibis.common.patterns as pats import ibis.expr.datatypes as dt import ibis.expr.operations as ops from ibis.backends.sql.rewrites import ( @@ -175,6 +176,20 @@ class SQLGlotCompiler(abc.ABC): ) """A sequence of rewrites to apply to the expression tree before compilation.""" + extra_supported_ops: frozenset = frozenset( + ( + ops.Project, + ops.Filter, + ops.Sort, + ops.WindowFunction, + ops.RowsWindowFrame, + ops.RangeWindowFrame, + ) + ) + """A frozenset of ops classes that are supported, but don't have explicit + `visit_*` methods (usually due to being handled by rewrite rules). Used by + `has_operation`""" + no_limit_value: sge.Null | None = None """The value to use to indicate no limit.""" @@ -366,6 +381,15 @@ def impl(self, _, *, _name: str = target_name, **kw): if not hasattr(cls, name): setattr(cls, name, cls.visit_Undefined) + # Expand extra_supported_ops with any rewrite rules + extra_supported_ops = set(cls.extra_supported_ops) + for rule in cls.rewrites: + if isinstance(rule, pats.Replace) and isinstance( + rule.matcher, pats.InstanceOf + ): + extra_supported_ops.add(rule.matcher.type) + cls.extra_supported_ops = frozenset(extra_supported_ops) + @property @abc.abstractmethod def dialect(self) -> str: diff --git a/ibis/backends/sqlite/tests/test_client.py b/ibis/backends/sqlite/tests/test_client.py index 7e57bcb431454..5f32197c57ea1 100644 --- a/ibis/backends/sqlite/tests/test_client.py +++ b/ibis/backends/sqlite/tests/test_client.py @@ -8,6 +8,7 @@ from pytest import param import ibis +import ibis.expr.operations as ops from ibis.conftest import not_windows @@ -75,3 +76,14 @@ def test_connect(url, ext, tmp_path): con = ibis.connect(url(path)) one = ibis.literal(1) assert con.execute(one) == 1 + + +def test_has_operation(con): + # Handled by hardcoded rewrite + assert con.has_operation(ops.Project) + # Handled by base class rewrite + assert con.has_operation(ops.Capitalize) + # Handled by compiler-specific rewrite + assert con.has_operation(ops.Sample) + # Handled by visit_* method + assert con.has_operation(ops.Cast)