From aa10c0dabc8634d3e1cd0b9606527e5d07db57db Mon Sep 17 00:00:00 2001 From: KonstantAnxiety Date: Mon, 18 Dec 2023 20:35:41 +0300 Subject: [PATCH] Enable gp_recursive_cte during statement preparation --- .../dl_connector_greenplum/core/adapters.py | 29 ++++++++++++++++++- .../async_adapters_postgres.py | 7 ++++- 2 files changed, 34 insertions(+), 2 deletions(-) diff --git a/lib/dl_connector_greenplum/dl_connector_greenplum/core/adapters.py b/lib/dl_connector_greenplum/dl_connector_greenplum/core/adapters.py index d18208cd3..ffd3652bb 100644 --- a/lib/dl_connector_greenplum/dl_connector_greenplum/core/adapters.py +++ b/lib/dl_connector_greenplum/dl_connector_greenplum/core/adapters.py @@ -1,5 +1,10 @@ -from typing import List +from contextlib import asynccontextmanager +from typing import ( + AsyncIterator, + List, +) +import asyncpg import sqlalchemy as sa from dl_core.connection_executors.models.db_adapter_data import RawSchemaInfo @@ -72,3 +77,25 @@ class AsyncGreenplumAdapter(AsyncPostgresAdapter): async def get_table_info(self, table_def: TableDefinition, fetch_idx_info: bool) -> RawSchemaInfo: raise NotImplementedError() + + @asynccontextmanager + async def _query_preparation_context(self, connection: asyncpg.Connection) -> AsyncIterator[None]: + async with super()._query_preparation_context(connection): + # enable gp_recursive_cte during query execution if it is disabled + + db_version = await connection.fetchval("SELECT version()") + if "greenplum" not in db_version.lower(): + # to keep compatibility with postgres + yield + return + + gp_recursive_cte_initial = await connection.fetchval("SHOW gp_recursive_cte") + if gp_recursive_cte_initial == "on": + yield + return + + await connection.execute("SET gp_recursive_cte = on") + try: + yield + finally: + await connection.execute("SET gp_recursive_cte = off") diff --git a/lib/dl_connector_postgresql/dl_connector_postgresql/core/postgresql_base/async_adapters_postgres.py b/lib/dl_connector_postgresql/dl_connector_postgresql/core/postgresql_base/async_adapters_postgres.py index 48f129d28..edaf57d10 100644 --- a/lib/dl_connector_postgresql/dl_connector_postgresql/core/postgresql_base/async_adapters_postgres.py +++ b/lib/dl_connector_postgresql/dl_connector_postgresql/core/postgresql_base/async_adapters_postgres.py @@ -262,6 +262,10 @@ def execution_context(self) -> typing.Generator[None, None, None]: stack.enter_context(context) yield + @asynccontextmanager + async def _query_preparation_context(self, connection: asyncpg.Connection) -> AsyncIterator[None]: + yield + async def _execute_by_step(self, query: DBAdapterQuery) -> AsyncIterator[ExecutionStep]: def make_record(raw_rec: asyncpg.Record, query_attrs: Iterable[asyncpg.Attribute]) -> TBIDataRow: row_converters = self._get_row_converters(query_attrs=query_attrs) @@ -285,7 +289,8 @@ def make_record(raw_rec: asyncpg.Record, query_attrs: Iterable[asyncpg.Attribute async with self._get_connection(query.db_name) as conn: # type: ignore # 2024-01-24 # TODO: Argument 1 to "_get_connection" of "AsyncPostgresAdapter" has incompatible type "str | None"; expected "str" [arg-type] # prepare works only inside a transaction async with conn.transaction(): - prepared_query = await conn.prepare(compiled_query) + async with self._query_preparation_context(conn): + prepared_query = await conn.prepare(compiled_query) cursor_info = self._make_cursor_info(prepared_query.get_attributes()) yield ExecutionStepCursorInfo(cursor_info=cursor_info) cursor = await prepared_query.cursor(*params)