diff --git a/tests/unit/sqlalchemy/test_compiler.py b/tests/unit/sqlalchemy/test_compiler.py index 1e3ec770..00bd686c 100644 --- a/tests/unit/sqlalchemy/test_compiler.py +++ b/tests/unit/sqlalchemy/test_compiler.py @@ -19,6 +19,7 @@ String, Table, ) +from sqlalchemy.schema import CreateTable from trino.sqlalchemy.dialect import TrinoDialect @@ -26,9 +27,16 @@ table = Table( 'table', metadata, - Column('id', Integer, primary_key=True), + Column('id', Integer), Column('name', String), ) +table_with_catalog = Table( + 'table', + metadata, + Column('id', Integer), + schema='default', + trino_catalog='other' +) @pytest.fixture @@ -64,3 +72,20 @@ def test_cte_insert_order(dialect): 'FROM "table")\n'\ ' SELECT cte.id, cte.name \n'\ 'FROM cte' + + +def test_catalogs_argument(dialect): + statement = select(table_with_catalog) + query = statement.compile(dialect=dialect) + assert str(query) == 'SELECT default."table".id \nFROM "other".default."table"' + + +def test_catalogs_create_table(dialect): + statement = CreateTable(table_with_catalog) + query = statement.compile(dialect=dialect) + assert str(query) == \ + '\n'\ + 'CREATE TABLE "other".default."table" (\n'\ + '\tid INTEGER\n'\ + ')\n'\ + '\n' diff --git a/trino/sqlalchemy/compiler.py b/trino/sqlalchemy/compiler.py index 6ab84c59..a085fbf3 100644 --- a/trino/sqlalchemy/compiler.py +++ b/trino/sqlalchemy/compiler.py @@ -10,6 +10,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from sqlalchemy.sql import compiler +try: + from sqlalchemy.sql.expression import ( + Alias, + CTE, + Subquery, + ) +except ImportError: + # For SQLAlchemy versions < 1.4, the CTE and Subquery classes did not explicitly exist + from sqlalchemy.sql.expression import Alias + CTE = type(None) + Subquery = type(None) # https://trino.io/docs/current/language/reserved.html RESERVED_WORDS = { @@ -102,6 +113,31 @@ def limit_clause(self, select, **kw): text += "\nLIMIT " + self.process(select._limit_clause, **kw) return text + def visit_table(self, table, asfrom=False, iscrud=False, ashint=False, + fromhints=None, use_schema=True, **kwargs): + sql = super(TrinoSQLCompiler, self).visit_table( + table, asfrom, iscrud, ashint, fromhints, use_schema, **kwargs + ) + return self.add_catalog(sql, table) + + @staticmethod + def add_catalog(sql, table): + if table is None: + return sql + + if isinstance(table, (Alias, CTE, Subquery)): + return sql + + if ( + 'trino' not in table.dialect_options + or 'catalog' not in table.dialect_options['trino'] + ): + return sql + + catalog = table.dialect_options['trino']['catalog'] + sql = f'"{catalog}".{sql}' + return sql + class TrinoDDLCompiler(compiler.DDLCompiler): pass @@ -173,3 +209,7 @@ def visit_TIME(self, type_, **kw): class TrinoIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS + + def format_table(self, table, use_schema=True, name=None): + result = super(TrinoIdentifierPreparer, self).format_table(table, use_schema, name) + return TrinoSQLCompiler.add_catalog(result, table)