From 57e9f80ad2d9c69ea1102cc6e1567f359de299f7 Mon Sep 17 00:00:00 2001 From: ajshedivy Date: Tue, 16 Apr 2024 16:39:07 -0500 Subject: [PATCH] add ws connection logic --- python_sc/sql_job.py | 79 +++++++++++++++ python_sc/sql_runner.py | 15 +++ python_sc/tls.py | 42 ++++++++ python_sc/types.py | 214 ++++++++++++++++++++++++++++++++++++++++ tests/hello_test.py | 23 ++++- 5 files changed, 371 insertions(+), 2 deletions(-) create mode 100644 python_sc/sql_job.py create mode 100644 python_sc/sql_runner.py create mode 100644 python_sc/tls.py create mode 100644 python_sc/types.py diff --git a/python_sc/sql_job.py b/python_sc/sql_job.py new file mode 100644 index 0000000..1abd292 --- /dev/null +++ b/python_sc/sql_job.py @@ -0,0 +1,79 @@ +import asyncio +import base64 +import json +import websocket +import ssl +from typing import Any, Dict, List, Optional, Union +from websocket import create_connection, WebSocket + +from python_sc.types import ConnectionResult, DaemonServer, JobStatus + + +class SQLJob: + def __init__(self) -> None: + self._unique_id_counter: int = 0; + self._socket: Any = None + self._reponse_emitter = None + self._status: JobStatus = JobStatus.NotStarted + self._trace_file = None + self._is_tracing_channeldata: bool = True + + self.__unique_id = self._get_unique_id('sqljob') + self.id: Optional[str] = None + + + def _get_unique_id(self, prefix: str = 'id') -> str: + self._unique_id_counter += 1 + return f"{prefix}{self._unique_id_counter}" + + def _get_channel(self, db2_server: DaemonServer) -> WebSocket: + uri = f"wss://{db2_server.host}:{db2_server.port}/db/" + headers = {"Authorization": "Basic " + base64.b64encode(f"{db2_server.user}:{db2_server.password}".encode()).decode('ascii')} + + # Prepare SSL context if necessary + if db2_server.ca: + ssl_opts = {"cert_reqs": ssl.CERT_NONE} if not db2_server.ignoreUnauthorized else {"ca_certs": db2_server.ca} + else: + ssl_opts = {"cert_reqs": ssl.CERT_NONE} if db2_server.ignoreUnauthorized is False else {} + + # Create WebSocket connection + socket = create_connection(uri, header=headers, sslopt=ssl_opts) + # socket = websocket.WebSocketApp(uri, header=headers, sslopt=ssl_opts) + + # Register message handler + def on_message(ws, message): + if self._is_tracing_channeldata: + print(message) + try: + response = json.loads(message) + print(f"Received message with ID: {response['id']}") + except Exception as e: + print(f"Error parsing message: {e}") + + socket.on_message = on_message + + return socket + + def send(self, content): + self._socket.send(content) + + def connect(self, db2_server: DaemonServer) -> ConnectionResult: + self._socket: WebSocket = self._get_channel(db2_server) + + connection_props = { + 'id': self._get_unique_id(), + 'type': 'connect', + 'technique': 'tcp', + 'application': 'Python Client', + 'props': "" + } + + self.send(json.dumps(connection_props)) + print(self._socket.recv()) + + + + + + + diff --git a/python_sc/sql_runner.py b/python_sc/sql_runner.py new file mode 100644 index 0000000..d30db1f --- /dev/null +++ b/python_sc/sql_runner.py @@ -0,0 +1,15 @@ +import asyncio + +class SQLJobRunner: + def run(self, coro): + """ Run a coroutine and return the result, handling the event loop. """ + try: + loop = asyncio.get_running_loop() + except RuntimeError: # No running loop + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + result = loop.run_until_complete(coro) + loop.close() + return result + else: + return loop.run_until_complete(coro) diff --git a/python_sc/tls.py b/python_sc/tls.py new file mode 100644 index 0000000..193b993 --- /dev/null +++ b/python_sc/tls.py @@ -0,0 +1,42 @@ +import asyncio +import socket +import ssl +from dataclasses import dataclass +from typing import Optional, Union + +from python_sc.types import DaemonServer + + +async def get_certificate(server: DaemonServer): + """ Asynchronously get the peer's certificate from a secure TLS connection. """ + # Create a default SSL context + context = ssl.create_default_context() + + if server.ignoreUnauthorized is False: + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + + if server.ca: + context.load_verify_locations(server.ca) + + # Create a non-blocking socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setblocking(False) + + # Wrap the socket with the SSL context + wrapped_socket = context.wrap_socket(sock, server_hostname=server.host) + + # Connect using asyncio's event loop + await asyncio.get_event_loop().sock_connect(wrapped_socket, (server.host, server.port)) + + try: + # Perform the handshake to establish the secure connection + await asyncio.get_event_loop().sock_do_handshake(wrapped_socket) + # Obtain the certificate + cert = wrapped_socket.getpeercert(True) + return cert + except Exception as e: + print(f"Error: {e}") + raise e + finally: + wrapped_socket.close() \ No newline at end of file diff --git a/python_sc/types.py b/python_sc/types.py new file mode 100644 index 0000000..174d2ae --- /dev/null +++ b/python_sc/types.py @@ -0,0 +1,214 @@ +from enum import Enum +from dataclasses import dataclass, field +from typing import Optional, List, Dict, Any, Union + +class JobStatus(Enum): + NotStarted = "notStarted" + Ready = "ready" + Busy = "busy" + Ended = "ended" + +class ExplainType(Enum): + Run = 0 + DoNotRun = 1 + +class TransactionEndType(Enum): + COMMIT = 0 + ROLLBACK = 1 + +class ServerTraceLevel(Enum): + OFF = "OFF" + ON = "ON" + ERRORS = "ERRORS" + DATASTREAM = "DATASTREAM" + +class ServerTraceDest(Enum): + FILE = "FILE" + IN_MEM = "IN_MEM" + + +@dataclass +class DaemonServer: + host: str + user: str + password: str + port: Optional[int] = None + ignoreUnauthorized: Optional[bool] = None + ca: Optional[Union[str, bytes]] = None + +@dataclass +class ServerResponse: + id: str + success: bool + sql_rc: int + sql_state: str + error: Optional[str] = None + +@dataclass +class ConnectionResult(ServerResponse): + job: str + id: str = field(init=False) + success: bool = field(init=False) + sql_rc: int = field(init=False) + sql_state: str = field(init=False) + error: Optional[str] = field(default=None, init=False) + +@dataclass +class VersionCheckResult(ServerResponse): + build_date: str + version: str + id: str = field(init=False) + success: bool = field(init=False) + sql_rc: int = field(init=False) + sql_state: str = field(init=False) + error: Optional[str] = field(default=None, init=False) + +@dataclass +class ColumnMetaData: + display_size: int + label: str + name: str + type: str + +@dataclass +class QueryMetaData: + column_count: int + columns: List[ColumnMetaData] + job: str + +@dataclass +class QueryResult: + metadata: QueryMetaData + is_done: bool + has_results: bool + update_count: int + data: List[Any] + +@dataclass +class ExplainResults(QueryResult): + vemetadata: QueryMetaData + vedata: Any + +@dataclass +class GetTraceDataResult(ServerResponse): + tracedata: str + id: str = field(init=False) + success: bool = field(init=False) + sql_rc: int = field(init=False) + sql_state: str = field(init=False) + error: Optional[str] = field(default=None, init=False) + +@dataclass +class JobLogEntry: + MESSAGE_ID: str + SEVERITY: str + MESSAGE_TIMESTAMP: str + FROM_LIBRARY: str + FROM_PROGRAM: str + MESSAGE_TYPE: str + MESSAGE_TEXT: str + MESSAGE_SECOND_LEVEL_TEXT: str + +@dataclass +class CLCommandResult(ServerResponse): + joblog: List[JobLogEntry] + id: str = field(init=False) + success: bool = field(init=False) + sql_rc: int = field(init=False) + sql_state: str = field(init=False) + error: Optional[str] = field(default=None, init=False) + +@dataclass +class QueryOptions: + isTerseResults: Optional[bool] = None + isClCommand: Optional[bool] = None + parameters: Optional[List[Any]] = None + autoClose: Optional[bool] = None + +@dataclass +class SetConfigResult(ServerResponse): + tracedest: ServerTraceDest + tracelevel: ServerTraceLevel + id: str = field(init=False) + success: bool = field(init=False) + sql_rc: int = field(init=False) + sql_state: str = field(init=False) + error: Optional[str] = field(default=None, init=False) + +@dataclass +class JDBCOptions: + naming: Optional[str] = None + date_format: Optional[str] = None + date_separator: Optional[str] = None + decimal_separator: Optional[str] = None + time_format: Optional[str] = None + time_separator: Optional[str] = None + full_open: Optional[bool] = None + access: Optional[str] = None + autocommit_exception: Optional[bool] = None + bidi_string_type: Optional[str] = None + bidi_implicit_reordering: Optional[bool] = None + bidi_numeric_ordering: Optional[bool] = None + data_truncation: Optional[bool] = None + driver: Optional[str] = None + errors: Optional[str] = None + extended_metadata: Optional[bool] = None + hold_input_locators: Optional[bool] = None + hold_statements: Optional[bool] = None + ignore_warnings: Optional[str] = None + keep_alive: Optional[bool] = None + key_ring_name: Optional[str] = None + key_ring_password: Optional[str] = None + metadata_source: Optional[str] = None + proxy_server: Optional[str] = None + remarks: Optional[str] = None + secondary_URL: Optional[str] = None + secure: Optional[bool] = None + server_trace: Optional[str] = None + thread_used: Optional[bool] = None + toolbox_trace: Optional[str] = None + trace: Optional[bool] = None + translate_binary: Optional[bool] = None + translate_boolean: Optional[bool] = None + libraries: Optional[List[str]] = None + auto_commit: Optional[bool] = None + concurrent_access_resolution: Optional[str] = None + cursor_hold: Optional[bool] = None + cursor_sensitivity: Optional[str] = None + database_name: Optional[str] = None + decfloat_rounding_mode: Optional[str] = None + maximum_precision: Optional[str] = None + maximum_scale: Optional[str] = None + minimum_divide_scale: Optional[str] = None + package_ccsid: Optional[str] = None + transaction_isolation: Optional[str] = None + translate_hex: Optional[str] = None + true_autocommit: Optional[bool] = None + XA_loosely_coupled_support: Optional[str] = None + big_decimal: Optional[bool] = None + block_criteria: Optional[str] = None + block_size: Optional[str] = None + data_compression: Optional[bool] = None + extended_dynamic: Optional[bool] = None + lazy_close: Optional[bool] = None + lob_threshold: Optional[str] = None + maximum_blocked_input_rows: Optional[str] = None + package: Optional[str] = None + package_add: Optional[bool] = None + package_cache: Optional[bool] = None + package_criteria: Optional[str] = None + package_error: Optional[str] = None + package_library: Optional[str] = None + prefetch: Optional[bool] = None + qaqqinilib: Optional[str] = None + query_optimize_goal: Optional[str] = None + query_timeout_mechanism: Optional[str] = None + query_storage_limit: Optional[str] = None + receive_buffer_size: Optional[str] = None + send_buffer_size: Optional[str] = None + variable_field_compression: Optional[bool] = None + sort: Optional[str] = None + sort_language: Optional[str] = None + sort_table: Optional[str] = None + sort_weight: Optional[str] = None + diff --git a/tests/hello_test.py b/tests/hello_test.py index 23b9253..9bcacb9 100644 --- a/tests/hello_test.py +++ b/tests/hello_test.py @@ -1,2 +1,21 @@ -def test_hello(): - print("Hello, World!") +import asyncio +from python_sc.sql_job import SQLJob +from python_sc.sql_runner import SQLJobRunner +from python_sc.tls import get_certificate +from python_sc.types import DaemonServer + +creds = DaemonServer( + host="localhost", + port=8085, + user="ashedivy", + password="", + ignoreUnauthorized=False +) + + +def test_channel_connect(): + # ca = asyncio.run(get_certificate(creds)) + # creds.ca = ca.raw if ca else None + + job = SQLJob() + job.connect(creds) \ No newline at end of file