Skip to content

Commit eea4994

Browse files
withnaleLinchin
andauthored
feat: Allow jobs to be run in a different project (#1180)
* feat: Allow jobs to be run in a different project * Update test_sqlalchemy_bigquery_remote.py --------- Co-authored-by: Lingqing Gan <[email protected]>
1 parent 05148cd commit eea4994

File tree

4 files changed

+157
-25
lines changed

4 files changed

+157
-25
lines changed

sqlalchemy_bigquery/base.py

+31-17
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from google import auth
2929
import google.api_core.exceptions
30-
from google.cloud.bigquery import dbapi
30+
from google.cloud.bigquery import dbapi, ConnectionProperty
3131
from google.cloud.bigquery.table import (
3232
RangePartitioning,
3333
TableReference,
@@ -61,6 +61,7 @@
6161
from .parse_url import parse_url
6262
from . import _helpers, _struct, _types
6363
import sqlalchemy_bigquery_vendored.sqlalchemy.postgresql.base as vendored_postgresql
64+
from google.cloud.bigquery import QueryJobConfig
6465

6566
# Illegal characters is intended to be all characters that are not explicitly
6667
# allowed as part of the flexible column names.
@@ -1080,6 +1081,7 @@ def __init__(
10801081
self,
10811082
arraysize=5000,
10821083
credentials_path=None,
1084+
billing_project_id=None,
10831085
location=None,
10841086
credentials_info=None,
10851087
credentials_base64=None,
@@ -1092,6 +1094,8 @@ def __init__(
10921094
self.credentials_path = credentials_path
10931095
self.credentials_info = credentials_info
10941096
self.credentials_base64 = credentials_base64
1097+
self.project_id = None
1098+
self.billing_project_id = billing_project_id
10951099
self.location = location
10961100
self.identifier_preparer = self.preparer(self)
10971101
self.dataset_id = None
@@ -1114,15 +1118,20 @@ def _build_formatted_table_id(table):
11141118
"""Build '<dataset_id>.<table_id>' string using given table."""
11151119
return "{}.{}".format(table.reference.dataset_id, table.table_id)
11161120

1117-
@staticmethod
1118-
def _add_default_dataset_to_job_config(job_config, project_id, dataset_id):
1119-
# If dataset_id is set, then we know the job_config isn't None
1120-
if dataset_id:
1121-
# If project_id is missing, use default project_id for the current environment
1121+
def create_job_config(self, provided_config: QueryJobConfig):
1122+
project_id = self.project_id
1123+
if self.dataset_id is None and project_id == self.billing_project_id:
1124+
return provided_config
1125+
job_config = provided_config or QueryJobConfig()
1126+
if project_id != self.billing_project_id:
1127+
job_config.connection_properties = [
1128+
ConnectionProperty(key="dataset_project_id", value=project_id)
1129+
]
1130+
if self.dataset_id:
11221131
if not project_id:
11231132
_, project_id = auth.default()
1124-
1125-
job_config.default_dataset = "{}.{}".format(project_id, dataset_id)
1133+
job_config.default_dataset = "{}.{}".format(project_id, self.dataset_id)
1134+
return job_config
11261135

11271136
def do_execute(self, cursor, statement, parameters, context=None):
11281137
kwargs = {}
@@ -1132,13 +1141,13 @@ def do_execute(self, cursor, statement, parameters, context=None):
11321141

11331142
def create_connect_args(self, url):
11341143
(
1135-
project_id,
1144+
self.project_id,
11361145
location,
11371146
dataset_id,
11381147
arraysize,
11391148
credentials_path,
11401149
credentials_base64,
1141-
default_query_job_config,
1150+
provided_job_config,
11421151
list_tables_page_size,
11431152
user_supplied_client,
11441153
) = parse_url(url)
@@ -1149,9 +1158,9 @@ def create_connect_args(self, url):
11491158
self.credentials_path = credentials_path or self.credentials_path
11501159
self.credentials_base64 = credentials_base64 or self.credentials_base64
11511160
self.dataset_id = dataset_id
1152-
self._add_default_dataset_to_job_config(
1153-
default_query_job_config, project_id, dataset_id
1154-
)
1161+
self.billing_project_id = self.billing_project_id or self.project_id
1162+
1163+
default_query_job_config = self.create_job_config(provided_job_config)
11551164

11561165
if user_supplied_client:
11571166
# The user is expected to supply a client with
@@ -1162,10 +1171,14 @@ def create_connect_args(self, url):
11621171
credentials_path=self.credentials_path,
11631172
credentials_info=self.credentials_info,
11641173
credentials_base64=self.credentials_base64,
1165-
project_id=project_id,
1174+
project_id=self.billing_project_id,
11661175
location=self.location,
11671176
default_query_job_config=default_query_job_config,
11681177
)
1178+
# If the user specified `bigquery://` we need to set the project_id
1179+
# from the client
1180+
self.project_id = self.project_id or client.project
1181+
self.billing_project_id = self.billing_project_id or client.project
11691182
return ([], {"client": client})
11701183

11711184
def _get_table_or_view_names(self, connection, item_types, schema=None):
@@ -1177,7 +1190,7 @@ def _get_table_or_view_names(self, connection, item_types, schema=None):
11771190
)
11781191

11791192
client = connection.connection._client
1180-
datasets = client.list_datasets()
1193+
datasets = client.list_datasets(self.project_id)
11811194

11821195
result = []
11831196
for dataset in datasets:
@@ -1278,7 +1291,8 @@ def _get_table(self, connection, table_name, schema=None):
12781291

12791292
client = connection.connection._client
12801293

1281-
table_ref = self._table_reference(schema, table_name, client.project)
1294+
# table_ref = self._table_reference(schema, table_name, client.project)
1295+
table_ref = self._table_reference(schema, table_name, self.project_id)
12821296
try:
12831297
table = client.get_table(table_ref)
12841298
except NotFound:
@@ -1332,7 +1346,7 @@ def get_schema_names(self, connection, **kw):
13321346
if isinstance(connection, Engine):
13331347
connection = connection.connect()
13341348

1335-
datasets = connection.connection._client.list_datasets()
1349+
datasets = connection.connection._client.list_datasets(self.project_id)
13361350
return [d.dataset_id for d in datasets]
13371351

13381352
def get_table_names(self, connection, schema=None, **kw):
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Copyright (c) 2017 The sqlalchemy-bigquery Authors
2+
#
3+
# Permission is hereby granted, free of charge, to any person obtaining a copy of
4+
# this software and associated documentation files (the "Software"), to deal in
5+
# the Software without restriction, including without limitation the rights to
6+
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
7+
# the Software, and to permit persons to whom the Software is furnished to do so,
8+
# subject to the following conditions:
9+
#
10+
# The above copyright notice and this permission notice shall be included in all
11+
# copies or substantial portions of the Software.
12+
#
13+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
15+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
16+
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
17+
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
18+
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
19+
20+
# -*- coding: utf-8 -*-
21+
22+
from sqlalchemy.engine import create_engine
23+
from sqlalchemy.exc import DatabaseError
24+
from sqlalchemy.schema import Table, MetaData
25+
import pytest
26+
import sqlalchemy
27+
import google.api_core.exceptions as core_exceptions
28+
29+
30+
EXPECTED_STATES = ["AL", "CA", "FL", "KY"]
31+
32+
REMOTE_TESTS = [
33+
("bigquery-public-data", "bigquery-public-data.usa_names.usa_1910_2013"),
34+
("bigquery-public-data", "usa_names.usa_1910_2013"),
35+
("bigquery-public-data/usa_names", "bigquery-public-data.usa_names.usa_1910_2013"),
36+
("bigquery-public-data/usa_names", "usa_1910_2013"),
37+
("bigquery-public-data/usa_names", "usa_names.usa_1910_2013"),
38+
]
39+
40+
41+
@pytest.fixture(scope="session")
42+
def engine_using_remote_dataset(bigquery_client):
43+
engine = create_engine(
44+
"bigquery://bigquery-public-data/usa_names",
45+
billing_project_id=bigquery_client.project,
46+
echo=True,
47+
)
48+
return engine
49+
50+
51+
def test_remote_tables_list(engine_using_remote_dataset):
52+
tables = sqlalchemy.inspect(engine_using_remote_dataset).get_table_names()
53+
assert "usa_1910_2013" in tables
54+
55+
56+
@pytest.mark.parametrize(
57+
["urlpath", "table_name"],
58+
REMOTE_TESTS,
59+
ids=[f"test_engine_remote_sql_{x}" for x in range(len(REMOTE_TESTS))],
60+
)
61+
def test_engine_remote_sql(bigquery_client, urlpath, table_name):
62+
engine = create_engine(
63+
f"bigquery://{urlpath}", billing_project_id=bigquery_client.project, echo=True
64+
)
65+
with engine.connect() as conn:
66+
rows = conn.execute(
67+
sqlalchemy.text(f"SELECT DISTINCT(state) FROM `{table_name}`")
68+
).fetchall()
69+
states = set(map(lambda row: row[0], rows))
70+
assert set(EXPECTED_STATES).issubset(states)
71+
72+
73+
@pytest.mark.parametrize(
74+
["urlpath", "table_name"],
75+
REMOTE_TESTS,
76+
ids=[f"test_engine_remote_table_{x}" for x in range(len(REMOTE_TESTS))],
77+
)
78+
def test_engine_remote_table(bigquery_client, urlpath, table_name):
79+
engine = create_engine(
80+
f"bigquery://{urlpath}", billing_project_id=bigquery_client.project, echo=True
81+
)
82+
with engine.connect() as conn:
83+
table = Table(table_name, MetaData(), autoload_with=engine)
84+
prepared = sqlalchemy.select(
85+
sqlalchemy.distinct(table.c.state)
86+
).set_label_style(sqlalchemy.LABEL_STYLE_TABLENAME_PLUS_COL)
87+
rows = conn.execute(prepared).fetchall()
88+
states = set(map(lambda row: row[0], rows))
89+
assert set(EXPECTED_STATES).issubset(states)
90+
91+
92+
@pytest.mark.parametrize(
93+
["urlpath", "table_name"],
94+
REMOTE_TESTS,
95+
ids=[f"test_engine_remote_table_fail_{x}" for x in range(len(REMOTE_TESTS))],
96+
)
97+
def test_engine_remote_table_fail(urlpath, table_name):
98+
engine = create_engine(f"bigquery://{urlpath}", echo=True)
99+
with pytest.raises(
100+
(DatabaseError, core_exceptions.Forbidden), match="Access Denied"
101+
):
102+
with engine.connect() as conn:
103+
table = Table(table_name, MetaData(), autoload_with=engine)
104+
prepared = sqlalchemy.select(
105+
sqlalchemy.distinct(table.c.state)
106+
).set_label_style(sqlalchemy.LABEL_STYLE_TABLENAME_PLUS_COL)
107+
conn.execute(prepared).fetchall()

tests/unit/fauxdbi.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -327,10 +327,12 @@ def _fix_pickled(self, row):
327327
pickle.loads(v.encode("latin1"))
328328
# \x80\x04 is latin-1 encoded prefix for Pickle protocol 4.
329329
if isinstance(v, str) and v[:2] == "\x80\x04" and v[-1] == "."
330-
else pickle.loads(base64.b16decode(v))
331-
# 8004 is base64 encoded prefix for Pickle protocol 4.
332-
if isinstance(v, str) and v[:4] == "8004" and v[-2:] == "2E"
333-
else v
330+
else (
331+
pickle.loads(base64.b16decode(v))
332+
# 8004 is base64 encoded prefix for Pickle protocol 4.
333+
if isinstance(v, str) and v[:4] == "8004" and v[-2:] == "2E"
334+
else v
335+
)
334336
)
335337
for d, v in zip(self.description, row)
336338
]
@@ -355,7 +357,10 @@ def __getattr__(self, name):
355357
class FauxClient:
356358
def __init__(self, project_id=None, default_query_job_config=None, *args, **kw):
357359
if project_id is None:
358-
if default_query_job_config is not None:
360+
if (
361+
default_query_job_config is not None
362+
and default_query_job_config.default_dataset
363+
):
359364
project_id = default_query_job_config.default_dataset.project
360365
else:
361366
project_id = "authproj" # we would still have gotten it from auth.
@@ -469,10 +474,10 @@ def get_table(self, table_ref):
469474
else:
470475
raise google.api_core.exceptions.NotFound(table_ref)
471476

472-
def list_datasets(self):
477+
def list_datasets(self, project="myproject"):
473478
return [
474-
google.cloud.bigquery.Dataset("myproject.mydataset"),
475-
google.cloud.bigquery.Dataset("myproject.yourdataset"),
479+
google.cloud.bigquery.Dataset(f"{project}.mydataset"),
480+
google.cloud.bigquery.Dataset(f"{project}.yourdataset"),
476481
]
477482

478483
def list_tables(self, dataset, page_size):

tests/unit/test_engine.py

+6
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,12 @@ def test_engine_dataset_but_no_project(faux_conn):
2727
assert conn.connection._client.project == "authproj"
2828

2929

30+
def test_engine_dataset_with_billing_project(faux_conn):
31+
engine = sqlalchemy.create_engine("bigquery://foo", billing_project_id="bar")
32+
conn = engine.connect()
33+
assert conn.connection._client.project == "bar"
34+
35+
3036
def test_engine_no_dataset_no_project(faux_conn):
3137
engine = sqlalchemy.create_engine("bigquery://")
3238
conn = engine.connect()

0 commit comments

Comments
 (0)