diff --git a/asyncdb/drivers/bigquery.py b/asyncdb/drivers/bigquery.py index 02988385..20d559a7 100644 --- a/asyncdb/drivers/bigquery.py +++ b/asyncdb/drivers/bigquery.py @@ -48,6 +48,7 @@ async def connection(self): credentials=self.credentials, project=self._project_id ) + self._connected = True else: self.credentials = self._account self._connection = bq.Client( @@ -59,7 +60,7 @@ async def connection(self): async def close(self): # BigQuery client does not maintain persistent connections, so nothing to close here. - pass + self._connected = False disconnect = close @@ -110,9 +111,9 @@ def get_load_config(self, **kwargs): args = {**kwargs, **args} return bq.LoadJobConfig(**args) - async def create_dataset(self, dataset: str): + async def create_dataset(self, dataset_id: str): try: - dataset_ref = self._connection.dataset(dataset) + dataset_ref = bq.DatasetReference(self._connection.project, dataset_id) dataset_obj = bq.Dataset(dataset_ref) dataset_obj = self._connection.create_dataset(dataset_obj) return dataset_obj @@ -139,7 +140,7 @@ async def create_table(self, dataset_id, table_id, schema): if not self._connection: await self.connection() - dataset_ref = self._connection.dataset(dataset_id) + dataset_ref = bq.DatasetReference(self._connection.project, dataset_id) table_ref = dataset_ref.table(table_id) table = bq.Table(table_ref, schema=schema) try: @@ -165,7 +166,8 @@ async def truncate_table(self, table_id: str, dataset_id: str): if not self._connection: await self.connection() - dataset_ref = self._connection.dataset(dataset_id) + # Construct a reference to the dataset + dataset_ref = bq.DatasetReference(self._connection.project, dataset_id) table_ref = dataset_ref.table(table_id) table = self._connection.get_table(table_ref) # API request to fetch the table schema @@ -369,7 +371,7 @@ def connected(self): return self._connection is not None def is_connected(self): - return self.connected + return self._connected def tables(self, schema: str = "") -> Iterable[Any]: raise NotImplementedError diff --git a/examples/test_bigquery.py b/examples/test_bigquery.py index 6eabf195..2b0a79a3 100644 --- a/examples/test_bigquery.py +++ b/examples/test_bigquery.py @@ -15,6 +15,7 @@ async def connect(loop): print( f"Connected: {bq.is_connected()}" ) + print('TEST ', await bq.test_connection()) query = """ SELECT corpus AS title, COUNT(word) AS unique_words FROM `bigquery-public-data.samples.shakespeare` diff --git a/tests/test_bigquery.py b/tests/test_bigquery.py new file mode 100644 index 00000000..ff7cb749 --- /dev/null +++ b/tests/test_bigquery.py @@ -0,0 +1,106 @@ +import pytest +import asyncio +from asyncdb import AsyncDB +from google.cloud import bigquery as gbq +from google.cloud.bigquery.table import RowIterator +from asyncdb.drivers.bigquery import bigquery + +# create a pool with parameters +DRIVER='bigquery' +PARAMS = { + "credentials": "~/proyectos/navigator/asyncdb/env/key.json", + "project_id": "unique-decker-385015" +} + +@pytest.fixture +async def conn(event_loop): + db = AsyncDB(DRIVER, params=PARAMS, loop=event_loop) + await db.connection() + yield db + await db.close() + +pytestmark = pytest.mark.asyncio + +@pytest.mark.parametrize("driver", [ + (DRIVER) +]) +async def test_connect(driver, event_loop): + db = AsyncDB(driver, params=PARAMS, loop=event_loop) + async with await db.connection() as conn: + pytest.assume(conn.is_connected() is True) + result, error = await conn.test_connection() + pytest.assume(isinstance(result, RowIterator)) + pytest.assume(not error) + pytest.assume(db.is_connected() is False) + +@pytest.mark.asyncio +async def test_bigquery_operations(event_loop): + bq = bigquery(loop=event_loop, params=PARAMS) + async with await bq.connection() as conn: + assert conn is not None, "Connection failed" + assert bq.is_connected(), "Connection failed" + # Test query + query = """ + SELECT corpus AS title, COUNT(word) AS unique_words + FROM `bigquery-public-data.samples.shakespeare` + GROUP BY title + ORDER BY unique_words + DESC LIMIT 10 + """ + results, error = await bq.query(query) + assert error is None, f"Query failed with error: {error}" + assert results is not None, "Query returned no results" + + # Test dataset creation + dataset = await bq.create_dataset('us_states_dataset') + assert dataset is not None, "Dataset creation failed" + + # Test table creation + schema = [ + gbq.SchemaField("name", "STRING", mode="REQUIRED"), + gbq.SchemaField("post_abbr", "STRING", mode="REQUIRED"), + ] + table = await bq.create_table( + dataset_id='us_states_dataset', + table_id='us_states', + schema=schema + ) + assert table is not None, "Table creation failed" + + # Test table truncation + truncated = await bq.truncate_table( + dataset_id='us_states_dataset', + table_id='us_states' + ) + + # Test data loading + gcs_uri = 'gs://cloud-samples-data/bigquery/us-states/us-states.json' + job_config = gbq.job.LoadJobConfig( + autodetect=True, + source_format=gbq.SourceFormat.NEWLINE_DELIMITED_JSON, + ) + job_config.schema = schema + load_job = await bq.load_table_from_uri( + source_uri=gcs_uri, + table=table, + job_config=job_config + ) + assert load_job, "Data loading failed" + + # Test fetching data + query = """ + SELECT name AS state, post_abbr as state_code + FROM `unique-decker-385015.us_states_dataset.us_states` + """ + job_config = bq.get_query_config( + use_legacy_sql=False + ) + results, error = await bq.query(query, job_config=job_config) + assert error is None, f"Query failed with error: {error}" + assert results is not None, "Query returned no results" + + # Test closing connection + assert not bq.is_connected(), "Closing connection failed" + +def pytest_sessionfinish(session, exitstatus): + asyncio.get_event_loop().close()