Skip to content

Commit

Permalink
Merge pull request #1267 from phenobarbital/new-drivers
Browse files Browse the repository at this point in the history
New drivers
  • Loading branch information
phenobarbital authored Sep 15, 2024
2 parents 1413ee1 + 43938dd commit 334fc18
Show file tree
Hide file tree
Showing 16 changed files with 1,174 additions and 117 deletions.
457 changes: 457 additions & 0 deletions asyncdb/drivers/aioch.py

Large diffs are not rendered by default.

267 changes: 195 additions & 72 deletions asyncdb/drivers/clickhouse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
from typing import Union, Any, Optional
from collections.abc import Iterable, Sequence, Awaitable
from pathlib import Path
from aiochclient import ChClient
from aiohttp import ClientSession
from clickhouse_driver import Client
from asyncdb.meta.record import Record
from .sql import SQLDriver
from ..exceptions import DriverError


class clickhouse(SQLDriver):
"""
clickhouse class for Connecting to a Clickhouse Cluster.
This class provides a consistent interface using aiochclient.
native clickhouse driver for Connecting to a Clickhouse Cluster.
This class provides a consistent interface using native clickhouse_driver.
Attributes:
-----------
Expand All @@ -29,8 +29,8 @@ class clickhouse(SQLDriver):

_provider: str = "clickhouse"
_syntax: str = "sql"
_dsn: str = "{database}"
_test_query: str = "SELECT version()"
_dsn: str = ""
_test_query: str = "SELECT now(), version()"

def __init__(
self,
Expand All @@ -46,16 +46,29 @@ def __init__(
Parameters:
-----------
dsn : str, optional
The Data Source Name for the database connection. Defaults to an empty string.
The Data Source Name for the database connection.
Defaults to an empty string.
loop : asyncio.AbstractEventLoop, optional
The event loop to use for asynchronous operations. Defaults to None, which uses the current event loop.
params : dict, optional
Additional connection parameters as a dictionary. Defaults to None.
kwargs : dict
Additional keyword arguments to pass to the base SQLDriver.
"""
self._session: Awaitable = None
SQLDriver.__init__(self, dsn, loop, params, **kwargs)
server_args = {
"secure": False,
"verify": False,
"compression": True
}
SQLDriver.__init__(
self,
dsn=dsn,
loop=loop,
params=params,
**kwargs
)
self.params.update(server_args)


async def connection(self, **kwargs):
"""
Expand All @@ -76,12 +89,19 @@ async def connection(self, **kwargs):
"""
self._connection = None
self._connected = False
if not self._session:
self._session = ClientSession()
self._executor = self.get_executor(
executor="thread", max_workers=10
)
try:
self._connection = ChClient(self._session, **self.params)
print(self._connection, await self._connection.is_alive())
if await self._connection.is_alive():
if self._dsn:
self._connection = await self._thread_func(
Client.from_url, self._dsn, executor=self._executor
)
else:
self._connection = await self._thread_func(
Client, **self.params, executor=self._executor
)
if self._connection.connection:
self._connected = True
return self
except Exception as exc:
Expand All @@ -104,13 +124,9 @@ async def close(self, timeout: int = 5) -> None:
--------
None
"""
try:
if self._session:
await self._session.close()
finally:
self._connection = None
self._connected = False
self._session = None
self._connection = None # Clickhouse does not have a close method.
self._connected = False
self._session = None

async def __aenter__(self) -> Any:
"""
Expand All @@ -128,8 +144,6 @@ async def __aenter__(self) -> Any:
If an error occurs during connection establishment.
"""
try:
if not self._session:
self._session = ClientSession()
if not self._connection:
await self.connection()
except Exception as err:
Expand All @@ -140,7 +154,102 @@ async def __aenter__(self) -> Any:
async def __aexit__(self, exc_type, exc, tb):
await self.close()

async def query(self, sentence: Any, *args, **kwargs) -> Iterable[Any]:
async def execute(
self,
sentence: Any,
params: Optional[Iterable] = None,
**kwargs
) -> Optional[Any]:
"""
Executes a transaction or command that does not necessarily
return a result asynchronously.
Parameters:
-----------
sentence : Any
The SQL command or transaction to execute.
kwargs : dict
Additional keyword arguments to be passed to the execution.
Returns:
--------
Optional[Any]
The result of the execution, if any.
"""
error = None
result = None
await self.valid_operation(sentence)
try:
if not self._executor:
self._executor = self.get_executor(
executor="thread", max_workers=2
)
new_args = {
"with_column_types": True,
"columnar": False,
"params": params,
**kwargs
}
if params:
new_args['params'] = params
result = await self._thread_func(
self._connection.execute,
sentence,
**new_args,
executor=self._executor
)
except Exception as exc:
error = exc
finally:
return [result, error]

async def execute_many(
self,
sentence: Union[str, list],
params: Optional[Iterable] = None
) -> Optional[Any]:
"""
Executes multiple transactions or commands asynchronously.
This method is similar to `execute`, but accepts multiple commands to be executed.
Parameters:
-----------
sentence : Union[str, list]
A single SQL command or a list of commands to execute.
params : iterable
A list of arguments to pass to each command.
Returns:
--------
Optional[Any]
The result of the executions, if any.
"""
error = None
result = None
if isinstance(sentence, str):
sentences = [sentence]
else:
sentences = sentence
results = []
for sentence in sentences:
await self.valid_operation(sentence)
result = await self.execute(sentence, params=params)
results.append(result)
return (result, error)

executemany = execute_many

def _construct_record(self, row, column_names):
return Record(dict(zip(column_names, row)), column_names)

async def query(
self,
sentence: Any,
*args,
row_format: str = None,
**kwargs
) -> Iterable[Any]:
"""
Executes a query to retrieve data from the database asynchronously.
Expand All @@ -161,15 +270,43 @@ async def query(self, sentence: Any, *args, **kwargs) -> Iterable[Any]:
error = None
self._result = None
await self.valid_operation(sentence)
if not row_format:
row_format = self._row_format
try:
result = await self._connection.fetch(sentence)
if not self._executor:
self._executor = self.get_executor(
executor="thread", max_workers=2
)
new_args = {
"with_column_types": True,
"columnar": False,
**kwargs
}
result, columns_info = await self._thread_func(
self._connection.execute,
sentence, *args, **new_args,
executor=self._executor
)
if result:
self._result = result
if row_format == 'record':
self._result = result
elif row_format in ('dict', 'iterable'):
# Get the column names from the executed query
columns = [col[0] for col in columns_info]
self._result = [dict(zip(columns, row)) for row in result]
else:
self._result = result
except Exception as exc:
error = exc
return await self._serializer(self._result, error)

async def queryrow(self, sentence: Any = None) -> Iterable[Any]:
async def queryrow(
self,
sentence: Any,
*args,
params: Optional[Iterable] = None,
**kwargs
) -> Iterable[Any]:
"""
Executes a query to retrieve a single row of data from the database asynchronously.
Expand All @@ -186,6 +323,37 @@ async def queryrow(self, sentence: Any = None) -> Iterable[Any]:
error = None
self._result = None
await self.valid_operation(sentence)
try:
if not self._executor:
self._executor = self.get_executor(
executor="thread", max_workers=2
)
new_args = {
"with_column_types": True,
"settings": {
'max_block_size': 100000
},
"chunk_size": 1,
**kwargs
}
rows_gen = await self._thread_func(
self._connection.execute_iter,
sentence, *args, **new_args,
executor=self._executor
)
# Extract the first element (column info) using next()
column_info = next(rows_gen)
# Extract column names
column_names = [col[0] for col in column_info]
print(rows_gen, column_names)
result = []
for row in rows_gen:

row_dict = dict(zip(column_names, row))
result.append(row_dict)
self._result = result
except Exception as exc:
error = exc
return await self._serializer(self._result, error)

async def fetch_all(self, sentence: str, *args, **kwargs) -> Sequence:
Expand Down Expand Up @@ -261,51 +429,6 @@ async def fetch_one(self, sentence: str, *args, **kwargs) -> Optional[dict]:
fetchone = fetch_one
fetchrow = fetch_one

async def execute(self, sentence: Any, **kwargs) -> Optional[Any]:
"""
Executes a transaction or command that does not necessarily return a result asynchronously.
Parameters:
-----------
sentence : Any
The SQL command or transaction to execute.
kwargs : dict
Additional keyword arguments to be passed to the execution.
Returns:
--------
Optional[Any]
The result of the execution, if any.
"""
error = None
result = None
return (result, error)

async def execute_many(self, sentence: Union[str, list], args: list) -> Optional[Any]:
"""
Executes multiple transactions or commands asynchronously.
This method is similar to `execute`, but accepts multiple commands to be executed.
Parameters:
-----------
sentence : Union[str, list]
A single SQL command or a list of commands to execute.
args : list
A list of arguments to pass to each command.
Returns:
--------
Optional[Any]
The result of the executions, if any.
"""
error = None
result = None
await self.valid_operation(sentence)
return (result, error)

executemany = execute_many

async def copy_to(self, sentence: Union[str, Path], destination: str, **kwargs) -> bool:
"""
Copies the result of a query to a file asynchronously.
Expand Down
2 changes: 1 addition & 1 deletion asyncdb/drivers/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ async def file_to_parquet(
pq.write_table(atable, parquet, compression="snappy")
except Exception as exc:
raise DriverError(
f"Query Error: {exc}"
f"Delta File To Parquet Error: {exc}"
) from exc

async def write(
Expand Down
Loading

0 comments on commit 334fc18

Please sign in to comment.