Skip to content

Commit

Permalink
working pytests for BigQuery driver
Browse files Browse the repository at this point in the history
  • Loading branch information
phenobarbital committed Dec 14, 2023
1 parent f407bda commit 892905d
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 6 deletions.
14 changes: 8 additions & 6 deletions asyncdb/drivers/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ async def connection(self):
credentials=self.credentials,
project=self._project_id
)
self._connected = True
else:
self.credentials = self._account
self._connection = bq.Client(
Expand All @@ -59,7 +60,7 @@ async def connection(self):

async def close(self):
# BigQuery client does not maintain persistent connections, so nothing to close here.
pass
self._connected = False

disconnect = close

Expand Down Expand Up @@ -110,9 +111,9 @@ def get_load_config(self, **kwargs):
args = {**kwargs, **args}
return bq.LoadJobConfig(**args)

async def create_dataset(self, dataset: str):
async def create_dataset(self, dataset_id: str):
try:
dataset_ref = self._connection.dataset(dataset)
dataset_ref = bq.DatasetReference(self._connection.project, dataset_id)
dataset_obj = bq.Dataset(dataset_ref)
dataset_obj = self._connection.create_dataset(dataset_obj)
return dataset_obj
Expand All @@ -139,7 +140,7 @@ async def create_table(self, dataset_id, table_id, schema):
if not self._connection:
await self.connection()

dataset_ref = self._connection.dataset(dataset_id)
dataset_ref = bq.DatasetReference(self._connection.project, dataset_id)
table_ref = dataset_ref.table(table_id)
table = bq.Table(table_ref, schema=schema)
try:
Expand All @@ -165,7 +166,8 @@ async def truncate_table(self, table_id: str, dataset_id: str):
if not self._connection:
await self.connection()

dataset_ref = self._connection.dataset(dataset_id)
# Construct a reference to the dataset
dataset_ref = bq.DatasetReference(self._connection.project, dataset_id)
table_ref = dataset_ref.table(table_id)
table = self._connection.get_table(table_ref) # API request to fetch the table schema

Expand Down Expand Up @@ -369,7 +371,7 @@ def connected(self):
return self._connection is not None

def is_connected(self):
return self.connected
return self._connected

def tables(self, schema: str = "") -> Iterable[Any]:
raise NotImplementedError
Expand Down
1 change: 1 addition & 0 deletions examples/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ async def connect(loop):
print(
f"Connected: {bq.is_connected()}"
)
print('TEST ', await bq.test_connection())
query = """
SELECT corpus AS title, COUNT(word) AS unique_words
FROM `bigquery-public-data.samples.shakespeare`
Expand Down
106 changes: 106 additions & 0 deletions tests/test_bigquery.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
import pytest
import asyncio
from asyncdb import AsyncDB
from google.cloud import bigquery as gbq
from google.cloud.bigquery.table import RowIterator
from asyncdb.drivers.bigquery import bigquery

# create a pool with parameters
DRIVER='bigquery'
PARAMS = {
"credentials": "~/proyectos/navigator/asyncdb/env/key.json",
"project_id": "unique-decker-385015"
}

@pytest.fixture
async def conn(event_loop):
db = AsyncDB(DRIVER, params=PARAMS, loop=event_loop)
await db.connection()
yield db
await db.close()

pytestmark = pytest.mark.asyncio

@pytest.mark.parametrize("driver", [
(DRIVER)
])
async def test_connect(driver, event_loop):
db = AsyncDB(driver, params=PARAMS, loop=event_loop)
async with await db.connection() as conn:
pytest.assume(conn.is_connected() is True)
result, error = await conn.test_connection()
pytest.assume(isinstance(result, RowIterator))
pytest.assume(not error)
pytest.assume(db.is_connected() is False)

@pytest.mark.asyncio
async def test_bigquery_operations(event_loop):
bq = bigquery(loop=event_loop, params=PARAMS)
async with await bq.connection() as conn:
assert conn is not None, "Connection failed"
assert bq.is_connected(), "Connection failed"
# Test query
query = """
SELECT corpus AS title, COUNT(word) AS unique_words
FROM `bigquery-public-data.samples.shakespeare`
GROUP BY title
ORDER BY unique_words
DESC LIMIT 10
"""
results, error = await bq.query(query)
assert error is None, f"Query failed with error: {error}"
assert results is not None, "Query returned no results"

# Test dataset creation
dataset = await bq.create_dataset('us_states_dataset')
assert dataset is not None, "Dataset creation failed"

# Test table creation
schema = [
gbq.SchemaField("name", "STRING", mode="REQUIRED"),
gbq.SchemaField("post_abbr", "STRING", mode="REQUIRED"),
]
table = await bq.create_table(
dataset_id='us_states_dataset',
table_id='us_states',
schema=schema
)
assert table is not None, "Table creation failed"

# Test table truncation
truncated = await bq.truncate_table(
dataset_id='us_states_dataset',
table_id='us_states'
)

# Test data loading
gcs_uri = 'gs://cloud-samples-data/bigquery/us-states/us-states.json'
job_config = gbq.job.LoadJobConfig(
autodetect=True,
source_format=gbq.SourceFormat.NEWLINE_DELIMITED_JSON,
)
job_config.schema = schema
load_job = await bq.load_table_from_uri(
source_uri=gcs_uri,
table=table,
job_config=job_config
)
assert load_job, "Data loading failed"

# Test fetching data
query = """
SELECT name AS state, post_abbr as state_code
FROM `unique-decker-385015.us_states_dataset.us_states`
"""
job_config = bq.get_query_config(
use_legacy_sql=False
)
results, error = await bq.query(query, job_config=job_config)
assert error is None, f"Query failed with error: {error}"
assert results is not None, "Query returned no results"

# Test closing connection
assert not bq.is_connected(), "Closing connection failed"

def pytest_sessionfinish(session, exitstatus):
asyncio.get_event_loop().close()

0 comments on commit 892905d

Please sign in to comment.