Skip to content

Commit

Permalink
feat: add ActionNonStreamExecuteQuery to RQE (#388)
Browse files Browse the repository at this point in the history
  • Loading branch information
ovsds authored Mar 21, 2024
1 parent 9961080 commit 8009f69
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 53 deletions.
11 changes: 11 additions & 0 deletions lib/dl_configs/dl_configs/rqe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import enum
import os
from typing import Optional

Expand All @@ -15,6 +16,11 @@
from dl_configs.utils import validate_one_of


class RQEExecuteRequestMode(enum.Enum):
STREAM = "stream"
NON_STREAM = "non_stream"


@attr.s(frozen=True)
class RQEBaseURL(SettingsBase):
scheme: str = s_attrib("SCHEME", env_var_converter=validate_one_of({"http", "https"}), missing="http") # type: ignore # 2024-01-24 # TODO: Incompatible types in assignment (expression has type "Attribute[Any]", variable has type "str") [assignment]
Expand All @@ -36,6 +42,11 @@ class RQEConfig(SettingsBase):
sensitive=True,
env_var_converter=lambda s: s.encode("ascii"),
)
execute_request_mode: RQEExecuteRequestMode = s_attrib( # type: ignore # 2024-01-24 # TODO: Incompatible types in assignment (expression has type "Attribute[Any]", variable has type "RQEExecuteRequestMode") [assignment]
"EXECUTE_REQUEST_MODE",
missing=RQEExecuteRequestMode.STREAM,
env_var_converter=lambda s: RQEExecuteRequestMode[s.lower()],
)

@classmethod
def get_default(cls) -> RQEConfig:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import json
import logging
import pickle
import typing
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -42,7 +43,10 @@
AsyncRawExecutionResult,
)
from dl_core.connection_executors.adapters.common_base import CommonBaseDirectAdapter
from dl_core.connection_executors.models.common import RemoteQueryExecutorData
from dl_core.connection_executors.models.common import (
RemoteQueryExecutorData,
RQEExecuteRequestMode,
)
from dl_core.connection_executors.models.constants import (
HEADER_BODY_SIGNATURE,
HEADER_REQUEST_ID,
Expand Down Expand Up @@ -232,10 +236,62 @@ async def _make_request_parse_response(
except Exception as resp_deserialization_exc:
raise QueryExecutorException("Unexpected response JSON schema") from resp_deserialization_exc

@staticmethod
def _parse_event(event: Any) -> Tuple[RQEEventType, Any]:
if not isinstance(event, tuple):
raise QueryExecutorException(f"QE parse: unexpected event type: {type(event)}")
if len(event) != 2:
raise QueryExecutorException(f"QE parse: event is not a pair: length={len(event)}")
event_type_name, event_data = event
if not isinstance(event_type_name, str):
raise QueryExecutorException(f"QE parse: event_type is not a str: {type(event_type_name)}")
try:
event_type = RQEEventType[event_type_name]
except KeyError:
raise QueryExecutorException(f"QE parse: unknown event_type: {event_type_name!r}") from None

return event_type, event_data

async def _get_execution_result(
self,
events: typing.AsyncGenerator[Tuple[RQEEventType, Any], None],
) -> AsyncRawExecutionResult:
ev_type, ev_data = await events.__anext__()
if ev_type == RQEEventType.raw_cursor_info:
raw_cursor_info = ev_data
else:
raise QueryExecutorException(f"QE parse: first event type is not 'raw_cursor_info': {ev_type}")

# 3-layer iterable: generator of chunks, chunks is a list of rows, row
# is a list of column values.
async def data_generator() -> AsyncGenerator[Sequence[Sequence[Any]], None]:
ev_type = None

async for ev_type, ev_data in events:
if ev_type == RQEEventType.raw_chunk:
chunk = ev_data
if not isinstance(chunk, (list, tuple)):
raise QueryExecutorException(f"QE parse: unexpected chunk type: {type(chunk)}")
yield chunk
elif ev_type == RQEEventType.error_dump:
try:
exc = self._parse_exception(ev_data)
except Exception as e:
raise QueryExecutorException(f"QE parse: failed to parse an error event: {ev_data!r}") from e
raise exc
elif ev_type == RQEEventType.finished:
return

raise QueryExecutorException(f"QE parse: finish event was not received (last: {ev_type})")

return AsyncRawExecutionResult(
raw_cursor_info=raw_cursor_info,
raw_chunk_generator=data_generator(),
)

# Here we have some async problem:
# "qe" stage will be finished before some nested stages in RQE (e.g. "qe/fetch")
@generic_profiler_async("qe") # type: ignore # TODO: fix
async def execute(self, query: DBAdapterQuery) -> AsyncRawExecutionResult:
async def _execute_stream(self, query: DBAdapterQuery) -> AsyncRawExecutionResult:
resp = await self._make_request(
dba_actions.ActionExecuteQuery(
db_adapter_query=query,
Expand All @@ -250,7 +306,7 @@ async def execute(self, query: DBAdapterQuery) -> AsyncRawExecutionResult:
exc = self._parse_exception(resp_body_json)
raise exc

async def event_gen() -> AsyncGenerator[Tuple[str, Any], None]:
async def event_gen() -> AsyncGenerator[Tuple[RQEEventType, Any], None]:
buf = b""

while True:
Expand All @@ -265,61 +321,50 @@ async def event_gen() -> AsyncGenerator[Tuple[str, Any], None]:
# This isn't a very correct way, but it's hard to use pickle in async differently.
if end_of_chunk and buf:
try:
parsed_event = pickle.loads(buf)
event = pickle.loads(buf)
except Exception as err:
raise QueryExecutorException("QE parse: failed to unpickle") from err

yield self._parse_event(event)
buf = b""
if not isinstance(parsed_event, tuple):
raise QueryExecutorException(f"QE parse: unexpected event type: {type(parsed_event)}")
if len(parsed_event) != 2:
raise QueryExecutorException(f"QE parse: event is not a pair: length={len(parsed_event)}")
event_type_name, event_data = parsed_event
if not isinstance(event_type_name, str):
raise QueryExecutorException(f"QE parse: event_type is not a str: {type(event_type_name)}")
try:
event_type = RQEEventType[event_type_name]
except KeyError:
raise QueryExecutorException(f"QE parse: unknown event_type: {event_type_name!r}") from None

yield event_type, event_data # type: ignore # TODO: fix

if raw_chunk == b"" and not end_of_chunk:
return

events = event_gen()
ev_type, ev_data = await events.__anext__()
if ev_type == RQEEventType.raw_cursor_info:
raw_cursor_info = ev_data
else:
raise QueryExecutorException(f"QE parse: first event type is not 'raw_cursor_info': {ev_type}")
return await self._get_execution_result(events)

# 3-layer iterable: generator of chunks, chunks is a list of rows, row
# is a list of column values.
async def data_generator() -> AsyncGenerator[Sequence[Sequence[Any]], None]:
ev_type = None
async def _execute_non_stream(self, query: DBAdapterQuery) -> AsyncRawExecutionResult:
response = await self._make_request(
dba_actions.ActionNonStreamExecuteQuery(
db_adapter_query=query,
target_conn_dto=self._target_dto,
dba_cls=self._dba_cls,
req_ctx_info=self._req_ctx_info,
),
)

async for ev_type, ev_data in events:
if ev_type == RQEEventType.raw_chunk:
chunk = ev_data
if not isinstance(chunk, (list, tuple)):
raise QueryExecutorException(f"QE parse: unexpected chunk type: {type(chunk)}")
yield chunk
elif ev_type == RQEEventType.error_dump:
try:
exc = self._parse_exception(ev_data)
except Exception as e:
raise QueryExecutorException(f"QE parse: failed to parse an error event: {ev_data!r}") from e
raise exc
elif ev_type == RQEEventType.finished:
return
if response.status != 200:
resp_body_json = await self._read_body_json(response)
exc = self._parse_exception(resp_body_json)
raise exc

raise QueryExecutorException(f"QE parse: finish event was not received (last: {ev_type})")
raw_data = await response.read()
raw_events = pickle.loads(raw_data)

return AsyncRawExecutionResult(
raw_cursor_info=raw_cursor_info,
raw_chunk_generator=data_generator(),
)
async def event_gen() -> AsyncGenerator[Tuple[RQEEventType, Any], None]:
for raw_event in raw_events:
yield self._parse_event(raw_event)

events = event_gen()
return await self._get_execution_result(events)

@generic_profiler_async("qe") # type: ignore # TODO: fix
async def execute(self, query: DBAdapterQuery) -> AsyncRawExecutionResult:
if self._rqe_data.execute_request_mode == RQEExecuteRequestMode.STREAM:
return await self._execute_stream(query)

return await self._execute_non_stream(query)

async def get_db_version(self, db_ident: DBIdent) -> Optional[str]:
return await self._make_request_parse_response(
Expand Down
4 changes: 4 additions & 0 deletions lib/dl_core/dl_core/connection_executors/models/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import attr

from dl_configs.rqe import RQEExecuteRequestMode


@attr.s(frozen=True)
class RemoteQueryExecutorData:
Expand All @@ -14,3 +16,5 @@ class RemoteQueryExecutorData:
sync_protocol: str = attr.ib()
sync_host: str = attr.ib()
sync_port: int = attr.ib()

execute_request_mode: RQEExecuteRequestMode = attr.ib(default=RQEExecuteRequestMode.STREAM)
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ def deserialize_response(self, data: Dict) -> _RES_SCHEMA_TV:
return self.ResultSchema().load(data) # type: ignore # TODO: fix


@attr.s(frozen=True)
class ActionNonStreamExecuteQuery(RemoteDBAdapterAction):
db_adapter_query: DBAdapterQuery = attr.ib()


@attr.s(frozen=True)
class ActionTest(NonStreamAction[None]):
class ResultSchema(BaseQEAPISchema):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ def to_object(self, data: Dict[str, Any]) -> dba_actions.ActionExecuteQuery:
return dba_actions.ActionExecuteQuery(**data)


class ActionNonStreamExecuteQuerySchema(DBAdapterActionBaseSchema):
db_adapter_query = fields.Nested(GenericDBAQuerySchema)

def to_object(self, data: Dict[str, Any]) -> dba_actions.ActionNonStreamExecuteQuery:
return dba_actions.ActionNonStreamExecuteQuery(**data)


class ActionGetDBVersionSchema(DBAdapterActionBaseSchema):
db_ident = fields.Nested(DBIdentSchema)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
class ActionSerializer:
MAP_ACT_TYPE_SCHEMA_CLS: ClassVar[Dict[Type[dba_actions.RemoteDBAdapterAction], Type[Schema]]] = {
dba_actions.ActionExecuteQuery: schema_actions.ActionExecuteQuerySchema,
dba_actions.ActionNonStreamExecuteQuery: schema_actions.ActionNonStreamExecuteQuerySchema,
dba_actions.ActionTest: schema_actions.ActionTestSchema,
dba_actions.ActionGetDBVersion: schema_actions.ActionGetDBVersionSchema,
dba_actions.ActionGetSchemaNames: schema_actions.ActionGetSchemaNamesSchema,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ async def handle_query_action(

return response

async def handle_non_stream_query_action(
self,
dba: AsyncDBAdapter,
dba_query: DBAdapterQuery,
) -> web.Response:
try:
result = await dba.execute(dba_query)
except Exception:
LOGGER.exception("Exception during execution")
raise

events: list[tuple[str, Any]] = [(RQEEventType.raw_cursor_info.value, result.raw_cursor_info)]
async for raw_chunk in result.raw_chunk_generator:
events.append((RQEEventType.raw_chunk.value, raw_chunk))
events.append((RQEEventType.finished.value, None))

return web.Response(body=pickle.dumps(events))

async def execute_non_streamed_action(
self,
dba: AsyncDBAdapter,
Expand Down Expand Up @@ -210,12 +228,16 @@ async def post(self) -> Union[web.Response, web.StreamResponse]:

if isinstance(action, act.ActionExecuteQuery):
return await self.handle_query_action(adapter, action.db_adapter_query)
elif isinstance(action, act.NonStreamAction):

if isinstance(action, act.ActionNonStreamExecuteQuery):
return await self.handle_non_stream_query_action(adapter, action.db_adapter_query)

if isinstance(action, act.NonStreamAction):
result = await self.execute_non_streamed_action(adapter, action)
resp_data = action.serialize_response(result)
return web.json_response(resp_data)
else:
raise NotImplementedError(f"Action {action} is not implemented in QE")

raise NotImplementedError(f"Action {action} is not implemented in QE")


def body_signature_validation_middleware(hmac_key: bytes) -> AIOHTTPMiddleware:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,25 @@ def execute_execute_action(
},
)

def execute_non_stream_execute_action(
self,
action: act.ActionNonStreamExecuteQuery,
) -> flask.Response:
dba = self.create_dba_for_action(action)

try:
db_result = dba.execute(action.db_adapter_query)
events: list[tuple[str, Any]] = [(RQEEventType.raw_cursor_info.value, db_result.cursor_info)]
for raw_chunk in db_result.data_chunks:
events.append((RQEEventType.raw_chunk.value, raw_chunk))
events.append((RQEEventType.finished.value, None))

return flask.Response(response=pickle.dumps(events))
except Exception:
# noinspection PyBroadException
self.try_close_dba(dba)
raise

@staticmethod
def create_dba_for_action(action: act.RemoteDBAdapterAction) -> SyncDirectDBAdapter:
LOGGER.info("Creating DBA")
Expand All @@ -189,17 +208,21 @@ def dispatch_request(self) -> flask.Response:
action = self.get_action()
LOGGER.info("Got QE action request: %s", action)

if isinstance(action, act.ActionExecuteQuery):
return self.execute_execute_action(action)

if isinstance(action, act.ActionNonStreamExecuteQuery):
return self.execute_non_stream_execute_action(action)

if isinstance(action, act.NonStreamAction):
dba = self.create_dba_for_action(action)
try:
result = self.execute_non_streamed_action(dba, action)
return flask.jsonify(action.serialize_response(result))
finally:
dba.close()
elif isinstance(action, act.ActionExecuteQuery):
return self.execute_execute_action(action)
else:
raise NotImplementedError(f"Action {action} is not implemented in QE")

raise NotImplementedError(f"Action {action} is not implemented in QE")


@attr.s
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def _get_rqe_data(self, external: bool) -> RemoteQueryExecutorData:
sync_protocol=sync_rqe_netloc.scheme,
sync_host=sync_rqe_netloc.host,
sync_port=sync_rqe_netloc.port,
execute_request_mode=self.rqe_config.execute_request_mode,
)

@property
Expand Down

0 comments on commit 8009f69

Please sign in to comment.