Skip to content

Commit

Permalink
fix some syntax issues
Browse files Browse the repository at this point in the history
  • Loading branch information
phenobarbital committed Sep 12, 2024
1 parent fcb7642 commit 29b2c10
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 40 deletions.
2 changes: 1 addition & 1 deletion asyncdb/drivers/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ async def file_to_parquet(
pq.write_table(atable, parquet, compression="snappy")
except Exception as exc:
raise DriverError(
f"Query Error: {exc}"
f"Delta File To Parquet Error: {exc}"
) from exc

async def write(
Expand Down
2 changes: 1 addition & 1 deletion asyncdb/drivers/mysqlclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ async def close(self):
await self._thread_func(self._pool.close)
self._connected = False
self._logger.debug(
f"MySQL Connection Closed."
"MySQL Connection Closed."
)

disconnect = close
Expand Down
11 changes: 9 additions & 2 deletions asyncdb/drivers/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,16 @@ def __init__(self, dsn: str = "", loop: asyncio.AbstractEventLoop = None, params
except KeyError:
self._lib_dir = None
try:
super(oracle, self).__init__(dsn=dsn, loop=loop, params=params, **kwargs)
super(oracle, self).__init__(
dsn=dsn,
loop=loop,
params=params,
**kwargs
)
_generated = datetime.now() - _starttime
print(f"Oracle Started in: {_generated}")
print(
f"Oracle Started in: {_generated}"
)
except Exception as err:
raise DriverError(f"Oracle Error: {err}") from err
# set the JSON encoder:
Expand Down
10 changes: 8 additions & 2 deletions asyncdb/drivers/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,14 @@ async def execute(self, sentence, *args, **kwargs) -> Any:
if self._connection:
try:
return await self._connection.execute_command(sentence, *args)
except (RedisError,) as err:
raise DriverError(f"Connection Error: {err}") from err
except RedisError as err:
raise DriverError(
f"Connection Error: {err}"
) from err
except Exception as err:
raise DriverError(
f"Unknown Redis Error: {err}"
) from err

execute_many = execute

Expand Down
57 changes: 28 additions & 29 deletions asyncdb/drivers/sa.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ async def __anext__(self):
# Retrieve the next item in the result set
try:
return next(self._result_iter)
except StopIteration:
raise StopAsyncIteration
except StopIteration as e:
raise StopAsyncIteration from e

class sa(SQLDriver, DBCursorBackend):
_provider = "sa"
Expand Down Expand Up @@ -84,8 +84,7 @@ def __init__(
if params:
self._driver = params.get('driver', "postgresql+asyncpg")
else:
params = {}
params["driver"] = "postgresql+asyncpg"
params = {"driver": "postgresql+asyncpg"}
SQLDriver.__init__(
self, dsn=dsn,
loop=loop,
Expand Down Expand Up @@ -226,16 +225,16 @@ async def query(
self,
sentence: Any,
params: List = None,
format: str = None
query_format: str = None
):
"""
Running Query.
"""
self._result = None
error = None
await self.valid_operation(sentence)
if not format:
format = self._row_format
if not query_format:
query_format = self._row_format
try:
self.start_timing()
async with self._connection.connect() as conn:
Expand All @@ -245,11 +244,11 @@ async def query(
rows = result.fetchall()
# Get the column names from the result metadata
column_names = result.keys()
if format in ("dict", "iterable"):
if query_format in ("dict", "iterable"):
self._result = [
dict(zip(column_names, row)) for row in rows
]
elif format == "record":
elif query_format == "record":
self._result = [
self._construct_record(row, column_names) for row in rows
]
Expand All @@ -275,7 +274,7 @@ async def queryrow(
self,
sentence: Any,
params: Any = None,
format: Optional[str] = None
query_format: Optional[str] = None
):
"""
Running Query and return only one row.
Expand All @@ -284,18 +283,18 @@ async def queryrow(
error = None
await self.valid_operation(sentence)
try:
if not format:
format = self._row_format
if not query_format:
query_format = self._row_format
result = None
async with self._connection.connect() as conn:
if isinstance(sentence, str):
sentence = text(sentence)
result = await conn.execute(sentence, params)
column_names = result.keys()
row = result.fetchone()
if format in ("dict", 'iterable'):
if query_format in ("dict", 'iterable'):
self._result = dict(zip(column_names, row))
elif format == "record":
elif query_format == "record":
self._result = self._construct_record(row, column_names)
else:
self._result = row
Expand All @@ -317,16 +316,16 @@ async def fetch_all(
self,
sentence: Any,
params: List = None,
format: Optional[str] = None
query_format: Optional[str] = None
):
"""
Fetch All Rows in a Query.
"""
result = None
await self.valid_operation(sentence)
try:
if not format:
format = self._row_format
if not query_format:
query_format = self._row_format
async with self._connection.connect() as conn:
if isinstance(sentence, str):
sentence = text(sentence)
Expand All @@ -335,11 +334,11 @@ async def fetch_all(
rows = rst.fetchall()
if rows is None:
return None
if format in ("dict", 'iterable'):
if query_format in ("dict", 'iterable'):
result = [
dict(zip(column_names, row)) for row in rows
]
elif format == "record":
elif query_format == "record":
result = [
self._construct_record(row, column_names) for row in rows
]
Expand All @@ -362,16 +361,16 @@ async def fetch_many(
sentence: Any,
size: int = 1,
params: List = None,
format: Optional[str] = None
query_format: Optional[str] = None
):
"""
Fetch Many Rows from a Query as requested.
"""
result = None
await self.valid_operation(sentence)
try:
if not format:
format = self._row_format
if not query_format:
query_format = self._row_format
async with self._connection.connect() as conn:
if isinstance(sentence, str):
sentence = text(sentence)
Expand All @@ -380,11 +379,11 @@ async def fetch_many(
rows = rst.fetchmany(size)
if rows is None:
return None
if format in ("dict", 'iterable'):
if query_format in ("dict", 'iterable'):
result = [
dict(zip(column_names, row)) for row in rows
]
elif format == "record":
elif query_format == "record":
result = [
self._construct_record(row, column_names) for row in rows
]
Expand All @@ -408,16 +407,16 @@ async def fetch_one(
self,
sentence: Any,
params: List = None,
format: Optional[str] = None
query_format: Optional[str] = None
):
"""
Running Query and return only one row.
"""
result = None
await self.valid_operation(sentence)
try:
if not format:
format = self._row_format
if not query_format:
query_format = self._row_format
async with self._connection.connect() as conn:
if isinstance(sentence, str):
sentence = text(sentence)
Expand All @@ -426,9 +425,9 @@ async def fetch_one(
row = rst.fetchone()
if row is None:
return None
if format in ("dict", 'iterable'):
if query_format in ("dict", 'iterable'):
result = dict(zip(column_names, row))
elif format == "record":
elif query_format == "record":
result = Record(
dict(zip(column_names, row)),
column_names
Expand Down
6 changes: 3 additions & 3 deletions asyncdb/interfaces/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ def __init__(
try:
self._loop = asyncio.new_event_loop()
asyncio.set_event_loop(self._loop)
except RuntimeError:
except RuntimeError as e:
raise RuntimeError(
"No Event Loop is running. Please, run this driver inside an asyncio loop."
)
) from e
if self._loop.is_closed():
self._loop = asyncio.get_running_loop()
asyncio.set_event_loop(self._loop)
Expand All @@ -100,7 +100,7 @@ def get_loop(self):
event_loop = get_loop

def event_loop_is_closed(self):
return True if not self._loop else bool(self._loop.is_closed())
return bool(self._loop.is_closed()) if self._loop else True


class PoolContextManager(Awaitable, AbstractAsyncContextManager):
Expand Down
1 change: 0 additions & 1 deletion examples/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ async def pooler(loop):
async with await pool.acquire() as conn:
# execute a sentence
result, error = await conn.test_connection()
print(result, 'Error: ', error)
print('Is closed: ', {db.is_connected()})
await pool.close()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ async def test_connect(driver, event_loop):
await db.connection()
pytest.assume(db.is_connected() is True)
result, error = await db.test_connection()
pytest.assume(type(result) == list)
pytest.assume(isinstance(result, list) and len(result) > 0)
pytest.assume(error is None)
await db.close()

Expand Down

0 comments on commit 29b2c10

Please sign in to comment.