Skip to content

Commit

Permalink
Merge pull request #9299 from OpenMined/test_upgrades
Browse files Browse the repository at this point in the history
Test upgradability for scenarios
  • Loading branch information
koenvanderveen authored Sep 24, 2024
2 parents f5c0e8c + 8e9f2ec commit 91a539c
Show file tree
Hide file tree
Showing 58 changed files with 9,882 additions and 4 deletions.
60 changes: 60 additions & 0 deletions .github/workflows/pr-tests-stack.yml
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,66 @@ jobs:
if: steps.changes.outputs.syft == 'true'
run: |
tox -e migration.test
pr-tests-scenarios-migrations:
strategy:
max-parallel: 99
matrix:
os: [ubuntu-latest]
python-version: ["3.12"]

runs-on: ${{ matrix.os }}
steps:
- name: "clean .git/config"
if: matrix.os == 'windows-latest'
continue-on-error: true
shell: bash
run: |
echo "deleting ${GITHUB_WORKSPACE}/.git/config"
rm ${GITHUB_WORKSPACE}/.git/config
- uses: actions/checkout@v4

- name: Check for file changes
uses: dorny/paths-filter@v3
id: changes
with:
base: ${{ github.ref }}
token: ${{ github.token }}
filters: .github/file-filters.yml

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
if: steps.changes.outputs.syft == 'true'
with:
python-version: ${{ matrix.python-version }}

- name: Install pip packages
if: steps.changes.outputs.syft == 'true'
run: |
python -m pip install --upgrade pip
pip install uv==0.4.1 tox==4.18.0 tox-uv==1.11.2
uv --version
- name: Get uv cache dir
id: pip-cache
if: steps.changes.outputs.syft == 'true'
shell: bash
run: |
echo "dir=$(uv cache dir)" >> $GITHUB_OUTPUT
- name: Load github cache
uses: actions/cache@v4
if: steps.changes.outputs.syft == 'true'
with:
path: ${{ steps.pip-cache.outputs.dir }}
key: ${{ runner.os }}-uv-py${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }}
restore-keys: |
${{ runner.os }}-uv-py${{ matrix.python-version }}-
- name: Run migration tests
if: steps.changes.outputs.syft == 'true'
run: |
tox -e migration.scenarios.test
pr-tests-migrations-k8s:
strategy:
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ repos:
exclude: ^(packages/grid/frontend/|.vscode)
- id: check-added-large-files
always_run: true
exclude: ^(packages/grid/backend/wheels/.*|docs/img/header.png|docs/img/terminalizer.gif)
exclude: ^(packages/grid/backend/wheels/.*|docs/img/header.png|docs/img/terminalizer.gif|^notebooks/scenarios/bigquery/upgradability/sync/migration_.*\.blob)
- id: check-yaml
always_run: true
exclude: ^(packages/grid/k8s/rendered/|packages/grid/helm/)
Expand Down
7 changes: 6 additions & 1 deletion notebooks/api/0.8/00-load-data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -711,6 +711,11 @@
}
],
"metadata": {
"kernelspec": {
"display_name": "syft_3.12",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
Expand All @@ -721,7 +726,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.5"
"version": "3.12.4"
},
"toc": {
"base_numbering": 1,
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# stdlib
import os

# syft absolute
from syft.util.util import str_to_bool

# relative
from .submit_query import make_submit_query

env_var = "TEST_BIGQUERY_APIS_LIVE"
use_live = str_to_bool(str(os.environ.get(env_var, "False")))
env_name = "Live" if use_live else "Mock"
print(f"Using {env_name} API Code, this will query BigQuery. ${env_var}=={use_live}")


if use_live:
# relative
from .live.schema import make_schema
from .live.test_query import make_test_query
else:
# relative
from .mock.schema import make_schema
from .mock.test_query import make_test_query
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# stdlib
from collections.abc import Callable

# syft absolute
import syft as sy
from syft import test_settings
from syft.rate_limiter import is_within_rate_limit


def make_schema(settings: dict, worker_pool: str) -> Callable:
updated_settings = {
"calls_per_min": 5,
"rate_limiter_enabled": True,
"credentials": test_settings.gce_service_account.to_dict(),
"region": test_settings.gce_region,
"project_id": test_settings.gce_project_id,
"dataset_1": test_settings.dataset_1,
"table_1": test_settings.table_1,
"table_2": test_settings.table_2,
} | settings

@sy.api_endpoint(
path="bigquery.schema",
description="This endpoint allows for visualising the metadata of tables available in BigQuery.",
settings=updated_settings,
helper_functions=[
is_within_rate_limit
], # Adds ratelimit as this is also a method available to data scientists
worker_pool=worker_pool,
)
def live_schema(
context,
) -> str:
# stdlib
import datetime

# third party
from google.cloud import bigquery # noqa: F811
from google.oauth2 import service_account
import pandas as pd

# syft absolute
from syft import SyftException

# Auth for Bigquer based on the workload identity
credentials = service_account.Credentials.from_service_account_info(
context.settings["credentials"]
)
scoped_credentials = credentials.with_scopes(
["https://www.googleapis.com/auth/cloud-platform"]
)

client = bigquery.Client(
credentials=scoped_credentials,
location=context.settings["region"],
)

# Store a dict with the calltimes for each user, via the email.
if context.settings["rate_limiter_enabled"]:
if context.user.email not in context.state.keys():
context.state[context.user.email] = []

if not context.code.is_within_rate_limit(context):
raise SyftException(
public_message="Rate limit of calls per minute has been reached."
)
context.state[context.user.email].append(datetime.datetime.now())

try:
# Formats the data schema in a data frame format
# Warning: the only supported format types are primitives, np.ndarrays and pd.DataFrames

data_schema = []
for table_id in [
f"{context.settings['dataset_1']}.{context.settings['table_1']}",
f"{context.settings['dataset_1']}.{context.settings['table_2']}",
]:
table = client.get_table(table_id)
for schema in table.schema:
data_schema.append(
{
"project": str(table.project),
"dataset_id": str(table.dataset_id),
"table_id": str(table.table_id),
"schema_name": str(schema.name),
"schema_field": str(schema.field_type),
"description": str(table.description),
"num_rows": str(table.num_rows),
}
)
return pd.DataFrame(data_schema)

except Exception as e:
# not a bigquery exception
if not hasattr(e, "_errors"):
output = f"got exception e: {type(e)} {str(e)}"
raise SyftException(
public_message=f"An error occured executing the API call {output}"
)

# Should add appropriate error handling for what should be exposed to the data scientists.
raise SyftException(
public_message="An error occured executing the API call, please contact the domain owner."
)

return live_schema
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# stdlib
from collections.abc import Callable

# syft absolute
import syft as sy
from syft import test_settings
from syft.rate_limiter import is_within_rate_limit


def make_test_query(settings) -> Callable:
updated_settings = {
"calls_per_min": 10,
"rate_limiter_enabled": True,
"credentials": test_settings.gce_service_account.to_dict(),
"region": test_settings.gce_region,
"project_id": test_settings.gce_project_id,
} | settings

# these are the same if you allow the rate limiter to be turned on and off
@sy.api_endpoint_method(
settings=updated_settings,
helper_functions=[is_within_rate_limit],
)
def live_test_query(
context,
sql_query: str,
) -> str:
# stdlib
import datetime

# third party
from google.cloud import bigquery # noqa: F811
from google.oauth2 import service_account

# syft absolute
from syft import SyftException

# Auth for Bigquer based on the workload identity
credentials = service_account.Credentials.from_service_account_info(
context.settings["credentials"]
)
scoped_credentials = credentials.with_scopes(
["https://www.googleapis.com/auth/cloud-platform"]
)

client = bigquery.Client(
credentials=scoped_credentials,
location=context.settings["region"],
)

# Store a dict with the calltimes for each user, via the email.
if context.settings["rate_limiter_enabled"]:
if context.user.email not in context.state.keys():
context.state[context.user.email] = []

if not context.code.is_within_rate_limit(context):
raise SyftException(
public_message="Rate limit of calls per minute has been reached."
)
context.state[context.user.email].append(datetime.datetime.now())

try:
rows = client.query_and_wait(
sql_query,
project=context.settings["project_id"],
)

if rows.total_rows > 1_000_000:
raise SyftException(
public_message="Please only write queries that gather aggregate statistics"
)

return rows.to_dataframe()

except Exception as e:
# not a bigquery exception
if not hasattr(e, "_errors"):
output = f"got exception e: {type(e)} {str(e)}"
raise SyftException(
public_message=f"An error occured executing the API call {output}"
)

# Treat all errors that we would like to be forwarded to the data scientists
# By default, any exception is only visible to the data owner.

if e._errors[0]["reason"] in [
"badRequest",
"blocked",
"duplicate",
"invalidQuery",
"invalid",
"jobBackendError",
"jobInternalError",
"notFound",
"notImplemented",
"rateLimitExceeded",
"resourceInUse",
"resourcesExceeded",
"tableUnavailable",
"timeout",
]:
raise SyftException(
public_message="Error occured during the call: "
+ e._errors[0]["message"]
)
else:
raise SyftException(
public_message="An error occured executing the API call, please contact the domain owner."
)

return live_test_query
Empty file.
Loading

0 comments on commit 91a539c

Please sign in to comment.