-
-
Notifications
You must be signed in to change notification settings - Fork 2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9288 from OpenMined/scenario-sync-v2
Moving helpers into a better location
- Loading branch information
Showing
28 changed files
with
1,221 additions
and
270 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# stdlib | ||
import os | ||
|
||
# syft absolute | ||
from syft.util.util import str_to_bool | ||
|
||
# relative | ||
from .submit_query import make_submit_query | ||
|
||
env_var = "TEST_BIGQUERY_APIS_LIVE" | ||
use_live = str_to_bool(str(os.environ.get(env_var, "False"))) | ||
env_name = "Live" if use_live else "Mock" | ||
print(f"Using {env_name} API Code, this will query BigQuery. ${env_var}=={use_live}") | ||
|
||
|
||
if use_live: | ||
# relative | ||
from .live.schema import make_schema | ||
from .live.test_query import make_test_query | ||
else: | ||
# relative | ||
from .mock.schema import make_schema | ||
from .mock.test_query import make_test_query |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# stdlib | ||
from collections.abc import Callable | ||
|
||
# syft absolute | ||
import syft as sy | ||
from syft import test_settings | ||
|
||
# relative | ||
from ..rate_limiter import is_within_rate_limit | ||
|
||
|
||
def make_schema(settings: dict, worker_pool: str) -> Callable: | ||
updated_settings = { | ||
"calls_per_min": 5, | ||
"rate_limiter_enabled": True, | ||
"credentials": test_settings.gce_service_account.to_dict(), | ||
"region": test_settings.gce_region, | ||
"project_id": test_settings.gce_project_id, | ||
"dataset_1": test_settings.dataset_1, | ||
"table_1": test_settings.table_1, | ||
"table_2": test_settings.table_2, | ||
} | settings | ||
|
||
@sy.api_endpoint( | ||
path="bigquery.schema", | ||
description="This endpoint allows for visualising the metadata of tables available in BigQuery.", | ||
settings=updated_settings, | ||
helper_functions=[ | ||
is_within_rate_limit | ||
], # Adds ratelimit as this is also a method available to data scientists | ||
worker_pool=worker_pool, | ||
) | ||
def live_schema( | ||
context, | ||
) -> str: | ||
# stdlib | ||
import datetime | ||
|
||
# third party | ||
from google.cloud import bigquery # noqa: F811 | ||
from google.oauth2 import service_account | ||
import pandas as pd | ||
|
||
# syft absolute | ||
from syft import SyftException | ||
|
||
# Auth for Bigquer based on the workload identity | ||
credentials = service_account.Credentials.from_service_account_info( | ||
context.settings["credentials"] | ||
) | ||
scoped_credentials = credentials.with_scopes( | ||
["https://www.googleapis.com/auth/cloud-platform"] | ||
) | ||
|
||
client = bigquery.Client( | ||
credentials=scoped_credentials, | ||
location=context.settings["region"], | ||
) | ||
|
||
# Store a dict with the calltimes for each user, via the email. | ||
if context.settings["rate_limiter_enabled"]: | ||
if context.user.email not in context.state.keys(): | ||
context.state[context.user.email] = [] | ||
|
||
if not context.code.is_within_rate_limit(context): | ||
raise SyftException( | ||
public_message="Rate limit of calls per minute has been reached." | ||
) | ||
context.state[context.user.email].append(datetime.datetime.now()) | ||
|
||
try: | ||
# Formats the data schema in a data frame format | ||
# Warning: the only supported format types are primitives, np.ndarrays and pd.DataFrames | ||
|
||
data_schema = [] | ||
for table_id in [ | ||
f"{context.settings['dataset_1']}.{context.settings['table_1']}", | ||
f"{context.settings['dataset_1']}.{context.settings['table_2']}", | ||
]: | ||
table = client.get_table(table_id) | ||
for schema in table.schema: | ||
data_schema.append( | ||
{ | ||
"project": str(table.project), | ||
"dataset_id": str(table.dataset_id), | ||
"table_id": str(table.table_id), | ||
"schema_name": str(schema.name), | ||
"schema_field": str(schema.field_type), | ||
"description": str(table.description), | ||
"num_rows": str(table.num_rows), | ||
} | ||
) | ||
return pd.DataFrame(data_schema) | ||
|
||
except Exception as e: | ||
# not a bigquery exception | ||
if not hasattr(e, "_errors"): | ||
output = f"got exception e: {type(e)} {str(e)}" | ||
raise SyftException( | ||
public_message=f"An error occured executing the API call {output}" | ||
) | ||
|
||
# Should add appropriate error handling for what should be exposed to the data scientists. | ||
raise SyftException( | ||
public_message="An error occured executing the API call, please contact the domain owner." | ||
) | ||
|
||
return live_schema |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
# stdlib | ||
from collections.abc import Callable | ||
|
||
# syft absolute | ||
import syft as sy | ||
from syft import test_settings | ||
|
||
# relative | ||
from ..rate_limiter import is_within_rate_limit | ||
|
||
|
||
def make_test_query(settings) -> Callable: | ||
updated_settings = { | ||
"calls_per_min": 10, | ||
"rate_limiter_enabled": True, | ||
"credentials": test_settings.gce_service_account.to_dict(), | ||
"region": test_settings.gce_region, | ||
"project_id": test_settings.gce_project_id, | ||
} | settings | ||
|
||
# these are the same if you allow the rate limiter to be turned on and off | ||
@sy.api_endpoint_method( | ||
settings=updated_settings, | ||
helper_functions=[is_within_rate_limit], | ||
) | ||
def live_test_query( | ||
context, | ||
sql_query: str, | ||
) -> str: | ||
# stdlib | ||
import datetime | ||
|
||
# third party | ||
from google.cloud import bigquery # noqa: F811 | ||
from google.oauth2 import service_account | ||
|
||
# syft absolute | ||
from syft import SyftException | ||
|
||
# Auth for Bigquer based on the workload identity | ||
credentials = service_account.Credentials.from_service_account_info( | ||
context.settings["credentials"] | ||
) | ||
scoped_credentials = credentials.with_scopes( | ||
["https://www.googleapis.com/auth/cloud-platform"] | ||
) | ||
|
||
client = bigquery.Client( | ||
credentials=scoped_credentials, | ||
location=context.settings["region"], | ||
) | ||
|
||
# Store a dict with the calltimes for each user, via the email. | ||
if context.settings["rate_limiter_enabled"]: | ||
if context.user.email not in context.state.keys(): | ||
context.state[context.user.email] = [] | ||
|
||
if not context.code.is_within_rate_limit(context): | ||
raise SyftException( | ||
public_message="Rate limit of calls per minute has been reached." | ||
) | ||
context.state[context.user.email].append(datetime.datetime.now()) | ||
|
||
try: | ||
rows = client.query_and_wait( | ||
sql_query, | ||
project=context.settings["project_id"], | ||
) | ||
|
||
if rows.total_rows > 1_000_000: | ||
raise SyftException( | ||
public_message="Please only write queries that gather aggregate statistics" | ||
) | ||
|
||
return rows.to_dataframe() | ||
|
||
except Exception as e: | ||
# not a bigquery exception | ||
if not hasattr(e, "_errors"): | ||
output = f"got exception e: {type(e)} {str(e)}" | ||
raise SyftException( | ||
public_message=f"An error occured executing the API call {output}" | ||
) | ||
|
||
# Treat all errors that we would like to be forwarded to the data scientists | ||
# By default, any exception is only visible to the data owner. | ||
|
||
if e._errors[0]["reason"] in [ | ||
"badRequest", | ||
"blocked", | ||
"duplicate", | ||
"invalidQuery", | ||
"invalid", | ||
"jobBackendError", | ||
"jobInternalError", | ||
"notFound", | ||
"notImplemented", | ||
"rateLimitExceeded", | ||
"resourceInUse", | ||
"resourcesExceeded", | ||
"tableUnavailable", | ||
"timeout", | ||
]: | ||
raise SyftException( | ||
public_message="Error occured during the call: " | ||
+ e._errors[0]["message"] | ||
) | ||
else: | ||
raise SyftException( | ||
public_message="An error occured executing the API call, please contact the domain owner." | ||
) | ||
|
||
return live_test_query |
Empty file.
Oops, something went wrong.