From 8009f695fa779490d181f419e7652591e7f6a01f Mon Sep 17 00:00:00 2001 From: Ovsyannikov Dmitrii Date: Thu, 21 Mar 2024 10:58:23 +0100 Subject: [PATCH] feat: add ActionNonStreamExecuteQuery to RQE (#388) --- lib/dl_configs/dl_configs/rqe.py | 11 ++ .../adapters/async_adapters_remote.py | 137 ++++++++++++------ .../connection_executors/models/common.py | 4 + .../qe_serializer/dba_actions.py | 5 + .../qe_serializer/schema_actions.py | 7 + .../qe_serializer/serializer.py | 1 + .../remote_query_executor/app_async.py | 28 +++- .../remote_query_executor/app_sync.py | 31 +++- .../conn_executor_factory.py | 1 + 9 files changed, 172 insertions(+), 53 deletions(-) diff --git a/lib/dl_configs/dl_configs/rqe.py b/lib/dl_configs/dl_configs/rqe.py index 7773a9a0e..f73f941f2 100644 --- a/lib/dl_configs/dl_configs/rqe.py +++ b/lib/dl_configs/dl_configs/rqe.py @@ -1,5 +1,6 @@ from __future__ import annotations +import enum import os from typing import Optional @@ -15,6 +16,11 @@ from dl_configs.utils import validate_one_of +class RQEExecuteRequestMode(enum.Enum): + STREAM = "stream" + NON_STREAM = "non_stream" + + @attr.s(frozen=True) class RQEBaseURL(SettingsBase): scheme: str = s_attrib("SCHEME", env_var_converter=validate_one_of({"http", "https"}), missing="http") # type: ignore # 2024-01-24 # TODO: Incompatible types in assignment (expression has type "Attribute[Any]", variable has type "str") [assignment] @@ -36,6 +42,11 @@ class RQEConfig(SettingsBase): sensitive=True, env_var_converter=lambda s: s.encode("ascii"), ) + execute_request_mode: RQEExecuteRequestMode = s_attrib( # type: ignore # 2024-01-24 # TODO: Incompatible types in assignment (expression has type "Attribute[Any]", variable has type "RQEExecuteRequestMode") [assignment] + "EXECUTE_REQUEST_MODE", + missing=RQEExecuteRequestMode.STREAM, + env_var_converter=lambda s: RQEExecuteRequestMode[s.lower()], + ) @classmethod def get_default(cls) -> RQEConfig: diff --git a/lib/dl_core/dl_core/connection_executors/adapters/async_adapters_remote.py b/lib/dl_core/dl_core/connection_executors/adapters/async_adapters_remote.py index cd3476299..5ed8278a7 100644 --- a/lib/dl_core/dl_core/connection_executors/adapters/async_adapters_remote.py +++ b/lib/dl_core/dl_core/connection_executors/adapters/async_adapters_remote.py @@ -4,6 +4,7 @@ import json import logging import pickle +import typing from typing import ( TYPE_CHECKING, Any, @@ -42,7 +43,10 @@ AsyncRawExecutionResult, ) from dl_core.connection_executors.adapters.common_base import CommonBaseDirectAdapter -from dl_core.connection_executors.models.common import RemoteQueryExecutorData +from dl_core.connection_executors.models.common import ( + RemoteQueryExecutorData, + RQEExecuteRequestMode, +) from dl_core.connection_executors.models.constants import ( HEADER_BODY_SIGNATURE, HEADER_REQUEST_ID, @@ -232,10 +236,62 @@ async def _make_request_parse_response( except Exception as resp_deserialization_exc: raise QueryExecutorException("Unexpected response JSON schema") from resp_deserialization_exc + @staticmethod + def _parse_event(event: Any) -> Tuple[RQEEventType, Any]: + if not isinstance(event, tuple): + raise QueryExecutorException(f"QE parse: unexpected event type: {type(event)}") + if len(event) != 2: + raise QueryExecutorException(f"QE parse: event is not a pair: length={len(event)}") + event_type_name, event_data = event + if not isinstance(event_type_name, str): + raise QueryExecutorException(f"QE parse: event_type is not a str: {type(event_type_name)}") + try: + event_type = RQEEventType[event_type_name] + except KeyError: + raise QueryExecutorException(f"QE parse: unknown event_type: {event_type_name!r}") from None + + return event_type, event_data + + async def _get_execution_result( + self, + events: typing.AsyncGenerator[Tuple[RQEEventType, Any], None], + ) -> AsyncRawExecutionResult: + ev_type, ev_data = await events.__anext__() + if ev_type == RQEEventType.raw_cursor_info: + raw_cursor_info = ev_data + else: + raise QueryExecutorException(f"QE parse: first event type is not 'raw_cursor_info': {ev_type}") + + # 3-layer iterable: generator of chunks, chunks is a list of rows, row + # is a list of column values. + async def data_generator() -> AsyncGenerator[Sequence[Sequence[Any]], None]: + ev_type = None + + async for ev_type, ev_data in events: + if ev_type == RQEEventType.raw_chunk: + chunk = ev_data + if not isinstance(chunk, (list, tuple)): + raise QueryExecutorException(f"QE parse: unexpected chunk type: {type(chunk)}") + yield chunk + elif ev_type == RQEEventType.error_dump: + try: + exc = self._parse_exception(ev_data) + except Exception as e: + raise QueryExecutorException(f"QE parse: failed to parse an error event: {ev_data!r}") from e + raise exc + elif ev_type == RQEEventType.finished: + return + + raise QueryExecutorException(f"QE parse: finish event was not received (last: {ev_type})") + + return AsyncRawExecutionResult( + raw_cursor_info=raw_cursor_info, + raw_chunk_generator=data_generator(), + ) + # Here we have some async problem: # "qe" stage will be finished before some nested stages in RQE (e.g. "qe/fetch") - @generic_profiler_async("qe") # type: ignore # TODO: fix - async def execute(self, query: DBAdapterQuery) -> AsyncRawExecutionResult: + async def _execute_stream(self, query: DBAdapterQuery) -> AsyncRawExecutionResult: resp = await self._make_request( dba_actions.ActionExecuteQuery( db_adapter_query=query, @@ -250,7 +306,7 @@ async def execute(self, query: DBAdapterQuery) -> AsyncRawExecutionResult: exc = self._parse_exception(resp_body_json) raise exc - async def event_gen() -> AsyncGenerator[Tuple[str, Any], None]: + async def event_gen() -> AsyncGenerator[Tuple[RQEEventType, Any], None]: buf = b"" while True: @@ -265,61 +321,50 @@ async def event_gen() -> AsyncGenerator[Tuple[str, Any], None]: # This isn't a very correct way, but it's hard to use pickle in async differently. if end_of_chunk and buf: try: - parsed_event = pickle.loads(buf) + event = pickle.loads(buf) except Exception as err: raise QueryExecutorException("QE parse: failed to unpickle") from err + yield self._parse_event(event) buf = b"" - if not isinstance(parsed_event, tuple): - raise QueryExecutorException(f"QE parse: unexpected event type: {type(parsed_event)}") - if len(parsed_event) != 2: - raise QueryExecutorException(f"QE parse: event is not a pair: length={len(parsed_event)}") - event_type_name, event_data = parsed_event - if not isinstance(event_type_name, str): - raise QueryExecutorException(f"QE parse: event_type is not a str: {type(event_type_name)}") - try: - event_type = RQEEventType[event_type_name] - except KeyError: - raise QueryExecutorException(f"QE parse: unknown event_type: {event_type_name!r}") from None - - yield event_type, event_data # type: ignore # TODO: fix if raw_chunk == b"" and not end_of_chunk: return events = event_gen() - ev_type, ev_data = await events.__anext__() - if ev_type == RQEEventType.raw_cursor_info: - raw_cursor_info = ev_data - else: - raise QueryExecutorException(f"QE parse: first event type is not 'raw_cursor_info': {ev_type}") + return await self._get_execution_result(events) - # 3-layer iterable: generator of chunks, chunks is a list of rows, row - # is a list of column values. - async def data_generator() -> AsyncGenerator[Sequence[Sequence[Any]], None]: - ev_type = None + async def _execute_non_stream(self, query: DBAdapterQuery) -> AsyncRawExecutionResult: + response = await self._make_request( + dba_actions.ActionNonStreamExecuteQuery( + db_adapter_query=query, + target_conn_dto=self._target_dto, + dba_cls=self._dba_cls, + req_ctx_info=self._req_ctx_info, + ), + ) - async for ev_type, ev_data in events: - if ev_type == RQEEventType.raw_chunk: - chunk = ev_data - if not isinstance(chunk, (list, tuple)): - raise QueryExecutorException(f"QE parse: unexpected chunk type: {type(chunk)}") - yield chunk - elif ev_type == RQEEventType.error_dump: - try: - exc = self._parse_exception(ev_data) - except Exception as e: - raise QueryExecutorException(f"QE parse: failed to parse an error event: {ev_data!r}") from e - raise exc - elif ev_type == RQEEventType.finished: - return + if response.status != 200: + resp_body_json = await self._read_body_json(response) + exc = self._parse_exception(resp_body_json) + raise exc - raise QueryExecutorException(f"QE parse: finish event was not received (last: {ev_type})") + raw_data = await response.read() + raw_events = pickle.loads(raw_data) - return AsyncRawExecutionResult( - raw_cursor_info=raw_cursor_info, - raw_chunk_generator=data_generator(), - ) + async def event_gen() -> AsyncGenerator[Tuple[RQEEventType, Any], None]: + for raw_event in raw_events: + yield self._parse_event(raw_event) + + events = event_gen() + return await self._get_execution_result(events) + + @generic_profiler_async("qe") # type: ignore # TODO: fix + async def execute(self, query: DBAdapterQuery) -> AsyncRawExecutionResult: + if self._rqe_data.execute_request_mode == RQEExecuteRequestMode.STREAM: + return await self._execute_stream(query) + + return await self._execute_non_stream(query) async def get_db_version(self, db_ident: DBIdent) -> Optional[str]: return await self._make_request_parse_response( diff --git a/lib/dl_core/dl_core/connection_executors/models/common.py b/lib/dl_core/dl_core/connection_executors/models/common.py index a82e97fba..92cb0cb79 100644 --- a/lib/dl_core/dl_core/connection_executors/models/common.py +++ b/lib/dl_core/dl_core/connection_executors/models/common.py @@ -2,6 +2,8 @@ import attr +from dl_configs.rqe import RQEExecuteRequestMode + @attr.s(frozen=True) class RemoteQueryExecutorData: @@ -14,3 +16,5 @@ class RemoteQueryExecutorData: sync_protocol: str = attr.ib() sync_host: str = attr.ib() sync_port: int = attr.ib() + + execute_request_mode: RQEExecuteRequestMode = attr.ib(default=RQEExecuteRequestMode.STREAM) diff --git a/lib/dl_core/dl_core/connection_executors/qe_serializer/dba_actions.py b/lib/dl_core/dl_core/connection_executors/qe_serializer/dba_actions.py index 2ac3d583e..4415000ea 100644 --- a/lib/dl_core/dl_core/connection_executors/qe_serializer/dba_actions.py +++ b/lib/dl_core/dl_core/connection_executors/qe_serializer/dba_actions.py @@ -66,6 +66,11 @@ def deserialize_response(self, data: Dict) -> _RES_SCHEMA_TV: return self.ResultSchema().load(data) # type: ignore # TODO: fix +@attr.s(frozen=True) +class ActionNonStreamExecuteQuery(RemoteDBAdapterAction): + db_adapter_query: DBAdapterQuery = attr.ib() + + @attr.s(frozen=True) class ActionTest(NonStreamAction[None]): class ResultSchema(BaseQEAPISchema): diff --git a/lib/dl_core/dl_core/connection_executors/qe_serializer/schema_actions.py b/lib/dl_core/dl_core/connection_executors/qe_serializer/schema_actions.py index ef4cea5b2..d6bf0db6e 100644 --- a/lib/dl_core/dl_core/connection_executors/qe_serializer/schema_actions.py +++ b/lib/dl_core/dl_core/connection_executors/qe_serializer/schema_actions.py @@ -62,6 +62,13 @@ def to_object(self, data: Dict[str, Any]) -> dba_actions.ActionExecuteQuery: return dba_actions.ActionExecuteQuery(**data) +class ActionNonStreamExecuteQuerySchema(DBAdapterActionBaseSchema): + db_adapter_query = fields.Nested(GenericDBAQuerySchema) + + def to_object(self, data: Dict[str, Any]) -> dba_actions.ActionNonStreamExecuteQuery: + return dba_actions.ActionNonStreamExecuteQuery(**data) + + class ActionGetDBVersionSchema(DBAdapterActionBaseSchema): db_ident = fields.Nested(DBIdentSchema) diff --git a/lib/dl_core/dl_core/connection_executors/qe_serializer/serializer.py b/lib/dl_core/dl_core/connection_executors/qe_serializer/serializer.py index 5263d876c..e338aeb60 100644 --- a/lib/dl_core/dl_core/connection_executors/qe_serializer/serializer.py +++ b/lib/dl_core/dl_core/connection_executors/qe_serializer/serializer.py @@ -27,6 +27,7 @@ class ActionSerializer: MAP_ACT_TYPE_SCHEMA_CLS: ClassVar[Dict[Type[dba_actions.RemoteDBAdapterAction], Type[Schema]]] = { dba_actions.ActionExecuteQuery: schema_actions.ActionExecuteQuerySchema, + dba_actions.ActionNonStreamExecuteQuery: schema_actions.ActionNonStreamExecuteQuerySchema, dba_actions.ActionTest: schema_actions.ActionTestSchema, dba_actions.ActionGetDBVersion: schema_actions.ActionGetDBVersionSchema, dba_actions.ActionGetSchemaNames: schema_actions.ActionGetSchemaNamesSchema, diff --git a/lib/dl_core/dl_core/connection_executors/remote_query_executor/app_async.py b/lib/dl_core/dl_core/connection_executors/remote_query_executor/app_async.py index 5011ad884..7072c9c4a 100644 --- a/lib/dl_core/dl_core/connection_executors/remote_query_executor/app_async.py +++ b/lib/dl_core/dl_core/connection_executors/remote_query_executor/app_async.py @@ -151,6 +151,24 @@ async def handle_query_action( return response + async def handle_non_stream_query_action( + self, + dba: AsyncDBAdapter, + dba_query: DBAdapterQuery, + ) -> web.Response: + try: + result = await dba.execute(dba_query) + except Exception: + LOGGER.exception("Exception during execution") + raise + + events: list[tuple[str, Any]] = [(RQEEventType.raw_cursor_info.value, result.raw_cursor_info)] + async for raw_chunk in result.raw_chunk_generator: + events.append((RQEEventType.raw_chunk.value, raw_chunk)) + events.append((RQEEventType.finished.value, None)) + + return web.Response(body=pickle.dumps(events)) + async def execute_non_streamed_action( self, dba: AsyncDBAdapter, @@ -210,12 +228,16 @@ async def post(self) -> Union[web.Response, web.StreamResponse]: if isinstance(action, act.ActionExecuteQuery): return await self.handle_query_action(adapter, action.db_adapter_query) - elif isinstance(action, act.NonStreamAction): + + if isinstance(action, act.ActionNonStreamExecuteQuery): + return await self.handle_non_stream_query_action(adapter, action.db_adapter_query) + + if isinstance(action, act.NonStreamAction): result = await self.execute_non_streamed_action(adapter, action) resp_data = action.serialize_response(result) return web.json_response(resp_data) - else: - raise NotImplementedError(f"Action {action} is not implemented in QE") + + raise NotImplementedError(f"Action {action} is not implemented in QE") def body_signature_validation_middleware(hmac_key: bytes) -> AIOHTTPMiddleware: diff --git a/lib/dl_core/dl_core/connection_executors/remote_query_executor/app_sync.py b/lib/dl_core/dl_core/connection_executors/remote_query_executor/app_sync.py index 416d88d52..2db46e6f1 100644 --- a/lib/dl_core/dl_core/connection_executors/remote_query_executor/app_sync.py +++ b/lib/dl_core/dl_core/connection_executors/remote_query_executor/app_sync.py @@ -169,6 +169,25 @@ def execute_execute_action( }, ) + def execute_non_stream_execute_action( + self, + action: act.ActionNonStreamExecuteQuery, + ) -> flask.Response: + dba = self.create_dba_for_action(action) + + try: + db_result = dba.execute(action.db_adapter_query) + events: list[tuple[str, Any]] = [(RQEEventType.raw_cursor_info.value, db_result.cursor_info)] + for raw_chunk in db_result.data_chunks: + events.append((RQEEventType.raw_chunk.value, raw_chunk)) + events.append((RQEEventType.finished.value, None)) + + return flask.Response(response=pickle.dumps(events)) + except Exception: + # noinspection PyBroadException + self.try_close_dba(dba) + raise + @staticmethod def create_dba_for_action(action: act.RemoteDBAdapterAction) -> SyncDirectDBAdapter: LOGGER.info("Creating DBA") @@ -189,6 +208,12 @@ def dispatch_request(self) -> flask.Response: action = self.get_action() LOGGER.info("Got QE action request: %s", action) + if isinstance(action, act.ActionExecuteQuery): + return self.execute_execute_action(action) + + if isinstance(action, act.ActionNonStreamExecuteQuery): + return self.execute_non_stream_execute_action(action) + if isinstance(action, act.NonStreamAction): dba = self.create_dba_for_action(action) try: @@ -196,10 +221,8 @@ def dispatch_request(self) -> flask.Response: return flask.jsonify(action.serialize_response(result)) finally: dba.close() - elif isinstance(action, act.ActionExecuteQuery): - return self.execute_execute_action(action) - else: - raise NotImplementedError(f"Action {action} is not implemented in QE") + + raise NotImplementedError(f"Action {action} is not implemented in QE") @attr.s diff --git a/lib/dl_core/dl_core/services_registry/conn_executor_factory.py b/lib/dl_core/dl_core/services_registry/conn_executor_factory.py index 83ad8020e..e5deb8856 100644 --- a/lib/dl_core/dl_core/services_registry/conn_executor_factory.py +++ b/lib/dl_core/dl_core/services_registry/conn_executor_factory.py @@ -226,6 +226,7 @@ def _get_rqe_data(self, external: bool) -> RemoteQueryExecutorData: sync_protocol=sync_rqe_netloc.scheme, sync_host=sync_rqe_netloc.host, sync_port=sync_rqe_netloc.port, + execute_request_mode=self.rqe_config.execute_request_mode, ) @property