Skip to content

Commit

Permalink
add query manager (#52)
Browse files Browse the repository at this point in the history
* add query manager

* simplify sql_job context manager

* update type hints

* update changelog
  • Loading branch information
ajshedivy authored Aug 28, 2024
1 parent c5389b9 commit ed40370
Show file tree
Hide file tree
Showing 7 changed files with 221 additions and 44 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased
- add query manager

## [v0.1.4](https://github.com/Mapepire-IBMi/mapepire-python/releases/tag/v0.1.4) - 2024-08-23

Expand Down
49 changes: 31 additions & 18 deletions mapepire_python/client/query.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from enum import Enum
from typing import Any, Dict, Generic, List, Optional, TypeVar, Union
from typing import Any, Dict, Generic, List, Optional, TypeVar

from ..types import QueryOptions
from .sql_job import SQLJob
Expand All @@ -13,25 +13,15 @@ class QueryState(Enum):
RUN_MORE_DATA_AVAIL = (2,)
RUN_DONE = (3,)
ERROR = 4


def get_query_options(opts: Optional[Union[Dict[str, Any], QueryOptions]] = None) -> QueryOptions:
if isinstance(opts, QueryOptions):
return opts
elif opts:
return QueryOptions(**opts)
else:
return QueryOptions(isClCommand=False, parameters=None, autoClose=False)



class Query(Generic[T]):
global_query_list: List["Query[Any]"] = []

def __init__(self, job: SQLJob, query: str, opts: QueryOptions) -> None:
self.job = job
self.sql: str = query
self.is_prepared: bool = True if opts.parameters is not None else False
self.parameters: Optional[List[str]] = opts.parameters
self.sql: str = query
self.is_cl_command: Optional[bool] = opts.isClCommand
self.should_auto_close: Optional[bool] = opts.autoClose
self.is_terse_results: Optional[bool] = opts.isTerseResults
Expand All @@ -40,6 +30,17 @@ def __init__(self, job: SQLJob, query: str, opts: QueryOptions) -> None:
self.state: QueryState = QueryState.NOT_YET_RUN

Query.global_query_list.append(self)

def __enter__(self):
return self

def __exit__(self, exc_type, exc_value, traceback):
self.close()

def _execute_query(self, qeury_object: Dict[str, Any]) -> Dict[str, Any]:
self.job.send(json.dumps(qeury_object))
query_result: Dict[str, Any] = json.loads(self.job._socket.recv())
return query_result

def run(self, rows_to_fetch: Optional[int] = None) -> Dict[str, Any]:
if rows_to_fetch is None:
Expand Down Expand Up @@ -71,8 +72,7 @@ def run(self, rows_to_fetch: Optional[int] = None) -> Dict[str, Any]:
"parameters": self.parameters,
}

self.job.send(json.dumps(query_object))
query_result: Dict[str, Any] = json.loads(self.job._socket.recv())
query_result: Dict[str, Any] = self._execute_query(query_object)

self.state = (
QueryState.RUN_DONE
Expand All @@ -81,7 +81,6 @@ def run(self, rows_to_fetch: Optional[int] = None) -> Dict[str, Any]:
)

if not query_result.get("success", False) and not self.is_cl_command:
print(query_result)
self.state = QueryState.ERROR
error_keys = ["error", "sql_state", "sql_rc"]
error_list = {
Expand Down Expand Up @@ -116,8 +115,7 @@ def fetch_more(self, rows_to_fetch: Optional[int] = None) -> Dict[str, Any]:
}

self._rows_to_fetch = rows_to_fetch
self.job.send(json.dumps(query_object))
query_result: Dict[str, Any] = json.loads(self.job._socket.recv())
query_result: Dict[str, Any] = self._execute_query(query_object)

self.state = (
QueryState.RUN_DONE
Expand All @@ -130,3 +128,18 @@ def fetch_more(self, rows_to_fetch: Optional[int] = None) -> Dict[str, Any]:
raise Exception(query_result["error"] or "Failed to run Query (unknown error)")

return query_result

def close(self):
if not self.job._socket.connected:
raise Exception('SQL Job not connected')
if self._correlation_id and self.state is not QueryState.RUN_DONE:
self.state = QueryState.RUN_DONE
query_object = {
'id': self.job._get_unique_id('sqlclose'),
'cont_id': self._correlation_id,
'type': 'sqlclose'
}

return self._execute_query(query_object)
elif not self._correlation_id:
self.state = QueryState.RUN_DONE
47 changes: 47 additions & 0 deletions mapepire_python/client/query_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import Any, Dict, Optional, Union

from ..types import QueryOptions
from .query import Query
from .sql_job import SQLJob


class QueryManager:
def __init__(self, job: SQLJob) -> None:
self.job = job

def get_query_options(
self, opts: Optional[Union[Dict[str, Any], QueryOptions]] = None
) -> QueryOptions:
query_options = (
opts
if isinstance(opts, QueryOptions)
else (
QueryOptions(**opts)
if isinstance(opts, dict)
else QueryOptions(isClCommand=False, parameters=None, autoClose=False)
)
)

return query_options

def create_query(
self,
query: str,
opts: Optional[Union[Dict[str, Any], QueryOptions]] = None,
) -> Query:

if opts and not isinstance(opts, (dict, QueryOptions)):
raise Exception("opts must be a dictionary, a QueryOptions object, or None")

query_options = self.get_query_options(opts)

return Query(self.job, query, opts=query_options)

def run_query(self, query: Query, rows_to_fetch: Optional[int] = None) -> Dict[str, Any]:
return query.run(rows_to_fetch=rows_to_fetch)

def query_and_run(
self, query: str, opts: Optional[Union[Dict[str, Any], QueryOptions]] = None, **kwargs
) -> Dict[str, Any]:
with self.create_query(query, opts) as query:
return query.run(**kwargs)
39 changes: 15 additions & 24 deletions mapepire_python/client/sql_job.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import base64
import json
import ssl
from typing import Any, Dict, Optional, Union

from websocket import WebSocket, create_connection
from websocket import WebSocket

from ..types import DaemonServer, JobStatus, QueryOptions
from .websocket import WebsocketConnection


class SQLJob:
def __init__(self, options: Dict[Any, Any] = {}) -> None:
def __init__(self, creds: DaemonServer = None, options: Dict[Any, Any] = {}) -> None:
self.options = options
self._unique_id_counter: int = 0
self._reponse_emitter = None
Expand All @@ -19,32 +18,23 @@ def __init__(self, options: Dict[Any, Any] = {}) -> None:

self.__unique_id = self._get_unique_id("sqljob")
self.id: Optional[str] = None
self.creds = creds

def __enter__(self):
if self.creds:
self.connect(self.creds)
return self

def __exit__(self, exc_type, exc_value, traceback):
self.close()

def _get_unique_id(self, prefix: str = "id") -> str:
self._unique_id_counter += 1
return f"{prefix}{self._unique_id_counter}"

def _get_channel(self, db2_server: DaemonServer) -> WebSocket:
uri = f"wss://{db2_server.host}:{db2_server.port}/db/"
headers = {
"Authorization": "Basic "
+ base64.b64encode(f"{db2_server.user}:{db2_server.password}".encode()).decode("ascii")
}

# Prepare SSL context if necessary
ssl_opts: Dict[str, Any] = {}
if db2_server.ignoreUnauthorized:
ssl_opts["cert_reqs"] = ssl.CERT_NONE
if db2_server.ca:
ssl_context = ssl.create_default_context(cadata=db2_server.ca)
ssl_context.check_hostname = False
ssl_opts["ssl_context"] = ssl_context
ssl_opts["cert_reqs"] = ssl.CERT_NONE # ignore certs for now

# Create WebSocket connection
socket = create_connection(uri, header=headers, sslopt=ssl_opts)

return socket
socket = WebsocketConnection(db2_server)
return socket.connect()

def send(self, content):
self._socket.send(content)
Expand Down Expand Up @@ -127,4 +117,5 @@ def query_and_run(
return query.run(**kwargs)

def close(self):
self._status = JobStatus.Ended
self._socket.close()
32 changes: 32 additions & 0 deletions mapepire_python/client/websocket.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import base64
import ssl
from typing import Any, Dict

from websocket import WebSocket, create_connection

from ..types import DaemonServer


class WebsocketConnection:
def __init__(self, db2_server: DaemonServer) -> None:
self.uri = f"wss://{db2_server.host}:{db2_server.port}/db/"
self.headers = {
"Authorization": "Basic "
+ base64.b64encode(f"{db2_server.user}:{db2_server.password}".encode()).decode("ascii")
}

self.ssl_opts = self._build_ssl_options(db2_server)

def _build_ssl_options(self, db2_server: DaemonServer) -> Dict[str, Any]:
ssl_opts: Dict[str, Any] = {}
if db2_server.ignoreUnauthorized:
ssl_opts["cert_reqs"] = ssl.CERT_NONE
if db2_server.ca:
ssl_context = ssl.create_default_context(cadata=db2_server.ca)
ssl_context.check_hostname = False
ssl_opts["ssl_context"] = ssl_context
ssl_opts["cert_reqs"] = ssl.CERT_NONE
return ssl_opts

def connect(self) -> WebSocket:
return create_connection(self.uri, header=self.headers, sslopt=self.ssl_opts)
9 changes: 7 additions & 2 deletions mapepire_python/types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
from dataclasses import dataclass, field
from dataclasses import dataclass, field, fields
from enum import Enum
from typing import Any, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

from dataclasses_json import dataclass_json


def dict_to_dataclass(data: Dict[str, Any], dataclass_type: Any) -> Any:
field_names = {f.name for f in fields(dataclass_type)}
filtered_data = {k: v for k, v in data.items() if k in field_names}
return dataclass_type(**filtered_data)

class JobStatus(Enum):
NotStarted = "notStarted"
Ready = "ready"
Expand Down
88 changes: 88 additions & 0 deletions tests/query_manager_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import os
import re

import pytest

from mapepire_python.client.query_manager import QueryManager
from mapepire_python.client.sql_job import SQLJob
from mapepire_python.types import DaemonServer, QueryOptions

# Fetch environment variables
server = os.getenv('VITE_SERVER')
user = os.getenv('VITE_DB_USER')
password = os.getenv('VITE_DB_PASS')
port = os.getenv('VITE_DB_PORT')

# Check if environment variables are set
if not server or not user or not password:
raise ValueError('One or more environment variables are missing.')


creds = DaemonServer(
host=server,
port=port,
user=user,
password=password,
ignoreUnauthorized=True,
)

def test_query_manager():
# connection logic
job = SQLJob()
job.connect(creds)

# Query Manager
query_manager = QueryManager(job)

# create a unique query
query = query_manager.create_query("select * from sample.employee")

# run query
result = query_manager.run_query(query)

assert result['success']
job.close()



def test_context_manager():
with SQLJob() as job:
job.connect(creds)

query_manager = QueryManager(job)
query = query_manager.create_query("select * from sample.department")
result = query_manager.run_query(query)
assert result['success']

def test_simple_v2():
with SQLJob(creds) as job:
query_manager = QueryManager(job)
query = query_manager.create_query('select * from sample.employee')
result = query_manager.run_query(query, rows_to_fetch=5)
assert result['success'] == True
assert result['is_done'] == False
assert result['has_results'] == True
query.close()

def test_query_large_dataset():
job = SQLJob()
_ = job.connect(creds)
query_manager = QueryManager(job)
query = query_manager.create_query('select * from sample.employee')

result = query_manager.run_query(query, rows_to_fetch=30)
query.close()
job.close()

assert result['success'] == True
assert result['is_done'] == False
assert result['has_results'] == True
assert len(result['data']) == 30

def test_query_and_run():
with SQLJob(creds) as job:
query_manager = QueryManager(job)
result = query_manager.query_and_run('select * from sample.employee', rows_to_fetch=5)
assert result['success'] == True
assert result['is_done'] == False
assert result['has_results'] == True

0 comments on commit ed40370

Please sign in to comment.