27
27
28
28
from google import auth
29
29
import google .api_core .exceptions
30
- from google .cloud .bigquery import dbapi
30
+ from google .cloud .bigquery import dbapi , ConnectionProperty
31
31
from google .cloud .bigquery .table import (
32
32
RangePartitioning ,
33
33
TableReference ,
61
61
from .parse_url import parse_url
62
62
from . import _helpers , _struct , _types
63
63
import sqlalchemy_bigquery_vendored .sqlalchemy .postgresql .base as vendored_postgresql
64
+ from google .cloud .bigquery import QueryJobConfig
64
65
65
66
# Illegal characters is intended to be all characters that are not explicitly
66
67
# allowed as part of the flexible column names.
@@ -1080,6 +1081,7 @@ def __init__(
1080
1081
self ,
1081
1082
arraysize = 5000 ,
1082
1083
credentials_path = None ,
1084
+ billing_project_id = None ,
1083
1085
location = None ,
1084
1086
credentials_info = None ,
1085
1087
credentials_base64 = None ,
@@ -1092,6 +1094,8 @@ def __init__(
1092
1094
self .credentials_path = credentials_path
1093
1095
self .credentials_info = credentials_info
1094
1096
self .credentials_base64 = credentials_base64
1097
+ self .project_id = None
1098
+ self .billing_project_id = billing_project_id
1095
1099
self .location = location
1096
1100
self .identifier_preparer = self .preparer (self )
1097
1101
self .dataset_id = None
@@ -1114,15 +1118,20 @@ def _build_formatted_table_id(table):
1114
1118
"""Build '<dataset_id>.<table_id>' string using given table."""
1115
1119
return "{}.{}" .format (table .reference .dataset_id , table .table_id )
1116
1120
1117
- @staticmethod
1118
- def _add_default_dataset_to_job_config (job_config , project_id , dataset_id ):
1119
- # If dataset_id is set, then we know the job_config isn't None
1120
- if dataset_id :
1121
- # If project_id is missing, use default project_id for the current environment
1121
+ def create_job_config (self , provided_config : QueryJobConfig ):
1122
+ project_id = self .project_id
1123
+ if self .dataset_id is None and project_id == self .billing_project_id :
1124
+ return provided_config
1125
+ job_config = provided_config or QueryJobConfig ()
1126
+ if project_id != self .billing_project_id :
1127
+ job_config .connection_properties = [
1128
+ ConnectionProperty (key = "dataset_project_id" , value = project_id )
1129
+ ]
1130
+ if self .dataset_id :
1122
1131
if not project_id :
1123
1132
_ , project_id = auth .default ()
1124
-
1125
- job_config . default_dataset = "{}.{}" . format ( project_id , dataset_id )
1133
+ job_config . default_dataset = "{}.{}" . format ( project_id , self . dataset_id )
1134
+ return job_config
1126
1135
1127
1136
def do_execute (self , cursor , statement , parameters , context = None ):
1128
1137
kwargs = {}
@@ -1132,13 +1141,13 @@ def do_execute(self, cursor, statement, parameters, context=None):
1132
1141
1133
1142
def create_connect_args (self , url ):
1134
1143
(
1135
- project_id ,
1144
+ self . project_id ,
1136
1145
location ,
1137
1146
dataset_id ,
1138
1147
arraysize ,
1139
1148
credentials_path ,
1140
1149
credentials_base64 ,
1141
- default_query_job_config ,
1150
+ provided_job_config ,
1142
1151
list_tables_page_size ,
1143
1152
user_supplied_client ,
1144
1153
) = parse_url (url )
@@ -1149,9 +1158,9 @@ def create_connect_args(self, url):
1149
1158
self .credentials_path = credentials_path or self .credentials_path
1150
1159
self .credentials_base64 = credentials_base64 or self .credentials_base64
1151
1160
self .dataset_id = dataset_id
1152
- self ._add_default_dataset_to_job_config (
1153
- default_query_job_config , project_id , dataset_id
1154
- )
1161
+ self .billing_project_id = self . billing_project_id or self . project_id
1162
+
1163
+ default_query_job_config = self . create_job_config ( provided_job_config )
1155
1164
1156
1165
if user_supplied_client :
1157
1166
# The user is expected to supply a client with
@@ -1162,10 +1171,14 @@ def create_connect_args(self, url):
1162
1171
credentials_path = self .credentials_path ,
1163
1172
credentials_info = self .credentials_info ,
1164
1173
credentials_base64 = self .credentials_base64 ,
1165
- project_id = project_id ,
1174
+ project_id = self . billing_project_id ,
1166
1175
location = self .location ,
1167
1176
default_query_job_config = default_query_job_config ,
1168
1177
)
1178
+ # If the user specified `bigquery://` we need to set the project_id
1179
+ # from the client
1180
+ self .project_id = self .project_id or client .project
1181
+ self .billing_project_id = self .billing_project_id or client .project
1169
1182
return ([], {"client" : client })
1170
1183
1171
1184
def _get_table_or_view_names (self , connection , item_types , schema = None ):
@@ -1177,7 +1190,7 @@ def _get_table_or_view_names(self, connection, item_types, schema=None):
1177
1190
)
1178
1191
1179
1192
client = connection .connection ._client
1180
- datasets = client .list_datasets ()
1193
+ datasets = client .list_datasets (self . project_id )
1181
1194
1182
1195
result = []
1183
1196
for dataset in datasets :
@@ -1278,7 +1291,8 @@ def _get_table(self, connection, table_name, schema=None):
1278
1291
1279
1292
client = connection .connection ._client
1280
1293
1281
- table_ref = self ._table_reference (schema , table_name , client .project )
1294
+ # table_ref = self._table_reference(schema, table_name, client.project)
1295
+ table_ref = self ._table_reference (schema , table_name , self .project_id )
1282
1296
try :
1283
1297
table = client .get_table (table_ref )
1284
1298
except NotFound :
@@ -1332,7 +1346,7 @@ def get_schema_names(self, connection, **kw):
1332
1346
if isinstance (connection , Engine ):
1333
1347
connection = connection .connect ()
1334
1348
1335
- datasets = connection .connection ._client .list_datasets ()
1349
+ datasets = connection .connection ._client .list_datasets (self . project_id )
1336
1350
return [d .dataset_id for d in datasets ]
1337
1351
1338
1352
def get_table_names (self , connection , schema = None , ** kw ):
0 commit comments