diff --git a/asyncdb/drivers/delta.py b/asyncdb/drivers/delta.py index b54ea5aa..2c8ca5bb 100644 --- a/asyncdb/drivers/delta.py +++ b/asyncdb/drivers/delta.py @@ -445,7 +445,7 @@ async def file_to_parquet( pq.write_table(atable, parquet, compression="snappy") except Exception as exc: raise DriverError( - f"Query Error: {exc}" + f"Delta File To Parquet Error: {exc}" ) from exc async def write( diff --git a/asyncdb/drivers/mysqlclient.py b/asyncdb/drivers/mysqlclient.py index 02932290..f0347df4 100644 --- a/asyncdb/drivers/mysqlclient.py +++ b/asyncdb/drivers/mysqlclient.py @@ -144,7 +144,7 @@ async def close(self): await self._thread_func(self._pool.close) self._connected = False self._logger.debug( - f"MySQL Connection Closed." + "MySQL Connection Closed." ) disconnect = close diff --git a/asyncdb/drivers/oracle.py b/asyncdb/drivers/oracle.py index 2ff7af0a..5872470e 100644 --- a/asyncdb/drivers/oracle.py +++ b/asyncdb/drivers/oracle.py @@ -29,9 +29,16 @@ def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params except KeyError: self._lib_dir = None try: - super(oracle, self).__init__(dsn=dsn, loop=loop, params=params, **kwargs) + super(oracle, self).__init__( + dsn=dsn, + loop=loop, + params=params, + **kwargs + ) _generated = datetime.now() - _starttime - print(f"Oracle Started in: {_generated}") + print( + f"Oracle Started in: {_generated}" + ) except Exception as err: raise DriverError(f"Oracle Error: {err}") from err # set the JSON encoder: diff --git a/asyncdb/drivers/redis.py b/asyncdb/drivers/redis.py index fb17aad1..a5802dc0 100644 --- a/asyncdb/drivers/redis.py +++ b/asyncdb/drivers/redis.py @@ -241,8 +241,14 @@ async def execute(self, sentence, *args, **kwargs) -> Any: if self._connection: try: return await self._connection.execute_command(sentence, *args) - except (RedisError,) as err: - raise DriverError(f"Connection Error: {err}") from err + except RedisError as err: + raise DriverError( + f"Connection Error: {err}" + ) from err + except Exception as err: + raise DriverError( + f"Unknown Redis Error: {err}" + ) from err execute_many = execute diff --git a/asyncdb/drivers/sa.py b/asyncdb/drivers/sa.py index 0457f7f7..fcc44bb1 100644 --- a/asyncdb/drivers/sa.py +++ b/asyncdb/drivers/sa.py @@ -42,8 +42,8 @@ async def __anext__(self): # Retrieve the next item in the result set try: return next(self._result_iter) - except StopIteration: - raise StopAsyncIteration + except StopIteration as e: + raise StopAsyncIteration from e class sa(SQLDriver, DBCursorBackend): _provider = "sa" @@ -84,8 +84,7 @@ def __init__( if params: self._driver = params.get('driver', "postgresql+asyncpg") else: - params = {} - params["driver"] = "postgresql+asyncpg" + params = {"driver": "postgresql+asyncpg"} SQLDriver.__init__( self, dsn=dsn, loop=loop, @@ -226,7 +225,7 @@ async def query( self, sentence: Any, params: List = None, - format: str = None + query_format: str = None ): """ Running Query. @@ -234,8 +233,8 @@ async def query( self._result = None error = None await self.valid_operation(sentence) - if not format: - format = self._row_format + if not query_format: + query_format = self._row_format try: self.start_timing() async with self._connection.connect() as conn: @@ -245,11 +244,11 @@ async def query( rows = result.fetchall() # Get the column names from the result metadata column_names = result.keys() - if format in ("dict", "iterable"): + if query_format in ("dict", "iterable"): self._result = [ dict(zip(column_names, row)) for row in rows ] - elif format == "record": + elif query_format == "record": self._result = [ self._construct_record(row, column_names) for row in rows ] @@ -275,7 +274,7 @@ async def queryrow( self, sentence: Any, params: Any = None, - format: Optional[str] = None + query_format: Optional[str] = None ): """ Running Query and return only one row. @@ -284,8 +283,8 @@ async def queryrow( error = None await self.valid_operation(sentence) try: - if not format: - format = self._row_format + if not query_format: + query_format = self._row_format result = None async with self._connection.connect() as conn: if isinstance(sentence, str): @@ -293,9 +292,9 @@ async def queryrow( result = await conn.execute(sentence, params) column_names = result.keys() row = result.fetchone() - if format in ("dict", 'iterable'): + if query_format in ("dict", 'iterable'): self._result = dict(zip(column_names, row)) - elif format == "record": + elif query_format == "record": self._result = self._construct_record(row, column_names) else: self._result = row @@ -317,7 +316,7 @@ async def fetch_all( self, sentence: Any, params: List = None, - format: Optional[str] = None + query_format: Optional[str] = None ): """ Fetch All Rows in a Query. @@ -325,8 +324,8 @@ async def fetch_all( result = None await self.valid_operation(sentence) try: - if not format: - format = self._row_format + if not query_format: + query_format = self._row_format async with self._connection.connect() as conn: if isinstance(sentence, str): sentence = text(sentence) @@ -335,11 +334,11 @@ async def fetch_all( rows = rst.fetchall() if rows is None: return None - if format in ("dict", 'iterable'): + if query_format in ("dict", 'iterable'): result = [ dict(zip(column_names, row)) for row in rows ] - elif format == "record": + elif query_format == "record": result = [ self._construct_record(row, column_names) for row in rows ] @@ -362,7 +361,7 @@ async def fetch_many( sentence: Any, size: int = 1, params: List = None, - format: Optional[str] = None + query_format: Optional[str] = None ): """ Fetch Many Rows from a Query as requested. @@ -370,8 +369,8 @@ async def fetch_many( result = None await self.valid_operation(sentence) try: - if not format: - format = self._row_format + if not query_format: + query_format = self._row_format async with self._connection.connect() as conn: if isinstance(sentence, str): sentence = text(sentence) @@ -380,11 +379,11 @@ async def fetch_many( rows = rst.fetchmany(size) if rows is None: return None - if format in ("dict", 'iterable'): + if query_format in ("dict", 'iterable'): result = [ dict(zip(column_names, row)) for row in rows ] - elif format == "record": + elif query_format == "record": result = [ self._construct_record(row, column_names) for row in rows ] @@ -408,7 +407,7 @@ async def fetch_one( self, sentence: Any, params: List = None, - format: Optional[str] = None + query_format: Optional[str] = None ): """ Running Query and return only one row. @@ -416,8 +415,8 @@ async def fetch_one( result = None await self.valid_operation(sentence) try: - if not format: - format = self._row_format + if not query_format: + query_format = self._row_format async with self._connection.connect() as conn: if isinstance(sentence, str): sentence = text(sentence) @@ -426,9 +425,9 @@ async def fetch_one( row = rst.fetchone() if row is None: return None - if format in ("dict", 'iterable'): + if query_format in ("dict", 'iterable'): result = dict(zip(column_names, row)) - elif format == "record": + elif query_format == "record": result = Record( dict(zip(column_names, row)), column_names diff --git a/asyncdb/interfaces/abstract.py b/asyncdb/interfaces/abstract.py index 9667c4a9..7e0cfd10 100644 --- a/asyncdb/interfaces/abstract.py +++ b/asyncdb/interfaces/abstract.py @@ -84,10 +84,10 @@ def __init__( try: self._loop = asyncio.new_event_loop() asyncio.set_event_loop(self._loop) - except RuntimeError: + except RuntimeError as e: raise RuntimeError( "No Event Loop is running. Please, run this driver inside an asyncio loop." - ) + ) from e if self._loop.is_closed(): self._loop = asyncio.get_running_loop() asyncio.set_event_loop(self._loop) @@ -100,7 +100,7 @@ def get_loop(self): event_loop = get_loop def event_loop_is_closed(self): - return True if not self._loop else bool(self._loop.is_closed()) + return bool(self._loop.is_closed()) if self._loop else True class PoolContextManager(Awaitable, AbstractAsyncContextManager): diff --git a/examples/test_mysql.py b/examples/test_mysql.py index e9530ee3..e57da2a5 100644 --- a/examples/test_mysql.py +++ b/examples/test_mysql.py @@ -23,7 +23,6 @@ async def pooler(loop): async with await pool.acquire() as conn: # execute a sentence result, error = await conn.test_connection() - print(result, 'Error: ', error) print('Is closed: ', {db.is_connected()}) await pool.close() diff --git a/tests/test_mysql.py b/tests/test_mysql.py index 1374616f..7c9082a6 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -46,7 +46,7 @@ async def test_connect(driver, event_loop): await db.connection() pytest.assume(db.is_connected() is True) result, error = await db.test_connection() - pytest.assume(type(result) == list) + pytest.assume(isinstance(result, list) and len(result) > 0) pytest.assume(error is None) await db.close()