Skip to content

Commit

Permalink
improve websocket error handling
Browse files Browse the repository at this point in the history
  • Loading branch information
ajshedivy committed Nov 26, 2024
1 parent 4c7ce02 commit 8788ec9
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 42 deletions.
10 changes: 9 additions & 1 deletion mapepire_python/base_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,17 @@ def _parse_connection_input(

if not isinstance(db2_server, DaemonServer):
raise TypeError("db2_server must be of type DaemonServer")

self.creds = db2_server
return db2_server

def __str__(self) -> str:
creds_str = self.creds
if isinstance(self.creds, DaemonServer):
creds_dict = self.creds.__dict__.copy()
creds_dict.pop("password", None) # Remove password if present
creds_str = str(creds_dict)
return f"BaseJob(creds={creds_str}, options={self.options})"

def connect(
self, db2_server: Union[DaemonServer, Dict[str, Any], Path], **kwargs
) -> Dict[str, Any]:
Expand Down
9 changes: 7 additions & 2 deletions mapepire_python/client/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from enum import Enum
from typing import Any, Dict, Generic, List, Optional, TypeVar

from mapepire_python.websocket import handle_ws_errors

from ..data_types import QueryOptions
from .sql_job import SQLJob

Expand Down Expand Up @@ -39,13 +41,14 @@ def __exit__(self, exc_type, exc_value, traceback):
self.close()

def __str__(self):
return f"Query(sql={self.sql}, parameters={self.parameters}, correlation_id={self._correlation_id})"
return f"Query(job={str(self.job)}, sql={self.sql}, parameters={self.parameters}, correlation_id={self._correlation_id})"

def _execute_query(self, qeury_object: Dict[str, Any]) -> Dict[str, Any]:
self.job.send(json.dumps(qeury_object))
query_result: Dict[str, Any] = json.loads(self.job._socket.recv())
return query_result

@handle_ws_errors
def prepare_sql_execute(self):
# check Query state first
if self.state == QueryState.RUN_DONE:
Expand All @@ -67,7 +70,6 @@ def prepare_sql_execute(self):
)

if not query_result.get("success", False) and not self.is_cl_command:
print(query_result)
self.state = QueryState.ERROR
error_keys = ["error", "sql_state", "sql_rc"]
error_list = {
Expand All @@ -82,6 +84,7 @@ def prepare_sql_execute(self):

return query_result

@handle_ws_errors
def run(self, rows_to_fetch: Optional[int] = None) -> Dict[str, Any]:
if rows_to_fetch is None:
rows_to_fetch = self._rows_to_fetch
Expand Down Expand Up @@ -135,6 +138,7 @@ def run(self, rows_to_fetch: Optional[int] = None) -> Dict[str, Any]:

return query_result

@handle_ws_errors
def fetch_more(self, rows_to_fetch: Optional[int] = None) -> Dict[str, Any]:
if rows_to_fetch is None:
rows_to_fetch = self._rows_to_fetch
Expand Down Expand Up @@ -169,6 +173,7 @@ def fetch_more(self, rows_to_fetch: Optional[int] = None) -> Dict[str, Any]:

return query_result

@handle_ws_errors
def close(self):
if not self.job._socket:
raise Exception("SQL Job not connected")
Expand Down
9 changes: 5 additions & 4 deletions mapepire_python/client/sql_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from websockets.sync.client import ClientConnection

from mapepire_python.client.websocket_client import WebsocketConnection
from mapepire_python.websocket import handle_ws_errors

from ..base_job import BaseJob
from ..data_types import DaemonServer, JobStatus, QueryOptions
Expand Down Expand Up @@ -70,17 +71,16 @@ def get_status(self) -> JobStatus:
"""
return self._status

@handle_ws_errors
def send(self, content: str) -> None:
"""sends content to the mapepire server
Args:
content (str): JSON content to be sent
"""
try:
self._socket.send(content)
except Exception as e:
raise e
self._socket.send(content)

@handle_ws_errors
def connect(
self, db2_server: Union[DaemonServer, Dict[str, Any], Path], **kwargs
) -> Dict[Any, Any]:
Expand Down Expand Up @@ -168,6 +168,7 @@ def query(

return Query(job=self, query=sql, opts=query_options)

@handle_ws_errors
def query_and_run(
self, sql: str, opts: Optional[Dict[str, Any]] = None, **kwargs
) -> Dict[str, Any]:
Expand Down
1 change: 1 addition & 0 deletions mapepire_python/client/websocket_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@


class WebsocketConnection(BaseConnection):

def __init__(self, db2_server: DaemonServer) -> None:
super().__init__(db2_server)

Expand Down
34 changes: 13 additions & 21 deletions mapepire_python/pool/pool_job.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import json
import logging
from pathlib import Path
from typing import Any, Dict, Optional, Union

Expand All @@ -14,6 +15,8 @@

__all__ = ["PoolJob"]

logger = logging.getLogger("websockets.client")


class PoolJob(BaseJob):
unique_id_counter = 0
Expand Down Expand Up @@ -65,10 +68,6 @@ def enable_local_trace_data(self):
def enable_local_channel_trace(self):
self.is_tracing_channel_data = True

def _local_log(self, level: bool, message: str) -> None:
if level:
print(message, flush=True)

async def get_channel(self, db2_server: DaemonServer) -> ClientConnection:
"""returns a websocket connection to the mapepire server
Expand All @@ -94,17 +93,17 @@ async def send(self, content: str) -> Dict[Any, Any]:
Returns:
str: response from the server
"""
self._local_log(self.enable_local_trace, f"sending data: {content}")
logger.debug(f"sending data: {content}")

req = json.loads(content)
if self.socket is None:
raise RuntimeError("Socket is not connected")
try:
await self.socket.send(content)
self.status = JobStatus.Busy
self._local_log(self.enable_local_trace, "wating for response ...")
logger.debug("waiting for response ...")
response = await self.wait_for_response(req["id"])
# self._local_log(self.enable_local_trace, f"recieved response: {response}")
# logger.debug(f"received response: {response}")
self.status = JobStatus.Ready if self.get_running_count() == 0 else JobStatus.Busy
return response # type: ignore
except Exception as e:
Expand All @@ -125,16 +124,14 @@ async def wait_for_response(self, req_id: str) -> str:
future = asyncio.Future()

def on_response(response):
self._local_log(
self.enable_local_trace, f"Received response for req_id: {req_id} - {response}"
)
logger.debug(f"Received response for req_id: {req_id} - {response}")
if not future.done():
future.set_result(response)
self.response_emitter.remove_listener(req_id, on_response)

try:
self.response_emitter.on(req_id, on_response)
self._local_log(self.enable_local_trace, f"Listener registered for req_id: {req_id}")
logger.debug(f"Listener registered for req_id: {req_id}")
return await future
except Exception as e:
self.response_emitter.remove_listener(req_id, on_response)
Expand All @@ -144,9 +141,8 @@ def get_status(self) -> JobStatus:
return self.status

def get_running_count(self) -> int:
self._local_log(
self.enable_local_trace,
f"--- running count {self.unique_id}: {len(self.response_emitter.event_names())}, status: {self.get_status()}",
logger.debug(
f"--- running count {self.unique_id}: {len(self.response_emitter.event_names())}, status: {self.get_status()}"
)
return len(self.response_emitter.event_names())

Expand Down Expand Up @@ -221,20 +217,16 @@ async def message_handler(self):
if self.socket is None:
raise RuntimeError("Socket is not connected")
async for message in self.socket:
self._local_log(self.enable_local_trace, f"Received raw message: {message}")
logger.debug(f"Received raw message: {message}")

try:
response = json.loads(message)
req_id = response.get("id")
if req_id:
self._local_log(
self.enable_local_trace, f"Emitting response for req_id: {req_id}"
)
logger.debug(f"Emitting response for req_id: {req_id}")
self.response_emitter.emit(req_id, response)
else:
self._local_log(
self.enable_local_trace, f"No req_id found in response: {response}"
)
logger.debug(f"No req_id found in response: {response}")
except json.JSONDecodeError as e:
raise ValueError(f"Error decoding JSON: {e}")
except Exception as e:
Expand Down
29 changes: 17 additions & 12 deletions mapepire_python/websocket.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import base64
import ssl
from functools import wraps
from typing import Callable, TypeVar
from typing import Any, Callable, TypeVar

from websockets import InvalidHandshake, InvalidURI
from websockets import ConcurrencyError, ConnectionClosed, InvalidHandshake, InvalidURI

from mapepire_python.data_types import DaemonServer

Expand All @@ -29,28 +29,33 @@ def _create_ssl_context(self, db2_server: DaemonServer):
return ssl_context


def _parse_ws_error(error: RuntimeError, connection: BaseConnection):
if not isinstance(error, RuntimeError):
return error
def _parse_ws_error(error: Exception, driver: Any = None):
to_str = str(driver)

if isinstance(error, InvalidURI):
raise InvalidURI("The provided URI is not a valid WebSocket URI.")
raise InvalidURI(f"The provided URI is not a valid WebSocket URI: {to_str}")
elif isinstance(error, OSError):
raise OSError("The TCP connection failed to connect to Mapepire server")
raise OSError(f"The TCP connection failed to connect to Mapepire server: {to_str}")
elif isinstance(error, InvalidHandshake):
raise InvalidHandshake("The opening handshake failed.")
elif isinstance(error, TimeoutError):
raise TimeoutError("The opening handshake timed out.")
elif isinstance(error, ConnectionClosed):
raise ConnectionClosed("The Conection was closed.")
elif isinstance(error, ConcurrencyError):
raise ConcurrencyError("Connection is sending a fragmented message")
elif isinstance(error, TypeError):
raise TypeError("Message doesn't have a supported type")
else:
return error


def handle_ws_errors(function: Callable[..., ReturnType]) -> Callable[..., ReturnType]:

@wraps(function)
def wrapper(*args, **kwargs):
def _impl(self, *args, **kwargs):
try:
return function(*args, **kwargs)
return function(self, *args, **kwargs)
except RuntimeError as err:
raise _parse_ws_error(err) from err
raise _parse_ws_error(err, driver=self) from err

return wrapper
return _impl
2 changes: 1 addition & 1 deletion tests/async_pool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ async def test_prepare_statement_invalid_params():
query = job.query("select * from sample.employee where bonus > ?", opts=opts)
with pytest.raises(Exception) as execinfo:
res = await query.run()
assert "Data type mismatch. (Infinite or NaN)" in str(execinfo.value)
assert "Data type mismatch." in str(execinfo.value)


@pytest.mark.asyncio
Expand Down
1 change: 1 addition & 0 deletions tests/simple_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def parse_sql_rc(message):

def test_connect_simple():
job = SQLJob()
print(job)
result = job.connect(creds)
assert result["success"]
job.close()
Expand Down
3 changes: 2 additions & 1 deletion tests/sql_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,8 @@ def test_prepare_statement_invalid_params():
query = job.query("select * from sample.employee where bonus > ?", opts=opts)
with pytest.raises(Exception) as execinfo:
query.run()
assert "Data type mismatch. (Infinite or NaN)" in str(execinfo.value)
print(execinfo)
assert "Data type mismatch." in str(execinfo.value)
job.close()


Expand Down

0 comments on commit 8788ec9

Please sign in to comment.