Skip to content

Commit

Permalink
Enable gp_recursive_cte during statement preparation
Browse files Browse the repository at this point in the history
  • Loading branch information
KonstantAnxiety committed Feb 7, 2024
1 parent 620f331 commit aa10c0d
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from typing import List
from contextlib import asynccontextmanager
from typing import (
AsyncIterator,
List,
)

import asyncpg
import sqlalchemy as sa

from dl_core.connection_executors.models.db_adapter_data import RawSchemaInfo
Expand Down Expand Up @@ -72,3 +77,25 @@ class AsyncGreenplumAdapter(AsyncPostgresAdapter):

async def get_table_info(self, table_def: TableDefinition, fetch_idx_info: bool) -> RawSchemaInfo:
raise NotImplementedError()

@asynccontextmanager
async def _query_preparation_context(self, connection: asyncpg.Connection) -> AsyncIterator[None]:
async with super()._query_preparation_context(connection):
# enable gp_recursive_cte during query execution if it is disabled

db_version = await connection.fetchval("SELECT version()")
if "greenplum" not in db_version.lower():
# to keep compatibility with postgres
yield
return

gp_recursive_cte_initial = await connection.fetchval("SHOW gp_recursive_cte")
if gp_recursive_cte_initial == "on":
yield
return

await connection.execute("SET gp_recursive_cte = on")
try:
yield
finally:
await connection.execute("SET gp_recursive_cte = off")
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ def execution_context(self) -> typing.Generator[None, None, None]:
stack.enter_context(context)
yield

@asynccontextmanager
async def _query_preparation_context(self, connection: asyncpg.Connection) -> AsyncIterator[None]:
yield

async def _execute_by_step(self, query: DBAdapterQuery) -> AsyncIterator[ExecutionStep]:
def make_record(raw_rec: asyncpg.Record, query_attrs: Iterable[asyncpg.Attribute]) -> TBIDataRow:
row_converters = self._get_row_converters(query_attrs=query_attrs)
Expand All @@ -285,7 +289,8 @@ def make_record(raw_rec: asyncpg.Record, query_attrs: Iterable[asyncpg.Attribute
async with self._get_connection(query.db_name) as conn: # type: ignore # 2024-01-24 # TODO: Argument 1 to "_get_connection" of "AsyncPostgresAdapter" has incompatible type "str | None"; expected "str" [arg-type]
# prepare works only inside a transaction
async with conn.transaction():
prepared_query = await conn.prepare(compiled_query)
async with self._query_preparation_context(conn):
prepared_query = await conn.prepare(compiled_query)
cursor_info = self._make_cursor_info(prepared_query.get_attributes())
yield ExecutionStepCursorInfo(cursor_info=cursor_info)
cursor = await prepared_query.cursor(*params)
Expand Down

0 comments on commit aa10c0d

Please sign in to comment.