Skip to content

Commit

Permalink
Merge pull request #9288 from OpenMined/scenario-sync-v2
Browse files Browse the repository at this point in the history
Moving helpers into a better location
  • Loading branch information
snwagh authored Sep 12, 2024
2 parents 1701735 + e3aa3d9 commit 4ad090c
Show file tree
Hide file tree
Showing 28 changed files with 1,221 additions and 270 deletions.
23 changes: 23 additions & 0 deletions notebooks/notebook_helpers/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# stdlib
import os

# syft absolute
from syft.util.util import str_to_bool

# relative
from .submit_query import make_submit_query

env_var = "TEST_BIGQUERY_APIS_LIVE"
use_live = str_to_bool(str(os.environ.get(env_var, "False")))
env_name = "Live" if use_live else "Mock"
print(f"Using {env_name} API Code, this will query BigQuery. ${env_var}=={use_live}")


if use_live:
# relative
from .live.schema import make_schema
from .live.test_query import make_test_query
else:
# relative
from .mock.schema import make_schema
from .mock.test_query import make_test_query
Empty file.
108 changes: 108 additions & 0 deletions notebooks/notebook_helpers/apis/live/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# stdlib
from collections.abc import Callable

# syft absolute
import syft as sy
from syft import test_settings

# relative
from ..rate_limiter import is_within_rate_limit


def make_schema(settings: dict, worker_pool: str) -> Callable:
updated_settings = {
"calls_per_min": 5,
"rate_limiter_enabled": True,
"credentials": test_settings.gce_service_account.to_dict(),
"region": test_settings.gce_region,
"project_id": test_settings.gce_project_id,
"dataset_1": test_settings.dataset_1,
"table_1": test_settings.table_1,
"table_2": test_settings.table_2,
} | settings

@sy.api_endpoint(
path="bigquery.schema",
description="This endpoint allows for visualising the metadata of tables available in BigQuery.",
settings=updated_settings,
helper_functions=[
is_within_rate_limit
], # Adds ratelimit as this is also a method available to data scientists
worker_pool=worker_pool,
)
def live_schema(
context,
) -> str:
# stdlib
import datetime

# third party
from google.cloud import bigquery # noqa: F811
from google.oauth2 import service_account
import pandas as pd

# syft absolute
from syft import SyftException

# Auth for Bigquer based on the workload identity
credentials = service_account.Credentials.from_service_account_info(
context.settings["credentials"]
)
scoped_credentials = credentials.with_scopes(
["https://www.googleapis.com/auth/cloud-platform"]
)

client = bigquery.Client(
credentials=scoped_credentials,
location=context.settings["region"],
)

# Store a dict with the calltimes for each user, via the email.
if context.settings["rate_limiter_enabled"]:
if context.user.email not in context.state.keys():
context.state[context.user.email] = []

if not context.code.is_within_rate_limit(context):
raise SyftException(
public_message="Rate limit of calls per minute has been reached."
)
context.state[context.user.email].append(datetime.datetime.now())

try:
# Formats the data schema in a data frame format
# Warning: the only supported format types are primitives, np.ndarrays and pd.DataFrames

data_schema = []
for table_id in [
f"{context.settings['dataset_1']}.{context.settings['table_1']}",
f"{context.settings['dataset_1']}.{context.settings['table_2']}",
]:
table = client.get_table(table_id)
for schema in table.schema:
data_schema.append(
{
"project": str(table.project),
"dataset_id": str(table.dataset_id),
"table_id": str(table.table_id),
"schema_name": str(schema.name),
"schema_field": str(schema.field_type),
"description": str(table.description),
"num_rows": str(table.num_rows),
}
)
return pd.DataFrame(data_schema)

except Exception as e:
# not a bigquery exception
if not hasattr(e, "_errors"):
output = f"got exception e: {type(e)} {str(e)}"
raise SyftException(
public_message=f"An error occured executing the API call {output}"
)

# Should add appropriate error handling for what should be exposed to the data scientists.
raise SyftException(
public_message="An error occured executing the API call, please contact the domain owner."
)

return live_schema
113 changes: 113 additions & 0 deletions notebooks/notebook_helpers/apis/live/test_query.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# stdlib
from collections.abc import Callable

# syft absolute
import syft as sy
from syft import test_settings

# relative
from ..rate_limiter import is_within_rate_limit


def make_test_query(settings) -> Callable:
updated_settings = {
"calls_per_min": 10,
"rate_limiter_enabled": True,
"credentials": test_settings.gce_service_account.to_dict(),
"region": test_settings.gce_region,
"project_id": test_settings.gce_project_id,
} | settings

# these are the same if you allow the rate limiter to be turned on and off
@sy.api_endpoint_method(
settings=updated_settings,
helper_functions=[is_within_rate_limit],
)
def live_test_query(
context,
sql_query: str,
) -> str:
# stdlib
import datetime

# third party
from google.cloud import bigquery # noqa: F811
from google.oauth2 import service_account

# syft absolute
from syft import SyftException

# Auth for Bigquer based on the workload identity
credentials = service_account.Credentials.from_service_account_info(
context.settings["credentials"]
)
scoped_credentials = credentials.with_scopes(
["https://www.googleapis.com/auth/cloud-platform"]
)

client = bigquery.Client(
credentials=scoped_credentials,
location=context.settings["region"],
)

# Store a dict with the calltimes for each user, via the email.
if context.settings["rate_limiter_enabled"]:
if context.user.email not in context.state.keys():
context.state[context.user.email] = []

if not context.code.is_within_rate_limit(context):
raise SyftException(
public_message="Rate limit of calls per minute has been reached."
)
context.state[context.user.email].append(datetime.datetime.now())

try:
rows = client.query_and_wait(
sql_query,
project=context.settings["project_id"],
)

if rows.total_rows > 1_000_000:
raise SyftException(
public_message="Please only write queries that gather aggregate statistics"
)

return rows.to_dataframe()

except Exception as e:
# not a bigquery exception
if not hasattr(e, "_errors"):
output = f"got exception e: {type(e)} {str(e)}"
raise SyftException(
public_message=f"An error occured executing the API call {output}"
)

# Treat all errors that we would like to be forwarded to the data scientists
# By default, any exception is only visible to the data owner.

if e._errors[0]["reason"] in [
"badRequest",
"blocked",
"duplicate",
"invalidQuery",
"invalid",
"jobBackendError",
"jobInternalError",
"notFound",
"notImplemented",
"rateLimitExceeded",
"resourceInUse",
"resourcesExceeded",
"tableUnavailable",
"timeout",
]:
raise SyftException(
public_message="Error occured during the call: "
+ e._errors[0]["message"]
)
else:
raise SyftException(
public_message="An error occured executing the API call, please contact the domain owner."
)

return live_test_query
Empty file.
Loading

0 comments on commit 4ad090c

Please sign in to comment.