diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 0ff2b65..d3d565e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -4,6 +4,11 @@ concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true +permissions: + contents: write + pull-requests: write + packages: write + on: pull_request: branches: @@ -139,6 +144,8 @@ jobs: - name: Publish GitHub release uses: softprops/action-gh-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN}} with: body_path: ${{ github.workspace }}-RELEASE_NOTES.md prerelease: ${{ contains(env.TAG, 'rc') }} diff --git a/.gitignore b/.gitignore index 96b4c51..e0251ad 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ doc/_build/ .DS_Store *.pem local_test/ +local_test.py # python diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..c92121c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,33 @@ +default_language_version: + python: python3.8 + +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.1 + hooks: + - id: ruff + types: [python] + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.9.0 + hooks: + - id: mypy + types: [python] + additional_dependencies: [ dataclasses-json>=0.6.4, websocket-client>=1.2.1] + exclude: '^tests/' + + - repo: https://github.com/ambv/black + rev: 24.4.0 + hooks: + - id: black + args: [--check] + types: [python] + + - repo: local + hooks: + - id: isort + name: isort + entry: isort + language: system + types: [python] + args: [--check,--profile=black] diff --git a/CHANGELOG.md b/CHANGELOG.md index d5bf518..ea648d4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,3 +11,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [v0.1.0](https://github.com/ajshedivy/python-wsdb/releases/tag/v0.1.0) - 2024-04-19 Add initial release + +## [0.1.2] +### Added +- pre-commit hooks +- repo formatting +- PEP PEP 563 style annotations diff --git a/pyproject.toml b/pyproject.toml index 8b72954..2a9706a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,8 @@ dev = [ "sphinx-copybutton==0.5.2", "sphinx-autobuild==2021.3.14", "sphinx-autodoc-typehints==1.23.3", - "packaging" + "packaging", + "pre-commit" ] [tool.setuptools.packages.find] @@ -106,6 +107,7 @@ target-version = "py39" ignore_missing_imports = true no_site_packages = true check_untyped_defs = true +exclude="(tests/)" [[tool.mypy.overrides]] module = "tests.*" diff --git a/python_wsdb/client/query.py b/python_wsdb/client/query.py index d5e5578..49388be 100644 --- a/python_wsdb/client/query.py +++ b/python_wsdb/client/query.py @@ -1,6 +1,6 @@ import json from enum import Enum -from typing import Any, Dict, Generic, List, Optional, TypeVar +from typing import Any, Dict, Generic, List, Optional, TypeVar, Union from python_wsdb.client.sql_job import SQLJob from python_wsdb.types import QueryOptions @@ -15,43 +15,43 @@ class QueryState(Enum): ERROR = 4 +def get_query_options(opts: Optional[Union[Dict[str, Any], QueryOptions]] = None) -> QueryOptions: + if isinstance(opts, QueryOptions): + return opts + elif opts: + return QueryOptions(**opts) + else: + return QueryOptions(isClCommand=False, parameters=None, autoClose=False) + + class Query(Generic[T]): global_query_list: List["Query[Any]"] = [] - def __init__( - self, - job: SQLJob, - query: str, - opts: QueryOptions = QueryOptions( - isClCommand=False, parameters=None, autoClose=False - ), - ) -> None: + def __init__(self, job: SQLJob, query: str, opts: QueryOptions) -> None: self.job = job self.is_prepared: bool = True if opts.parameters is not None else False self.parameters: Optional[List[str]] = opts.parameters self.sql: str = query - self.is_cl_command: bool | None = opts.isClCommand - self.should_auto_close: bool | None = opts.autoClose - self.is_terse_results: bool | None = opts.isTerseResults + self.is_cl_command: Optional[bool] = opts.isClCommand + self.should_auto_close: Optional[bool] = opts.autoClose + self.is_terse_results: Optional[bool] = opts.isTerseResults self._rows_to_fetch: int = 100 self.state: QueryState = QueryState.NOT_YET_RUN Query.global_query_list.append(self) - def run(self, rows_to_fetch: int | Any = None) -> Dict[str, Any]: + def run(self, rows_to_fetch: Optional[int] = None) -> Dict[str, Any]: if rows_to_fetch is None: rows_to_fetch = self._rows_to_fetch else: self._rows_to_fetch = rows_to_fetch - - # fmt: off - match self.state: - case QueryState.RUN_MORE_DATA_AVAIL: - raise Exception("Statement has already been run") - case QueryState.RUN_DONE: - raise Exception("Statement has already been fully run") - # fmt: on + + # check Query state first + if self.state == QueryState.RUN_MORE_DATA_AVAIL: + raise Exception("Statement has already been run") + elif self.state == QueryState.RUN_DONE: + raise Exception("Statement has already been fully run") query_object: Dict[str, Any] = {} if self.is_cl_command: @@ -84,9 +84,11 @@ def run(self, rows_to_fetch: int | Any = None) -> Dict[str, Any]: print(query_result) self.state = QueryState.ERROR error_keys = ["error", "sql_state", "sql_rc"] - error_list = {key:query_result[key] for key in error_keys if key in query_result.keys()} + error_list = { + key: query_result[key] for key in error_keys if key in query_result.keys() + } if len(error_list) == 0: - error_list['error'] = "failed to run query for unknown reason" + error_list["error"] = "failed to run query for unknown reason" raise Exception(error_list) @@ -94,42 +96,37 @@ def run(self, rows_to_fetch: int | Any = None) -> Dict[str, Any]: return query_result - def fetch_more(self, rows_to_fetch: int | Any = None) -> Dict[str, Any]: + def fetch_more(self, rows_to_fetch: Optional[int] = None) -> Dict[str, Any]: if rows_to_fetch is None: rows_to_fetch = self._rows_to_fetch else: self._rows_to_fetch = rows_to_fetch - match self.state: - case QueryState.NOT_YET_RUN: - raise Exception("Statement has not been run") - case QueryState.RUN_DONE: - raise Exception("Statement has already been fully run") - + if self.state == QueryState.NOT_YET_RUN: + raise Exception("Statement has not been run") + elif self.state == QueryState.RUN_DONE: + raise Exception("Statement has already been fully run") + query_object = { - 'id': self.job._get_unique_id('fetchMore'), - 'cont_id': self._correlation_id, - 'type': 'sqlmore', - 'sql': self.sql, - 'rows': rows_to_fetch + "id": self.job._get_unique_id("fetchMore"), + "cont_id": self._correlation_id, + "type": "sqlmore", + "sql": self.sql, + "rows": rows_to_fetch, } - + self._rows_to_fetch = rows_to_fetch self.job.send(json.dumps(query_object)) query_result: Dict[str, Any] = json.loads(self.job._socket.recv()) - + self.state = ( QueryState.RUN_DONE if query_result.get("is_done", False) else QueryState.RUN_MORE_DATA_AVAIL ) - - if not query_result['success']: + + if not query_result["success"]: self.state = QueryState.ERROR - raise Exception(query_result['error'] or "Failed to run Query (unknown error)") - + raise Exception(query_result["error"] or "Failed to run Query (unknown error)") + return query_result - - - - diff --git a/python_wsdb/client/sql_job.py b/python_wsdb/client/sql_job.py index 419bfad..6447b68 100644 --- a/python_wsdb/client/sql_job.py +++ b/python_wsdb/client/sql_job.py @@ -1,7 +1,7 @@ import base64 import json import ssl -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, Union from websocket import WebSocket, create_connection @@ -33,7 +33,6 @@ def _get_channel(self, db2_server: DaemonServer) -> WebSocket: # Prepare SSL context if necessary ssl_opts: Dict[str, Any] = {} - if db2_server.ignoreUnauthorized: ssl_opts["cert_reqs"] = ssl.CERT_NONE if db2_server.ca: @@ -45,18 +44,6 @@ def _get_channel(self, db2_server: DaemonServer) -> WebSocket: # Create WebSocket connection socket = create_connection(uri, header=headers, sslopt=ssl_opts) - # Register message handler - def on_message(ws, message): - if self._is_tracing_channeldata: - print(message) - try: - response = json.loads(message) - print(f"Received message with ID: {response['id']}") - except Exception as e: - print(f"Error parsing message: {e}") - - socket.on_message = on_message - return socket def send(self, content): @@ -99,18 +86,43 @@ def connect(self, db2_server: DaemonServer) -> Dict[Any, Any]: return result def query( - self, sql: str, opts: Optional[Dict[str, Any]] = None, + self, + sql: str, + opts: Optional[Union[Dict[str, Any], QueryOptions]] = None, ): + """ + Create a Query object using provided SQL and options. If opts is None, + the default options defined in Query constructor are used. opts can be a + dictionary to be converted to QueryOptions, or a QueryOptions object directly. + + Args: + sql (str): The SQL query string. + opts (Optional[Union[Dict[str, Any], QueryOptions]]): Additional options + for the query which can be a dictionary or a QueryOptions object. + + Returns: + Query: A configured Query object. + """ from python_wsdb.client.query import Query - if isinstance(opts, dict): - opts = QueryOptions(**opts) - return Query(job=self, query=sql, opts=opts) - return Query(job=self, query=sql) + if opts is not None and not isinstance(opts, (dict, QueryOptions)): + raise ValueError("opts must be a dictionary, a QueryOptions object, or None") + + query_options = ( + opts + if isinstance(opts, QueryOptions) + else ( + QueryOptions(**opts) + if opts + else QueryOptions(isClCommand=False, parameters=None, autoClose=False) + ) + ) + + return Query(job=self, query=sql, opts=query_options) def query_and_run( self, sql: str, opts: Optional[Dict[str, Any]] = None, **kwargs - ): + ) -> Dict[str, Any]: query = self.query(sql, opts) return query.run(**kwargs) diff --git a/python_wsdb/ssl.py b/python_wsdb/ssl.py index 1c71cb0..1877f77 100644 --- a/python_wsdb/ssl.py +++ b/python_wsdb/ssl.py @@ -1,10 +1,11 @@ import socket import ssl +from typing import Optional from python_wsdb.types import DaemonServer -def get_certificate(creds: DaemonServer) -> (bytes | None): +def get_certificate(creds: DaemonServer) -> Optional[bytes]: context = ssl.create_default_context() context.check_hostname = False context.verify_mode = ssl.CERT_NONE diff --git a/scripts/personalize.py b/scripts/personalize.py index c9fa2fe..9f5a4b5 100644 --- a/scripts/personalize.py +++ b/scripts/personalize.py @@ -72,7 +72,10 @@ default=False, ) @click.option( - "--dry-run", is_flag=True, hidden=True, default=False, + "--dry-run", + is_flag=True, + hidden=True, + default=False, ) def main( github_org: str, github_repo: str, package_name: str, yes: bool = False, dry_run: bool = False