Skip to content

Commit

Permalink
add ws connection logic
Browse files Browse the repository at this point in the history
  • Loading branch information
ajshedivy committed Apr 16, 2024
1 parent fe7ee7c commit 57e9f80
Show file tree
Hide file tree
Showing 5 changed files with 371 additions and 2 deletions.
79 changes: 79 additions & 0 deletions python_sc/sql_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import asyncio
import base64
import json
import websocket
import ssl
from typing import Any, Dict, List, Optional, Union
from websocket import create_connection, WebSocket

from python_sc.types import ConnectionResult, DaemonServer, JobStatus


class SQLJob:
def __init__(self) -> None:
self._unique_id_counter: int = 0;
self._socket: Any = None
self._reponse_emitter = None
self._status: JobStatus = JobStatus.NotStarted
self._trace_file = None
self._is_tracing_channeldata: bool = True

self.__unique_id = self._get_unique_id('sqljob')
self.id: Optional[str] = None


def _get_unique_id(self, prefix: str = 'id') -> str:
self._unique_id_counter += 1
return f"{prefix}{self._unique_id_counter}"

def _get_channel(self, db2_server: DaemonServer) -> WebSocket:
uri = f"wss://{db2_server.host}:{db2_server.port}/db/"
headers = {"Authorization": "Basic " + base64.b64encode(f"{db2_server.user}:{db2_server.password}".encode()).decode('ascii')}

# Prepare SSL context if necessary
if db2_server.ca:
ssl_opts = {"cert_reqs": ssl.CERT_NONE} if not db2_server.ignoreUnauthorized else {"ca_certs": db2_server.ca}
else:
ssl_opts = {"cert_reqs": ssl.CERT_NONE} if db2_server.ignoreUnauthorized is False else {}

# Create WebSocket connection
socket = create_connection(uri, header=headers, sslopt=ssl_opts)
# socket = websocket.WebSocketApp(uri, header=headers, sslopt=ssl_opts)

# Register message handler
def on_message(ws, message):
if self._is_tracing_channeldata:
print(message)
try:
response = json.loads(message)
print(f"Received message with ID: {response['id']}")
except Exception as e:
print(f"Error parsing message: {e}")

socket.on_message = on_message

return socket

def send(self, content):
self._socket.send(content)

def connect(self, db2_server: DaemonServer) -> ConnectionResult:
self._socket: WebSocket = self._get_channel(db2_server)

connection_props = {
'id': self._get_unique_id(),
'type': 'connect',
'technique': 'tcp',
'application': 'Python Client',
'props': ""
}

self.send(json.dumps(connection_props))
print(self._socket.recv())







15 changes: 15 additions & 0 deletions python_sc/sql_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import asyncio

class SQLJobRunner:
def run(self, coro):
""" Run a coroutine and return the result, handling the event loop. """
try:
loop = asyncio.get_running_loop()
except RuntimeError: # No running loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
result = loop.run_until_complete(coro)
loop.close()
return result
else:
return loop.run_until_complete(coro)
42 changes: 42 additions & 0 deletions python_sc/tls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import asyncio
import socket
import ssl
from dataclasses import dataclass
from typing import Optional, Union

from python_sc.types import DaemonServer


async def get_certificate(server: DaemonServer):
""" Asynchronously get the peer's certificate from a secure TLS connection. """
# Create a default SSL context
context = ssl.create_default_context()

if server.ignoreUnauthorized is False:
context.check_hostname = False
context.verify_mode = ssl.CERT_NONE

if server.ca:
context.load_verify_locations(server.ca)

# Create a non-blocking socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.setblocking(False)

# Wrap the socket with the SSL context
wrapped_socket = context.wrap_socket(sock, server_hostname=server.host)

# Connect using asyncio's event loop
await asyncio.get_event_loop().sock_connect(wrapped_socket, (server.host, server.port))

try:
# Perform the handshake to establish the secure connection
await asyncio.get_event_loop().sock_do_handshake(wrapped_socket)
# Obtain the certificate
cert = wrapped_socket.getpeercert(True)
return cert
except Exception as e:
print(f"Error: {e}")
raise e
finally:
wrapped_socket.close()
214 changes: 214 additions & 0 deletions python_sc/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,214 @@
from enum import Enum
from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any, Union

class JobStatus(Enum):
NotStarted = "notStarted"
Ready = "ready"
Busy = "busy"
Ended = "ended"

class ExplainType(Enum):
Run = 0
DoNotRun = 1

class TransactionEndType(Enum):
COMMIT = 0
ROLLBACK = 1

class ServerTraceLevel(Enum):
OFF = "OFF"
ON = "ON"
ERRORS = "ERRORS"
DATASTREAM = "DATASTREAM"

class ServerTraceDest(Enum):
FILE = "FILE"
IN_MEM = "IN_MEM"


@dataclass
class DaemonServer:
host: str
user: str
password: str
port: Optional[int] = None
ignoreUnauthorized: Optional[bool] = None
ca: Optional[Union[str, bytes]] = None

@dataclass
class ServerResponse:
id: str
success: bool
sql_rc: int
sql_state: str
error: Optional[str] = None

@dataclass
class ConnectionResult(ServerResponse):
job: str
id: str = field(init=False)
success: bool = field(init=False)
sql_rc: int = field(init=False)
sql_state: str = field(init=False)
error: Optional[str] = field(default=None, init=False)

@dataclass
class VersionCheckResult(ServerResponse):
build_date: str
version: str
id: str = field(init=False)
success: bool = field(init=False)
sql_rc: int = field(init=False)
sql_state: str = field(init=False)
error: Optional[str] = field(default=None, init=False)

@dataclass
class ColumnMetaData:
display_size: int
label: str
name: str
type: str

@dataclass
class QueryMetaData:
column_count: int
columns: List[ColumnMetaData]
job: str

@dataclass
class QueryResult:
metadata: QueryMetaData
is_done: bool
has_results: bool
update_count: int
data: List[Any]

@dataclass
class ExplainResults(QueryResult):
vemetadata: QueryMetaData
vedata: Any

@dataclass
class GetTraceDataResult(ServerResponse):
tracedata: str
id: str = field(init=False)
success: bool = field(init=False)
sql_rc: int = field(init=False)
sql_state: str = field(init=False)
error: Optional[str] = field(default=None, init=False)

@dataclass
class JobLogEntry:
MESSAGE_ID: str
SEVERITY: str
MESSAGE_TIMESTAMP: str
FROM_LIBRARY: str
FROM_PROGRAM: str
MESSAGE_TYPE: str
MESSAGE_TEXT: str
MESSAGE_SECOND_LEVEL_TEXT: str

@dataclass
class CLCommandResult(ServerResponse):
joblog: List[JobLogEntry]
id: str = field(init=False)
success: bool = field(init=False)
sql_rc: int = field(init=False)
sql_state: str = field(init=False)
error: Optional[str] = field(default=None, init=False)

@dataclass
class QueryOptions:
isTerseResults: Optional[bool] = None
isClCommand: Optional[bool] = None
parameters: Optional[List[Any]] = None
autoClose: Optional[bool] = None

@dataclass
class SetConfigResult(ServerResponse):
tracedest: ServerTraceDest
tracelevel: ServerTraceLevel
id: str = field(init=False)
success: bool = field(init=False)
sql_rc: int = field(init=False)
sql_state: str = field(init=False)
error: Optional[str] = field(default=None, init=False)

@dataclass
class JDBCOptions:
naming: Optional[str] = None
date_format: Optional[str] = None
date_separator: Optional[str] = None
decimal_separator: Optional[str] = None
time_format: Optional[str] = None
time_separator: Optional[str] = None
full_open: Optional[bool] = None
access: Optional[str] = None
autocommit_exception: Optional[bool] = None
bidi_string_type: Optional[str] = None
bidi_implicit_reordering: Optional[bool] = None
bidi_numeric_ordering: Optional[bool] = None
data_truncation: Optional[bool] = None
driver: Optional[str] = None
errors: Optional[str] = None
extended_metadata: Optional[bool] = None
hold_input_locators: Optional[bool] = None
hold_statements: Optional[bool] = None
ignore_warnings: Optional[str] = None
keep_alive: Optional[bool] = None
key_ring_name: Optional[str] = None
key_ring_password: Optional[str] = None
metadata_source: Optional[str] = None
proxy_server: Optional[str] = None
remarks: Optional[str] = None
secondary_URL: Optional[str] = None
secure: Optional[bool] = None
server_trace: Optional[str] = None
thread_used: Optional[bool] = None
toolbox_trace: Optional[str] = None
trace: Optional[bool] = None
translate_binary: Optional[bool] = None
translate_boolean: Optional[bool] = None
libraries: Optional[List[str]] = None
auto_commit: Optional[bool] = None
concurrent_access_resolution: Optional[str] = None
cursor_hold: Optional[bool] = None
cursor_sensitivity: Optional[str] = None
database_name: Optional[str] = None
decfloat_rounding_mode: Optional[str] = None
maximum_precision: Optional[str] = None
maximum_scale: Optional[str] = None
minimum_divide_scale: Optional[str] = None
package_ccsid: Optional[str] = None
transaction_isolation: Optional[str] = None
translate_hex: Optional[str] = None
true_autocommit: Optional[bool] = None
XA_loosely_coupled_support: Optional[str] = None
big_decimal: Optional[bool] = None
block_criteria: Optional[str] = None
block_size: Optional[str] = None
data_compression: Optional[bool] = None
extended_dynamic: Optional[bool] = None
lazy_close: Optional[bool] = None
lob_threshold: Optional[str] = None
maximum_blocked_input_rows: Optional[str] = None
package: Optional[str] = None
package_add: Optional[bool] = None
package_cache: Optional[bool] = None
package_criteria: Optional[str] = None
package_error: Optional[str] = None
package_library: Optional[str] = None
prefetch: Optional[bool] = None
qaqqinilib: Optional[str] = None
query_optimize_goal: Optional[str] = None
query_timeout_mechanism: Optional[str] = None
query_storage_limit: Optional[str] = None
receive_buffer_size: Optional[str] = None
send_buffer_size: Optional[str] = None
variable_field_compression: Optional[bool] = None
sort: Optional[str] = None
sort_language: Optional[str] = None
sort_table: Optional[str] = None
sort_weight: Optional[str] = None

23 changes: 21 additions & 2 deletions tests/hello_test.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,21 @@
def test_hello():
print("Hello, World!")
import asyncio
from python_sc.sql_job import SQLJob
from python_sc.sql_runner import SQLJobRunner
from python_sc.tls import get_certificate
from python_sc.types import DaemonServer

creds = DaemonServer(
host="localhost",
port=8085,
user="ashedivy",
password="",
ignoreUnauthorized=False
)


def test_channel_connect():
# ca = asyncio.run(get_certificate(creds))
# creds.ca = ca.raw if ca else None

job = SQLJob()
job.connect(creds)

0 comments on commit 57e9f80

Please sign in to comment.