Skip to content

Commit

Permalink
use websockets library as ws backend
Browse files Browse the repository at this point in the history
  • Loading branch information
ajshedivy committed Nov 25, 2024
1 parent a4ce0c6 commit 12dd737
Show file tree
Hide file tree
Showing 15 changed files with 195 additions and 91 deletions.
2 changes: 1 addition & 1 deletion environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ dependencies:
- mypy
- pip:
- dataclasses-json>=0.6.4
- websocket-client>=1.2.1
- websockets>=14.0
- pyee
- websockets
- pep249abc
2 changes: 1 addition & 1 deletion mapepire_python/client/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def fetch_more(self, rows_to_fetch: Optional[int] = None) -> Dict[str, Any]:
return query_result

def close(self):
if not self.job._socket.connected:
if not self.job._socket:
raise Exception("SQL Job not connected")
if self._correlation_id and self.state is not QueryState.RUN_DONE:
self.state = QueryState.RUN_DONE
Expand Down
14 changes: 9 additions & 5 deletions mapepire_python/client/sql_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@
from pathlib import Path
from typing import Any, Dict, Optional, Union

from websocket import WebSocket
from websockets.sync.client import ClientConnection

from mapepire_python.client.websocket_client import WebsocketConnection

from ..base_job import BaseJob
from ..data_types import DaemonServer, JobStatus, QueryOptions
from .websocket import WebsocketConnection

__all__ = ["SQLJob"]

Expand Down Expand Up @@ -49,7 +50,7 @@ def _get_unique_id(self, prefix: str = "id") -> str:
self._unique_id_counter += 1
return f"{prefix}{self._unique_id_counter}"

def _get_channel(self, db2_server: DaemonServer) -> WebSocket:
def _get_channel(self, db2_server: DaemonServer) -> ClientConnection:
"""returns a websocket connection to the mapepire server
Args:
Expand All @@ -75,7 +76,10 @@ def send(self, content: str) -> None:
Args:
content (str): JSON content to be sent
"""
self._socket.send(content)
try:
self._socket.send(content)
except Exception as e:
raise e

def connect(
self, db2_server: Union[DaemonServer, Dict[str, Any], Path], **kwargs
Expand All @@ -93,7 +97,7 @@ def connect(
"""
db2_server = self._parse_connection_input(db2_server, **kwargs)

self._socket: WebSocket = self._get_channel(db2_server)
self._socket: ClientConnection = self._get_channel(db2_server)

props = ";".join(
[
Expand Down
40 changes: 0 additions & 40 deletions mapepire_python/client/websocket.py

This file was deleted.

19 changes: 19 additions & 0 deletions mapepire_python/client/websocket_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from websockets.sync.client import ClientConnection, connect

from mapepire_python.websocket import BaseConnection, handle_ws_errors

from ..data_types import DaemonServer


class WebsocketConnection(BaseConnection):
def __init__(self, db2_server: DaemonServer) -> None:
super().__init__(db2_server)

@handle_ws_errors
def connect(self) -> ClientConnection:
return connect(
self.uri,
additional_headers=self.headers,
open_timeout=10,
ssl=self._create_ssl_context(self.db2_server),
)
5 changes: 2 additions & 3 deletions mapepire_python/core/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, connection: "Connection", job: SQLJob) -> None:
self._connection = weakref.proxy(connection)
self.job = job
self.query: Query = None
self.query_q = deque(maxlen=20)
self.query_q: deque[Query] = deque(maxlen=20)
self.__closed = False
self.__has_results = False

Expand Down Expand Up @@ -102,7 +102,6 @@ def execute(
query = Query(self.job, operation, create_opts)

prepare_result = query.prepare_sql_execute()
# print(prepare_result)

if prepare_result["has_results"]:
self.query = query
Expand Down Expand Up @@ -188,7 +187,7 @@ def nextset(self) -> Optional[bool]:
def close(self) -> None:
if self._closed:
return
if self.query and self.job._socket.connected:
if self.query:
for q in self.query_q:
q.close()
self.query_q.clear()
Expand Down
18 changes: 18 additions & 0 deletions mapepire_python/pool/async_websocket_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
from websockets.asyncio.client import ClientConnection, connect

from mapepire_python.websocket import BaseConnection, handle_ws_errors


class AsyncWebSocketConnection(BaseConnection):
def __init__(self, db2_server):
super().__init__(db2_server)

@handle_ws_errors
async def connect(self) -> ClientConnection:
websocket = await connect(
self.uri,
additional_headers=self.headers,
open_timeout=10,
ssl=self._create_ssl_context(self.db2_server),
)
return websocket
48 changes: 17 additions & 31 deletions mapepire_python/pool/pool_job.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import asyncio
import base64
import json
import ssl
from pathlib import Path
from typing import Any, Dict, Optional, Union

import websockets
from pyee.asyncio import AsyncIOEventEmitter
from websockets.asyncio.client import ClientConnection

from mapepire_python.pool.async_websocket_client import AsyncWebSocketConnection

from ..base_job import BaseJob
from ..data_types import DaemonServer, JobStatus, QueryOptions
Expand Down Expand Up @@ -68,7 +69,7 @@ def _local_log(self, level: bool, message: str) -> None:
if level:
print(message, flush=True)

async def get_channel(self, db2_server: DaemonServer) -> websockets.WebSocketClientProtocol:
async def get_channel(self, db2_server: DaemonServer) -> ClientConnection:
"""returns a websocket connection to the mapepire server
Args:
Expand All @@ -81,26 +82,8 @@ async def get_channel(self, db2_server: DaemonServer) -> websockets.WebSocketCli
Returns:
websockets.WebSocketClientProtocol: websocket connection
"""
uri = f"wss://{db2_server.host}:{db2_server.port}/db/"
headers = {
"Authorization": "Basic "
+ base64.b64encode(f"{db2_server.user}:{db2_server.password}".encode()).decode("ascii")
}

ssl_contest = ssl.create_default_context(cafile=db2_server.ca)
ssl_contest.check_hostname = False
ssl_contest.verify_mode = ssl.CERT_NONE

try:
socket = await websockets.connect(
uri=uri, extra_headers=headers, ssl=ssl_contest, ping_timeout=None, open_timeout=30
)
except TimeoutError as e:
raise TimeoutError("Failed to connect to server") from e
except Exception as e:
raise e

return socket
socket = AsyncWebSocketConnection(db2_server)
return await socket.connect()

async def send(self, content: str) -> Dict[Any, Any]:
"""sends content to the mapepire server
Expand All @@ -116,13 +99,16 @@ async def send(self, content: str) -> Dict[Any, Any]:
req = json.loads(content)
if self.socket is None:
raise RuntimeError("Socket is not connected")
await self.socket.send(content)
self.status = JobStatus.Busy
self._local_log(self.enable_local_trace, "wating for response ...")
response = await self.wait_for_response(req["id"])
self._local_log(self.enable_local_trace, f"recieved response: {response}")
self.status = JobStatus.Ready if self.get_running_count() == 0 else JobStatus.Busy
return response # type: ignore
try:
await self.socket.send(content)
self.status = JobStatus.Busy
self._local_log(self.enable_local_trace, "wating for response ...")
response = await self.wait_for_response(req["id"])
# self._local_log(self.enable_local_trace, f"recieved response: {response}")
self.status = JobStatus.Ready if self.get_running_count() == 0 else JobStatus.Busy
return response # type: ignore
except Exception as e:
raise e

async def wait_for_response(self, req_id: str) -> str:
"""when a request is sent to the server, this method waits for the response
Expand Down Expand Up @@ -253,7 +239,7 @@ async def message_handler(self):
raise ValueError(f"Error decoding JSON: {e}")
except Exception as e:
raise RuntimeError(f"Error: {e}")
except websockets.exceptions.ConnectionClosed:
except websockets.exceptions.ConnectionClosedError:
await self.dispose()

def query(
Expand Down
7 changes: 6 additions & 1 deletion mapepire_python/ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,9 @@ def get_certificate(creds: DaemonServer) -> Optional[bytes]:
)
with socket.create_connection((creds.host, creds.port)) as sock:
with context.wrap_socket(sock, server_hostname=creds.host) as ssock:
return ssock.getpeercert(binary_form=True)
try:
ssock.do_handshake()
cert = ssock.getpeercert(binary_form=True)
return ssl.DER_cert_to_PEM_cert(cert)
except ssl.SSLError as er:
raise er
58 changes: 58 additions & 0 deletions mapepire_python/websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import base64
import ssl
from functools import wraps
from typing import Callable, TypeVar

from websockets import InvalidHandshake, InvalidURI

from mapepire_python.data_types import DaemonServer

ReturnType = TypeVar("ReturnType")


class BaseConnection:
def __init__(self, db2_server: DaemonServer) -> None:
self.uri = f"wss://{db2_server.host}:{db2_server.port}/db/"
self.headers = {
"Authorization": "Basic "
+ base64.b64encode(f"{db2_server.user}:{db2_server.password}".encode()).decode("ascii")
}
self.db2_server = db2_server

def _create_ssl_context(self, db2_server: DaemonServer):
ssl_context = ssl.create_default_context(purpose=ssl.Purpose.SERVER_AUTH)
if db2_server.ignoreUnauthorized:
ssl_context.check_hostname = False
ssl_context.verify_mode = ssl.CERT_NONE
elif db2_server.ca:
ssl_context.load_verify_locations(cadata=db2_server.ca)
return ssl_context


def _parse_ws_error(error: RuntimeError, connection: BaseConnection):
if not isinstance(error, RuntimeError):
return error
if isinstance(error, InvalidURI):
raise InvalidURI(f"The provided URI: {connection.uri} is not a valid WebSocket URI.")
elif isinstance(error, OSError):
raise OSError(
f"The TCP connection failed to connect to Mapepire server {connection.db2_server.host}:{connection.db2_server.port}"
)
elif isinstance(error, InvalidHandshake):
raise InvalidHandshake("The opening handshake failed.")
elif isinstance(error, TimeoutError):
raise TimeoutError("The opening handshake timed out.")
else:
return error


def handle_ws_errors(function: Callable[..., ReturnType]) -> Callable[..., ReturnType]:

@wraps(function)
def wrapper(*args, **kwargs):
try:
return function(*args, **kwargs)
except RuntimeError as err:
raise _parse_ws_error(err) from err

return wrapper
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ authors = [
requires-python = ">3.9"
dependencies = [
"dataclasses-json>=0.6.4",
"websocket-client>=1.2.1",
"websockets==13.1",
"websockets>=14.0",
"pyee",
"pep249abc"
]
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
dataclasses-json>=0.6.4
websocket-client>=1.2.1
websockets>=14.0
pep249abc
pytest
pytest-asyncio
Expand Down
12 changes: 6 additions & 6 deletions tests/pooling_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ async def test_simple_pool_cm():
assert len(job_names) == 3
assert pool.get_active_job_count() == 3
finally:
# Ensure all tasks are completed before exiting
await pool.end()
pending = asyncio.all_tasks()
if pending:
for task in pending:
for task in pending:
if task is not asyncio.current_task():
task.cancel()


Expand Down Expand Up @@ -176,10 +176,10 @@ async def test_pop_jobs_returns_free_job():
assert pool.get_active_job_count() == 4
await asyncio.gather(*executed_promises)
finally:
# Ensure all tasks are completed before exiting
await pool.end()
pending = asyncio.all_tasks()
if pending:
for task in pending:
for task in pending:
if task is not asyncio.current_task():
task.cancel()


Expand Down
Loading

0 comments on commit 12dd737

Please sign in to comment.