From e7179e62b006d9b9f63e1d182ff6a715f6e2487e Mon Sep 17 00:00:00 2001 From: Sameer Wagh Date: Fri, 13 Sep 2024 10:27:12 -0400 Subject: [PATCH 01/11] Moving helpers into its own module without helpers Co-authored-by: Brendan Schell --- .pre-commit-config.yaml | 1 + ...tart-and-configure-server-and-admins.ipynb | 16 +- .../001-scale-delete-worker-pools.ipynb | 67 ++--- .../bigquery/010-setup-bigquery-pool.ipynb | 14 +- .../bigquery/011-users-emails-passwords.ipynb | 16 +- .../bigquery/020-configure-api.ipynb | 13 +- .../scenarios/bigquery/021-create-jobs.ipynb | 7 +- .../bigquery/040-do-review-requests.ipynb | 13 +- .../bigquery/050-ds-get-results.ipynb | 13 +- .../sync/01-setup-high-low-datasites.ipynb | 12 +- .../sync/02-configure-api-and-sync.ipynb | 19 +- packages/syft/src/syft/__init__.py | 8 - packages/syft/src/syft/util/util.py | 15 - test_helpers/apis/__init__.py | 23 -- test_helpers/apis/live/schema.py | 108 ------- test_helpers/apis/live/test_query.py | 113 -------- test_helpers/apis/mock/data.py | 268 ------------------ test_helpers/apis/mock/schema.py | 52 ---- test_helpers/apis/mock/test_query.py | 138 --------- test_helpers/apis/rate_limiter.py | 16 -- test_helpers/apis/submit_query.py | 42 --- .../scenarios/bigquery/level_2_basic_test.py | 7 + 22 files changed, 77 insertions(+), 904 deletions(-) delete mode 100644 test_helpers/apis/__init__.py delete mode 100644 test_helpers/apis/live/schema.py delete mode 100644 test_helpers/apis/live/test_query.py delete mode 100644 test_helpers/apis/mock/data.py delete mode 100644 test_helpers/apis/mock/schema.py delete mode 100644 test_helpers/apis/mock/test_query.py delete mode 100644 test_helpers/apis/rate_limiter.py delete mode 100644 test_helpers/apis/submit_query.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e1c50cd3b96..56bc5340ddb 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -161,6 +161,7 @@ repos: "--non-interactive", "--config-file=tox.ini", ] + excludes: ^packages/syft/src/syft/util/test_helpers - repo: https://github.com/kynan/nbstripout rev: 0.7.1 diff --git a/notebooks/scenarios/bigquery/000-start-and-configure-server-and-admins.ipynb b/notebooks/scenarios/bigquery/000-start-and-configure-server-and-admins.ipynb index 8035f5e61e1..86cbdc836c6 100644 --- a/notebooks/scenarios/bigquery/000-start-and-configure-server-and-admins.ipynb +++ b/notebooks/scenarios/bigquery/000-start-and-configure-server-and-admins.ipynb @@ -20,24 +20,12 @@ "metadata": {}, "outputs": [], "source": [ - "# isort: off\n", "# stdlib\n", "from os import environ as env\n", "\n", "# syft absolute\n", "import syft as sy\n", - "from syft import test_helpers # noqa: F401\n", - "\n", - "# third party\n", - "from email_helpers import get_email_server\n", - "# isort: on" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Launch & login" + "from syft.util.test_helpers.email_helpers import get_email_server" ] }, { @@ -249,7 +237,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.4" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb b/notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb index be9579059eb..a8299b5cdcd 100644 --- a/notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb +++ b/notebooks/scenarios/bigquery/001-scale-delete-worker-pools.ipynb @@ -20,18 +20,13 @@ "metadata": {}, "outputs": [], "source": [ - "# isort: off\n", "# stdlib\n", "import os\n", "\n", "# syft absolute\n", "import syft as sy\n", - "from syft import test_helpers # noqa: F401\n", - "\n", - "# third party\n", - "from email_helpers import Timeout\n", - "from email_helpers import get_email_server\n", - "# isort: on" + "from syft.util.test_helpers.email_helpers import Timeout\n", + "from syft.util.test_helpers.email_helpers import get_email_server" ] }, { @@ -40,14 +35,6 @@ "id": "2", "metadata": {}, "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "3", - "metadata": {}, - "outputs": [], "source": [ "environment = os.environ.get(\"ORCHESTRA_DEPLOYMENT_TYPE\", \"python\")\n", "\n", @@ -60,7 +47,7 @@ }, { "cell_type": "markdown", - "id": "4", + "id": "3", "metadata": {}, "source": [ "### Launch server & login" @@ -69,7 +56,7 @@ { "cell_type": "code", "execution_count": null, - "id": "5", + "id": "4", "metadata": {}, "outputs": [], "source": [ @@ -86,7 +73,7 @@ { "cell_type": "code", "execution_count": null, - "id": "6", + "id": "5", "metadata": {}, "outputs": [], "source": [ @@ -96,7 +83,7 @@ { "cell_type": "code", "execution_count": null, - "id": "7", + "id": "6", "metadata": {}, "outputs": [], "source": [ @@ -108,7 +95,7 @@ { "cell_type": "code", "execution_count": null, - "id": "8", + "id": "7", "metadata": {}, "outputs": [], "source": [ @@ -118,7 +105,7 @@ { "cell_type": "code", "execution_count": null, - "id": "9", + "id": "8", "metadata": {}, "outputs": [], "source": [ @@ -128,7 +115,7 @@ }, { "cell_type": "markdown", - "id": "10", + "id": "9", "metadata": {}, "source": [ "### Scale Worker pool" @@ -136,7 +123,7 @@ }, { "cell_type": "markdown", - "id": "11", + "id": "10", "metadata": {}, "source": [ "##### Scale up" @@ -145,7 +132,7 @@ { "cell_type": "code", "execution_count": null, - "id": "12", + "id": "11", "metadata": {}, "outputs": [], "source": [ @@ -159,7 +146,7 @@ { "cell_type": "code", "execution_count": null, - "id": "13", + "id": "12", "metadata": {}, "outputs": [], "source": [ @@ -169,7 +156,7 @@ { "cell_type": "code", "execution_count": null, - "id": "14", + "id": "13", "metadata": {}, "outputs": [], "source": [ @@ -189,7 +176,7 @@ }, { "cell_type": "markdown", - "id": "15", + "id": "14", "metadata": {}, "source": [ "##### Scale down" @@ -198,7 +185,7 @@ { "cell_type": "code", "execution_count": null, - "id": "16", + "id": "15", "metadata": {}, "outputs": [], "source": [ @@ -213,7 +200,7 @@ { "cell_type": "code", "execution_count": null, - "id": "17", + "id": "16", "metadata": {}, "outputs": [], "source": [ @@ -232,7 +219,7 @@ { "cell_type": "code", "execution_count": null, - "id": "18", + "id": "17", "metadata": {}, "outputs": [], "source": [ @@ -245,7 +232,7 @@ }, { "cell_type": "markdown", - "id": "19", + "id": "18", "metadata": {}, "source": [ "#### Delete Worker Pool" @@ -254,7 +241,7 @@ { "cell_type": "code", "execution_count": null, - "id": "20", + "id": "19", "metadata": {}, "outputs": [], "source": [ @@ -267,7 +254,7 @@ { "cell_type": "code", "execution_count": null, - "id": "21", + "id": "20", "metadata": {}, "outputs": [], "source": [ @@ -277,7 +264,7 @@ }, { "cell_type": "markdown", - "id": "22", + "id": "21", "metadata": {}, "source": [ "#### Re-launch the default worker pool" @@ -286,7 +273,7 @@ { "cell_type": "code", "execution_count": null, - "id": "23", + "id": "22", "metadata": {}, "outputs": [], "source": [ @@ -296,7 +283,7 @@ { "cell_type": "code", "execution_count": null, - "id": "24", + "id": "23", "metadata": {}, "outputs": [], "source": [ @@ -310,7 +297,7 @@ { "cell_type": "code", "execution_count": null, - "id": "25", + "id": "24", "metadata": {}, "outputs": [], "source": [ @@ -324,7 +311,7 @@ { "cell_type": "code", "execution_count": null, - "id": "26", + "id": "25", "metadata": {}, "outputs": [], "source": [ @@ -334,7 +321,7 @@ { "cell_type": "code", "execution_count": null, - "id": "27", + "id": "26", "metadata": {}, "outputs": [], "source": [ @@ -344,7 +331,7 @@ { "cell_type": "code", "execution_count": null, - "id": "28", + "id": "27", "metadata": {}, "outputs": [], "source": [] diff --git a/notebooks/scenarios/bigquery/010-setup-bigquery-pool.ipynb b/notebooks/scenarios/bigquery/010-setup-bigquery-pool.ipynb index 22f6dfaa977..d72a82f9eb1 100644 --- a/notebooks/scenarios/bigquery/010-setup-bigquery-pool.ipynb +++ b/notebooks/scenarios/bigquery/010-setup-bigquery-pool.ipynb @@ -18,18 +18,13 @@ "metadata": {}, "outputs": [], "source": [ - "# isort: off\n", "# stdlib\n", "import os\n", "\n", "# syft absolute\n", "import syft as sy\n", - "from syft import test_helpers # noqa: F401\n", "from syft import test_settings\n", - "\n", - "# third party\n", - "from email_helpers import get_email_server\n", - "# isort: on" + "from syft.util.test_helpers.email_helpers import get_email_server" ] }, { @@ -526,6 +521,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "syft", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -536,7 +536,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/notebooks/scenarios/bigquery/011-users-emails-passwords.ipynb b/notebooks/scenarios/bigquery/011-users-emails-passwords.ipynb index b87fd2a7731..7978234b07e 100644 --- a/notebooks/scenarios/bigquery/011-users-emails-passwords.ipynb +++ b/notebooks/scenarios/bigquery/011-users-emails-passwords.ipynb @@ -22,21 +22,17 @@ "metadata": {}, "outputs": [], "source": [ - "# isort: off\n", "# stdlib\n", "import os\n", "\n", "# syft absolute\n", "import syft as sy\n", - "from syft import test_helpers # noqa: F401\n", - "\n", - "# third party\n", - "from email_helpers import SENDER\n", - "from email_helpers import create_user\n", - "from email_helpers import get_email_server\n", - "from email_helpers import make_user\n", - "from email_helpers import save_users\n", - "# isort: on" + "from syft import get_helpers # noqa: F401\n", + "from syft.util.test_helpers.email_helpers import SENDER\n", + "from syft.util.test_helpers.email_helpers import create_user\n", + "from syft.util.test_helpers.email_helpers import get_email_server\n", + "from syft.util.test_helpers.email_helpers import make_user\n", + "from syft.util.test_helpers.email_helpers import save_users" ] }, { diff --git a/notebooks/scenarios/bigquery/020-configure-api.ipynb b/notebooks/scenarios/bigquery/020-configure-api.ipynb index 83abef20ff7..1371dfaf5c7 100644 --- a/notebooks/scenarios/bigquery/020-configure-api.ipynb +++ b/notebooks/scenarios/bigquery/020-configure-api.ipynb @@ -28,22 +28,17 @@ "metadata": {}, "outputs": [], "source": [ - "# isort: off\n", "# stdlib\n", "\n", "# syft absolute\n", "import syft as sy\n", - "from syft import test_helpers # noqa: F401\n", "from syft import test_settings\n", - "\n", - "# third party\n", - "from apis import make_schema\n", - "from apis import make_submit_query\n", - "from apis import make_test_query\n", + "from syft.util.test_helpers.apis import make_schema\n", + "from syft.util.test_helpers.apis import make_submit_query\n", + "from syft.util.test_helpers.apis import make_test_query\n", "\n", "# run email server\n", - "from email_helpers import get_email_server\n", - "# isort: on" + "from syft.util.test_helpers.email_helpers import get_email_server" ] }, { diff --git a/notebooks/scenarios/bigquery/021-create-jobs.ipynb b/notebooks/scenarios/bigquery/021-create-jobs.ipynb index 5a14895133a..e576d45dff8 100644 --- a/notebooks/scenarios/bigquery/021-create-jobs.ipynb +++ b/notebooks/scenarios/bigquery/021-create-jobs.ipynb @@ -33,19 +33,14 @@ "metadata": {}, "outputs": [], "source": [ - "# isort: off\n", "# stdlib\n", "from collections import Counter\n", "import os\n", "\n", "# syft absolute\n", "import syft as sy\n", - "from syft import test_helpers # noqa: F401\n", "from syft.service.job.job_stash import JobStatus\n", - "\n", - "# third party\n", - "from email_helpers import get_email_server\n", - "# isort: on" + "from syft.util.test_helpers.email_helpers import get_email_server" ] }, { diff --git a/notebooks/scenarios/bigquery/040-do-review-requests.ipynb b/notebooks/scenarios/bigquery/040-do-review-requests.ipynb index aa4a7b0c2a1..2f5b9fc00e5 100644 --- a/notebooks/scenarios/bigquery/040-do-review-requests.ipynb +++ b/notebooks/scenarios/bigquery/040-do-review-requests.ipynb @@ -18,21 +18,16 @@ "metadata": {}, "outputs": [], "source": [ - "# isort: off\n", "# stdlib\n", "import random\n", "\n", "# syft absolute\n", "import syft as sy\n", - "from syft import test_helpers # noqa: F401\n", "from syft.service.job.job_stash import Job\n", - "\n", - "# third party\n", - "from email_helpers import get_email_server\n", - "from job_helpers import approve_by_running\n", - "from job_helpers import get_job_emails\n", - "from job_helpers import get_request_for_job_info\n", - "# isort: on" + "from syft.util.test_helpers.email_helpers import get_email_server\n", + "from syft.util.test_helpers.job_helpers import approve_by_running\n", + "from syft.util.test_helpers.job_helpers import get_job_emails\n", + "from syft.util.test_helpers.job_helpers import get_request_for_job_info" ] }, { diff --git a/notebooks/scenarios/bigquery/050-ds-get-results.ipynb b/notebooks/scenarios/bigquery/050-ds-get-results.ipynb index 9a9bc1ef588..35791771b2f 100644 --- a/notebooks/scenarios/bigquery/050-ds-get-results.ipynb +++ b/notebooks/scenarios/bigquery/050-ds-get-results.ipynb @@ -18,17 +18,12 @@ "metadata": {}, "outputs": [], "source": [ - "# isort: off\n", "# syft absolute\n", "import syft as sy\n", - "from syft import test_helpers # noqa: F401\n", - "\n", - "# third party\n", - "from email_helpers import get_email_server\n", - "from email_helpers import load_users\n", - "from job_helpers import load_jobs\n", - "from job_helpers import save_jobs\n", - "# isort: on" + "from syft.util.test_helpers.email_helpers import get_email_server\n", + "from syft.util.test_helpers.email_helpers import load_users\n", + "from syft.util.test_helpers.job_helpers import load_jobs\n", + "from syft.util.test_helpers.job_helpers import save_jobs" ] }, { diff --git a/notebooks/scenarios/bigquery/sync/01-setup-high-low-datasites.ipynb b/notebooks/scenarios/bigquery/sync/01-setup-high-low-datasites.ipynb index 633a73c38e4..691c58d4b00 100644 --- a/notebooks/scenarios/bigquery/sync/01-setup-high-low-datasites.ipynb +++ b/notebooks/scenarios/bigquery/sync/01-setup-high-low-datasites.ipynb @@ -42,15 +42,15 @@ "metadata": {}, "outputs": [], "source": [ - "# isort: off\n", "# syft absolute\n", "import syft as sy\n", - "from syft import test_helpers # noqa: F401\n", "from syft import test_settings\n", - "\n", - "from worker_helpers import build_and_launch_worker_pool_from_docker_str\n", - "from worker_helpers import launch_worker_pool_from_docker_tag_and_registry\n", - "# isort: on" + "from syft.util.test_helpers.worker_helpers import (\n", + " build_and_launch_worker_pool_from_docker_str,\n", + ")\n", + "from syft.util.test_helpers.worker_helpers import (\n", + " launch_worker_pool_from_docker_tag_and_registry,\n", + ")" ] }, { diff --git a/notebooks/scenarios/bigquery/sync/02-configure-api-and-sync.ipynb b/notebooks/scenarios/bigquery/sync/02-configure-api-and-sync.ipynb index 094841ef58e..99274aba2a8 100644 --- a/notebooks/scenarios/bigquery/sync/02-configure-api-and-sync.ipynb +++ b/notebooks/scenarios/bigquery/sync/02-configure-api-and-sync.ipynb @@ -36,24 +36,21 @@ "metadata": {}, "outputs": [], "source": [ - "# isort: off\n", "# stdlib\n", "\n", - "# syft absolute\n", - "import syft as sy\n", - "from syft import test_helpers # noqa: F401\n", - "from syft import test_settings\n", - "from syft.client.syncing import compare_clients\n", - "\n", "# set to use the live APIs\n", "# import os\n", "# os.environ[\"TEST_BIGQUERY_APIS_LIVE\"] = \"True\"\n", "# third party\n", - "from apis import make_schema\n", - "from apis import make_submit_query\n", - "from apis import make_test_query\n", "import pandas as pd\n", - "# isort: on" + "\n", + "# syft absolute\n", + "import syft as sy\n", + "from syft import test_settings\n", + "from syft.client.syncing import compare_clients\n", + "from syft.util.test_helpers.apis import make_schema\n", + "from syft.util.test_helpers.apis import make_submit_query\n", + "from syft.util.test_helpers.apis import make_test_query" ] }, { diff --git a/packages/syft/src/syft/__init__.py b/packages/syft/src/syft/__init__.py index fb0fdfa69b1..2534f22077e 100644 --- a/packages/syft/src/syft/__init__.py +++ b/packages/syft/src/syft/__init__.py @@ -154,14 +154,6 @@ def _test_settings() -> Any: return test_settings() -@module_property -def _test_helpers() -> None: - # relative - from .util.util import add_helper_path_to_python_path - - add_helper_path_to_python_path() - - @module_property def hello_baby() -> None: print("Hello baby!") diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index 83efaa196e7..fa20c3fc2c2 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -1143,21 +1143,6 @@ def test_settings() -> Any: return test_settings -def add_helper_path_to_python_path() -> None: - current_path = "." - - # jupyter uses "." which resolves to the notebook - if not is_interpreter_jupyter(): - # python uses the file which has from syft import test_settings in it - import_path = get_caller_file_path() - if import_path: - current_path = import_path - - base_dir = find_base_dir_with_tox_ini(current_path) - notebook_helper_path = os.path.join(base_dir, "test_helpers") - sys.path.append(notebook_helper_path) - - class CustomRepr(reprlib.Repr): def repr_str(self, obj: Any, level: int = 0) -> str: if len(obj) <= self.maxstring: diff --git a/test_helpers/apis/__init__.py b/test_helpers/apis/__init__.py deleted file mode 100644 index 7231b580696..00000000000 --- a/test_helpers/apis/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -# stdlib -import os - -# syft absolute -from syft.util.util import str_to_bool - -# relative -from .submit_query import make_submit_query - -env_var = "TEST_BIGQUERY_APIS_LIVE" -use_live = str_to_bool(str(os.environ.get(env_var, "False"))) -env_name = "Live" if use_live else "Mock" -print(f"Using {env_name} API Code, this will query BigQuery. ${env_var}=={use_live}") - - -if use_live: - # relative - from .live.schema import make_schema - from .live.test_query import make_test_query -else: - # relative - from .mock.schema import make_schema - from .mock.test_query import make_test_query diff --git a/test_helpers/apis/live/schema.py b/test_helpers/apis/live/schema.py deleted file mode 100644 index 5b39d9d9066..00000000000 --- a/test_helpers/apis/live/schema.py +++ /dev/null @@ -1,108 +0,0 @@ -# stdlib -from collections.abc import Callable - -# syft absolute -import syft as sy -from syft import test_settings - -# relative -from ..rate_limiter import is_within_rate_limit - - -def make_schema(settings: dict, worker_pool: str) -> Callable: - updated_settings = { - "calls_per_min": 5, - "rate_limiter_enabled": True, - "credentials": test_settings.gce_service_account.to_dict(), - "region": test_settings.gce_region, - "project_id": test_settings.gce_project_id, - "dataset_1": test_settings.dataset_1, - "table_1": test_settings.table_1, - "table_2": test_settings.table_2, - } | settings - - @sy.api_endpoint( - path="bigquery.schema", - description="This endpoint allows for visualising the metadata of tables available in BigQuery.", - settings=updated_settings, - helper_functions=[ - is_within_rate_limit - ], # Adds ratelimit as this is also a method available to data scientists - worker_pool=worker_pool, - ) - def live_schema( - context, - ) -> str: - # stdlib - import datetime - - # third party - from google.cloud import bigquery # noqa: F811 - from google.oauth2 import service_account - import pandas as pd - - # syft absolute - from syft import SyftException - - # Auth for Bigquer based on the workload identity - credentials = service_account.Credentials.from_service_account_info( - context.settings["credentials"] - ) - scoped_credentials = credentials.with_scopes( - ["https://www.googleapis.com/auth/cloud-platform"] - ) - - client = bigquery.Client( - credentials=scoped_credentials, - location=context.settings["region"], - ) - - # Store a dict with the calltimes for each user, via the email. - if context.settings["rate_limiter_enabled"]: - if context.user.email not in context.state.keys(): - context.state[context.user.email] = [] - - if not context.code.is_within_rate_limit(context): - raise SyftException( - public_message="Rate limit of calls per minute has been reached." - ) - context.state[context.user.email].append(datetime.datetime.now()) - - try: - # Formats the data schema in a data frame format - # Warning: the only supported format types are primitives, np.ndarrays and pd.DataFrames - - data_schema = [] - for table_id in [ - f"{context.settings['dataset_1']}.{context.settings['table_1']}", - f"{context.settings['dataset_1']}.{context.settings['table_2']}", - ]: - table = client.get_table(table_id) - for schema in table.schema: - data_schema.append( - { - "project": str(table.project), - "dataset_id": str(table.dataset_id), - "table_id": str(table.table_id), - "schema_name": str(schema.name), - "schema_field": str(schema.field_type), - "description": str(table.description), - "num_rows": str(table.num_rows), - } - ) - return pd.DataFrame(data_schema) - - except Exception as e: - # not a bigquery exception - if not hasattr(e, "_errors"): - output = f"got exception e: {type(e)} {str(e)}" - raise SyftException( - public_message=f"An error occured executing the API call {output}" - ) - - # Should add appropriate error handling for what should be exposed to the data scientists. - raise SyftException( - public_message="An error occured executing the API call, please contact the domain owner." - ) - - return live_schema diff --git a/test_helpers/apis/live/test_query.py b/test_helpers/apis/live/test_query.py deleted file mode 100644 index 344879dcb62..00000000000 --- a/test_helpers/apis/live/test_query.py +++ /dev/null @@ -1,113 +0,0 @@ -# stdlib -from collections.abc import Callable - -# syft absolute -import syft as sy -from syft import test_settings - -# relative -from ..rate_limiter import is_within_rate_limit - - -def make_test_query(settings) -> Callable: - updated_settings = { - "calls_per_min": 10, - "rate_limiter_enabled": True, - "credentials": test_settings.gce_service_account.to_dict(), - "region": test_settings.gce_region, - "project_id": test_settings.gce_project_id, - } | settings - - # these are the same if you allow the rate limiter to be turned on and off - @sy.api_endpoint_method( - settings=updated_settings, - helper_functions=[is_within_rate_limit], - ) - def live_test_query( - context, - sql_query: str, - ) -> str: - # stdlib - import datetime - - # third party - from google.cloud import bigquery # noqa: F811 - from google.oauth2 import service_account - - # syft absolute - from syft import SyftException - - # Auth for Bigquer based on the workload identity - credentials = service_account.Credentials.from_service_account_info( - context.settings["credentials"] - ) - scoped_credentials = credentials.with_scopes( - ["https://www.googleapis.com/auth/cloud-platform"] - ) - - client = bigquery.Client( - credentials=scoped_credentials, - location=context.settings["region"], - ) - - # Store a dict with the calltimes for each user, via the email. - if context.settings["rate_limiter_enabled"]: - if context.user.email not in context.state.keys(): - context.state[context.user.email] = [] - - if not context.code.is_within_rate_limit(context): - raise SyftException( - public_message="Rate limit of calls per minute has been reached." - ) - context.state[context.user.email].append(datetime.datetime.now()) - - try: - rows = client.query_and_wait( - sql_query, - project=context.settings["project_id"], - ) - - if rows.total_rows > 1_000_000: - raise SyftException( - public_message="Please only write queries that gather aggregate statistics" - ) - - return rows.to_dataframe() - - except Exception as e: - # not a bigquery exception - if not hasattr(e, "_errors"): - output = f"got exception e: {type(e)} {str(e)}" - raise SyftException( - public_message=f"An error occured executing the API call {output}" - ) - - # Treat all errors that we would like to be forwarded to the data scientists - # By default, any exception is only visible to the data owner. - - if e._errors[0]["reason"] in [ - "badRequest", - "blocked", - "duplicate", - "invalidQuery", - "invalid", - "jobBackendError", - "jobInternalError", - "notFound", - "notImplemented", - "rateLimitExceeded", - "resourceInUse", - "resourcesExceeded", - "tableUnavailable", - "timeout", - ]: - raise SyftException( - public_message="Error occured during the call: " - + e._errors[0]["message"] - ) - else: - raise SyftException( - public_message="An error occured executing the API call, please contact the domain owner." - ) - - return live_test_query diff --git a/test_helpers/apis/mock/data.py b/test_helpers/apis/mock/data.py deleted file mode 100644 index 82262bf7a01..00000000000 --- a/test_helpers/apis/mock/data.py +++ /dev/null @@ -1,268 +0,0 @@ -# stdlib -from math import nan - -schema_dict = { - "project": { - 0: "example-project", - 1: "example-project", - 2: "example-project", - 3: "example-project", - 4: "example-project", - 5: "example-project", - 6: "example-project", - 7: "example-project", - 8: "example-project", - 9: "example-project", - 10: "example-project", - 11: "example-project", - 12: "example-project", - 13: "example-project", - 14: "example-project", - 15: "example-project", - 16: "example-project", - 17: "example-project", - 18: "example-project", - 19: "example-project", - 20: "example-project", - 21: "example-project", - 22: "example-project", - }, - "dataset_id": { - 0: "test_1gb", - 1: "test_1gb", - 2: "test_1gb", - 3: "test_1gb", - 4: "test_1gb", - 5: "test_1gb", - 6: "test_1gb", - 7: "test_1gb", - 8: "test_1gb", - 9: "test_1gb", - 10: "test_1gb", - 11: "test_1gb", - 12: "test_1gb", - 13: "test_1gb", - 14: "test_1gb", - 15: "test_1gb", - 16: "test_1gb", - 17: "test_1gb", - 18: "test_1gb", - 19: "test_1gb", - 20: "test_1gb", - 21: "test_1gb", - 22: "test_1gb", - }, - "table_id": { - 0: "posts", - 1: "posts", - 2: "posts", - 3: "posts", - 4: "posts", - 5: "posts", - 6: "posts", - 7: "comments", - 8: "comments", - 9: "comments", - 10: "comments", - 11: "comments", - 12: "comments", - 13: "comments", - 14: "comments", - 15: "comments", - 16: "comments", - 17: "comments", - 18: "comments", - 19: "comments", - 20: "comments", - 21: "comments", - 22: "comments", - }, - "schema_name": { - 0: "int64_field_0", - 1: "id", - 2: "name", - 3: "subscribers_count", - 4: "permalink", - 5: "nsfw", - 6: "spam", - 7: "int64_field_0", - 8: "id", - 9: "body", - 10: "parent_id", - 11: "created_at", - 12: "last_modified_at", - 13: "gilded", - 14: "permalink", - 15: "score", - 16: "comment_id", - 17: "post_id", - 18: "author_id", - 19: "spam", - 20: "deleted", - 21: "upvote_raio", - 22: "collapsed_in_crowd_control", - }, - "schema_field": { - 0: "INTEGER", - 1: "STRING", - 2: "STRING", - 3: "INTEGER", - 4: "STRING", - 5: "FLOAT", - 6: "BOOLEAN", - 7: "INTEGER", - 8: "STRING", - 9: "STRING", - 10: "STRING", - 11: "INTEGER", - 12: "INTEGER", - 13: "BOOLEAN", - 14: "STRING", - 15: "INTEGER", - 16: "STRING", - 17: "STRING", - 18: "STRING", - 19: "BOOLEAN", - 20: "BOOLEAN", - 21: "FLOAT", - 22: "BOOLEAN", - }, - "description": { - 0: "None", - 1: "None", - 2: "None", - 3: "None", - 4: "None", - 5: "None", - 6: "None", - 7: "None", - 8: "None", - 9: "None", - 10: "None", - 11: "None", - 12: "None", - 13: "None", - 14: "None", - 15: "None", - 16: "None", - 17: "None", - 18: "None", - 19: "None", - 20: "None", - 21: "None", - 22: "None", - }, - "num_rows": { - 0: "2000000", - 1: "2000000", - 2: "2000000", - 3: "2000000", - 4: "2000000", - 5: "2000000", - 6: "2000000", - 7: "2000000", - 8: "2000000", - 9: "2000000", - 10: "2000000", - 11: "2000000", - 12: "2000000", - 13: "2000000", - 14: "2000000", - 15: "2000000", - 16: "2000000", - 17: "2000000", - 18: "2000000", - 19: "2000000", - 20: "2000000", - 21: "2000000", - 22: "2000000", - }, -} - - -query_dict = { - "int64_field_0": { - 0: 4, - 1: 5, - 2: 10, - 3: 16, - 4: 17, - 5: 23, - 6: 24, - 7: 25, - 8: 27, - 9: 40, - }, - "id": { - 0: "t5_via1x", - 1: "t5_cv9gn", - 2: "t5_8p2tq", - 3: "t5_8fcro", - 4: "t5_td5of", - 5: "t5_z01fv", - 6: "t5_hmqjk", - 7: "t5_1flyj", - 8: "t5_5rwej", - 9: "t5_uurcv", - }, - "name": { - 0: "/channel/mylittlepony", - 1: "/channel/polyamory", - 2: "/channel/Catholicism", - 3: "/channel/cordcutters", - 4: "/channel/stevenuniverse", - 5: "/channel/entitledbitch", - 6: "/channel/engineering", - 7: "/channel/nottheonion", - 8: "/channel/FoodPorn", - 9: "/channel/puppysmiles", - }, - "subscribers_count": { - 0: 4323081, - 1: 2425929, - 2: 4062607, - 3: 7543226, - 4: 2692168, - 5: 2709080, - 6: 8766144, - 7: 2580984, - 8: 7784809, - 9: 3715991, - }, - "permalink": { - 0: "/channel//channel/mylittlepony", - 1: "/channel//channel/polyamory", - 2: "/channel//channel/Catholicism", - 3: "/channel//channel/cordcutters", - 4: "/channel//channel/stevenuniverse", - 5: "/channel//channel/entitledbitch", - 6: "/channel//channel/engineering", - 7: "/channel//channel/nottheonion", - 8: "/channel//channel/FoodPorn", - 9: "/channel//channel/puppysmiles", - }, - "nsfw": { - 0: nan, - 1: nan, - 2: nan, - 3: nan, - 4: nan, - 5: nan, - 6: nan, - 7: nan, - 8: nan, - 9: nan, - }, - "spam": { - 0: False, - 1: False, - 2: False, - 3: False, - 4: False, - 5: False, - 6: False, - 7: False, - 8: False, - 9: False, - }, -} diff --git a/test_helpers/apis/mock/schema.py b/test_helpers/apis/mock/schema.py deleted file mode 100644 index a95e04f2f1d..00000000000 --- a/test_helpers/apis/mock/schema.py +++ /dev/null @@ -1,52 +0,0 @@ -# stdlib -from collections.abc import Callable - -# syft absolute -import syft as sy - -# relative -from ..rate_limiter import is_within_rate_limit -from .data import schema_dict - - -def make_schema(settings, worker_pool) -> Callable: - updated_settings = { - "calls_per_min": 5, - "rate_limiter_enabled": True, - "schema_dict": schema_dict, - } | settings - - @sy.api_endpoint( - path="bigquery.schema", - description="This endpoint allows for visualising the metadata of tables available in BigQuery.", - settings=updated_settings, - helper_functions=[is_within_rate_limit], - worker_pool=worker_pool, - ) - def mock_schema( - context, - ) -> str: - # syft absolute - from syft import SyftException - - # Store a dict with the calltimes for each user, via the email. - if context.settings["rate_limiter_enabled"]: - # stdlib - import datetime - - if context.user.email not in context.state.keys(): - context.state[context.user.email] = [] - - if not context.code.is_within_rate_limit(context): - raise SyftException( - public_message="Rate limit of calls per minute has been reached." - ) - context.state[context.user.email].append(datetime.datetime.now()) - - # third party - import pandas as pd - - df = pd.DataFrame(context.settings["schema_dict"]) - return df - - return mock_schema diff --git a/test_helpers/apis/mock/test_query.py b/test_helpers/apis/mock/test_query.py deleted file mode 100644 index ae028a8cf36..00000000000 --- a/test_helpers/apis/mock/test_query.py +++ /dev/null @@ -1,138 +0,0 @@ -# stdlib -from collections.abc import Callable - -# syft absolute -import syft as sy - -# relative -from ..rate_limiter import is_within_rate_limit -from .data import query_dict - - -def extract_limit_value(sql_query: str) -> int: - # stdlib - import re - - limit_pattern = re.compile(r"\bLIMIT\s+(\d+)\b", re.IGNORECASE) - match = limit_pattern.search(sql_query) - if match: - return int(match.group(1)) - return None - - -def is_valid_sql(query: str) -> bool: - # stdlib - import sqlite3 - - # Prepare an in-memory SQLite database - conn = sqlite3.connect(":memory:") - cursor = conn.cursor() - - try: - # Use the EXPLAIN QUERY PLAN command to get the query plan - cursor.execute(f"EXPLAIN QUERY PLAN {query}") - except sqlite3.Error as e: - if "no such table" in str(e).lower(): - return True - return False - finally: - conn.close() - - -def adjust_dataframe_rows(df, target_rows: int): - # third party - import pandas as pd - - current_rows = len(df) - - if target_rows > current_rows: - # Repeat rows to match target_rows - repeat_times = (target_rows + current_rows - 1) // current_rows - df_expanded = pd.concat([df] * repeat_times, ignore_index=True).head( - target_rows - ) - else: - # Truncate rows to match target_rows - df_expanded = df.head(target_rows) - - return df_expanded - - -def make_test_query(settings: dict) -> Callable: - updated_settings = { - "calls_per_min": 10, - "rate_limiter_enabled": True, - "query_dict": query_dict, - } | settings - - # these are the same if you allow the rate limiter to be turned on and off - @sy.api_endpoint_method( - settings=updated_settings, - helper_functions=[ - is_within_rate_limit, - extract_limit_value, - is_valid_sql, - adjust_dataframe_rows, - ], - ) - def mock_test_query( - context, - sql_query: str, - ) -> str: - # stdlib - import datetime - - # third party - from google.api_core.exceptions import BadRequest - - # syft absolute - from syft import SyftException - - # Store a dict with the calltimes for each user, via the email. - if context.settings["rate_limiter_enabled"]: - if context.user.email not in context.state.keys(): - context.state[context.user.email] = [] - - if not context.code.is_within_rate_limit(context): - raise SyftException( - public_message="Rate limit of calls per minute has been reached." - ) - context.state[context.user.email].append(datetime.datetime.now()) - - bad_table = "invalid_table" - bad_post = ( - "BadRequest: 400 POST " - "https://bigquery.googleapis.com/bigquery/v2/projects/project-id/" - "queries?prettyPrint=false: " - ) - if bad_table in sql_query: - try: - raise BadRequest( - f'{bad_post} Table "{bad_table}" must be qualified ' - "with a dataset (e.g. dataset.table)." - ) - except Exception as e: - raise SyftException( - public_message=f"*must be qualified with a dataset*. {e}" - ) - - if not context.code.is_valid_sql(sql_query): - raise BadRequest( - f'{bad_post} Syntax error: Unexpected identifier "{sql_query}" at [1:1]' - ) - - # third party - import pandas as pd - - limit = context.code.extract_limit_value(sql_query) - if limit > 1_000_000: - raise SyftException( - public_message="Please only write queries that gather aggregate statistics" - ) - - base_df = pd.DataFrame(context.settings["query_dict"]) - - df = context.code.adjust_dataframe_rows(base_df, limit) - return df - - return mock_test_query diff --git a/test_helpers/apis/rate_limiter.py b/test_helpers/apis/rate_limiter.py deleted file mode 100644 index 8ce319b61f4..00000000000 --- a/test_helpers/apis/rate_limiter.py +++ /dev/null @@ -1,16 +0,0 @@ -def is_within_rate_limit(context) -> bool: - """Rate limiter for custom API calls made by users.""" - # stdlib - import datetime - - state = context.state - settings = context.settings - email = context.user.email - - current_time = datetime.datetime.now() - calls_last_min = [ - 1 if (current_time - call_time).seconds < 60 else 0 - for call_time in state[email] - ] - - return sum(calls_last_min) < settings.get("calls_per_min", 5) diff --git a/test_helpers/apis/submit_query.py b/test_helpers/apis/submit_query.py deleted file mode 100644 index a0125ee009b..00000000000 --- a/test_helpers/apis/submit_query.py +++ /dev/null @@ -1,42 +0,0 @@ -# syft absolute -import syft as sy - - -def make_submit_query(settings, worker_pool): - updated_settings = {"user_code_worker": worker_pool} | settings - - @sy.api_endpoint( - path="bigquery.submit_query", - description="API endpoint that allows you to submit SQL queries to run on the private data.", - worker_pool=worker_pool, - settings=updated_settings, - ) - def submit_query( - context, - func_name: str, - query: str, - ) -> str: - # syft absolute - import syft as sy - - @sy.syft_function( - name=func_name, - input_policy=sy.MixedInputPolicy( - endpoint=sy.Constant( - val=context.admin_client.api.services.bigquery.test_query - ), - query=sy.Constant(val=query), - client=context.admin_client, - ), - worker_pool_name=context.settings["user_code_worker"], - ) - def execute_query(query: str, endpoint): - res = endpoint(sql_query=query) - return res - - request = context.user_client.code.request_code_execution(execute_query) - context.admin_client.requests.set_tags(request, ["autosync"]) - - return f"Query submitted {request}. Use `client.code.{func_name}()` to run your query" - - return submit_query diff --git a/tests/scenarios/bigquery/level_2_basic_test.py b/tests/scenarios/bigquery/level_2_basic_test.py index 6f4f4372ad7..e3c6f18379a 100644 --- a/tests/scenarios/bigquery/level_2_basic_test.py +++ b/tests/scenarios/bigquery/level_2_basic_test.py @@ -41,6 +41,13 @@ from syft.service.job.job_stash import Job +def test_check_test_helper_module_import(): + # syft absolute + from syft.util.test_helpers.email_helpers import SENDER + + assert SENDER == "noreply@openmined.org" + + @unsync async def get_prebuilt_worker_image(events, client, expected_tag, event_name): await events.await_for(event_name=event_name, show=True) From bdf2aa8ebfb97c3d5e9490748dcc0a74fbbaff95 Mon Sep 17 00:00:00 2001 From: Sameer Wagh Date: Fri, 13 Sep 2024 10:42:56 -0400 Subject: [PATCH 02/11] Adding the helper files to the test_helpers directory --- .pre-commit-config.yaml | 2 +- .../syft/util/test_helpers/apis/__init__.py | 21 + .../util/test_helpers/apis/live/__init__.py | 0 .../util/test_helpers/apis/live/schema.py | 108 +++++ .../util/test_helpers/apis/live/test_query.py | 113 +++++ .../util/test_helpers/apis/mock/__init__.py | 0 .../syft/util/test_helpers/apis/mock/data.py | 268 ++++++++++++ .../util/test_helpers/apis/mock/schema.py | 52 +++ .../util/test_helpers/apis/mock/test_query.py | 138 ++++++ .../util/test_helpers/apis/rate_limiter.py | 16 + .../util/test_helpers/apis/submit_query.py | 42 ++ .../syft/util/test_helpers/email_helpers.py | 338 +++++++++++++++ .../src/syft/util/test_helpers/job_helpers.py | 398 ++++++++++++++++++ .../syft/util/test_helpers/sync_helpers.py | 192 +++++++++ 14 files changed, 1687 insertions(+), 1 deletion(-) create mode 100644 packages/syft/src/syft/util/test_helpers/apis/__init__.py create mode 100644 packages/syft/src/syft/util/test_helpers/apis/live/__init__.py create mode 100644 packages/syft/src/syft/util/test_helpers/apis/live/schema.py create mode 100644 packages/syft/src/syft/util/test_helpers/apis/live/test_query.py create mode 100644 packages/syft/src/syft/util/test_helpers/apis/mock/__init__.py create mode 100644 packages/syft/src/syft/util/test_helpers/apis/mock/data.py create mode 100644 packages/syft/src/syft/util/test_helpers/apis/mock/schema.py create mode 100644 packages/syft/src/syft/util/test_helpers/apis/mock/test_query.py create mode 100644 packages/syft/src/syft/util/test_helpers/apis/rate_limiter.py create mode 100644 packages/syft/src/syft/util/test_helpers/apis/submit_query.py create mode 100644 packages/syft/src/syft/util/test_helpers/email_helpers.py create mode 100644 packages/syft/src/syft/util/test_helpers/job_helpers.py create mode 100644 packages/syft/src/syft/util/test_helpers/sync_helpers.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 56bc5340ddb..c4f5669130d 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -161,7 +161,7 @@ repos: "--non-interactive", "--config-file=tox.ini", ] - excludes: ^packages/syft/src/syft/util/test_helpers + exclude: ^(packages/syft/src/syft/util/test_helpers) - repo: https://github.com/kynan/nbstripout rev: 0.7.1 diff --git a/packages/syft/src/syft/util/test_helpers/apis/__init__.py b/packages/syft/src/syft/util/test_helpers/apis/__init__.py new file mode 100644 index 00000000000..e8221857fba --- /dev/null +++ b/packages/syft/src/syft/util/test_helpers/apis/__init__.py @@ -0,0 +1,21 @@ +# stdlib +import os + +# relative +from ...util import str_to_bool +from .submit_query import make_submit_query + +env_var = "TEST_BIGQUERY_APIS_LIVE" +use_live = str_to_bool(str(os.environ.get(env_var, "False"))) +env_name = "Live" if use_live else "Mock" +print(f"Using {env_name} API Code, this will query BigQuery. ${env_var}=={use_live}") + + +if use_live: + # relative + from .live.schema import make_schema + from .live.test_query import make_test_query +else: + # relative + from .mock.schema import make_schema + from .mock.test_query import make_test_query diff --git a/packages/syft/src/syft/util/test_helpers/apis/live/__init__.py b/packages/syft/src/syft/util/test_helpers/apis/live/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/packages/syft/src/syft/util/test_helpers/apis/live/schema.py b/packages/syft/src/syft/util/test_helpers/apis/live/schema.py new file mode 100644 index 00000000000..8b9e753fe47 --- /dev/null +++ b/packages/syft/src/syft/util/test_helpers/apis/live/schema.py @@ -0,0 +1,108 @@ +# stdlib +from collections.abc import Callable + +# syft absolute +import syft as sy + +# relative +from ..... import test_settings +from ..rate_limiter import is_within_rate_limit + + +def make_schema(settings: dict, worker_pool: str) -> Callable: + updated_settings = { + "calls_per_min": 5, + "rate_limiter_enabled": True, + "credentials": test_settings.gce_service_account.to_dict(), + "region": test_settings.gce_region, + "project_id": test_settings.gce_project_id, + "dataset_1": test_settings.dataset_1, + "table_1": test_settings.table_1, + "table_2": test_settings.table_2, + } | settings + + @sy.api_endpoint( + path="bigquery.schema", + description="This endpoint allows for visualising the metadata of tables available in BigQuery.", + settings=updated_settings, + helper_functions=[ + is_within_rate_limit + ], # Adds ratelimit as this is also a method available to data scientists + worker_pool=worker_pool, + ) + def live_schema( + context, + ) -> str: + # stdlib + import datetime + + # third party + from google.cloud import bigquery # noqa: F811 + from google.oauth2 import service_account + import pandas as pd + + # relative + from ..... import SyftException + + # Auth for Bigquer based on the workload identity + credentials = service_account.Credentials.from_service_account_info( + context.settings["credentials"] + ) + scoped_credentials = credentials.with_scopes( + ["https://www.googleapis.com/auth/cloud-platform"] + ) + + client = bigquery.Client( + credentials=scoped_credentials, + location=context.settings["region"], + ) + + # Store a dict with the calltimes for each user, via the email. + if context.settings["rate_limiter_enabled"]: + if context.user.email not in context.state.keys(): + context.state[context.user.email] = [] + + if not context.code.is_within_rate_limit(context): + raise SyftException( + public_message="Rate limit of calls per minute has been reached." + ) + context.state[context.user.email].append(datetime.datetime.now()) + + try: + # Formats the data schema in a data frame format + # Warning: the only supported format types are primitives, np.ndarrays and pd.DataFrames + + data_schema = [] + for table_id in [ + f"{context.settings['dataset_1']}.{context.settings['table_1']}", + f"{context.settings['dataset_1']}.{context.settings['table_2']}", + ]: + table = client.get_table(table_id) + for schema in table.schema: + data_schema.append( + { + "project": str(table.project), + "dataset_id": str(table.dataset_id), + "table_id": str(table.table_id), + "schema_name": str(schema.name), + "schema_field": str(schema.field_type), + "description": str(table.description), + "num_rows": str(table.num_rows), + } + ) + return pd.DataFrame(data_schema) + + except Exception as e: + # not a bigquery exception + if not hasattr(e, "_errors"): + output = f"got exception e: {type(e)} {str(e)}" + raise SyftException( + public_message=f"An error occured executing the API call {output}" + ) + + # Should add appropriate error handling for what should be exposed to the data scientists. + raise SyftException( + public_message="An error occured executing the API call, please contact the domain owner." + ) + + return live_schema diff --git a/packages/syft/src/syft/util/test_helpers/apis/live/test_query.py b/packages/syft/src/syft/util/test_helpers/apis/live/test_query.py new file mode 100644 index 00000000000..6384dfca452 --- /dev/null +++ b/packages/syft/src/syft/util/test_helpers/apis/live/test_query.py @@ -0,0 +1,113 @@ +# stdlib +from collections.abc import Callable + +# syft absolute +import syft as sy + +# relative +from ..... import test_settings +from ..rate_limiter import is_within_rate_limit + + +def make_test_query(settings) -> Callable: + updated_settings = { + "calls_per_min": 10, + "rate_limiter_enabled": True, + "credentials": test_settings.gce_service_account.to_dict(), + "region": test_settings.gce_region, + "project_id": test_settings.gce_project_id, + } | settings + + # these are the same if you allow the rate limiter to be turned on and off + @sy.api_endpoint_method( + settings=updated_settings, + helper_functions=[is_within_rate_limit], + ) + def live_test_query( + context, + sql_query: str, + ) -> str: + # stdlib + import datetime + + # third party + from google.cloud import bigquery # noqa: F811 + from google.oauth2 import service_account + + # relative + from ..... import SyftException + + # Auth for Bigquer based on the workload identity + credentials = service_account.Credentials.from_service_account_info( + context.settings["credentials"] + ) + scoped_credentials = credentials.with_scopes( + ["https://www.googleapis.com/auth/cloud-platform"] + ) + + client = bigquery.Client( + credentials=scoped_credentials, + location=context.settings["region"], + ) + + # Store a dict with the calltimes for each user, via the email. + if context.settings["rate_limiter_enabled"]: + if context.user.email not in context.state.keys(): + context.state[context.user.email] = [] + + if not context.code.is_within_rate_limit(context): + raise SyftException( + public_message="Rate limit of calls per minute has been reached." + ) + context.state[context.user.email].append(datetime.datetime.now()) + + try: + rows = client.query_and_wait( + sql_query, + project=context.settings["project_id"], + ) + + if rows.total_rows > 1_000_000: + raise SyftException( + public_message="Please only write queries that gather aggregate statistics" + ) + + return rows.to_dataframe() + + except Exception as e: + # not a bigquery exception + if not hasattr(e, "_errors"): + output = f"got exception e: {type(e)} {str(e)}" + raise SyftException( + public_message=f"An error occured executing the API call {output}" + ) + + # Treat all errors that we would like to be forwarded to the data scientists + # By default, any exception is only visible to the data owner. + + if e._errors[0]["reason"] in [ + "badRequest", + "blocked", + "duplicate", + "invalidQuery", + "invalid", + "jobBackendError", + "jobInternalError", + "notFound", + "notImplemented", + "rateLimitExceeded", + "resourceInUse", + "resourcesExceeded", + "tableUnavailable", + "timeout", + ]: + raise SyftException( + public_message="Error occured during the call: " + + e._errors[0]["message"] + ) + else: + raise SyftException( + public_message="An error occured executing the API call, please contact the domain owner." + ) + + return live_test_query diff --git a/packages/syft/src/syft/util/test_helpers/apis/mock/__init__.py b/packages/syft/src/syft/util/test_helpers/apis/mock/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/packages/syft/src/syft/util/test_helpers/apis/mock/data.py b/packages/syft/src/syft/util/test_helpers/apis/mock/data.py new file mode 100644 index 00000000000..82262bf7a01 --- /dev/null +++ b/packages/syft/src/syft/util/test_helpers/apis/mock/data.py @@ -0,0 +1,268 @@ +# stdlib +from math import nan + +schema_dict = { + "project": { + 0: "example-project", + 1: "example-project", + 2: "example-project", + 3: "example-project", + 4: "example-project", + 5: "example-project", + 6: "example-project", + 7: "example-project", + 8: "example-project", + 9: "example-project", + 10: "example-project", + 11: "example-project", + 12: "example-project", + 13: "example-project", + 14: "example-project", + 15: "example-project", + 16: "example-project", + 17: "example-project", + 18: "example-project", + 19: "example-project", + 20: "example-project", + 21: "example-project", + 22: "example-project", + }, + "dataset_id": { + 0: "test_1gb", + 1: "test_1gb", + 2: "test_1gb", + 3: "test_1gb", + 4: "test_1gb", + 5: "test_1gb", + 6: "test_1gb", + 7: "test_1gb", + 8: "test_1gb", + 9: "test_1gb", + 10: "test_1gb", + 11: "test_1gb", + 12: "test_1gb", + 13: "test_1gb", + 14: "test_1gb", + 15: "test_1gb", + 16: "test_1gb", + 17: "test_1gb", + 18: "test_1gb", + 19: "test_1gb", + 20: "test_1gb", + 21: "test_1gb", + 22: "test_1gb", + }, + "table_id": { + 0: "posts", + 1: "posts", + 2: "posts", + 3: "posts", + 4: "posts", + 5: "posts", + 6: "posts", + 7: "comments", + 8: "comments", + 9: "comments", + 10: "comments", + 11: "comments", + 12: "comments", + 13: "comments", + 14: "comments", + 15: "comments", + 16: "comments", + 17: "comments", + 18: "comments", + 19: "comments", + 20: "comments", + 21: "comments", + 22: "comments", + }, + "schema_name": { + 0: "int64_field_0", + 1: "id", + 2: "name", + 3: "subscribers_count", + 4: "permalink", + 5: "nsfw", + 6: "spam", + 7: "int64_field_0", + 8: "id", + 9: "body", + 10: "parent_id", + 11: "created_at", + 12: "last_modified_at", + 13: "gilded", + 14: "permalink", + 15: "score", + 16: "comment_id", + 17: "post_id", + 18: "author_id", + 19: "spam", + 20: "deleted", + 21: "upvote_raio", + 22: "collapsed_in_crowd_control", + }, + "schema_field": { + 0: "INTEGER", + 1: "STRING", + 2: "STRING", + 3: "INTEGER", + 4: "STRING", + 5: "FLOAT", + 6: "BOOLEAN", + 7: "INTEGER", + 8: "STRING", + 9: "STRING", + 10: "STRING", + 11: "INTEGER", + 12: "INTEGER", + 13: "BOOLEAN", + 14: "STRING", + 15: "INTEGER", + 16: "STRING", + 17: "STRING", + 18: "STRING", + 19: "BOOLEAN", + 20: "BOOLEAN", + 21: "FLOAT", + 22: "BOOLEAN", + }, + "description": { + 0: "None", + 1: "None", + 2: "None", + 3: "None", + 4: "None", + 5: "None", + 6: "None", + 7: "None", + 8: "None", + 9: "None", + 10: "None", + 11: "None", + 12: "None", + 13: "None", + 14: "None", + 15: "None", + 16: "None", + 17: "None", + 18: "None", + 19: "None", + 20: "None", + 21: "None", + 22: "None", + }, + "num_rows": { + 0: "2000000", + 1: "2000000", + 2: "2000000", + 3: "2000000", + 4: "2000000", + 5: "2000000", + 6: "2000000", + 7: "2000000", + 8: "2000000", + 9: "2000000", + 10: "2000000", + 11: "2000000", + 12: "2000000", + 13: "2000000", + 14: "2000000", + 15: "2000000", + 16: "2000000", + 17: "2000000", + 18: "2000000", + 19: "2000000", + 20: "2000000", + 21: "2000000", + 22: "2000000", + }, +} + + +query_dict = { + "int64_field_0": { + 0: 4, + 1: 5, + 2: 10, + 3: 16, + 4: 17, + 5: 23, + 6: 24, + 7: 25, + 8: 27, + 9: 40, + }, + "id": { + 0: "t5_via1x", + 1: "t5_cv9gn", + 2: "t5_8p2tq", + 3: "t5_8fcro", + 4: "t5_td5of", + 5: "t5_z01fv", + 6: "t5_hmqjk", + 7: "t5_1flyj", + 8: "t5_5rwej", + 9: "t5_uurcv", + }, + "name": { + 0: "/channel/mylittlepony", + 1: "/channel/polyamory", + 2: "/channel/Catholicism", + 3: "/channel/cordcutters", + 4: "/channel/stevenuniverse", + 5: "/channel/entitledbitch", + 6: "/channel/engineering", + 7: "/channel/nottheonion", + 8: "/channel/FoodPorn", + 9: "/channel/puppysmiles", + }, + "subscribers_count": { + 0: 4323081, + 1: 2425929, + 2: 4062607, + 3: 7543226, + 4: 2692168, + 5: 2709080, + 6: 8766144, + 7: 2580984, + 8: 7784809, + 9: 3715991, + }, + "permalink": { + 0: "/channel//channel/mylittlepony", + 1: "/channel//channel/polyamory", + 2: "/channel//channel/Catholicism", + 3: "/channel//channel/cordcutters", + 4: "/channel//channel/stevenuniverse", + 5: "/channel//channel/entitledbitch", + 6: "/channel//channel/engineering", + 7: "/channel//channel/nottheonion", + 8: "/channel//channel/FoodPorn", + 9: "/channel//channel/puppysmiles", + }, + "nsfw": { + 0: nan, + 1: nan, + 2: nan, + 3: nan, + 4: nan, + 5: nan, + 6: nan, + 7: nan, + 8: nan, + 9: nan, + }, + "spam": { + 0: False, + 1: False, + 2: False, + 3: False, + 4: False, + 5: False, + 6: False, + 7: False, + 8: False, + 9: False, + }, +} diff --git a/packages/syft/src/syft/util/test_helpers/apis/mock/schema.py b/packages/syft/src/syft/util/test_helpers/apis/mock/schema.py new file mode 100644 index 00000000000..c4c7216b20a --- /dev/null +++ b/packages/syft/src/syft/util/test_helpers/apis/mock/schema.py @@ -0,0 +1,52 @@ +# stdlib +from collections.abc import Callable + +# syft absolute +import syft as sy + +# relative +from ..rate_limiter import is_within_rate_limit +from .data import schema_dict + + +def make_schema(settings, worker_pool) -> Callable: + updated_settings = { + "calls_per_min": 5, + "rate_limiter_enabled": True, + "schema_dict": schema_dict, + } | settings + + @sy.api_endpoint( + path="bigquery.schema", + description="This endpoint allows for visualising the metadata of tables available in BigQuery.", + settings=updated_settings, + helper_functions=[is_within_rate_limit], + worker_pool=worker_pool, + ) + def mock_schema( + context, + ) -> str: + # relative + from ..... import SyftException + + # Store a dict with the calltimes for each user, via the email. + if context.settings["rate_limiter_enabled"]: + # stdlib + import datetime + + if context.user.email not in context.state.keys(): + context.state[context.user.email] = [] + + if not context.code.is_within_rate_limit(context): + raise SyftException( + public_message="Rate limit of calls per minute has been reached." + ) + context.state[context.user.email].append(datetime.datetime.now()) + + # third party + import pandas as pd + + df = pd.DataFrame(context.settings["schema_dict"]) + return df + + return mock_schema diff --git a/packages/syft/src/syft/util/test_helpers/apis/mock/test_query.py b/packages/syft/src/syft/util/test_helpers/apis/mock/test_query.py new file mode 100644 index 00000000000..1937dcecb80 --- /dev/null +++ b/packages/syft/src/syft/util/test_helpers/apis/mock/test_query.py @@ -0,0 +1,138 @@ +# stdlib +from collections.abc import Callable + +# syft absolute +import syft as sy + +# relative +from ..rate_limiter import is_within_rate_limit +from .data import query_dict + + +def extract_limit_value(sql_query: str) -> int: + # stdlib + import re + + limit_pattern = re.compile(r"\bLIMIT\s+(\d+)\b", re.IGNORECASE) + match = limit_pattern.search(sql_query) + if match: + return int(match.group(1)) + return None + + +def is_valid_sql(query: str) -> bool: + # stdlib + import sqlite3 + + # Prepare an in-memory SQLite database + conn = sqlite3.connect(":memory:") + cursor = conn.cursor() + + try: + # Use the EXPLAIN QUERY PLAN command to get the query plan + cursor.execute(f"EXPLAIN QUERY PLAN {query}") + except sqlite3.Error as e: + if "no such table" in str(e).lower(): + return True + return False + finally: + conn.close() + + +def adjust_dataframe_rows(df, target_rows: int): + # third party + import pandas as pd + + current_rows = len(df) + + if target_rows > current_rows: + # Repeat rows to match target_rows + repeat_times = (target_rows + current_rows - 1) // current_rows + df_expanded = pd.concat([df] * repeat_times, ignore_index=True).head( + target_rows + ) + else: + # Truncate rows to match target_rows + df_expanded = df.head(target_rows) + + return df_expanded + + +def make_test_query(settings: dict) -> Callable: + updated_settings = { + "calls_per_min": 10, + "rate_limiter_enabled": True, + "query_dict": query_dict, + } | settings + + # these are the same if you allow the rate limiter to be turned on and off + @sy.api_endpoint_method( + settings=updated_settings, + helper_functions=[ + is_within_rate_limit, + extract_limit_value, + is_valid_sql, + adjust_dataframe_rows, + ], + ) + def mock_test_query( + context, + sql_query: str, + ) -> str: + # stdlib + import datetime + + # third party + from google.api_core.exceptions import BadRequest + + # relative + from ..... import SyftException + + # Store a dict with the calltimes for each user, via the email. + if context.settings["rate_limiter_enabled"]: + if context.user.email not in context.state.keys(): + context.state[context.user.email] = [] + + if not context.code.is_within_rate_limit(context): + raise SyftException( + public_message="Rate limit of calls per minute has been reached." + ) + context.state[context.user.email].append(datetime.datetime.now()) + + bad_table = "invalid_table" + bad_post = ( + "BadRequest: 400 POST " + "https://bigquery.googleapis.com/bigquery/v2/projects/project-id/" + "queries?prettyPrint=false: " + ) + if bad_table in sql_query: + try: + raise BadRequest( + f'{bad_post} Table "{bad_table}" must be qualified ' + "with a dataset (e.g. dataset.table)." + ) + except Exception as e: + raise SyftException( + public_message=f"*must be qualified with a dataset*. {e}" + ) + + if not context.code.is_valid_sql(sql_query): + raise BadRequest( + f'{bad_post} Syntax error: Unexpected identifier "{sql_query}" at [1:1]' + ) + + # third party + import pandas as pd + + limit = context.code.extract_limit_value(sql_query) + if limit > 1_000_000: + raise SyftException( + public_message="Please only write queries that gather aggregate statistics" + ) + + base_df = pd.DataFrame(context.settings["query_dict"]) + + df = context.code.adjust_dataframe_rows(base_df, limit) + return df + + return mock_test_query diff --git a/packages/syft/src/syft/util/test_helpers/apis/rate_limiter.py b/packages/syft/src/syft/util/test_helpers/apis/rate_limiter.py new file mode 100644 index 00000000000..8ce319b61f4 --- /dev/null +++ b/packages/syft/src/syft/util/test_helpers/apis/rate_limiter.py @@ -0,0 +1,16 @@ +def is_within_rate_limit(context) -> bool: + """Rate limiter for custom API calls made by users.""" + # stdlib + import datetime + + state = context.state + settings = context.settings + email = context.user.email + + current_time = datetime.datetime.now() + calls_last_min = [ + 1 if (current_time - call_time).seconds < 60 else 0 + for call_time in state[email] + ] + + return sum(calls_last_min) < settings.get("calls_per_min", 5) diff --git a/packages/syft/src/syft/util/test_helpers/apis/submit_query.py b/packages/syft/src/syft/util/test_helpers/apis/submit_query.py new file mode 100644 index 00000000000..a0125ee009b --- /dev/null +++ b/packages/syft/src/syft/util/test_helpers/apis/submit_query.py @@ -0,0 +1,42 @@ +# syft absolute +import syft as sy + + +def make_submit_query(settings, worker_pool): + updated_settings = {"user_code_worker": worker_pool} | settings + + @sy.api_endpoint( + path="bigquery.submit_query", + description="API endpoint that allows you to submit SQL queries to run on the private data.", + worker_pool=worker_pool, + settings=updated_settings, + ) + def submit_query( + context, + func_name: str, + query: str, + ) -> str: + # syft absolute + import syft as sy + + @sy.syft_function( + name=func_name, + input_policy=sy.MixedInputPolicy( + endpoint=sy.Constant( + val=context.admin_client.api.services.bigquery.test_query + ), + query=sy.Constant(val=query), + client=context.admin_client, + ), + worker_pool_name=context.settings["user_code_worker"], + ) + def execute_query(query: str, endpoint): + res = endpoint(sql_query=query) + return res + + request = context.user_client.code.request_code_execution(execute_query) + context.admin_client.requests.set_tags(request, ["autosync"]) + + return f"Query submitted {request}. Use `client.code.{func_name}()` to run your query" + + return submit_query diff --git a/packages/syft/src/syft/util/test_helpers/email_helpers.py b/packages/syft/src/syft/util/test_helpers/email_helpers.py new file mode 100644 index 00000000000..e9aa83037fc --- /dev/null +++ b/packages/syft/src/syft/util/test_helpers/email_helpers.py @@ -0,0 +1,338 @@ +# stdlib +import asyncio +from dataclasses import dataclass +from dataclasses import field +import json +import re +import time +from typing import Any + +# third party +from aiosmtpd.controller import Controller +from faker import Faker + +# relative +from ...service.user.user_roles import ServiceRole + +fake = Faker() + + +@dataclass +class Email: + email_from: str + email_to: str + email_content: str + + def to_dict(self) -> dict: + output = {} + for k, v in self.__dict__.items(): + output[k] = v + return output + + def __iter__(self): + yield from self.to_dict().items() + + def __getitem__(self, key): + return self.to_dict()[key] + + def __repr__(self) -> str: + return f"{self.email_to}\n{self.email_from}\n\n{self.email_content}" + + +class EmailServer: + def __init__(self, filepath="./emails.json"): + self.filepath = filepath + self._emails: dict[str, list[Email]] = self.load_emails() + + def load_emails(self) -> dict[str, list[Email]]: + try: + with open(self.filepath) as f: + data = json.load(f) + return {k: [Email(**email) for email in v] for k, v in data.items()} + except Exception as e: + print("Issues reading email file", e) + return {} + + def save_emails(self) -> None: + with open(self.filepath, "w") as f: + data = { + k: [email.to_dict() for email in v] for k, v in self._emails.items() + } + f.write(json.dumps(data)) + + def add_email_for_user(self, user_email: str, email: Email) -> None: + if user_email not in self._emails: + self._emails[user_email] = [] + self._emails[user_email].append(email) + self.save_emails() + + def get_emails_for_user(self, user_email: str) -> list[Email]: + self._emails: dict[str, list[Email]] = self.load_emails() + return self._emails.get(user_email, []) + + def reset_emails(self) -> None: + self._emails = {} + self.save_emails() + + +SENDER = "noreply@openmined.org" + + +def get_token(email) -> str: + # stdlib + import re + + pattern = r"syft_client\.reset_password\(token='(.*?)', new_password=.*?\)" + try: + token = re.search(pattern, email.email_content).group(1) + except Exception: + raise Exception(f"No token found in email: {email.email_content}") + return token + + +@dataclass +class TestUser: + name: str + email: str + password: str + role: ServiceRole + new_password: str | None = None + email_disabled: bool = False + reset_password: bool = False + reset_token: str | None = None + _client_cache: Any | None = field(default=None, repr=False, init=False) + _email_server: EmailServer | None = None + + @property + def latest_password(self) -> str: + if self.new_password: + return self.new_password + return self.password + + def make_new_password(self) -> str: + self.new_password = fake.password() + return self.new_password + + @property + def client(self): + return self._client_cache + + def relogin(self) -> None: + self.client = self.client + + @client.setter + def client(self, client): + client = client.login(email=self.email, password=self.latest_password) + self._client_cache = client + + def to_dict(self) -> dict: + output = {} + for k, v in self.__dict__.items(): + if k.startswith("_"): + continue + if k == "role": + v = str(v) + output[k] = v + return output + + def __iter__(self): + for key, val in self.to_dict().items(): + if not key.startswith("_"): + yield key, val + + def __getitem__(self, key): + if key.startswith("_"): + return None + return self.to_dict()[key] + + def update_password(self): + self.password = self.new_password + self.new_password = None + + @property + def emails(self) -> list[Email]: + if not self._email_server: + print("Not connected to email server object") + return [] + return self._email_server.get_emails_for_user(self.email) + + def get_token(self) -> str: + for email in reversed(self.emails): + token = None + try: + token = get_token(email) + break + except Exception: + pass + self.reset_token = token + return token + + +def save_users(users): + user_dicts = [] + for user in users: + user_dicts.append(user.to_dict()) + print(user_dicts) + with open("./users.json", "w") as f: + f.write(json.dumps(user_dicts)) + + +def load_users(high_client: None, path="./users.json"): + users = [] + with open(path) as f: + data = f.read() + user_dicts = json.loads(data) + for user in user_dicts: + test_user = TestUser(**user) + if high_client: + test_user.client = high_client + users.append(test_user) + return users + + +def make_user( + name: str | None = None, + email: str | None = None, + password: str | None = None, + role: ServiceRole = ServiceRole.DATA_SCIENTIST, +): + fake = Faker() + if name is None: + name = fake.name() + if email is None: + ascii_string = re.sub(r"[^a-zA-Z\s]", "", name).lower() + dashed_string = ascii_string.replace(" ", "-") + email = f"{dashed_string}-fake@openmined.org" + if password is None: + password = fake.password() + + return TestUser(name=name, email=email, password=password, role=role) + + +def user_exists(root_client, email: str) -> bool: + users = root_client.api.services.user + for user in users: + if user.email == email: + return True + return False + + +class SMTPTestServer: + def __init__(self, email_server): + self.port = 9025 + self.hostname = "0.0.0.0" + self._stop_event = asyncio.Event() + + # Simple email handler class + class SimpleHandler: + async def handle_DATA(self, server, session, envelope): + try: + print(f"> SMTPTestServer got an email for {envelope.rcpt_tos}") + email = Email( + email_from=envelope.mail_from, + email_to=envelope.rcpt_tos, + email_content=envelope.content.decode( + "utf-8", errors="replace" + ), + ) + email_server.add_email_for_user(envelope.rcpt_tos[0], email) + email_server.save_emails() + return "250 Message accepted for delivery" + except Exception as e: + print(f"> Error handling email: {e}") + return "550 Internal Server Error" + + try: + self.handler = SimpleHandler() + self.controller = Controller( + self.handler, hostname=self.hostname, port=self.port + ) + except Exception as e: + print(f"> Error initializing SMTPTestServer Controller: {e}") + + def start(self): + print(f"> Starting SMTPTestServer on: {self.hostname}:{self.port}") + asyncio.create_task(self.async_loop()) + + async def async_loop(self): + try: + print(f"> Starting SMTPTestServer on: {self.hostname}:{self.port}") + self.controller.start() + await ( + self._stop_event.wait() + ) # Wait until the event is set to stop the server + except Exception as e: + print(f"> Error with SMTPTestServer: {e}") + + def stop(self): + try: + print("> Stopping SMTPTestServer") + loop = asyncio.get_running_loop() + if loop.is_running(): + loop.create_task(self.async_stop()) + else: + asyncio.run(self.async_stop()) + except Exception as e: + print(f"> Error stopping SMTPTestServer: {e}") + + async def async_stop(self): + self.controller.stop() + self._stop_event.set() # Stop the server by setting the event + + +class TimeoutError(Exception): + pass + + +class Timeout: + def __init__(self, timeout_duration): + if timeout_duration > 60: + raise ValueError("Timeout duration cannot exceed 60 seconds.") + self.timeout_duration = timeout_duration + + def run_with_timeout(self, condition_func, *args, **kwargs): + start_time = time.time() + result = None + + while True: + elapsed_time = time.time() - start_time + if elapsed_time > self.timeout_duration: + raise TimeoutError( + f"Function execution exceeded {self.timeout_duration} seconds." + ) + + # Check if the condition is met + try: + if condition_func(): + print("Condition met, exiting early.") + break + except Exception as e: + print(f"Exception in target function: {e}") + break # Exit the loop if an exception occurs in the function + time.sleep(1) + + return result + + +def get_email_server(reset=False): + email_server = EmailServer() + if reset: + email_server.reset_emails() + smtp_server = SMTPTestServer(email_server) + smtp_server.start() + return email_server, smtp_server + + +def create_user(root_client, test_user): + if not user_exists(root_client, test_user.email): + fake = Faker() + root_client.register( + name=test_user.name, + email=test_user.email, + password=test_user.password, + password_verify=test_user.password, + institution=fake.company(), + website=fake.url(), + ) + else: + print("User already exists", test_user) diff --git a/packages/syft/src/syft/util/test_helpers/job_helpers.py b/packages/syft/src/syft/util/test_helpers/job_helpers.py new file mode 100644 index 00000000000..ac26ad5f8ff --- /dev/null +++ b/packages/syft/src/syft/util/test_helpers/job_helpers.py @@ -0,0 +1,398 @@ +# stdlib +from collections import defaultdict +from collections.abc import Callable +from dataclasses import dataclass +from dataclasses import field +import json +import random +import re +import secrets +import textwrap +from typing import Any + +# relative +from ... import test_settings +from .email_helpers import TestUser + +from ...client.client import SyftClient # noqa + +dataset_1 = test_settings.get("dataset_1", default="dataset_1") +dataset_2 = test_settings.get("dataset_2", default="dataset_2") +table_1 = test_settings.get("table_1", default="table_1") +table_2 = test_settings.get("table_2", default="table_2") +table_1_col_id = test_settings.get("table_1_col_id", default="table_id") +table_1_col_score = test_settings.get("table_1_col_score", default="colname") +table_2_col_id = test_settings.get("table_2_col_id", default="table_id") +table_2_col_score = test_settings.get("table_2_col_score", default="colname") + + +@dataclass +class TestJob: + user_email: str + func_name: str + query: str + job_type: str + settings: dict # make a type so we can rely on attributes + should_succeed: bool + should_submit: bool = True + code_path: str | None = field(default=None) + admin_reviewed: bool = False + result_as_expected: bool | None = None + + _client_cache: SyftClient | None = field(default=None, repr=False, init=False) + + @property + def is_submitted(self) -> bool: + return self.code_path is not None + + @property + def client(self): + return self._client_cache + + @client.setter + def client(self, client): + self._client_cache = client + + def to_dict(self) -> dict: + output = {} + for k, v in self.__dict__.items(): + if k.startswith("_"): + continue + output[k] = v + return output + + def __iter__(self): + for key, val in self.to_dict().items(): + if key.startswith("_"): + yield key, val + + def __getitem__(self, key): + if key.startswith("_"): + return None + return self.to_dict()[key] + + @property + def code_method(self) -> None | Callable: + try: + return getattr(self.client.code, self.func_name, None) + except Exception as e: + print(f"Cant find code method. {e}") + return None + + +def make_query(settings: dict) -> str: + query = f""" + SELECT {settings['groupby_col']}, AVG({settings['score_col']}) AS average_score + FROM {settings['dataset']}.{settings['table']} + GROUP BY {settings['groupby_col']} + LIMIT {settings['limit']}""".strip() + + return textwrap.dedent(query) + + +def create_simple_query_job(user: TestUser) -> TestJob: + job_type = "simple_query" + func_name = f"{job_type}_{secrets.token_hex(3)}" + + dataset = random.choice([dataset_1, dataset_2]) + table, groupby_col, score_col = random.choice( + [ + (table_1, table_1_col_id, table_1_col_score), + (table_2, table_2_col_id, table_2_col_score), + ] + ) + limit = random.randint(1, 1_000_000) + + settings = { + "dataset": dataset, + "table": table, + "groupby_col": groupby_col, + "score_col": score_col, + "limit": limit, + } + query = make_query(settings) + + result = TestJob( + user_email=user.email, + func_name=func_name, + query=query, + job_type=job_type, + settings=settings, + should_succeed=True, + ) + + result.client = user.client + return result + + +def create_wrong_asset_query(user: TestUser) -> TestJob: + job_type = "wrong_asset_query" + func_name = f"{job_type}_{secrets.token_hex(3)}" + + valid_job = create_simple_query_job(user) + settings = valid_job.settings + corrupted_asset = random.choice(["dataset", "table"]) + settings[corrupted_asset] = "wrong_asset" + query = make_query(settings) + + result = TestJob( + user_email=user.email, + func_name=func_name, + query=query, + job_type=job_type, + settings=settings, + should_succeed=False, + ) + + result.client = user.client + return result + + +def create_wrong_syntax_query(user: TestUser) -> TestJob: + job_type = "wrong_syntax_query" + func_name = f"{job_type}_{secrets.token_hex(3)}" + + query = "SELECT * FROM table INCORRECT SYNTAX" + + result = TestJob( + user_email=user.email, + func_name=func_name, + query=query, + job_type=job_type, + settings={}, + should_succeed=False, + ) + + result.client = user.client + return result + + +def create_long_query_job(user: TestUser) -> TestJob: + job_type = "job_too_much_text" + func_name = f"{job_type}_{secrets.token_hex(3)}" + + query = "a" * 1_000 + + result = TestJob( + user_email=user.email, + func_name=func_name, + query=query, + job_type=job_type, + settings={}, + should_succeed=False, + ) + + result.client = user.client + return result + + +def create_query_long_name(user: TestUser) -> TestJob: + job_type = "job_long_name" + func_name = f"{job_type}_{secrets.token_hex(3)}" + + job = create_simple_query_job(user) + + job.job_type = job_type + job.func_name = func_name + "a" * 1_000 + + return job + + +def create_job_funcname_xss(user: TestUser) -> TestJob: + job_type = "job_funcname_xss" + func_name = f"{job_type}_{secrets.token_hex(3)}" + func_name += "" + + job = create_simple_query_job(user) + job.job_type = job_type + job.func_name = func_name + job.should_submit = False + return job + + +def get_request_for_job_info(requests, job): + job_requests = [r for r in requests if r.code.service_func_name == job.func_name] + if len(job_requests) != 1: + raise Exception(f"Too many or too few requests: {job} in requests: {requests}") + return job_requests[0] + + +def create_job_query_xss(user: TestUser) -> TestJob: + job_type = "job_query_xss" + func_name = f"{job_type}_{secrets.token_hex(3)}" + + job = create_simple_query_job(user) + job.job_type = job_type + job.func_name = func_name + job.query += "" + job.should_succeed = False + + return job + + +def create_job_many_columns(user: TestUser) -> TestJob: + job_type = "job_many_columns" + func_name = f"{job_type}_{secrets.token_hex(3)}" + + job = create_simple_query_job(user) + job.job_type = job_type + job.func_name = func_name + settings = job.settings + job.settings["num_extra_cols"] = random.randint(100, 1000) + + new_columns_string = ", ".join( + f"{settings['score_col']} as col_{i}" for i in range(settings["num_extra_cols"]) + ) + + job.query = f""" + SELECT {settings['groupby_col']}, AVG({settings['score_col']}) AS average_score, {new_columns_string} + FROM {settings['dataset']}.{settings['table']} + GROUP BY {settings['groupby_col']} + LIMIT {settings['limit']}""".strip() + + return job + + +def create_random_job(user: TestUser) -> TestJob: + job_func = random.choice(create_job_functions) + return job_func(user) + + +def create_jobs(users: list[TestUser], total_jobs: int = 10) -> list[TestJob]: + jobs = [] + num_users = len(users) + user_index = 0 + each_count = 0 + # keep making jobs until we have enough + while len(jobs) < total_jobs: + # if we havent used each job type yet keep getting the next one + if each_count < len(create_job_functions): + job_func = create_job_functions[each_count] + each_count += 1 + else: + # otherwise lets get a random one + job_func = create_random_job + # use the current index of user + jobs.append(job_func(users[user_index])) + + # only go as high as the last user index + if user_index < num_users - 1: + user_index += 1 + else: + # reset back to the first user + user_index = 0 + + # in case we stuffed up + if len(jobs) > total_jobs: + jobs = jobs[:total_jobs] + return jobs + + +def submit_job(job: TestJob) -> tuple[Any, str]: + client = job.client + response = client.api.services.bigquery.submit_query( + func_name=job.func_name, query=job.query + ) + job.code_path = extract_code_path(response) + return response + + +def extract_code_path(response) -> str | None: + pattern = r"client\.code\.(\w+)\(\)" + match = re.search(pattern, str(response)) + if match: + extracted_code = match.group(1) + return extracted_code + return None + + +def approve_by_running(request): + job = request.code(blocking=False) + result = job.wait() + print("got result of type", type(result), "bool", bool(result)) + # got result of type bool False + # assert result won't work unless we know what type is coming back + job_info = job.info(result=True) + # need force when running multiple times + # todo check and dont run if its already done + response = request.deposit_result(job_info, approve=True, force=True) + return response + + +def get_job_emails(jobs, client, email_server): + all_requests = client.requests + res = {} + for job in jobs: + request = get_request_for_job_info(all_requests, job) + emails = email_server.get_emails_for_user(request.requesting_user_email) + res[request.requesting_user_email] = emails + return res + + +def resolve_request(request): + service_func_name = request.code.service_func_name + if service_func_name.startswith("simple_query"): + request.approve() # approve because it is good + if service_func_name.startswith("wrong_asset_query"): + request.approve() # approve because it is bad + if service_func_name.startswith("wrong_syntax_query"): + request.approve() # approve because it is bad + if service_func_name.startswith("job_too_much_text"): + request.deny(reason="too long, boring!") # deny because it is bad + if service_func_name.startswith("job_long_name"): + request.approve() + if service_func_name.startswith("job_funcname_xss"): + request.deny(reason="too long, boring!") # never reach doesnt matter + if service_func_name.startswith("job_query_xss"): + request.approve() # approve because it is bad + if service_func_name.startswith("job_many_columns"): + request.approve() # approve because it is bad + + return (request.id, request.status) + + +create_job_functions = [ + create_simple_query_job, # quick way to increase the odds + create_simple_query_job, + create_simple_query_job, + create_simple_query_job, + create_simple_query_job, + create_simple_query_job, + create_wrong_syntax_query, + create_long_query_job, + create_query_long_name, + create_job_funcname_xss, + create_job_query_xss, + create_job_many_columns, +] + + +def save_jobs(jobs, filepath="./jobs.json"): + user_jobs = defaultdict(list) + for job in jobs: + user_jobs[job.user_email].append(job.to_dict()) + with open(filepath, "w") as f: + f.write(json.dumps(user_jobs)) + + +def load_jobs(users, high_client, filepath="./jobs.json"): + data = {} + try: + with open(filepath) as f: + data = json.loads(f.read()) + except Exception as e: + print(f"cant read file: {filepath}: {e}") + data = {} + jobs_list = [] + for user in users: + if user.email not in data: + print(f"{user.email} missing from jobs") + continue + user_jobs = data[user.email] + for user_job in user_jobs: + test_job = TestJob(**user_job) + if user._client_cache is None: + user.client = high_client + test_job.client = user.client + jobs_list.append(test_job) + return jobs_list diff --git a/packages/syft/src/syft/util/test_helpers/sync_helpers.py b/packages/syft/src/syft/util/test_helpers/sync_helpers.py new file mode 100644 index 00000000000..7252b896ea2 --- /dev/null +++ b/packages/syft/src/syft/util/test_helpers/sync_helpers.py @@ -0,0 +1,192 @@ +# third party +from tqdm import tqdm + +# syft absolute +import syft as sy + +# relative +from ...client.datasite_client import DatasiteClient +from ...client.syncing import compare_clients +from ...service.code.user_code import UserCode +from ...service.job.job_stash import Job +from ...service.job.job_stash import JobStatus +from ...service.request.request import Request +from ...service.request.request import RequestStatus +from ...service.sync.diff_state import ObjectDiffBatch +from ...types.result import Err + + +def deny_requests_without_autosync_tag(client_low: DatasiteClient): + # Deny all requests that are not autosync + requests = client_low.requests.get_all() + if isinstance(requests, sy.SyftError): + print(requests) + return + + denied_requests = [] + for request in tqdm(requests): + if request.status != RequestStatus.PENDING: + continue + if "autosync" not in request.tags: + request.deny( + reason="This request has been denied automatically. " + "Please use the designated API to submit your request." + ) + denied_requests.append(request.id) + print(f"Denied {len(denied_requests)} requests without autosync tag") + + +def is_request_to_sync(batch: ObjectDiffBatch) -> bool: + # True if this is a new low-side request + # TODO add condition for sql requests/usercodes + low_request = batch.root.low_obj + return ( + isinstance(low_request, Request) + and batch.status == "NEW" + and "autosync" in low_request.tags + ) + + +def is_job_to_sync(batch: ObjectDiffBatch): + # True if this is a new high-side job that is either COMPLETED or ERRORED + if batch.status != "NEW": + return False + if not isinstance(batch.root.high_obj, Job): + return False + job = batch.root.high_obj + return job.status in (JobStatus.ERRORED, JobStatus.COMPLETED) + + +def execute_requests( + client_high: DatasiteClient, request_ids: list[sy.UID] +) -> dict[sy.UID, Job]: + jobs_by_request_id = {} + for request_id in request_ids: + request = client_high.requests.get_by_uid(request_id) + if not isinstance(request, Request): + continue + + code = request.code + if not isinstance(code, UserCode): + continue + + func_name = request.code.service_func_name + api_func = getattr(client_high.code, func_name, None) + if api_func is None: + continue + + job = api_func(blocking=False) + jobs_by_request_id[request_id] = job + + return jobs_by_request_id + + +def deny_failed_jobs( + client_low: DatasiteClient, + jobs: list[Job], +) -> None: + # NOTE no syncing is needed, requests are denied on the low side + denied_requests = [] + + for job in jobs: + if job.status != JobStatus.ERRORED: + continue + + error_result = job.result + if isinstance(error_result, Err): + error_msg = error_result.err_value + else: + error_msg = "An unknown error occurred, please check the Job logs for more information." + + code_id = job.user_code_id + if code_id is None: + continue + requests = client_low.requests.get_by_usercode_id(code_id) + if isinstance(requests, list) and len(requests) > 0: + request = requests[0] + request.deny(reason=f"Execution failed: {error_msg}") + denied_requests.append(request.id) + else: + print(f"Failed to deny request for job {job.id}") + + print(f"Denied {len(denied_requests)} failed requests") + + +def sync_finished_jobs( + client_low: DatasiteClient, + client_high: DatasiteClient, +) -> dict[sy.UID, sy.SyftError | sy.SyftSuccess] | sy.SyftError: + sync_job_results = {} + synced_jobs = [] + diff = compare_clients( + from_client=client_high, to_client=client_low, include_types=["job"] + ) + if isinstance(diff, sy.SyftError): + print(diff) + return diff + + for batch in diff.batches: + if is_job_to_sync(batch): + job = batch.root.high_obj + + w = batch.resolve(build_state=False) + share_result = w.click_share_all_private_data() + if isinstance(share_result, sy.SyftError): + sync_job_results[job.id] = share_result + continue + sync_result = w.click_sync() + + synced_jobs.append(job) + sync_job_results[job.id] = sync_result + + print(f"Sharing {len(sync_job_results)} new results") + deny_failed_jobs(client_low, synced_jobs) + return sync_job_results + + +def sync_new_requests( + client_low: DatasiteClient, + client_high: DatasiteClient, +) -> dict[sy.UID, sy.SyftSuccess | sy.SyftError] | sy.SyftError: + sync_request_results = {} + diff = compare_clients( + from_client=client_low, to_client=client_high, include_types=["request"] + ) + if isinstance(diff, sy.SyftError): + print(diff) + return sync_request_results + print(f"{len(diff.batches)} request batches found") + for batch in tqdm(diff.batches): + if is_request_to_sync(batch): + request_id = batch.root.low_obj.id + w = batch.resolve(build_state=False) + result = w.click_sync() + sync_request_results[request_id] = result + return sync_request_results + + +def sync_and_execute_new_requests( + client_low: DatasiteClient, client_high: DatasiteClient +) -> None: + sync_results = sync_new_requests(client_low, client_high) + if isinstance(sync_results, sy.SyftError): + print(sync_results) + return + + request_ids = [ + uid for uid, res in sync_results.items() if isinstance(res, sy.SyftSuccess) + ] + print(f"Synced {len(request_ids)} new requests") + + jobs_by_request = execute_requests(client_high, request_ids) + print(f"Started {len(jobs_by_request)} new jobs") + + +def auto_sync(client_low: DatasiteClient, client_high: DatasiteClient) -> None: + print("Starting auto sync") + print("Denying non tagged jobs") + deny_requests_without_autosync_tag(client_low) + print("Syncing and executing") + sync_and_execute_new_requests(client_low, client_high) + sync_finished_jobs(client_low, client_high) + print("Finished auto sync") From 23bcb421000d5de9b0fd731340aa4ea8322f7309 Mon Sep 17 00:00:00 2001 From: Sameer Wagh Date: Fri, 13 Sep 2024 10:49:41 -0400 Subject: [PATCH 03/11] Fixed spurious import --- notebooks/scenarios/bigquery/011-users-emails-passwords.ipynb | 1 - 1 file changed, 1 deletion(-) diff --git a/notebooks/scenarios/bigquery/011-users-emails-passwords.ipynb b/notebooks/scenarios/bigquery/011-users-emails-passwords.ipynb index 7978234b07e..9a8bfdcdf9c 100644 --- a/notebooks/scenarios/bigquery/011-users-emails-passwords.ipynb +++ b/notebooks/scenarios/bigquery/011-users-emails-passwords.ipynb @@ -27,7 +27,6 @@ "\n", "# syft absolute\n", "import syft as sy\n", - "from syft import get_helpers # noqa: F401\n", "from syft.util.test_helpers.email_helpers import SENDER\n", "from syft.util.test_helpers.email_helpers import create_user\n", "from syft.util.test_helpers.email_helpers import get_email_server\n", From 1b174179c2cb877fffb1a5e0bdcf478d66f1aa65 Mon Sep 17 00:00:00 2001 From: Sameer Wagh Date: Fri, 13 Sep 2024 10:59:17 -0400 Subject: [PATCH 04/11] Added init --- .../src/syft/util/test_helpers/__init__.py | 0 test_helpers/email_helpers.py | 338 --------------- test_helpers/job_helpers.py | 400 ------------------ test_helpers/sync_helpers.py | 190 --------- 4 files changed, 928 deletions(-) create mode 100644 packages/syft/src/syft/util/test_helpers/__init__.py delete mode 100644 test_helpers/email_helpers.py delete mode 100644 test_helpers/job_helpers.py delete mode 100644 test_helpers/sync_helpers.py diff --git a/packages/syft/src/syft/util/test_helpers/__init__.py b/packages/syft/src/syft/util/test_helpers/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/test_helpers/email_helpers.py b/test_helpers/email_helpers.py deleted file mode 100644 index f58d41a20f8..00000000000 --- a/test_helpers/email_helpers.py +++ /dev/null @@ -1,338 +0,0 @@ -# stdlib -import asyncio -from dataclasses import dataclass -from dataclasses import field -import json -import re -import time -from typing import Any - -# third party -from aiosmtpd.controller import Controller -from faker import Faker - -# syft absolute -from syft.service.user.user_roles import ServiceRole - -fake = Faker() - - -@dataclass -class Email: - email_from: str - email_to: str - email_content: str - - def to_dict(self) -> dict: - output = {} - for k, v in self.__dict__.items(): - output[k] = v - return output - - def __iter__(self): - yield from self.to_dict().items() - - def __getitem__(self, key): - return self.to_dict()[key] - - def __repr__(self) -> str: - return f"{self.email_to}\n{self.email_from}\n\n{self.email_content}" - - -class EmailServer: - def __init__(self, filepath="./emails.json"): - self.filepath = filepath - self._emails: dict[str, list[Email]] = self.load_emails() - - def load_emails(self) -> dict[str, list[Email]]: - try: - with open(self.filepath) as f: - data = json.load(f) - return {k: [Email(**email) for email in v] for k, v in data.items()} - except Exception as e: - print("Issues reading email file", e) - return {} - - def save_emails(self) -> None: - with open(self.filepath, "w") as f: - data = { - k: [email.to_dict() for email in v] for k, v in self._emails.items() - } - f.write(json.dumps(data)) - - def add_email_for_user(self, user_email: str, email: Email) -> None: - if user_email not in self._emails: - self._emails[user_email] = [] - self._emails[user_email].append(email) - self.save_emails() - - def get_emails_for_user(self, user_email: str) -> list[Email]: - self._emails: dict[str, list[Email]] = self.load_emails() - return self._emails.get(user_email, []) - - def reset_emails(self) -> None: - self._emails = {} - self.save_emails() - - -SENDER = "noreply@openmined.org" - - -def get_token(email) -> str: - # stdlib - import re - - pattern = r"syft_client\.reset_password\(token='(.*?)', new_password=.*?\)" - try: - token = re.search(pattern, email.email_content).group(1) - except Exception: - raise Exception(f"No token found in email: {email.email_content}") - return token - - -@dataclass -class TestUser: - name: str - email: str - password: str - role: ServiceRole - new_password: str | None = None - email_disabled: bool = False - reset_password: bool = False - reset_token: str | None = None - _client_cache: Any | None = field(default=None, repr=False, init=False) - _email_server: EmailServer | None = None - - @property - def latest_password(self) -> str: - if self.new_password: - return self.new_password - return self.password - - def make_new_password(self) -> str: - self.new_password = fake.password() - return self.new_password - - @property - def client(self): - return self._client_cache - - def relogin(self) -> None: - self.client = self.client - - @client.setter - def client(self, client): - client = client.login(email=self.email, password=self.latest_password) - self._client_cache = client - - def to_dict(self) -> dict: - output = {} - for k, v in self.__dict__.items(): - if k.startswith("_"): - continue - if k == "role": - v = str(v) - output[k] = v - return output - - def __iter__(self): - for key, val in self.to_dict().items(): - if not key.startswith("_"): - yield key, val - - def __getitem__(self, key): - if key.startswith("_"): - return None - return self.to_dict()[key] - - def update_password(self): - self.password = self.new_password - self.new_password = None - - @property - def emails(self) -> list[Email]: - if not self._email_server: - print("Not connected to email server object") - return [] - return self._email_server.get_emails_for_user(self.email) - - def get_token(self) -> str: - for email in reversed(self.emails): - token = None - try: - token = get_token(email) - break - except Exception: - pass - self.reset_token = token - return token - - -def save_users(users): - user_dicts = [] - for user in users: - user_dicts.append(user.to_dict()) - print(user_dicts) - with open("./users.json", "w") as f: - f.write(json.dumps(user_dicts)) - - -def load_users(high_client: None, path="./users.json"): - users = [] - with open(path) as f: - data = f.read() - user_dicts = json.loads(data) - for user in user_dicts: - test_user = TestUser(**user) - if high_client: - test_user.client = high_client - users.append(test_user) - return users - - -def make_user( - name: str | None = None, - email: str | None = None, - password: str | None = None, - role: ServiceRole = ServiceRole.DATA_SCIENTIST, -): - fake = Faker() - if name is None: - name = fake.name() - if email is None: - ascii_string = re.sub(r"[^a-zA-Z\s]", "", name).lower() - dashed_string = ascii_string.replace(" ", "-") - email = f"{dashed_string}-fake@openmined.org" - if password is None: - password = fake.password() - - return TestUser(name=name, email=email, password=password, role=role) - - -def user_exists(root_client, email: str) -> bool: - users = root_client.api.services.user - for user in users: - if user.email == email: - return True - return False - - -class SMTPTestServer: - def __init__(self, email_server): - self.port = 9025 - self.hostname = "0.0.0.0" - self._stop_event = asyncio.Event() - - # Simple email handler class - class SimpleHandler: - async def handle_DATA(self, server, session, envelope): - try: - print(f"> SMTPTestServer got an email for {envelope.rcpt_tos}") - email = Email( - email_from=envelope.mail_from, - email_to=envelope.rcpt_tos, - email_content=envelope.content.decode( - "utf-8", errors="replace" - ), - ) - email_server.add_email_for_user(envelope.rcpt_tos[0], email) - email_server.save_emails() - return "250 Message accepted for delivery" - except Exception as e: - print(f"> Error handling email: {e}") - return "550 Internal Server Error" - - try: - self.handler = SimpleHandler() - self.controller = Controller( - self.handler, hostname=self.hostname, port=self.port - ) - except Exception as e: - print(f"> Error initializing SMTPTestServer Controller: {e}") - - def start(self): - print(f"> Starting SMTPTestServer on: {self.hostname}:{self.port}") - asyncio.create_task(self.async_loop()) - - async def async_loop(self): - try: - print(f"> Starting SMTPTestServer on: {self.hostname}:{self.port}") - self.controller.start() - await ( - self._stop_event.wait() - ) # Wait until the event is set to stop the server - except Exception as e: - print(f"> Error with SMTPTestServer: {e}") - - def stop(self): - try: - print("> Stopping SMTPTestServer") - loop = asyncio.get_running_loop() - if loop.is_running(): - loop.create_task(self.async_stop()) - else: - asyncio.run(self.async_stop()) - except Exception as e: - print(f"> Error stopping SMTPTestServer: {e}") - - async def async_stop(self): - self.controller.stop() - self._stop_event.set() # Stop the server by setting the event - - -class TimeoutError(Exception): - pass - - -class Timeout: - def __init__(self, timeout_duration): - if timeout_duration > 60: - raise ValueError("Timeout duration cannot exceed 60 seconds.") - self.timeout_duration = timeout_duration - - def run_with_timeout(self, condition_func, *args, **kwargs): - start_time = time.time() - result = None - - while True: - elapsed_time = time.time() - start_time - if elapsed_time > self.timeout_duration: - raise TimeoutError( - f"Function execution exceeded {self.timeout_duration} seconds." - ) - - # Check if the condition is met - try: - if condition_func(): - print("Condition met, exiting early.") - break - except Exception as e: - print(f"Exception in target function: {e}") - break # Exit the loop if an exception occurs in the function - time.sleep(1) - - return result - - -def get_email_server(reset=False): - email_server = EmailServer() - if reset: - email_server.reset_emails() - smtp_server = SMTPTestServer(email_server) - smtp_server.start() - return email_server, smtp_server - - -def create_user(root_client, test_user): - if not user_exists(root_client, test_user.email): - fake = Faker() - root_client.register( - name=test_user.name, - email=test_user.email, - password=test_user.password, - password_verify=test_user.password, - institution=fake.company(), - website=fake.url(), - ) - else: - print("User already exists", test_user) diff --git a/test_helpers/job_helpers.py b/test_helpers/job_helpers.py deleted file mode 100644 index 78494d381e7..00000000000 --- a/test_helpers/job_helpers.py +++ /dev/null @@ -1,400 +0,0 @@ -# stdlib -from collections import defaultdict -from collections.abc import Callable -from dataclasses import dataclass -from dataclasses import field -import json -import random -import re -import secrets -import textwrap -from typing import Any - -# third party -from email_helpers import TestUser - -# syft absolute -from syft import test_settings - -from syft.client.client import SyftClient # noqa - -dataset_1 = test_settings.get("dataset_1", default="dataset_1") -dataset_2 = test_settings.get("dataset_2", default="dataset_2") -table_1 = test_settings.get("table_1", default="table_1") -table_2 = test_settings.get("table_2", default="table_2") -table_1_col_id = test_settings.get("table_1_col_id", default="table_id") -table_1_col_score = test_settings.get("table_1_col_score", default="colname") -table_2_col_id = test_settings.get("table_2_col_id", default="table_id") -table_2_col_score = test_settings.get("table_2_col_score", default="colname") - - -@dataclass -class TestJob: - user_email: str - func_name: str - query: str - job_type: str - settings: dict # make a type so we can rely on attributes - should_succeed: bool - should_submit: bool = True - code_path: str | None = field(default=None) - admin_reviewed: bool = False - result_as_expected: bool | None = None - - _client_cache: SyftClient | None = field(default=None, repr=False, init=False) - - @property - def is_submitted(self) -> bool: - return self.code_path is not None - - @property - def client(self): - return self._client_cache - - @client.setter - def client(self, client): - self._client_cache = client - - def to_dict(self) -> dict: - output = {} - for k, v in self.__dict__.items(): - if k.startswith("_"): - continue - output[k] = v - return output - - def __iter__(self): - for key, val in self.to_dict().items(): - if key.startswith("_"): - yield key, val - - def __getitem__(self, key): - if key.startswith("_"): - return None - return self.to_dict()[key] - - @property - def code_method(self) -> None | Callable: - try: - return getattr(self.client.code, self.func_name, None) - except Exception as e: - print(f"Cant find code method. {e}") - return None - - -def make_query(settings: dict) -> str: - query = f""" - SELECT {settings['groupby_col']}, AVG({settings['score_col']}) AS average_score - FROM {settings['dataset']}.{settings['table']} - GROUP BY {settings['groupby_col']} - LIMIT {settings['limit']}""".strip() - - return textwrap.dedent(query) - - -def create_simple_query_job(user: TestUser) -> TestJob: - job_type = "simple_query" - func_name = f"{job_type}_{secrets.token_hex(3)}" - - dataset = random.choice([dataset_1, dataset_2]) - table, groupby_col, score_col = random.choice( - [ - (table_1, table_1_col_id, table_1_col_score), - (table_2, table_2_col_id, table_2_col_score), - ] - ) - limit = random.randint(1, 1_000_000) - - settings = { - "dataset": dataset, - "table": table, - "groupby_col": groupby_col, - "score_col": score_col, - "limit": limit, - } - query = make_query(settings) - - result = TestJob( - user_email=user.email, - func_name=func_name, - query=query, - job_type=job_type, - settings=settings, - should_succeed=True, - ) - - result.client = user.client - return result - - -def create_wrong_asset_query(user: TestUser) -> TestJob: - job_type = "wrong_asset_query" - func_name = f"{job_type}_{secrets.token_hex(3)}" - - valid_job = create_simple_query_job(user) - settings = valid_job.settings - corrupted_asset = random.choice(["dataset", "table"]) - settings[corrupted_asset] = "wrong_asset" - query = make_query(settings) - - result = TestJob( - user_email=user.email, - func_name=func_name, - query=query, - job_type=job_type, - settings=settings, - should_succeed=False, - ) - - result.client = user.client - return result - - -def create_wrong_syntax_query(user: TestUser) -> TestJob: - job_type = "wrong_syntax_query" - func_name = f"{job_type}_{secrets.token_hex(3)}" - - query = "SELECT * FROM table INCORRECT SYNTAX" - - result = TestJob( - user_email=user.email, - func_name=func_name, - query=query, - job_type=job_type, - settings={}, - should_succeed=False, - ) - - result.client = user.client - return result - - -def create_long_query_job(user: TestUser) -> TestJob: - job_type = "job_too_much_text" - func_name = f"{job_type}_{secrets.token_hex(3)}" - - query = "a" * 1_000 - - result = TestJob( - user_email=user.email, - func_name=func_name, - query=query, - job_type=job_type, - settings={}, - should_succeed=False, - ) - - result.client = user.client - return result - - -def create_query_long_name(user: TestUser) -> TestJob: - job_type = "job_long_name" - func_name = f"{job_type}_{secrets.token_hex(3)}" - - job = create_simple_query_job(user) - - job.job_type = job_type - job.func_name = func_name + "a" * 1_000 - - return job - - -def create_job_funcname_xss(user: TestUser) -> TestJob: - job_type = "job_funcname_xss" - func_name = f"{job_type}_{secrets.token_hex(3)}" - func_name += "" - - job = create_simple_query_job(user) - job.job_type = job_type - job.func_name = func_name - job.should_submit = False - return job - - -def get_request_for_job_info(requests, job): - job_requests = [r for r in requests if r.code.service_func_name == job.func_name] - if len(job_requests) != 1: - raise Exception(f"Too many or too few requests: {job} in requests: {requests}") - return job_requests[0] - - -def create_job_query_xss(user: TestUser) -> TestJob: - job_type = "job_query_xss" - func_name = f"{job_type}_{secrets.token_hex(3)}" - - job = create_simple_query_job(user) - job.job_type = job_type - job.func_name = func_name - job.query += "" - job.should_succeed = False - - return job - - -def create_job_many_columns(user: TestUser) -> TestJob: - job_type = "job_many_columns" - func_name = f"{job_type}_{secrets.token_hex(3)}" - - job = create_simple_query_job(user) - job.job_type = job_type - job.func_name = func_name - settings = job.settings - job.settings["num_extra_cols"] = random.randint(100, 1000) - - new_columns_string = ", ".join( - f"{settings['score_col']} as col_{i}" for i in range(settings["num_extra_cols"]) - ) - - job.query = f""" - SELECT {settings['groupby_col']}, AVG({settings['score_col']}) AS average_score, {new_columns_string} - FROM {settings['dataset']}.{settings['table']} - GROUP BY {settings['groupby_col']} - LIMIT {settings['limit']}""".strip() - - return job - - -def create_random_job(user: TestUser) -> TestJob: - job_func = random.choice(create_job_functions) - return job_func(user) - - -def create_jobs(users: list[TestUser], total_jobs: int = 10) -> list[TestJob]: - jobs = [] - num_users = len(users) - user_index = 0 - each_count = 0 - # keep making jobs until we have enough - while len(jobs) < total_jobs: - # if we havent used each job type yet keep getting the next one - if each_count < len(create_job_functions): - job_func = create_job_functions[each_count] - each_count += 1 - else: - # otherwise lets get a random one - job_func = create_random_job - # use the current index of user - jobs.append(job_func(users[user_index])) - - # only go as high as the last user index - if user_index < num_users - 1: - user_index += 1 - else: - # reset back to the first user - user_index = 0 - - # in case we stuffed up - if len(jobs) > total_jobs: - jobs = jobs[:total_jobs] - return jobs - - -def submit_job(job: TestJob) -> tuple[Any, str]: - client = job.client - response = client.api.services.bigquery.submit_query( - func_name=job.func_name, query=job.query - ) - job.code_path = extract_code_path(response) - return response - - -def extract_code_path(response) -> str | None: - pattern = r"client\.code\.(\w+)\(\)" - match = re.search(pattern, str(response)) - if match: - extracted_code = match.group(1) - return extracted_code - return None - - -def approve_by_running(request): - job = request.code(blocking=False) - result = job.wait() - print("got result of type", type(result), "bool", bool(result)) - # got result of type bool False - # assert result won't work unless we know what type is coming back - job_info = job.info(result=True) - # need force when running multiple times - # todo check and dont run if its already done - response = request.deposit_result(job_info, approve=True, force=True) - return response - - -def get_job_emails(jobs, client, email_server): - all_requests = client.requests - res = {} - for job in jobs: - request = get_request_for_job_info(all_requests, job) - emails = email_server.get_emails_for_user(request.requesting_user_email) - res[request.requesting_user_email] = emails - return res - - -def resolve_request(request): - service_func_name = request.code.service_func_name - if service_func_name.startswith("simple_query"): - request.approve() # approve because it is good - if service_func_name.startswith("wrong_asset_query"): - request.approve() # approve because it is bad - if service_func_name.startswith("wrong_syntax_query"): - request.approve() # approve because it is bad - if service_func_name.startswith("job_too_much_text"): - request.deny(reason="too long, boring!") # deny because it is bad - if service_func_name.startswith("job_long_name"): - request.approve() - if service_func_name.startswith("job_funcname_xss"): - request.deny(reason="too long, boring!") # never reach doesnt matter - if service_func_name.startswith("job_query_xss"): - request.approve() # approve because it is bad - if service_func_name.startswith("job_many_columns"): - request.approve() # approve because it is bad - - return (request.id, request.status) - - -create_job_functions = [ - create_simple_query_job, # quick way to increase the odds - create_simple_query_job, - create_simple_query_job, - create_simple_query_job, - create_simple_query_job, - create_simple_query_job, - create_wrong_syntax_query, - create_long_query_job, - create_query_long_name, - create_job_funcname_xss, - create_job_query_xss, - create_job_many_columns, -] - - -def save_jobs(jobs, filepath="./jobs.json"): - user_jobs = defaultdict(list) - for job in jobs: - user_jobs[job.user_email].append(job.to_dict()) - with open(filepath, "w") as f: - f.write(json.dumps(user_jobs)) - - -def load_jobs(users, high_client, filepath="./jobs.json"): - data = {} - try: - with open(filepath) as f: - data = json.loads(f.read()) - except Exception as e: - print(f"cant read file: {filepath}: {e}") - data = {} - jobs_list = [] - for user in users: - if user.email not in data: - print(f"{user.email} missing from jobs") - continue - user_jobs = data[user.email] - for user_job in user_jobs: - test_job = TestJob(**user_job) - if user._client_cache is None: - user.client = high_client - test_job.client = user.client - jobs_list.append(test_job) - return jobs_list diff --git a/test_helpers/sync_helpers.py b/test_helpers/sync_helpers.py deleted file mode 100644 index e1d558016ba..00000000000 --- a/test_helpers/sync_helpers.py +++ /dev/null @@ -1,190 +0,0 @@ -# third party -from tqdm import tqdm - -# syft absolute -import syft as sy -from syft.client.datasite_client import DatasiteClient -from syft.client.syncing import compare_clients -from syft.service.code.user_code import UserCode -from syft.service.job.job_stash import Job -from syft.service.job.job_stash import JobStatus -from syft.service.request.request import Request -from syft.service.request.request import RequestStatus -from syft.service.sync.diff_state import ObjectDiffBatch -from syft.types.result import Err - - -def deny_requests_without_autosync_tag(client_low: DatasiteClient): - # Deny all requests that are not autosync - requests = client_low.requests.get_all() - if isinstance(requests, sy.SyftError): - print(requests) - return - - denied_requests = [] - for request in tqdm(requests): - if request.status != RequestStatus.PENDING: - continue - if "autosync" not in request.tags: - request.deny( - reason="This request has been denied automatically. " - "Please use the designated API to submit your request." - ) - denied_requests.append(request.id) - print(f"Denied {len(denied_requests)} requests without autosync tag") - - -def is_request_to_sync(batch: ObjectDiffBatch) -> bool: - # True if this is a new low-side request - # TODO add condition for sql requests/usercodes - low_request = batch.root.low_obj - return ( - isinstance(low_request, Request) - and batch.status == "NEW" - and "autosync" in low_request.tags - ) - - -def is_job_to_sync(batch: ObjectDiffBatch): - # True if this is a new high-side job that is either COMPLETED or ERRORED - if batch.status != "NEW": - return False - if not isinstance(batch.root.high_obj, Job): - return False - job = batch.root.high_obj - return job.status in (JobStatus.ERRORED, JobStatus.COMPLETED) - - -def execute_requests( - client_high: DatasiteClient, request_ids: list[sy.UID] -) -> dict[sy.UID, Job]: - jobs_by_request_id = {} - for request_id in request_ids: - request = client_high.requests.get_by_uid(request_id) - if not isinstance(request, Request): - continue - - code = request.code - if not isinstance(code, UserCode): - continue - - func_name = request.code.service_func_name - api_func = getattr(client_high.code, func_name, None) - if api_func is None: - continue - - job = api_func(blocking=False) - jobs_by_request_id[request_id] = job - - return jobs_by_request_id - - -def deny_failed_jobs( - client_low: DatasiteClient, - jobs: list[Job], -) -> None: - # NOTE no syncing is needed, requests are denied on the low side - denied_requests = [] - - for job in jobs: - if job.status != JobStatus.ERRORED: - continue - - error_result = job.result - if isinstance(error_result, Err): - error_msg = error_result.err_value - else: - error_msg = "An unknown error occurred, please check the Job logs for more information." - - code_id = job.user_code_id - if code_id is None: - continue - requests = client_low.requests.get_by_usercode_id(code_id) - if isinstance(requests, list) and len(requests) > 0: - request = requests[0] - request.deny(reason=f"Execution failed: {error_msg}") - denied_requests.append(request.id) - else: - print(f"Failed to deny request for job {job.id}") - - print(f"Denied {len(denied_requests)} failed requests") - - -def sync_finished_jobs( - client_low: DatasiteClient, - client_high: DatasiteClient, -) -> dict[sy.UID, sy.SyftError | sy.SyftSuccess] | sy.SyftError: - sync_job_results = {} - synced_jobs = [] - diff = compare_clients( - from_client=client_high, to_client=client_low, include_types=["job"] - ) - if isinstance(diff, sy.SyftError): - print(diff) - return diff - - for batch in diff.batches: - if is_job_to_sync(batch): - job = batch.root.high_obj - - w = batch.resolve(build_state=False) - share_result = w.click_share_all_private_data() - if isinstance(share_result, sy.SyftError): - sync_job_results[job.id] = share_result - continue - sync_result = w.click_sync() - - synced_jobs.append(job) - sync_job_results[job.id] = sync_result - - print(f"Sharing {len(sync_job_results)} new results") - deny_failed_jobs(client_low, synced_jobs) - return sync_job_results - - -def sync_new_requests( - client_low: DatasiteClient, - client_high: DatasiteClient, -) -> dict[sy.UID, sy.SyftSuccess | sy.SyftError] | sy.SyftError: - sync_request_results = {} - diff = compare_clients( - from_client=client_low, to_client=client_high, include_types=["request"] - ) - if isinstance(diff, sy.SyftError): - print(diff) - return sync_request_results - print(f"{len(diff.batches)} request batches found") - for batch in tqdm(diff.batches): - if is_request_to_sync(batch): - request_id = batch.root.low_obj.id - w = batch.resolve(build_state=False) - result = w.click_sync() - sync_request_results[request_id] = result - return sync_request_results - - -def sync_and_execute_new_requests( - client_low: DatasiteClient, client_high: DatasiteClient -) -> None: - sync_results = sync_new_requests(client_low, client_high) - if isinstance(sync_results, sy.SyftError): - print(sync_results) - return - - request_ids = [ - uid for uid, res in sync_results.items() if isinstance(res, sy.SyftSuccess) - ] - print(f"Synced {len(request_ids)} new requests") - - jobs_by_request = execute_requests(client_high, request_ids) - print(f"Started {len(jobs_by_request)} new jobs") - - -def auto_sync(client_low: DatasiteClient, client_high: DatasiteClient) -> None: - print("Starting auto sync") - print("Denying non tagged jobs") - deny_requests_without_autosync_tag(client_low) - print("Syncing and executing") - sync_and_execute_new_requests(client_low, client_high) - sync_finished_jobs(client_low, client_high) - print("Finished auto sync") From 145c63a0cdc7efcbf830d8b51d2b332ad66644ad Mon Sep 17 00:00:00 2001 From: Sameer Wagh Date: Fri, 13 Sep 2024 11:15:20 -0400 Subject: [PATCH 05/11] Changing pre-commit hook for absolutify Co-authored-by: Brendan Schell --- .pre-commit-config.yaml | 3 ++- packages/syft/src/syft/util/test_helpers/apis/live/schema.py | 4 ++-- .../syft/src/syft/util/test_helpers/apis/live/test_query.py | 4 ++-- packages/syft/src/syft/util/test_helpers/apis/mock/schema.py | 4 ++-- .../syft/src/syft/util/test_helpers/apis/mock/test_query.py | 4 ++-- 5 files changed, 10 insertions(+), 9 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index c4f5669130d..521e3f9b60c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -47,7 +47,8 @@ repos: packages/syft/src/syft/proto.*| packages/syft/tests/syft/lib/python.*| packages/grid.*| - packages/syft/src/syft/federated/model_serialization/protos.py + packages/syft/src/syft/federated/model_serialization/protos.py| + packages/syft/src/syft/util/test_helpers/.*| )$ - repo: https://github.com/MarcoGorelli/absolufy-imports diff --git a/packages/syft/src/syft/util/test_helpers/apis/live/schema.py b/packages/syft/src/syft/util/test_helpers/apis/live/schema.py index 8b9e753fe47..7a63ab467d1 100644 --- a/packages/syft/src/syft/util/test_helpers/apis/live/schema.py +++ b/packages/syft/src/syft/util/test_helpers/apis/live/schema.py @@ -41,8 +41,8 @@ def live_schema( from google.oauth2 import service_account import pandas as pd - # relative - from ..... import SyftException + # syft absolute + from syft import SyftException # Auth for Bigquer based on the workload identity credentials = service_account.Credentials.from_service_account_info( diff --git a/packages/syft/src/syft/util/test_helpers/apis/live/test_query.py b/packages/syft/src/syft/util/test_helpers/apis/live/test_query.py index 6384dfca452..cca61eae533 100644 --- a/packages/syft/src/syft/util/test_helpers/apis/live/test_query.py +++ b/packages/syft/src/syft/util/test_helpers/apis/live/test_query.py @@ -34,8 +34,8 @@ def live_test_query( from google.cloud import bigquery # noqa: F811 from google.oauth2 import service_account - # relative - from ..... import SyftException + # syft absolute + from syft import SyftException # Auth for Bigquer based on the workload identity credentials = service_account.Credentials.from_service_account_info( diff --git a/packages/syft/src/syft/util/test_helpers/apis/mock/schema.py b/packages/syft/src/syft/util/test_helpers/apis/mock/schema.py index c4c7216b20a..a95e04f2f1d 100644 --- a/packages/syft/src/syft/util/test_helpers/apis/mock/schema.py +++ b/packages/syft/src/syft/util/test_helpers/apis/mock/schema.py @@ -26,8 +26,8 @@ def make_schema(settings, worker_pool) -> Callable: def mock_schema( context, ) -> str: - # relative - from ..... import SyftException + # syft absolute + from syft import SyftException # Store a dict with the calltimes for each user, via the email. if context.settings["rate_limiter_enabled"]: diff --git a/packages/syft/src/syft/util/test_helpers/apis/mock/test_query.py b/packages/syft/src/syft/util/test_helpers/apis/mock/test_query.py index 1937dcecb80..ae028a8cf36 100644 --- a/packages/syft/src/syft/util/test_helpers/apis/mock/test_query.py +++ b/packages/syft/src/syft/util/test_helpers/apis/mock/test_query.py @@ -85,8 +85,8 @@ def mock_test_query( # third party from google.api_core.exceptions import BadRequest - # relative - from ..... import SyftException + # syft absolute + from syft import SyftException # Store a dict with the calltimes for each user, via the email. if context.settings["rate_limiter_enabled"]: From ed795877956b59be33449708a5d3bb374e404654 Mon Sep 17 00:00:00 2001 From: Sameer Wagh Date: Fri, 13 Sep 2024 11:23:45 -0400 Subject: [PATCH 06/11] Changed imports --- .../scenarios/bigquery/021-create-jobs.ipynb | 16 ++++++++-------- .../bigquery/040-do-review-requests.ipynb | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/notebooks/scenarios/bigquery/021-create-jobs.ipynb b/notebooks/scenarios/bigquery/021-create-jobs.ipynb index e576d45dff8..392103a751c 100644 --- a/notebooks/scenarios/bigquery/021-create-jobs.ipynb +++ b/notebooks/scenarios/bigquery/021-create-jobs.ipynb @@ -120,8 +120,8 @@ "metadata": {}, "outputs": [], "source": [ - "# third party\n", - "from email_helpers import load_users" + "# syft absolute\n", + "from syft.util.test_helpers.email_helpers import load_users" ] }, { @@ -149,10 +149,10 @@ "metadata": {}, "outputs": [], "source": [ - "# third party\n", - "from job_helpers import TestJob\n", - "from job_helpers import create_jobs\n", - "from job_helpers import extract_code_path" + "# syft absolute\n", + "from syft.util.test_helpers.job_helpers import TestJob\n", + "from syft.util.test_helpers.job_helpers import create_jobs\n", + "from syft.util.test_helpers.job_helpers import extract_code_path" ] }, { @@ -194,8 +194,8 @@ "metadata": {}, "outputs": [], "source": [ - "# third party\n", - "from job_helpers import save_jobs" + "# syft absolute\n", + "from syft.util.test_helpers.job_helpers import save_jobs" ] }, { diff --git a/notebooks/scenarios/bigquery/040-do-review-requests.ipynb b/notebooks/scenarios/bigquery/040-do-review-requests.ipynb index 2f5b9fc00e5..8acc4e55274 100644 --- a/notebooks/scenarios/bigquery/040-do-review-requests.ipynb +++ b/notebooks/scenarios/bigquery/040-do-review-requests.ipynb @@ -95,10 +95,10 @@ "metadata": {}, "outputs": [], "source": [ - "# third party\n", - "from email_helpers import load_users\n", - "from job_helpers import load_jobs\n", - "from job_helpers import save_jobs" + "# syft absolute\n", + "from syft.util.test_helpers.email_helpers import load_users\n", + "from syft.util.test_helpers.job_helpers import load_jobs\n", + "from syft.util.test_helpers.job_helpers import save_jobs" ] }, { From 2a6009e9910869db2c04f1438c1140da552dc682 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Wed, 18 Sep 2024 10:32:23 -0400 Subject: [PATCH 07/11] Add security ignore comments to helpers Co-authored-by: Sameer Wagh --- .../syft/util/test_helpers/email_helpers.py | 4 +- .../src/syft/util/test_helpers/job_helpers.py | 16 ++-- .../syft/util/test_helpers/worker_helpers.py | 86 +++++++++++++++++++ 3 files changed, 96 insertions(+), 10 deletions(-) create mode 100644 packages/syft/src/syft/util/test_helpers/worker_helpers.py diff --git a/packages/syft/src/syft/util/test_helpers/email_helpers.py b/packages/syft/src/syft/util/test_helpers/email_helpers.py index e9aa83037fc..66a476c566d 100644 --- a/packages/syft/src/syft/util/test_helpers/email_helpers.py +++ b/packages/syft/src/syft/util/test_helpers/email_helpers.py @@ -162,7 +162,7 @@ def get_token(self) -> str: try: token = get_token(email) break - except Exception: + except Exception: # nosec pass self.reset_token = token return token @@ -220,7 +220,7 @@ def user_exists(root_client, email: str) -> bool: class SMTPTestServer: def __init__(self, email_server): self.port = 9025 - self.hostname = "0.0.0.0" + self.hostname = "127.0.0.1" self._stop_event = asyncio.Event() # Simple email handler class diff --git a/packages/syft/src/syft/util/test_helpers/job_helpers.py b/packages/syft/src/syft/util/test_helpers/job_helpers.py index ac26ad5f8ff..bac08bad5d6 100644 --- a/packages/syft/src/syft/util/test_helpers/job_helpers.py +++ b/packages/syft/src/syft/util/test_helpers/job_helpers.py @@ -85,7 +85,7 @@ def make_query(settings: dict) -> str: SELECT {settings['groupby_col']}, AVG({settings['score_col']}) AS average_score FROM {settings['dataset']}.{settings['table']} GROUP BY {settings['groupby_col']} - LIMIT {settings['limit']}""".strip() + LIMIT {settings['limit']}""".strip() # nosec: B608 return textwrap.dedent(query) @@ -94,14 +94,14 @@ def create_simple_query_job(user: TestUser) -> TestJob: job_type = "simple_query" func_name = f"{job_type}_{secrets.token_hex(3)}" - dataset = random.choice([dataset_1, dataset_2]) - table, groupby_col, score_col = random.choice( + dataset = random.choice([dataset_1, dataset_2]) # nosec: B311 + table, groupby_col, score_col = random.choice( # nosec: B311 [ (table_1, table_1_col_id, table_1_col_score), (table_2, table_2_col_id, table_2_col_score), ] ) - limit = random.randint(1, 1_000_000) + limit = random.randint(1, 1_000_000) # nosec: B311 settings = { "dataset": dataset, @@ -131,7 +131,7 @@ def create_wrong_asset_query(user: TestUser) -> TestJob: valid_job = create_simple_query_job(user) settings = valid_job.settings - corrupted_asset = random.choice(["dataset", "table"]) + corrupted_asset = random.choice(["dataset", "table"]) # nosec: B311 settings[corrupted_asset] = "wrong_asset" query = make_query(settings) @@ -238,7 +238,7 @@ def create_job_many_columns(user: TestUser) -> TestJob: job.job_type = job_type job.func_name = func_name settings = job.settings - job.settings["num_extra_cols"] = random.randint(100, 1000) + job.settings["num_extra_cols"] = random.randint(100, 1000) # nosec: B311 new_columns_string = ", ".join( f"{settings['score_col']} as col_{i}" for i in range(settings["num_extra_cols"]) @@ -248,13 +248,13 @@ def create_job_many_columns(user: TestUser) -> TestJob: SELECT {settings['groupby_col']}, AVG({settings['score_col']}) AS average_score, {new_columns_string} FROM {settings['dataset']}.{settings['table']} GROUP BY {settings['groupby_col']} - LIMIT {settings['limit']}""".strip() + LIMIT {settings['limit']}""".strip() # nosec: B608 return job def create_random_job(user: TestUser) -> TestJob: - job_func = random.choice(create_job_functions) + job_func = random.choice(create_job_functions) # nosec: B311 return job_func(user) diff --git a/packages/syft/src/syft/util/test_helpers/worker_helpers.py b/packages/syft/src/syft/util/test_helpers/worker_helpers.py new file mode 100644 index 00000000000..0b848503514 --- /dev/null +++ b/packages/syft/src/syft/util/test_helpers/worker_helpers.py @@ -0,0 +1,86 @@ +# syft absolute +import syft as sy + + +def build_and_launch_worker_pool_from_docker_str( + environment: str, + client: sy.DatasiteClient, + worker_pool_name: str, + custom_pool_pod_annotations: dict, + custom_pool_pod_labels: dict, + worker_dockerfile: str, + external_registry: str, + docker_tag: str, + scale_to: int, +): + result = client.api.services.image_registry.add(external_registry) + assert "success" in result.message + + # For some reason, when using k9s, result.value is empty so can't use the below line + # local_registry = result.value + local_registry = client.api.services.image_registry[0] + + docker_config = sy.DockerWorkerConfig(dockerfile=worker_dockerfile) + assert docker_config.dockerfile == worker_dockerfile + submit_result = client.api.services.worker_image.submit(worker_config=docker_config) + print(submit_result.message) + assert "success" in submit_result.message + + worker_image = submit_result.value + + if environment == "remote": + docker_build_result = client.api.services.worker_image.build( + image_uid=worker_image.id, + tag=docker_tag, + registry_uid=local_registry.id, + ) + print(docker_build_result) + + if environment == "remote": + push_result = client.api.services.worker_image.push(worker_image.id) + print(push_result) + + result = client.api.services.worker_pool.launch( + pool_name=worker_pool_name, + image_uid=worker_image.id, + num_workers=1, + pod_annotations=custom_pool_pod_annotations, + pod_labels=custom_pool_pod_labels, + ) + print(result) + # assert 'success' in str(result.message) + + if environment == "remote": + result = client.worker_pools.scale(number=scale_to, pool_name=worker_pool_name) + print(result) + + +def launch_worker_pool_from_docker_tag_and_registry( + environment: str, + client: sy.DatasiteClient, + worker_pool_name: str, + custom_pool_pod_annotations: dict, + custom_pool_pod_labels: dict, + docker_tag: str, + external_registry: str, + scale_to: int = 1, +): + res = client.api.services.image_registry.add(external_registry) + assert "success" in res.message + docker_config = sy.PrebuiltWorkerConfig(tag=docker_tag) + image_result = client.api.services.worker_image.submit(worker_config=docker_config) + assert "success" in res.message + worker_image = image_result.value + + launch_result = client.api.services.worker_pool.launch( + pool_name=worker_pool_name, + image_uid=worker_image.id, + num_workers=1, + pod_annotations=custom_pool_pod_annotations, + pod_labels=custom_pool_pod_labels, + ) + if environment == "remote" and scale_to > 1: + result = client.worker_pools.scale(number=scale_to, pool_name=worker_pool_name) + print(result) + + return launch_result \ No newline at end of file From 3efb2505576a23bd42794442095fb5c9290b2278 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Wed, 18 Sep 2024 10:38:47 -0400 Subject: [PATCH 08/11] add security skips on worker_helpers Co-authored-by: Sameer Wagh --- .../src/syft/util/test_helpers/worker_helpers.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/packages/syft/src/syft/util/test_helpers/worker_helpers.py b/packages/syft/src/syft/util/test_helpers/worker_helpers.py index 0b848503514..3c2667fecc8 100644 --- a/packages/syft/src/syft/util/test_helpers/worker_helpers.py +++ b/packages/syft/src/syft/util/test_helpers/worker_helpers.py @@ -14,17 +14,17 @@ def build_and_launch_worker_pool_from_docker_str( scale_to: int, ): result = client.api.services.image_registry.add(external_registry) - assert "success" in result.message + assert "success" in result.message # nosec: B101 # For some reason, when using k9s, result.value is empty so can't use the below line # local_registry = result.value local_registry = client.api.services.image_registry[0] docker_config = sy.DockerWorkerConfig(dockerfile=worker_dockerfile) - assert docker_config.dockerfile == worker_dockerfile + assert docker_config.dockerfile == worker_dockerfile # nosec: B101 submit_result = client.api.services.worker_image.submit(worker_config=docker_config) print(submit_result.message) - assert "success" in submit_result.message + assert "success" in submit_result.message # nosec: B101 worker_image = submit_result.value @@ -66,10 +66,10 @@ def launch_worker_pool_from_docker_tag_and_registry( scale_to: int = 1, ): res = client.api.services.image_registry.add(external_registry) - assert "success" in res.message + assert "success" in res.message # nosec: B101 docker_config = sy.PrebuiltWorkerConfig(tag=docker_tag) image_result = client.api.services.worker_image.submit(worker_config=docker_config) - assert "success" in res.message + assert "success" in res.message # nosec: B101 worker_image = image_result.value launch_result = client.api.services.worker_pool.launch( @@ -83,4 +83,4 @@ def launch_worker_pool_from_docker_tag_and_registry( result = client.worker_pools.scale(number=scale_to, pool_name=worker_pool_name) print(result) - return launch_result \ No newline at end of file + return launch_result From ba216dc51ed4536c657f198836fd3ac4e437a69f Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Wed, 18 Sep 2024 15:15:20 -0400 Subject: [PATCH 09/11] revert host back to 0.0.0.0 - fix this later Co-authored-by: Sameer Wagh --- packages/syft/src/syft/util/test_helpers/email_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/util/test_helpers/email_helpers.py b/packages/syft/src/syft/util/test_helpers/email_helpers.py index 66a476c566d..16904a6ccfd 100644 --- a/packages/syft/src/syft/util/test_helpers/email_helpers.py +++ b/packages/syft/src/syft/util/test_helpers/email_helpers.py @@ -220,7 +220,7 @@ def user_exists(root_client, email: str) -> bool: class SMTPTestServer: def __init__(self, email_server): self.port = 9025 - self.hostname = "127.0.0.1" + self.hostname = "0.0.0.0" self._stop_event = asyncio.Event() # Simple email handler class From 27132df4cbf6a3b0d5a39a010dca4b6454044015 Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Wed, 18 Sep 2024 15:19:41 -0400 Subject: [PATCH 10/11] Ignore unbound port sec issue in helper Co-authored-by: Sameer Wagh --- packages/syft/src/syft/util/test_helpers/email_helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/util/test_helpers/email_helpers.py b/packages/syft/src/syft/util/test_helpers/email_helpers.py index 16904a6ccfd..ddfee82fef3 100644 --- a/packages/syft/src/syft/util/test_helpers/email_helpers.py +++ b/packages/syft/src/syft/util/test_helpers/email_helpers.py @@ -220,7 +220,7 @@ def user_exists(root_client, email: str) -> bool: class SMTPTestServer: def __init__(self, email_server): self.port = 9025 - self.hostname = "0.0.0.0" + self.hostname = "0.0.0.0" # nosec: B104 self._stop_event = asyncio.Event() # Simple email handler class From ca9dee6877d0306f68e9c243b3874a3060d3742b Mon Sep 17 00:00:00 2001 From: Brendan Schell Date: Wed, 18 Sep 2024 17:05:44 -0400 Subject: [PATCH 11/11] add aiosmtpd to deps --- tox.ini | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tox.ini b/tox.ini index 15353917ac2..6a54b437098 100644 --- a/tox.ini +++ b/tox.ini @@ -380,6 +380,7 @@ deps = nbmake db-dtypes google-cloud-bigquery + aiosmtpd changedir = {toxinidir}/notebooks allowlist_externals = bash @@ -494,6 +495,7 @@ deps = nbmake db-dtypes google-cloud-bigquery + aiosmtpd changedir = {toxinidir}/notebooks allowlist_externals = bash