From 6b1df3e73a27d7ab0825674e2aa1b5ea20e865ec Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Tue, 20 Aug 2024 17:33:45 +1000 Subject: [PATCH 01/13] Added initial scenario testing prototype --- tests/scenario/bigquery/1_setup_test.py | 108 ++++++++ tests/scenario/bigquery/events.json | 3 + tests/scenario/bigquery/fixtures.py | 329 ++++++++++++++++++++++++ tests/scenario/bigquery/prototype.ipynb | 277 ++++++++++++++++++++ 4 files changed, 717 insertions(+) create mode 100644 tests/scenario/bigquery/1_setup_test.py create mode 100644 tests/scenario/bigquery/events.json create mode 100644 tests/scenario/bigquery/fixtures.py create mode 100644 tests/scenario/bigquery/prototype.ipynb diff --git a/tests/scenario/bigquery/1_setup_test.py b/tests/scenario/bigquery/1_setup_test.py new file mode 100644 index 00000000000..e23a8e3fed9 --- /dev/null +++ b/tests/scenario/bigquery/1_setup_test.py @@ -0,0 +1,108 @@ +# pip install pytest-asyncio pytest-timeout +# stdlib +import asyncio + +# third party +from faker import Faker +from fixtures import * +import pytest + + +# An async function that returns "Hello, World!" +async def hello_world(): + await asyncio.sleep(1) # Simulate some async work + return "Hello, World!" + + +# # An async test function using pytest-asyncio +# @pytest.mark.asyncio +# async def test_hello_world(): +# result = await hello_world() +# assert result == "Hello, World!" + + +@pytest.mark.asyncio +async def run_mock_dataframe_scenario(manager, set_event: bool = True): + manager.reset_test_state() + + USERS_CREATED = "users_created" + MOCK_READABLE = "mock_readable" + + fake = Faker() + + admin = make_admin() + + server = make_server(admin) + + root_client = admin.client(server) + + dataset_name = fake.name() + dataset = create_dataset(name=dataset_name) + + result = await hello_world() + assert result == "Hello, World!" + + upload_dataset(root_client, dataset) + + users = [make_user() for i in range(2)] + + def create_users(root_client, manager, users): + for test_user in users: + create_user(root_client, test_user) + manager.register_event(USERS_CREATED) + + def user_can_read_mock_dataset(server, manager, user, dataset_name): + print("waiting ", USERS_CREATED) + with WaitForEvent(manager, USERS_CREATED, retry_secs=1): + print("logging in user") + user_client = user.client(server) + print("getting dataset", dataset_name) + mock = user_client.api.services.dataset[dataset_name].assets[0].mock + df = trade_flow_df_mock(trade_flow_df()) + assert df.equals(mock) + if set_event: + manager.register_event(MOCK_READABLE) + + user = users[0] + + asyncit( + user_can_read_mock_dataset, + server=server, + manager=manager, + user=user, + dataset_name=dataset_name, + ) + + asyncit(create_users, root_client=root_client, manager=manager, users=users) + + server.land() + + +@pytest.mark.asyncio +async def test_can_read_mock_dataframe(): + manager = TestEventManager() + MOCK_READABLE = "mock_readable" + await run_mock_dataframe_scenario(manager) + + async with AsyncWaitForEvent(manager, MOCK_READABLE, retry_secs=1, timeout_secs=10): + print("Test Complete") + result = manager.get_event_or_raise(MOCK_READABLE) + assert result + + loop = asyncio.get_event_loop() + loop.stop() + + +# @pytest.mark.asyncio +# async def test_cant_read_mock_dataframe(): +# manager = TestEventManager() +# MOCK_READABLE = "mock_readable" +# await run_mock_dataframe_scenario(manager, set_event=False) + +# async with AsyncWaitForEvent(manager, MOCK_READABLE, retry_secs=1, timeout_secs=10): +# print("Test Complete") +# with pytest.raises(Exception): +# result = manager.get_event_or_raise(MOCK_READABLE) + +# loop = asyncio.get_event_loop() +# loop.stop() diff --git a/tests/scenario/bigquery/events.json b/tests/scenario/bigquery/events.json new file mode 100644 index 00000000000..dd3bd15a160 --- /dev/null +++ b/tests/scenario/bigquery/events.json @@ -0,0 +1,3 @@ +{ + "events": {} +} diff --git a/tests/scenario/bigquery/fixtures.py b/tests/scenario/bigquery/fixtures.py new file mode 100644 index 00000000000..ae3fec3e7bb --- /dev/null +++ b/tests/scenario/bigquery/fixtures.py @@ -0,0 +1,329 @@ +# stdlib +import asyncio +from collections.abc import Callable +from dataclasses import asdict +from dataclasses import dataclass +from dataclasses import field +from datetime import datetime +import json +import os +from threading import Event as ThreadingEvent +import time +from typing import Any + +# third party +from faker import Faker +import pandas as pd + +# syft absolute +import syft as sy +from syft import autocache +from syft.service.user.user_roles import ServiceRole + +loop = None + + +def get_or_create_event_loop(): + try: + # Try to get the current running event loop + loop = asyncio.get_running_loop() + except RuntimeError: + # No running event loop, so create a new one + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + return loop + + +loop = get_or_create_event_loop() +loop.set_debug(True) + + +def make_server(test_user_admin): + server = sy.orchestra.launch( + name="test-datasite-1", port="auto", dev_mode=True, reset=True + ) + return server + + +def get_root_client(server, test_user_admin): + return server.login(email=test_user_admin.email, password=test_user_admin.password) + + +def trade_flow_df(): + canada_dataset_url = "https://github.com/OpenMined/datasets/blob/main/trade_flow/ca%20-%20feb%202021.csv?raw=True" + df = pd.read_csv(autocache(canada_dataset_url)) + return df + + +def trade_flow_df_mock(df): + return df[10:20] + + +def create_dataset(name: str): + df = trade_flow_df() + ca_data = df[0:10] + mock_ca_data = trade_flow_df_mock(df) + dataset = sy.Dataset(name=name) + dataset.set_description("Canada Trade Data Markdown Description") + dataset.set_summary("Canada Trade Data Short Summary") + dataset.add_citation("Person, place or thing") + dataset.add_url("https://github.com/OpenMined/datasets/tree/main/trade_flow") + dataset.add_contributor( + name="Andrew Trask", + email="andrew@openmined.org", + note="Andrew runs this datasite and prepared the dataset metadata.", + ) + dataset.add_contributor( + name="Madhava Jay", + email="madhava@openmined.org", + note="Madhava tweaked the description to add the URL because Andrew forgot.", + ) + ctf = sy.Asset(name="canada_trade_flow") + ctf.set_description( + "Canada trade flow represents export & import of different commodities to other countries" + ) + ctf.add_contributor( + name="Andrew Trask", + email="andrew@openmined.org", + note="Andrew runs this datasite and prepared the asset.", + ) + ctf.set_obj(ca_data) + ctf.set_shape(ca_data.shape) + ctf.set_mock(mock_ca_data, mock_is_real=False) + dataset.add_asset(ctf) + return dataset + + +def dataset_exists(root_client, dataset_name: str) -> bool: + datasets = root_client.api.services.dataset + for dataset in datasets: + if dataset.name == dataset_name: + return True + return False + + +def upload_dataset(user_client, dataset): + if not dataset_exists(user_client, dataset): + user_client.upload_dataset(dataset) + else: + print("Dataset already exists") + + +@dataclass +class TestUser: + name: str + email: str + password: str + role: ServiceRole + server_cache: Any | None = None + + def client(self, server=None): + if server is None: + server = self.server_cache + else: + self.server_cache = server + + return server.login(email=self.email, password=self.password) + + +def user_exists(root_client, email: str) -> bool: + users = root_client.api.services.user + for user in users: + if user.email == email: + return True + return False + + +def make_user( + name: str | None = None, + email: str | None = None, + password: str | None = None, + role: ServiceRole = ServiceRole.DATA_SCIENTIST, +): + fake = Faker() + if name is None: + name = fake.name() + if email is None: + email = fake.email() + if password is None: + password = fake.password() + + return TestUser(name=name, email=email, password=password, role=role) + + +def make_admin(email="info@openmined.org", password="changethis"): + fake = Faker() + return make_user( + email=email, password=password, name=fake.name(), role=ServiceRole.ADMIN + ) + + +def create_user(root_client, test_user): + if not user_exists(root_client, test_user.email): + fake = Faker() + root_client.register( + name=test_user.name, + email=test_user.email, + password=test_user.password, + password_verify=test_user.password, + institution=fake.company(), + website=fake.url(), + ) + else: + print("User already exists", test_user) + + +@dataclass +class TestEvent: + name: str + event_time: float = field(default_factory=lambda: time.time()) + + def __post_init__(self): + self.event_time = float(self.event_time) + + def __repr__(self): + formatted_time = datetime.fromtimestamp(self.event_time).strftime( + "%Y-%m-%d %H:%M:%S" + ) + return f"TestEvent(name={self.name}, event_time={formatted_time})" + + +class TestEventManager: + def __init__(self, file_path: str = "events.json"): + self.file_path = file_path + self._load_events() + + def _load_events(self): + if os.path.exists(self.file_path): + with open(self.file_path) as f: + self.data = json.load(f) + else: + self.data = {"events": {}} + + def _save_events(self): + with open(self.file_path, "w") as f: + json.dump(self.data, f, indent=4) + + def reset_test_state(self): + """Resets the state by clearing all events and saving the empty structure.""" + self.data = {"events": {}} + self._save_events() + + def register_event(self, event: TestEvent | str): + if isinstance(event, str): + event = TestEvent(name=event) + """Registers a new event, adding it to the file.""" + if event.name not in self.data["events"]: + self.data["events"][event.name] = asdict(event) + self._save_events() + else: + print( + f"Event '{event.name}' already exists. Use register_event_once or reset_test_state first." + ) + + def register_event_once(self, event: TestEvent): + """Registers the event only if it does not already exist. Raises an error if it does.""" + if event.name in self.data["events"]: + raise ValueError(f"Event '{event.name}' already exists.") + print(f"Event: {event}") + self.register_event(event) + + def get_event(self, event_name: str) -> TestEvent | None: + """Retrieves an event by name. Returns None if it does not exist.""" + event_data = self.data["events"].get(event_name) + if event_data: + return TestEvent(**event_data) + return None + + def get_event_or_raise(self, event_name: str) -> TestEvent: + """Retrieves an event by name. Returns None if it does not exist.""" + event_data = self.data["events"].get(event_name) + if event_data: + return TestEvent(**event_data) + raise Exception(f"No event: {event_name}") + + +def asyncit(func: Callable, *args, **kwargs): + print("Got kwargs", kwargs.keys()) + """Wrap a non-async function to run in the background as an asyncio task.""" + + async def async_func(*args, **kwargs): + # Run the function in a background thread using asyncio.to_thread + try: + return await asyncio.to_thread(func, *args, **kwargs) + except Exception as e: + print(f"An error occurred in asyncit: {e}") + + loop = get_or_create_event_loop() + + # Schedule the async function to run as a background task + return loop.create_task(async_func(*args, **kwargs)) + + +class AsyncWaitForEvent: + def __init__( + self, + event_manager, + event_name: str, + retry_secs: int = 1, + timeout_secs: int = None, + ): + self.event_manager = event_manager + self.event_name = event_name + self.retry_secs = retry_secs + self.timeout_secs = timeout_secs + self.event_occurred = asyncio.Event() + + async def _event_waiter(self): + """Internal method that runs asynchronously to wait for the event.""" + elapsed_time = 0 + while not self.event_manager.get_event(self.event_name): + await asyncio.sleep(self.retry_secs) + elapsed_time += self.retry_secs + if self.timeout_secs is not None and elapsed_time >= self.timeout_secs: + break + self.event_occurred.set() + + async def __aenter__(self): + """Starts the event waiter task and waits for the event to occur before returning.""" + self.waiter_task = get_or_create_event_loop().create_task(self._event_waiter()) + await self.event_occurred.wait() + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + """Ensure the event waiter task completes before exiting the context.""" + await self.waiter_task + + +class WaitForEvent: + def __init__( + self, + event_manager, + event_name: str, + retry_secs: int = 1, + timeout_secs: int = None, + ): + self.event_manager = event_manager + self.event_name = event_name + self.retry_secs = retry_secs + self.timeout_secs = timeout_secs + self.event_occurred = ThreadingEvent() + + def _event_waiter(self): + """Internal method that runs synchronously to wait for the event.""" + elapsed_time = 0 + while not self.event_manager.get_event(self.event_name): + time.sleep(self.retry_secs) + elapsed_time += self.retry_secs + if self.timeout_secs is not None and elapsed_time >= self.timeout_secs: + break + self.event_occurred.set() + + def __enter__(self): + """Starts the event waiter and waits for the event to occur.""" + self._event_waiter() + return self + + def __exit__(self, exc_type, exc_value, traceback): + """Nothing specific to do for synchronous exit.""" + pass diff --git a/tests/scenario/bigquery/prototype.ipynb b/tests/scenario/bigquery/prototype.ipynb new file mode 100644 index 00000000000..d68a8cb4f39 --- /dev/null +++ b/tests/scenario/bigquery/prototype.ipynb @@ -0,0 +1,277 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# third party\n", + "from faker import Faker\n", + "from fixtures import *" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "manager = TestEventManager()\n", + "manager.reset_test_state()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "USERS_CREATED = \"users_created\"\n", + "MOCK_READABLE = \"mock_readable\"" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "fake = Faker()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "TestUser(name='Eduardo Edwards', email='info@openmined.org', password='changethis', role=, server_cache=None)" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "admin = make_admin()\n", + "admin" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test Complete\n", + "None\n" + ] + }, + { + "ename": "AssertionError", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m--------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[7], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m result \u001b[38;5;241m=\u001b[39m manager\u001b[38;5;241m.\u001b[39mget_event(MOCK_READABLE)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(result)\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m result\n", + "\u001b[0;31mAssertionError\u001b[0m: " + ] + } + ], + "source": [ + "async with AsyncWaitForEvent(manager, MOCK_READABLE, retry_secs=1, timeout_secs=10):\n", + " print(\"Test Complete\")\n", + " result = manager.get_event(MOCK_READABLE)\n", + " print(result)\n", + " assert result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "server = make_server(admin)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "root_client = admin.client(server)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_name = fake.name()\n", + "dataset = create_dataset(name=dataset_name)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "upload_dataset(root_client, dataset)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "users = [make_user() for i in range(2)]\n", + "users" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def create_users(root_client, manager, users):\n", + " for test_user in users:\n", + " create_user(root_client, test_user)\n", + " manager.register_event(USERS_CREATED)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def user_can_read_mock_dataset(server, manager, user, dataset_name):\n", + " print(\"waiting \", USERS_CREATED)\n", + " with WaitForEvent(manager, USERS_CREATED, retry_secs=1):\n", + " print(\"logging in user\")\n", + " user_client = user.client(server)\n", + " print(\"getting dataset\", dataset_name)\n", + " mock = user_client.api.services.dataset[dataset_name].assets[0].mock\n", + " df = trade_flow_df_mock(trade_flow_df())\n", + " assert df.equals(mock)\n", + " manager.register_event(MOCK_READABLE)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "user = users[0]\n", + "user" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "asyncit(\n", + " user_can_read_mock_dataset,\n", + " server=server,\n", + " manager=manager,\n", + " user=user,\n", + " dataset_name=dataset_name,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "asyncit(create_users, root_client=root_client, manager=manager, users=users)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with WaitForEvent(manager, MOCK_READABLE, retry_secs=1, timeout_secs=10):\n", + " print(\"Test Complete\")\n", + " result = manager.get_event(MOCK_READABLE)\n", + " print(result)\n", + " assert result" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.2" + }, + "toc": { + "base_numbering": 1, + "nav_menu": {}, + "number_sections": true, + "sideBar": true, + "skip_h1_title": false, + "title_cell": "Table of Contents", + "title_sidebar": "Contents", + "toc_cell": false, + "toc_position": {}, + "toc_section_display": true, + "toc_window_display": true + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} From 168e12e534baadfdb0f9ff6d49bbd82cb16b0e74 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Thu, 22 Aug 2024 17:40:31 +1000 Subject: [PATCH 02/13] Tweaks --- tests/scenario/bigquery/1_setup_test.py | 49 ++++++++++++------------- tests/scenario/bigquery/events.json | 9 ++++- tests/scenario/bigquery/fixtures.py | 10 +---- 3 files changed, 32 insertions(+), 36 deletions(-) diff --git a/tests/scenario/bigquery/1_setup_test.py b/tests/scenario/bigquery/1_setup_test.py index e23a8e3fed9..0e1048060cd 100644 --- a/tests/scenario/bigquery/1_setup_test.py +++ b/tests/scenario/bigquery/1_setup_test.py @@ -22,7 +22,7 @@ async def hello_world(): @pytest.mark.asyncio -async def run_mock_dataframe_scenario(manager, set_event: bool = True): +async def run_mock_dataframe_scenario(manager, admin, server, set_event: bool = True): manager.reset_test_state() USERS_CREATED = "users_created" @@ -30,10 +30,6 @@ async def run_mock_dataframe_scenario(manager, set_event: bool = True): fake = Faker() - admin = make_admin() - - server = make_server(admin) - root_client = admin.client(server) dataset_name = fake.name() @@ -61,6 +57,7 @@ def user_can_read_mock_dataset(server, manager, user, dataset_name): df = trade_flow_df_mock(trade_flow_df()) assert df.equals(mock) if set_event: + print("REGISTERING EVENT", MOCK_READABLE) manager.register_event(MOCK_READABLE) user = users[0] @@ -75,34 +72,34 @@ def user_can_read_mock_dataset(server, manager, user, dataset_name): asyncit(create_users, root_client=root_client, manager=manager, users=users) - server.land() - @pytest.mark.asyncio -async def test_can_read_mock_dataframe(): - manager = TestEventManager() +async def test_can_read_mock_dataframe(request): + manager = TestEventManager(test_name=request.node.name) + admin = make_admin() + server = make_server(admin) MOCK_READABLE = "mock_readable" - await run_mock_dataframe_scenario(manager) - - async with AsyncWaitForEvent(manager, MOCK_READABLE, retry_secs=1, timeout_secs=10): + await run_mock_dataframe_scenario(manager, admin, server) + async with AsyncWaitForEvent(manager, MOCK_READABLE, retry_secs=1, timeout_secs=15): print("Test Complete") result = manager.get_event_or_raise(MOCK_READABLE) assert result - + server.land() loop = asyncio.get_event_loop() loop.stop() -# @pytest.mark.asyncio -# async def test_cant_read_mock_dataframe(): -# manager = TestEventManager() -# MOCK_READABLE = "mock_readable" -# await run_mock_dataframe_scenario(manager, set_event=False) - -# async with AsyncWaitForEvent(manager, MOCK_READABLE, retry_secs=1, timeout_secs=10): -# print("Test Complete") -# with pytest.raises(Exception): -# result = manager.get_event_or_raise(MOCK_READABLE) - -# loop = asyncio.get_event_loop() -# loop.stop() +@pytest.mark.asyncio +async def test_cant_read_mock_dataframe(request): + manager = TestEventManager(test_name=request.node.name) + admin = make_admin() + server = make_server(admin) + MOCK_READABLE = "mock_readable" + await run_mock_dataframe_scenario(manager, admin, server, set_event=False) + async with AsyncWaitForEvent(manager, MOCK_READABLE, retry_secs=1, timeout_secs=15): + print("Test Complete") + with pytest.raises(Exception): + result = manager.get_event_or_raise(MOCK_READABLE) + server.land() + loop = asyncio.get_event_loop() + loop.stop() diff --git a/tests/scenario/bigquery/events.json b/tests/scenario/bigquery/events.json index dd3bd15a160..051ff52246d 100644 --- a/tests/scenario/bigquery/events.json +++ b/tests/scenario/bigquery/events.json @@ -1,3 +1,8 @@ { - "events": {} -} + "events": { + "users_created": { + "name": "users_created", + "event_time": 1724311872.780498 + } + } +} \ No newline at end of file diff --git a/tests/scenario/bigquery/fixtures.py b/tests/scenario/bigquery/fixtures.py index ae3fec3e7bb..fa0e5226b81 100644 --- a/tests/scenario/bigquery/fixtures.py +++ b/tests/scenario/bigquery/fixtures.py @@ -189,8 +189,8 @@ def __repr__(self): class TestEventManager: - def __init__(self, file_path: str = "events.json"): - self.file_path = file_path + def __init__(self, test_name: str): + self.file_path = f"events_{test_name}.json" self._load_events() def _load_events(self): @@ -205,14 +205,12 @@ def _save_events(self): json.dump(self.data, f, indent=4) def reset_test_state(self): - """Resets the state by clearing all events and saving the empty structure.""" self.data = {"events": {}} self._save_events() def register_event(self, event: TestEvent | str): if isinstance(event, str): event = TestEvent(name=event) - """Registers a new event, adding it to the file.""" if event.name not in self.data["events"]: self.data["events"][event.name] = asdict(event) self._save_events() @@ -222,21 +220,17 @@ def register_event(self, event: TestEvent | str): ) def register_event_once(self, event: TestEvent): - """Registers the event only if it does not already exist. Raises an error if it does.""" if event.name in self.data["events"]: raise ValueError(f"Event '{event.name}' already exists.") - print(f"Event: {event}") self.register_event(event) def get_event(self, event_name: str) -> TestEvent | None: - """Retrieves an event by name. Returns None if it does not exist.""" event_data = self.data["events"].get(event_name) if event_data: return TestEvent(**event_data) return None def get_event_or_raise(self, event_name: str) -> TestEvent: - """Retrieves an event by name. Returns None if it does not exist.""" event_data = self.data["events"].get(event_name) if event_data: return TestEvent(**event_data) From 9754998899876d1421a0885b4e74a38e0a7868f3 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Fri, 23 Aug 2024 18:35:51 +1000 Subject: [PATCH 03/13] Refactored scenario testing to be better --- packages/syft/setup.cfg | 4 + tests/scenario/bigquery/1_setup_test.py | 116 +++----- tests/scenario/bigquery/api.py | 10 + tests/scenario/bigquery/asserts.py | 21 ++ tests/scenario/bigquery/events.json | 8 - tests/scenario/bigquery/events.py | 35 +++ tests/scenario/bigquery/fixtures.py | 323 ----------------------- tests/scenario/bigquery/fixtures_sync.py | 133 ++++++++++ tests/scenario/bigquery/make.py | 11 + tests/scenario/bigquery/partials.py | 23 ++ tests/scenario/bigquery/story.py | 20 ++ tests/scenario/bigquery/users.py | 23 ++ tox.ini | 13 + 13 files changed, 326 insertions(+), 414 deletions(-) create mode 100644 tests/scenario/bigquery/api.py create mode 100644 tests/scenario/bigquery/asserts.py delete mode 100644 tests/scenario/bigquery/events.json create mode 100644 tests/scenario/bigquery/events.py delete mode 100644 tests/scenario/bigquery/fixtures.py create mode 100644 tests/scenario/bigquery/fixtures_sync.py create mode 100644 tests/scenario/bigquery/make.py create mode 100644 tests/scenario/bigquery/partials.py create mode 100644 tests/scenario/bigquery/story.py create mode 100644 tests/scenario/bigquery/users.py diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 64221e223b1..e6b25b2efbc 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -122,6 +122,10 @@ test_plugins = pytest-sugar pytest-lazy-fixture pytest-rerunfailures + pytest-asyncio + pytest-timeout + anyio + unsync coverage faker distro diff --git a/tests/scenario/bigquery/1_setup_test.py b/tests/scenario/bigquery/1_setup_test.py index 0e1048060cd..b39f5449af2 100644 --- a/tests/scenario/bigquery/1_setup_test.py +++ b/tests/scenario/bigquery/1_setup_test.py @@ -1,105 +1,55 @@ -# pip install pytest-asyncio pytest-timeout -# stdlib -import asyncio - # third party +from api import get_datasets +from asserts import has +from events import EVENT_DATASET_MOCK_READABLE +from events import EVENT_DATASET_UPLOADED +from events import EVENT_USER_ADMIN_CREATED +from events import EventManager from faker import Faker -from fixtures import * +from fixtures_sync import create_dataset +from fixtures_sync import make_admin +from fixtures_sync import make_server +from fixtures_sync import make_user +from fixtures_sync import upload_dataset +from make import create_users +from partials import with_client import pytest - - -# An async function that returns "Hello, World!" -async def hello_world(): - await asyncio.sleep(1) # Simulate some async work - return "Hello, World!" - - -# # An async test function using pytest-asyncio -# @pytest.mark.asyncio -# async def test_hello_world(): -# result = await hello_world() -# assert result == "Hello, World!" +from story import user_can_read_mock_dataset @pytest.mark.asyncio -async def run_mock_dataframe_scenario(manager, admin, server, set_event: bool = True): - manager.reset_test_state() +async def test_create_dataset_and_read_mock(request): + events = EventManager() + server = make_server(request) + + dataset_get_all = with_client(get_datasets, server) - USERS_CREATED = "users_created" - MOCK_READABLE = "mock_readable" + assert dataset_get_all() == 0 fake = Faker() + admin = make_admin() + events.register(EVENT_USER_ADMIN_CREATED) root_client = admin.client(server) - dataset_name = fake.name() dataset = create_dataset(name=dataset_name) - result = await hello_world() - assert result == "Hello, World!" - upload_dataset(root_client, dataset) - users = [make_user() for i in range(2)] - - def create_users(root_client, manager, users): - for test_user in users: - create_user(root_client, test_user) - manager.register_event(USERS_CREATED) + events.register(EVENT_DATASET_UPLOADED) - def user_can_read_mock_dataset(server, manager, user, dataset_name): - print("waiting ", USERS_CREATED) - with WaitForEvent(manager, USERS_CREATED, retry_secs=1): - print("logging in user") - user_client = user.client(server) - print("getting dataset", dataset_name) - mock = user_client.api.services.dataset[dataset_name].assets[0].mock - df = trade_flow_df_mock(trade_flow_df()) - assert df.equals(mock) - if set_event: - print("REGISTERING EVENT", MOCK_READABLE) - manager.register_event(MOCK_READABLE) + users = [make_user() for i in range(2)] user = users[0] - asyncit( - user_can_read_mock_dataset, - server=server, - manager=manager, - user=user, - dataset_name=dataset_name, - ) - - asyncit(create_users, root_client=root_client, manager=manager, users=users) - - -@pytest.mark.asyncio -async def test_can_read_mock_dataframe(request): - manager = TestEventManager(test_name=request.node.name) - admin = make_admin() - server = make_server(admin) - MOCK_READABLE = "mock_readable" - await run_mock_dataframe_scenario(manager, admin, server) - async with AsyncWaitForEvent(manager, MOCK_READABLE, retry_secs=1, timeout_secs=15): - print("Test Complete") - result = manager.get_event_or_raise(MOCK_READABLE) - assert result - server.land() - loop = asyncio.get_event_loop() - loop.stop() + user_can_read_mock_dataset(server, events, user, dataset_name) + create_users(root_client, events, users) + await has( + lambda: dataset_get_all() == 1, + "1 Dataset", + timeout=15, + retry=1, + ) -@pytest.mark.asyncio -async def test_cant_read_mock_dataframe(request): - manager = TestEventManager(test_name=request.node.name) - admin = make_admin() - server = make_server(admin) - MOCK_READABLE = "mock_readable" - await run_mock_dataframe_scenario(manager, admin, server, set_event=False) - async with AsyncWaitForEvent(manager, MOCK_READABLE, retry_secs=1, timeout_secs=15): - print("Test Complete") - with pytest.raises(Exception): - result = manager.get_event_or_raise(MOCK_READABLE) - server.land() - loop = asyncio.get_event_loop() - loop.stop() + await events.wait_for(event_name=EVENT_DATASET_MOCK_READABLE) diff --git a/tests/scenario/bigquery/api.py b/tests/scenario/bigquery/api.py new file mode 100644 index 00000000000..4343a603ff1 --- /dev/null +++ b/tests/scenario/bigquery/api.py @@ -0,0 +1,10 @@ +# third party +from unsync import unsync + + +@unsync +def get_datasets(client): + print("Checking datasets") + num_datasets = len(client.api.services.dataset.get_all()) + print(">>> num datasets", num_datasets) + return num_datasets diff --git a/tests/scenario/bigquery/asserts.py b/tests/scenario/bigquery/asserts.py new file mode 100644 index 00000000000..f71549352ea --- /dev/null +++ b/tests/scenario/bigquery/asserts.py @@ -0,0 +1,21 @@ +# stdlib +import inspect + +# third party +import anyio + + +class FailedAssert(Exception): + pass + + +async def has(expr, expects="", timeout=10, retry=1): + try: + with anyio.fail_after(timeout): + result = expr() + while not result: + print(f"> {expects} {expr}...not yet satisfied") + await anyio.sleep(retry) + except TimeoutError: + lambda_source = inspect.getsource(expr) + raise FailedAssert(f"{lambda_source} {expects}") diff --git a/tests/scenario/bigquery/events.json b/tests/scenario/bigquery/events.json deleted file mode 100644 index 051ff52246d..00000000000 --- a/tests/scenario/bigquery/events.json +++ /dev/null @@ -1,8 +0,0 @@ -{ - "events": { - "users_created": { - "name": "users_created", - "event_time": 1724311872.780498 - } - } -} \ No newline at end of file diff --git a/tests/scenario/bigquery/events.py b/tests/scenario/bigquery/events.py new file mode 100644 index 00000000000..42f761302c7 --- /dev/null +++ b/tests/scenario/bigquery/events.py @@ -0,0 +1,35 @@ +# third party +import anyio + +EVENT_USER_ADMIN_CREATED = "user_admin_created" +EVENT_USERS_CREATED = "users_created" +EVENT_DATASET_UPLOADED = "dataset_uploaded" +EVENT_DATASET_MOCK_READABLE = "dataset_mock_readable" + + +class EventManager: + def __init__(self): + self.events = {} + self.event_waiters = {} + + def register(self, event_name: str): + self.events[event_name] = anyio.Event() + waiters = self.event_waiters.get(event_name, []) + for waiter in waiters: + waiter.set() + + async def wait_for(self, event_name: str, timeout: float = 15.0): + if event_name in self.events: + return self.events[event_name] + + waiter = anyio.Event() + self.event_waiters.setdefault(event_name, []).append(waiter) + + try: + with anyio.move_on_after(timeout) as cancel_scope: + await waiter.wait() + if cancel_scope.cancel_called: + raise TimeoutError(f"Timeout waiting for event: {event_name}") + return self.events[event_name] + finally: + self.event_waiters[event_name].remove(waiter) diff --git a/tests/scenario/bigquery/fixtures.py b/tests/scenario/bigquery/fixtures.py deleted file mode 100644 index fa0e5226b81..00000000000 --- a/tests/scenario/bigquery/fixtures.py +++ /dev/null @@ -1,323 +0,0 @@ -# stdlib -import asyncio -from collections.abc import Callable -from dataclasses import asdict -from dataclasses import dataclass -from dataclasses import field -from datetime import datetime -import json -import os -from threading import Event as ThreadingEvent -import time -from typing import Any - -# third party -from faker import Faker -import pandas as pd - -# syft absolute -import syft as sy -from syft import autocache -from syft.service.user.user_roles import ServiceRole - -loop = None - - -def get_or_create_event_loop(): - try: - # Try to get the current running event loop - loop = asyncio.get_running_loop() - except RuntimeError: - # No running event loop, so create a new one - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - return loop - - -loop = get_or_create_event_loop() -loop.set_debug(True) - - -def make_server(test_user_admin): - server = sy.orchestra.launch( - name="test-datasite-1", port="auto", dev_mode=True, reset=True - ) - return server - - -def get_root_client(server, test_user_admin): - return server.login(email=test_user_admin.email, password=test_user_admin.password) - - -def trade_flow_df(): - canada_dataset_url = "https://github.com/OpenMined/datasets/blob/main/trade_flow/ca%20-%20feb%202021.csv?raw=True" - df = pd.read_csv(autocache(canada_dataset_url)) - return df - - -def trade_flow_df_mock(df): - return df[10:20] - - -def create_dataset(name: str): - df = trade_flow_df() - ca_data = df[0:10] - mock_ca_data = trade_flow_df_mock(df) - dataset = sy.Dataset(name=name) - dataset.set_description("Canada Trade Data Markdown Description") - dataset.set_summary("Canada Trade Data Short Summary") - dataset.add_citation("Person, place or thing") - dataset.add_url("https://github.com/OpenMined/datasets/tree/main/trade_flow") - dataset.add_contributor( - name="Andrew Trask", - email="andrew@openmined.org", - note="Andrew runs this datasite and prepared the dataset metadata.", - ) - dataset.add_contributor( - name="Madhava Jay", - email="madhava@openmined.org", - note="Madhava tweaked the description to add the URL because Andrew forgot.", - ) - ctf = sy.Asset(name="canada_trade_flow") - ctf.set_description( - "Canada trade flow represents export & import of different commodities to other countries" - ) - ctf.add_contributor( - name="Andrew Trask", - email="andrew@openmined.org", - note="Andrew runs this datasite and prepared the asset.", - ) - ctf.set_obj(ca_data) - ctf.set_shape(ca_data.shape) - ctf.set_mock(mock_ca_data, mock_is_real=False) - dataset.add_asset(ctf) - return dataset - - -def dataset_exists(root_client, dataset_name: str) -> bool: - datasets = root_client.api.services.dataset - for dataset in datasets: - if dataset.name == dataset_name: - return True - return False - - -def upload_dataset(user_client, dataset): - if not dataset_exists(user_client, dataset): - user_client.upload_dataset(dataset) - else: - print("Dataset already exists") - - -@dataclass -class TestUser: - name: str - email: str - password: str - role: ServiceRole - server_cache: Any | None = None - - def client(self, server=None): - if server is None: - server = self.server_cache - else: - self.server_cache = server - - return server.login(email=self.email, password=self.password) - - -def user_exists(root_client, email: str) -> bool: - users = root_client.api.services.user - for user in users: - if user.email == email: - return True - return False - - -def make_user( - name: str | None = None, - email: str | None = None, - password: str | None = None, - role: ServiceRole = ServiceRole.DATA_SCIENTIST, -): - fake = Faker() - if name is None: - name = fake.name() - if email is None: - email = fake.email() - if password is None: - password = fake.password() - - return TestUser(name=name, email=email, password=password, role=role) - - -def make_admin(email="info@openmined.org", password="changethis"): - fake = Faker() - return make_user( - email=email, password=password, name=fake.name(), role=ServiceRole.ADMIN - ) - - -def create_user(root_client, test_user): - if not user_exists(root_client, test_user.email): - fake = Faker() - root_client.register( - name=test_user.name, - email=test_user.email, - password=test_user.password, - password_verify=test_user.password, - institution=fake.company(), - website=fake.url(), - ) - else: - print("User already exists", test_user) - - -@dataclass -class TestEvent: - name: str - event_time: float = field(default_factory=lambda: time.time()) - - def __post_init__(self): - self.event_time = float(self.event_time) - - def __repr__(self): - formatted_time = datetime.fromtimestamp(self.event_time).strftime( - "%Y-%m-%d %H:%M:%S" - ) - return f"TestEvent(name={self.name}, event_time={formatted_time})" - - -class TestEventManager: - def __init__(self, test_name: str): - self.file_path = f"events_{test_name}.json" - self._load_events() - - def _load_events(self): - if os.path.exists(self.file_path): - with open(self.file_path) as f: - self.data = json.load(f) - else: - self.data = {"events": {}} - - def _save_events(self): - with open(self.file_path, "w") as f: - json.dump(self.data, f, indent=4) - - def reset_test_state(self): - self.data = {"events": {}} - self._save_events() - - def register_event(self, event: TestEvent | str): - if isinstance(event, str): - event = TestEvent(name=event) - if event.name not in self.data["events"]: - self.data["events"][event.name] = asdict(event) - self._save_events() - else: - print( - f"Event '{event.name}' already exists. Use register_event_once or reset_test_state first." - ) - - def register_event_once(self, event: TestEvent): - if event.name in self.data["events"]: - raise ValueError(f"Event '{event.name}' already exists.") - self.register_event(event) - - def get_event(self, event_name: str) -> TestEvent | None: - event_data = self.data["events"].get(event_name) - if event_data: - return TestEvent(**event_data) - return None - - def get_event_or_raise(self, event_name: str) -> TestEvent: - event_data = self.data["events"].get(event_name) - if event_data: - return TestEvent(**event_data) - raise Exception(f"No event: {event_name}") - - -def asyncit(func: Callable, *args, **kwargs): - print("Got kwargs", kwargs.keys()) - """Wrap a non-async function to run in the background as an asyncio task.""" - - async def async_func(*args, **kwargs): - # Run the function in a background thread using asyncio.to_thread - try: - return await asyncio.to_thread(func, *args, **kwargs) - except Exception as e: - print(f"An error occurred in asyncit: {e}") - - loop = get_or_create_event_loop() - - # Schedule the async function to run as a background task - return loop.create_task(async_func(*args, **kwargs)) - - -class AsyncWaitForEvent: - def __init__( - self, - event_manager, - event_name: str, - retry_secs: int = 1, - timeout_secs: int = None, - ): - self.event_manager = event_manager - self.event_name = event_name - self.retry_secs = retry_secs - self.timeout_secs = timeout_secs - self.event_occurred = asyncio.Event() - - async def _event_waiter(self): - """Internal method that runs asynchronously to wait for the event.""" - elapsed_time = 0 - while not self.event_manager.get_event(self.event_name): - await asyncio.sleep(self.retry_secs) - elapsed_time += self.retry_secs - if self.timeout_secs is not None and elapsed_time >= self.timeout_secs: - break - self.event_occurred.set() - - async def __aenter__(self): - """Starts the event waiter task and waits for the event to occur before returning.""" - self.waiter_task = get_or_create_event_loop().create_task(self._event_waiter()) - await self.event_occurred.wait() - return self - - async def __aexit__(self, exc_type, exc_value, traceback): - """Ensure the event waiter task completes before exiting the context.""" - await self.waiter_task - - -class WaitForEvent: - def __init__( - self, - event_manager, - event_name: str, - retry_secs: int = 1, - timeout_secs: int = None, - ): - self.event_manager = event_manager - self.event_name = event_name - self.retry_secs = retry_secs - self.timeout_secs = timeout_secs - self.event_occurred = ThreadingEvent() - - def _event_waiter(self): - """Internal method that runs synchronously to wait for the event.""" - elapsed_time = 0 - while not self.event_manager.get_event(self.event_name): - time.sleep(self.retry_secs) - elapsed_time += self.retry_secs - if self.timeout_secs is not None and elapsed_time >= self.timeout_secs: - break - self.event_occurred.set() - - def __enter__(self): - """Starts the event waiter and waits for the event to occur.""" - self._event_waiter() - return self - - def __exit__(self, exc_type, exc_value, traceback): - """Nothing specific to do for synchronous exit.""" - pass diff --git a/tests/scenario/bigquery/fixtures_sync.py b/tests/scenario/bigquery/fixtures_sync.py new file mode 100644 index 00000000000..2f25eeb980f --- /dev/null +++ b/tests/scenario/bigquery/fixtures_sync.py @@ -0,0 +1,133 @@ +# stdlib +from typing import Any + +# third party +from faker import Faker +import pandas as pd +from users import TestUser + +# syft absolute +import syft as sy +from syft import autocache +from syft.service.user.user_roles import ServiceRole + + +def make_user( + name: str | None = None, + email: str | None = None, + password: str | None = None, + role: ServiceRole = ServiceRole.DATA_SCIENTIST, +): + fake = Faker() + if name is None: + name = fake.name() + if email is None: + email = fake.email() + if password is None: + password = fake.password() + + return TestUser(name=name, email=email, password=password, role=role) + + +def make_admin(email="info@openmined.org", password="changethis"): + fake = Faker() + return make_user( + email=email, password=password, name=fake.name(), role=ServiceRole.ADMIN + ) + + +def trade_flow_df(): + canada_dataset_url = "https://github.com/OpenMined/datasets/blob/main/trade_flow/ca%20-%20feb%202021.csv?raw=True" + df = pd.read_csv(autocache(canada_dataset_url)) + return df + + +def trade_flow_df_mock(df): + return df[10:20] + + +def user_exists(root_client, email: str) -> bool: + users = root_client.api.services.user + for user in users: + if user.email == email: + return True + return False + + +def create_user(root_client, test_user): + if not user_exists(root_client, test_user.email): + fake = Faker() + root_client.register( + name=test_user.name, + email=test_user.email, + password=test_user.password, + password_verify=test_user.password, + institution=fake.company(), + website=fake.url(), + ) + else: + print("User already exists", test_user) + + +def dataset_exists(root_client, dataset_name: str) -> bool: + datasets = root_client.api.services.dataset + for dataset in datasets: + if dataset.name == dataset_name: + return True + return False + + +def upload_dataset(user_client, dataset): + if not dataset_exists(user_client, dataset): + user_client.upload_dataset(dataset) + else: + print("Dataset already exists") + + +def create_dataset(name: str): + df = trade_flow_df() + ca_data = df[0:10] + mock_ca_data = trade_flow_df_mock(df) + dataset = sy.Dataset(name=name) + dataset.set_description("Canada Trade Data Markdown Description") + dataset.set_summary("Canada Trade Data Short Summary") + dataset.add_citation("Person, place or thing") + dataset.add_url("https://github.com/OpenMined/datasets/tree/main/trade_flow") + dataset.add_contributor( + name="Andrew Trask", + email="andrew@openmined.org", + note="Andrew runs this datasite and prepared the dataset metadata.", + ) + dataset.add_contributor( + name="Madhava Jay", + email="madhava@openmined.org", + note="Madhava tweaked the description to add the URL because Andrew forgot.", + ) + ctf = sy.Asset(name="canada_trade_flow") + ctf.set_description( + "Canada trade flow represents export & import of different commodities to other countries" + ) + ctf.add_contributor( + name="Andrew Trask", + email="andrew@openmined.org", + note="Andrew runs this datasite and prepared the asset.", + ) + ctf.set_obj(ca_data) + ctf.set_shape(ca_data.shape) + ctf.set_mock(mock_ca_data, mock_is_real=False) + dataset.add_asset(ctf) + return dataset + + +def make_server(request) -> Any: + print("making server") + server = sy.orchestra.launch( + name="test-datasite-1", port="auto", dev_mode=True, reset=True + ) + + def cleanup(): + print("landing server") + server.land() + + request.addfinalizer(cleanup) + return server diff --git a/tests/scenario/bigquery/make.py b/tests/scenario/bigquery/make.py new file mode 100644 index 00000000000..9d655164aa2 --- /dev/null +++ b/tests/scenario/bigquery/make.py @@ -0,0 +1,11 @@ +# third party +from events import EVENT_USERS_CREATED +from fixtures_sync import create_user +from unsync import unsync + + +@unsync +async def create_users(root_client, events, users): + for test_user in users: + create_user(root_client, test_user) + events.register(EVENT_USERS_CREATED) diff --git a/tests/scenario/bigquery/partials.py b/tests/scenario/bigquery/partials.py new file mode 100644 index 00000000000..e62b4fbd04b --- /dev/null +++ b/tests/scenario/bigquery/partials.py @@ -0,0 +1,23 @@ +# stdlib +from collections.abc import Callable + +# third party +from unsync import Unfuture +from unsync import unsync + +# syft absolute +from syft.client.datasite_client import DatasiteClient +from syft.orchestra import ServerHandle + + +def with_client(func, client: unsync | DatasiteClient | ServerHandle) -> Callable: + if isinstance(client, ServerHandle): + client = client.client + + def with_func(): + result = func(client) + if isinstance(result, Unfuture): + result = result.result() + return result + + return with_func diff --git a/tests/scenario/bigquery/story.py b/tests/scenario/bigquery/story.py new file mode 100644 index 00000000000..0817aee8293 --- /dev/null +++ b/tests/scenario/bigquery/story.py @@ -0,0 +1,20 @@ +# third party +from events import EVENT_DATASET_MOCK_READABLE +from events import EVENT_USERS_CREATED +from fixtures_sync import trade_flow_df +from fixtures_sync import trade_flow_df_mock +from unsync import unsync + + +@unsync +async def user_can_read_mock_dataset(server, events, user, dataset_name): + print("waiting ", EVENT_USERS_CREATED) + await events.wait_for(event_name=EVENT_USERS_CREATED) + user_client = user.client(server) + print("getting dataset", dataset_name) + mock = user_client.api.services.dataset[dataset_name].assets[0].mock + df = trade_flow_df_mock(trade_flow_df()) + print("Are we here?") + if df.equals(mock): + print("REGISTERING EVENT", EVENT_DATASET_MOCK_READABLE) + events.register(EVENT_DATASET_MOCK_READABLE) diff --git a/tests/scenario/bigquery/users.py b/tests/scenario/bigquery/users.py new file mode 100644 index 00000000000..e5d142befce --- /dev/null +++ b/tests/scenario/bigquery/users.py @@ -0,0 +1,23 @@ +# stdlib +from dataclasses import dataclass +from typing import Any + +# syft absolute +from syft.service.user.user_roles import ServiceRole + + +@dataclass +class TestUser: + name: str + email: str + password: str + role: ServiceRole + server_cache: Any | None = None + + def client(self, server=None): + if server is None: + server = self.server_cache + else: + self.server_cache = server + + return server.login(email=self.email, password=self.password) diff --git a/tox.ini b/tox.ini index 110aafd69db..efb121bfe2c 100644 --- a/tox.ini +++ b/tox.ini @@ -17,6 +17,7 @@ envlist = syft.publish syft.test.security syft.test.unit + syft.test.scenario syft.test.notebook single_container.launch single_container.destroy @@ -263,6 +264,18 @@ commands = bash -c 'ulimit -n 4096 || true' pytest -n auto --dist loadgroup --durations=20 --disable-warnings +[testenv:syft.test.scenario] +description = Syft Scenario Tests +deps = + -e{toxinidir}/packages/syft[dev,data_science] +allowlist_externals = + bash + uv +changedir = {toxinidir}/tests/scenario +setenv = +commands = + pytest -s --disable-warnings + [testenv:syft.test.notebook] description = Syft Notebook Tests deps = From 4da9809e927d3fd3bc766bf43ad86c377a8db563 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Mon, 26 Aug 2024 18:07:20 +1000 Subject: [PATCH 04/13] WIP Replicating Bigquery tasks - experimenting with different ways to pass event waiter --- .gitignore | 5 +- packages/syft/src/syft/util/util.py | 32 ++- tests/scenario/bigquery/1_setup_test.py | 55 ---- tests/scenarios/bigquery/1_setup_test.py | 272 ++++++++++++++++++ tests/scenarios/bigquery/a.py | 22 ++ tests/{scenario => scenarios}/bigquery/api.py | 0 .../bigquery/asserts.py | 0 .../bigquery/events.py | 9 + .../bigquery/fixtures_sync.py | 0 .../{scenario => scenarios}/bigquery/make.py | 5 +- .../bigquery/partials.py | 0 .../bigquery/prototype.ipynb | 0 .../{scenario => scenarios}/bigquery/story.py | 0 .../{scenario => scenarios}/bigquery/users.py | 0 tox.ini | 2 +- 15 files changed, 339 insertions(+), 63 deletions(-) delete mode 100644 tests/scenario/bigquery/1_setup_test.py create mode 100644 tests/scenarios/bigquery/1_setup_test.py create mode 100644 tests/scenarios/bigquery/a.py rename tests/{scenario => scenarios}/bigquery/api.py (100%) rename tests/{scenario => scenarios}/bigquery/asserts.py (100%) rename tests/{scenario => scenarios}/bigquery/events.py (79%) rename tests/{scenario => scenarios}/bigquery/fixtures_sync.py (100%) rename tests/{scenario => scenarios}/bigquery/make.py (54%) rename tests/{scenario => scenarios}/bigquery/partials.py (100%) rename tests/{scenario => scenarios}/bigquery/prototype.ipynb (100%) rename tests/{scenario => scenarios}/bigquery/story.py (100%) rename tests/{scenario => scenarios}/bigquery/users.py (100%) diff --git a/.gitignore b/.gitignore index e016a1d9b9b..72d8b65e3ce 100644 --- a/.gitignore +++ b/.gitignore @@ -81,4 +81,7 @@ out.* .git-blame-ignore-revs # migration data -packages/grid/helm/examples/dev/migration.yaml \ No newline at end of file +packages/grid/helm/examples/dev/migration.yaml + +# dynaconf settings file +**/settings.yaml \ No newline at end of file diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index 10a99212655..f63c4266020 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -10,6 +10,7 @@ from contextlib import contextmanager from copy import deepcopy from datetime import datetime +import inspect import functools import hashlib from itertools import chain @@ -39,6 +40,7 @@ from types import ModuleType from typing import Any + # third party from IPython.display import display from forbiddenfruit import curse @@ -1070,6 +1072,19 @@ def get_latest_tag(registry: str, repo: str) -> str | None: return None +def get_caller_file_path() -> str | None: + stack = inspect.stack() + + for frame_info in stack: + code_context = frame_info.code_context + if code_context and len(code_context) > 0: + if "from syft import test_settings" in str(frame_info.code_context): + caller_file_path = os.path.dirname(os.path.abspath(frame_info.filename)) + return caller_file_path + + return None + + def find_base_dir_with_tox_ini(start_path: str = ".") -> str | None: base_path = os.path.abspath(start_path) while True: @@ -1101,8 +1116,19 @@ def test_settings() -> Any: # third party from dynaconf import Dynaconf - base_dir = find_base_dir_with_tox_ini() - config_files = get_all_config_files(base_dir, ".") if base_dir else [] + config_files = [] + current_path = "." + + # jupyter uses "." which resolves to the notebook + if not is_interpreter_jupyter(): + # python uses the file which has from syft import test_settings in it + import_path = get_caller_file_path() + if import_path: + current_path = import_path + + base_dir = find_base_dir_with_tox_ini(current_path) + config_files = get_all_config_files(base_dir, current_path) + config_files = list(reversed(config_files)) # create # can override with # import os @@ -1111,7 +1137,7 @@ def test_settings() -> Any: # Dynaconf settings test_settings = Dynaconf( - settings_files=list(reversed(config_files)), + settings_files=config_files, environments=True, envvar_prefix="TEST", ) diff --git a/tests/scenario/bigquery/1_setup_test.py b/tests/scenario/bigquery/1_setup_test.py deleted file mode 100644 index b39f5449af2..00000000000 --- a/tests/scenario/bigquery/1_setup_test.py +++ /dev/null @@ -1,55 +0,0 @@ -# third party -from api import get_datasets -from asserts import has -from events import EVENT_DATASET_MOCK_READABLE -from events import EVENT_DATASET_UPLOADED -from events import EVENT_USER_ADMIN_CREATED -from events import EventManager -from faker import Faker -from fixtures_sync import create_dataset -from fixtures_sync import make_admin -from fixtures_sync import make_server -from fixtures_sync import make_user -from fixtures_sync import upload_dataset -from make import create_users -from partials import with_client -import pytest -from story import user_can_read_mock_dataset - - -@pytest.mark.asyncio -async def test_create_dataset_and_read_mock(request): - events = EventManager() - server = make_server(request) - - dataset_get_all = with_client(get_datasets, server) - - assert dataset_get_all() == 0 - - fake = Faker() - admin = make_admin() - events.register(EVENT_USER_ADMIN_CREATED) - - root_client = admin.client(server) - dataset_name = fake.name() - dataset = create_dataset(name=dataset_name) - - upload_dataset(root_client, dataset) - - events.register(EVENT_DATASET_UPLOADED) - - users = [make_user() for i in range(2)] - - user = users[0] - - user_can_read_mock_dataset(server, events, user, dataset_name) - create_users(root_client, events, users) - - await has( - lambda: dataset_get_all() == 1, - "1 Dataset", - timeout=15, - retry=1, - ) - - await events.wait_for(event_name=EVENT_DATASET_MOCK_READABLE) diff --git a/tests/scenarios/bigquery/1_setup_test.py b/tests/scenarios/bigquery/1_setup_test.py new file mode 100644 index 00000000000..a5e740c0826 --- /dev/null +++ b/tests/scenarios/bigquery/1_setup_test.py @@ -0,0 +1,272 @@ +# third party +from api import get_datasets +from asserts import has +import asyncio +from events import EVENT_DATASET_MOCK_READABLE +from events import EVENT_DATASET_UPLOADED +from events import EVENT_USER_ADMIN_CREATED +from events import EventManager +from faker import Faker +from fixtures_sync import create_dataset +from fixtures_sync import make_admin +from fixtures_sync import make_server +from fixtures_sync import make_user +from fixtures_sync import upload_dataset +from make import create_users +from partials import with_client +import pytest +from story import user_can_read_mock_dataset +from syft import test_settings +import syft as sy +from asserts import FailedAssert +from unsync import unsync + +from events import EVENT_USERS_CREATED + +EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED = "prebuilt_worker_image_bigquery_created" +EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED = "external_registry_bigquery_created" +EVENT_WORKER_POOL_CREATED = "worker_pool_created" +EVENT_ALLOW_GUEST_SIGNUP_DISABLED = "allow_guest_signup_disabled" +EVENT_USERS_CREATED_CHECKED = "users_created_checked" + + +# dataset stuff +# """ +# dataset_get_all = with_client(get_datasets, server) + +# assert dataset_get_all() == 0 + +# dataset_name = fake.name() +# dataset = create_dataset(name=dataset_name) + +# upload_dataset(root_client, dataset) + +# events.register(EVENT_DATASET_UPLOADED) + +# user_can_read_mock_dataset(server, events, user, dataset_name) + +# await has( +# lambda: dataset_get_all() == 1, +# "1 Dataset", +# timeout=15, +# retry=1, +# ) + +# """ + + +@unsync +def get_prebuilt_worker_image(events, client, expected_tag, event_name): + if events.wait_for(event_name=event_name): + worker_images = client.images.get_all() + for worker_image in worker_images: + if expected_tag in str(worker_image.image_identifier): + assert expected_tag in str(worker_image.image_identifier) + return worker_image + raise FailedAssert(f"get_prebuilt_worker_image cannot find {expected_tag}") + + +async def create_prebuilt_worker_image(events, client, expected_tag, event_name): + print("1") + external_registry = test_settings.get("external_registry", default="docker.io") + print("2") + docker_config = sy.PrebuiltWorkerConfig(tag=f"{external_registry}/{expected_tag}") + print("3") + result = client.api.services.worker_image.submit(worker_config=docker_config) + print("4", result) + assert isinstance(result, sy.SyftSuccess) + events.register(event_name) + print("5", event_name) + + +@unsync +def add_external_registry(events, client, event_name): + external_registry = test_settings.get("external_registry", default="docker.io") + result = client.api.services.image_registry.add(external_registry) + assert isinstance(result, sy.SyftSuccess) + events.register(event_name) + + +@unsync +def create_worker_pool( + events, client, worker_pool_name, worker_pool_result, event_name +): + # block until this is available + worker_image = worker_pool_result.result(timeout=5) + + result = client.api.services.worker_pool.launch( + pool_name=worker_pool_name, + image_uid=worker_image.id, + num_workers=1, + ) + assert isinstance(result, sy.SyftSuccess) + events.register(event_name) + + +async def check_worker_pool_exists(events, client, worker_pool_name, event_name): + timeout = 30 + print("waiting for check_worker_pool_exists", event_name, timeout) + await events.wait_for(event_name=event_name, timeout=timeout) + print("its been 30 seconds, trying to get all worker pools") + pools = client.worker_pools.get_all() + print("pools", len(pools), pools) + for pool in pools: + print("pool name", pool.name) + if worker_pool_name == pool.name: + assert worker_pool_name == pool.name + return worker_pool_name == pool.name + + raise FailedAssert( + f"check_worker_pool_exists cannot find worker_pool_name {worker_pool_name}" + ) + + +def set_settings_allow_guest_signup(events, client, enabled, event_name): + result = client.settings.allow_guest_signup(enable=enabled) + assert isinstance(result, sy.SyftSuccess) + events.register(event_name) + + +async def check_users_created(events, client, users, event_name, event_set): + print("check users created") + expected_emails = {user.email for user in users} + found_emails = set() + print("wait for created event", event_name) + await events.wait_for(event_name=event_name) + print("finished waiting getting all the users") + user_results = client.api.services.user.get_all() + for user_result in user_results: + if user_result.email in expected_emails: + found_emails.add(user_result.email) + + print( + "len(found_emails) == len(expected_emails)", + len(found_emails) == len(expected_emails), + ) + if len(found_emails) == len(expected_emails): + events.register(event_set) + else: + raise FailedAssert( + f"check_users_created only found {len(found_emails)} of {len(expected_emails)} " + f"emails: {found_emails}, {expected_emails}" + ) + + +@pytest.mark.asyncio +async def test_create_dataset_and_read_mock(request): + events = EventManager() + server = make_server(request) + + admin = make_admin() + events.register(EVENT_USER_ADMIN_CREATED) + await events.wait_for(event_name=EVENT_USER_ADMIN_CREATED) + assert events.happened(EVENT_USER_ADMIN_CREATED) + + root_client = admin.client(server) + worker_pool_name = "bigquery-pool" + + worker_docker_tag = f"openmined/bigquery:{sy.__version__}" + print("running create_prebuilt_worker_image") + asyncio.create_task( + create_prebuilt_worker_image( + events, + root_client, + worker_docker_tag, + EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED, + ) + ) + print("finished queing create_prebuilt_worker_image") + + print("waiting...") + # await events.wait_for( + # event_name=EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED, timeout=30 + # ) + # assert await events.happened(EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED) + + # worker_image_result = get_prebuilt_worker_image( + # events, + # root_client, + # worker_docker_tag, + # EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED, + # ) + + # await events.wait_for(event_name=EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED) + # assert events.happened(EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED) + + # add_external_registry(events, root_client, EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED) + + # await events.wait_for(event_name=EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED) + # assert events.happened(EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED) + + # create_worker_pool( + # events, + # root_client, + # worker_pool_name, + # worker_image_result, + # EVENT_WORKER_POOL_CREATED, + # ) + + # await events.wait_for(event_name=EVENT_WORKER_POOL_CREATED) + # assert events.happened(EVENT_WORKER_POOL_CREATED) + + # check_worker_pool_exists( + # events, root_client, worker_pool_name, EVENT_WORKER_POOL_CREATED + # ) + + # await events.wait_for(event_name=EVENT_WORKER_POOL_CREATED) + # assert events.happened(EVENT_WORKER_POOL_CREATED) + + # set_settings_allow_guest_signup( + # events, root_client, False, EVENT_ALLOW_GUEST_SIGNUP_DISABLED + # ) + + # await events.wait_for(event_name=EVENT_ALLOW_GUEST_SIGNUP_DISABLED) + # assert events.happened(EVENT_ALLOW_GUEST_SIGNUP_DISABLED) + + # users = [make_user() for i in range(2)] + # # user = users[0] + + # create_users(root_client, events, users, EVENT_USERS_CREATED) + + # await events.wait_for(event_name=EVENT_USERS_CREATED) + # assert events.happened(EVENT_USERS_CREATED) + + # check_users_created( + # events, root_client, users, EVENT_USERS_CREATED, EVENT_USERS_CREATED_CHECKED + # ) + + # await events.wait_for(event_name=EVENT_USERS_CREATED_CHECKED) + # assert events.happened(EVENT_USERS_CREATED_CHECKED) + + # check users are created + # high_client.api.services.user.get_all() + + # # check_cant_sign_up + + # # create users + + # # create api endpoints + # # check they respond + + # # login as user + # # test queries + # # submit code via api + # # verify its not accessible yet + + # # continuously checking for + # # new untriaged requests + # # executing them locally + # # submitting the results + + # # users get the results + # # continuously checking + # # assert he random number of rows is there + + # # await has( + # # lambda: dataset_get_all() == 1, + # # "1 Dataset", + # # timeout=15, + # # retry=1, + # # ) + + # await events.wait_for(event_name=EVENT_USERS_CREATED_CHECKED, timeout=60) diff --git a/tests/scenarios/bigquery/a.py b/tests/scenarios/bigquery/a.py new file mode 100644 index 00000000000..e8e42185a2c --- /dev/null +++ b/tests/scenarios/bigquery/a.py @@ -0,0 +1,22 @@ +import inspect +import os + + +def get_caller_file_path() -> str | None: + stack = inspect.stack() + print("stack", stack) + + for frame_info in stack: + if "from syft import test_settings" in str(frame_info.code_context): + print(f"File: {frame_info.filename}") + print(f"Line: {frame_info.lineno}") + print(f"Code: {frame_info.code_context[0].strip()}") + caller_file_path = os.path.dirname(os.path.abspath(frame_info.filename)) + print("possible path", caller_file_path) + return caller_file_path + + return None + + +result = get_caller_file_path() +print(result) diff --git a/tests/scenario/bigquery/api.py b/tests/scenarios/bigquery/api.py similarity index 100% rename from tests/scenario/bigquery/api.py rename to tests/scenarios/bigquery/api.py diff --git a/tests/scenario/bigquery/asserts.py b/tests/scenarios/bigquery/asserts.py similarity index 100% rename from tests/scenario/bigquery/asserts.py rename to tests/scenarios/bigquery/asserts.py diff --git a/tests/scenario/bigquery/events.py b/tests/scenarios/bigquery/events.py similarity index 79% rename from tests/scenario/bigquery/events.py rename to tests/scenarios/bigquery/events.py index 42f761302c7..3f8a550c4f6 100644 --- a/tests/scenario/bigquery/events.py +++ b/tests/scenarios/bigquery/events.py @@ -14,6 +14,9 @@ def __init__(self): def register(self, event_name: str): self.events[event_name] = anyio.Event() + # change to two step process to better track event outcomes + self.events[event_name].set() + waiters = self.event_waiters.get(event_name, []) for waiter in waiters: waiter.set() @@ -33,3 +36,9 @@ async def wait_for(self, event_name: str, timeout: float = 15.0): return self.events[event_name] finally: self.event_waiters[event_name].remove(waiter) + + async def happened(self, event_name: str) -> bool: + if event_name in self.events: + event = self.events[event_name] + return event.is_set() + return False diff --git a/tests/scenario/bigquery/fixtures_sync.py b/tests/scenarios/bigquery/fixtures_sync.py similarity index 100% rename from tests/scenario/bigquery/fixtures_sync.py rename to tests/scenarios/bigquery/fixtures_sync.py diff --git a/tests/scenario/bigquery/make.py b/tests/scenarios/bigquery/make.py similarity index 54% rename from tests/scenario/bigquery/make.py rename to tests/scenarios/bigquery/make.py index 9d655164aa2..b05b10af7da 100644 --- a/tests/scenario/bigquery/make.py +++ b/tests/scenarios/bigquery/make.py @@ -1,11 +1,10 @@ # third party -from events import EVENT_USERS_CREATED from fixtures_sync import create_user from unsync import unsync @unsync -async def create_users(root_client, events, users): +async def create_users(root_client, events, users, event_name): for test_user in users: create_user(root_client, test_user) - events.register(EVENT_USERS_CREATED) + events.register(event_name) diff --git a/tests/scenario/bigquery/partials.py b/tests/scenarios/bigquery/partials.py similarity index 100% rename from tests/scenario/bigquery/partials.py rename to tests/scenarios/bigquery/partials.py diff --git a/tests/scenario/bigquery/prototype.ipynb b/tests/scenarios/bigquery/prototype.ipynb similarity index 100% rename from tests/scenario/bigquery/prototype.ipynb rename to tests/scenarios/bigquery/prototype.ipynb diff --git a/tests/scenario/bigquery/story.py b/tests/scenarios/bigquery/story.py similarity index 100% rename from tests/scenario/bigquery/story.py rename to tests/scenarios/bigquery/story.py diff --git a/tests/scenario/bigquery/users.py b/tests/scenarios/bigquery/users.py similarity index 100% rename from tests/scenario/bigquery/users.py rename to tests/scenarios/bigquery/users.py diff --git a/tox.ini b/tox.ini index 9853a628fca..ac515c0b922 100644 --- a/tox.ini +++ b/tox.ini @@ -272,7 +272,7 @@ deps = allowlist_externals = bash uv -changedir = {toxinidir}/tests/scenario +changedir = {toxinidir}/tests/scenarios setenv = commands = pytest -s --disable-warnings From b44161ce38b0bee77bddbe94a9ad605eec6bfe9e Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Tue, 27 Aug 2024 14:06:02 +1000 Subject: [PATCH 05/13] More progress on scenario testing framework --- packages/syft/src/syft/service/api/api.py | 10 +- packages/syft/src/syft/util/util.py | 3 +- tests/.gitignore | 1 + tests/scenarios/bigquery/1_setup_test.py | 368 ++++++++++++++-------- tests/scenarios/bigquery/a.py | 22 -- tests/scenarios/bigquery/asserts.py | 30 ++ tests/scenarios/bigquery/events.py | 247 +++++++++++++-- tests/scenarios/bigquery/fixtures_sync.py | 14 +- tests/scenarios/bigquery/make.py | 214 +++++++++++++ tests/scenarios/bigquery/prototype.ipynb | 277 ---------------- 10 files changed, 716 insertions(+), 470 deletions(-) create mode 100644 tests/.gitignore delete mode 100644 tests/scenarios/bigquery/a.py delete mode 100644 tests/scenarios/bigquery/prototype.ipynb diff --git a/packages/syft/src/syft/service/api/api.py b/packages/syft/src/syft/service/api/api.py index 89e19146e99..2d43d88f188 100644 --- a/packages/syft/src/syft/service/api/api.py +++ b/packages/syft/src/syft/service/api/api.py @@ -7,6 +7,7 @@ import linecache import re import textwrap +from textwrap import dedent from typing import Any from typing import cast @@ -176,6 +177,7 @@ def __repr__(self) -> str: @classmethod def validate_api_code(cls, api_code: str) -> str: valid_code = True + api_code = dedent(api_code) try: ast.parse(api_code) except SyntaxError: @@ -558,7 +560,7 @@ def exec_code( ) else: return SyftError( - message="Ops something went wrong during this endpoint execution, please contact your admin." + message="Oops something went wrong during this endpoint execution, please contact your admin." ) @@ -692,7 +694,8 @@ def api_endpoint( def decorator(f: Callable) -> TwinAPIEndpoint | SyftError: try: helper_functions_dict = { - f.__name__: inspect.getsource(f) for f in (helper_functions or []) + f.__name__: dedent(inspect.getsource(f)) + for f in (helper_functions or []) } res = CreateTwinAPIEndpoint( path=path, @@ -724,7 +727,8 @@ def api_endpoint_method( def decorator(f: Callable) -> Endpoint | SyftError: try: helper_functions_dict = { - f.__name__: inspect.getsource(f) for f in (helper_functions or []) + f.__name__: dedent(inspect.getsource(f)) + for f in (helper_functions or []) } return Endpoint( api_code=inspect.getsource(f), diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index f63c4266020..be3c3f57e41 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -10,9 +10,9 @@ from contextlib import contextmanager from copy import deepcopy from datetime import datetime -import inspect import functools import hashlib +import inspect from itertools import chain from itertools import repeat import json @@ -40,7 +40,6 @@ from types import ModuleType from typing import Any - # third party from IPython.display import display from forbiddenfruit import curse diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 00000000000..48894e3b168 --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +**/*.events \ No newline at end of file diff --git a/tests/scenarios/bigquery/1_setup_test.py b/tests/scenarios/bigquery/1_setup_test.py index a5e740c0826..3f0927b46ab 100644 --- a/tests/scenarios/bigquery/1_setup_test.py +++ b/tests/scenarios/bigquery/1_setup_test.py @@ -1,34 +1,30 @@ -# third party -from api import get_datasets -from asserts import has +# stdlib import asyncio -from events import EVENT_DATASET_MOCK_READABLE -from events import EVENT_DATASET_UPLOADED +import inspect + +# third party +from asserts import ensure_package_installed +from events import EVENT_ALLOW_GUEST_SIGNUP_DISABLED +from events import EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED +from events import EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED +from events import EVENT_USERS_CREATED +from events import EVENT_USERS_CREATED_CHECKED from events import EVENT_USER_ADMIN_CREATED +from events import EVENT_WORKER_POOL_CREATED from events import EventManager +from events import Scenario from faker import Faker -from fixtures_sync import create_dataset from fixtures_sync import make_admin from fixtures_sync import make_server from fixtures_sync import make_user -from fixtures_sync import upload_dataset +from make import create_endpoints_query from make import create_users -from partials import with_client import pytest -from story import user_can_read_mock_dataset -from syft import test_settings -import syft as sy -from asserts import FailedAssert from unsync import unsync -from events import EVENT_USERS_CREATED - -EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED = "prebuilt_worker_image_bigquery_created" -EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED = "external_registry_bigquery_created" -EVENT_WORKER_POOL_CREATED = "worker_pool_created" -EVENT_ALLOW_GUEST_SIGNUP_DISABLED = "allow_guest_signup_disabled" -EVENT_USERS_CREATED_CHECKED = "users_created_checked" - +# syft absolute +import syft as sy +from syft import test_settings # dataset stuff # """ @@ -56,31 +52,29 @@ @unsync -def get_prebuilt_worker_image(events, client, expected_tag, event_name): - if events.wait_for(event_name=event_name): - worker_images = client.images.get_all() - for worker_image in worker_images: - if expected_tag in str(worker_image.image_identifier): - assert expected_tag in str(worker_image.image_identifier) - return worker_image - raise FailedAssert(f"get_prebuilt_worker_image cannot find {expected_tag}") +async def get_prebuilt_worker_image(events, client, expected_tag, event_name): + await events.await_for(event_name=event_name, show=True) + worker_images = client.images.get_all() + for worker_image in worker_images: + if expected_tag in str(worker_image.image_identifier): + assert expected_tag in str(worker_image.image_identifier) + return worker_image + print(f"get_prebuilt_worker_image cannot find {expected_tag}") + # raise FailedAssert() +@unsync async def create_prebuilt_worker_image(events, client, expected_tag, event_name): - print("1") external_registry = test_settings.get("external_registry", default="docker.io") - print("2") docker_config = sy.PrebuiltWorkerConfig(tag=f"{external_registry}/{expected_tag}") - print("3") result = client.api.services.worker_image.submit(worker_config=docker_config) - print("4", result) assert isinstance(result, sy.SyftSuccess) + asyncio.sleep(5) events.register(event_name) - print("5", event_name) @unsync -def add_external_registry(events, client, event_name): +async def add_external_registry(events, client, event_name): external_registry = test_settings.get("external_registry", default="docker.io") result = client.api.services.image_registry.add(external_registry) assert isinstance(result, sy.SyftSuccess) @@ -88,7 +82,7 @@ def add_external_registry(events, client, event_name): @unsync -def create_worker_pool( +async def create_worker_pool( events, client, worker_pool_name, worker_pool_result, event_name ): # block until this is available @@ -99,157 +93,254 @@ def create_worker_pool( image_uid=worker_image.id, num_workers=1, ) - assert isinstance(result, sy.SyftSuccess) - events.register(event_name) + + if isinstance(result, list) and isinstance( + result[0], sy.service.worker.worker_pool.ContainerSpawnStatus + ): + events.register(event_name) + else: + print("bad result", result) +@unsync async def check_worker_pool_exists(events, client, worker_pool_name, event_name): timeout = 30 - print("waiting for check_worker_pool_exists", event_name, timeout) - await events.wait_for(event_name=event_name, timeout=timeout) - print("its been 30 seconds, trying to get all worker pools") + await events.await_for(event_name=event_name, timeout=timeout) pools = client.worker_pools.get_all() - print("pools", len(pools), pools) for pool in pools: - print("pool name", pool.name) if worker_pool_name == pool.name: assert worker_pool_name == pool.name return worker_pool_name == pool.name - raise FailedAssert( - f"check_worker_pool_exists cannot find worker_pool_name {worker_pool_name}" - ) + print(f"check_worker_pool_exists cannot find worker_pool_name {worker_pool_name}") + # raise FailedAssert( + + # ) -def set_settings_allow_guest_signup(events, client, enabled, event_name): +@unsync +async def set_settings_allow_guest_signup( + events, client, enabled, event_name: str | None = None +): result = client.settings.allow_guest_signup(enable=enabled) - assert isinstance(result, sy.SyftSuccess) - events.register(event_name) + if event_name: + if isinstance(result, sy.SyftSuccess): + events.register(event_name) + else: + print("cant set settings alow guest signup") +@unsync async def check_users_created(events, client, users, event_name, event_set): - print("check users created") expected_emails = {user.email for user in users} found_emails = set() - print("wait for created event", event_name) - await events.wait_for(event_name=event_name) - print("finished waiting getting all the users") + await events.await_for(event_name=event_name) user_results = client.api.services.user.get_all() for user_result in user_results: if user_result.email in expected_emails: found_emails.add(user_result.email) - print( - "len(found_emails) == len(expected_emails)", - len(found_emails) == len(expected_emails), - ) if len(found_emails) == len(expected_emails): events.register(event_set) else: - raise FailedAssert( + print( f"check_users_created only found {len(found_emails)} of {len(expected_emails)} " f"emails: {found_emails}, {expected_emails}" ) + # raise FailedAssert() + + +def guest_register(client, test_user): + guest_client = client.guest() + fake = Faker() + result = guest_client.register( + name=test_user.name, + email=test_user.email, + password=test_user.password, + password_verify=test_user.password, + institution=fake.company(), + website=fake.url(), + ) + return result + + +async def result_is( + events, + expr, + matches: bool | type | object, + after: str | None = None, + register: str | None = None, +): + if after: + await events.await_for(event_name=after) + + lambda_source = inspect.getsource(expr) + try: + result = expr() + if isinstance(matches, bool): + assertion = result == matches + if isinstance(matches, type): + assertion = isinstance(result, matches) + else: + if hasattr(result, "message"): + message = result.message.replace("*", "") + assertion = isinstance(result, type(matches)) and message in str(result) + + if assertion and register: + events.register(event_name=register) + return assertion + except Exception as e: + print(f"insinstance({lambda_source}, {matches}). {e}") + + return False + + +@unsync +async def set_endpoint_settings(events, client, path, after: str, register: str): + if after: + await events.await_for(event_name=after) + + # Here, we update the endpoint to timeout after 100s (rather the default of 60s) + result1 = client.api.services.api.update(endpoint_path=path, endpoint_timeout=120) + result2 = client.api.services.api.update( + endpoint_path=path, hide_mock_definition=True + ) + if isinstance(result1, sy.SyftSuccess) and isinstance(result2, sy.SyftSuccess): + events.register(register) + else: + print(f"Failed to update api endpoint. {path}") + + +EVENT_QUERY_ENDPOINT_CREATED = "query_endpoint_created" +EVENT_QUERY_ENDPOINT_CONFIGURED = "query_endpoint_configured" + + +def make_test_query(client, path): + query = f"SELECT {test_settings.table_2_col_id}, AVG({test_settings.table_2_col_score}) AS average_score \ + FROM {test_settings.dataset_2}.{test_settings.table_2} \ + GROUP BY {test_settings.table_2_col_id} \ + LIMIT 10000" + + api_method = api_for_path(client, path) + result = api_method(sql_query=query) + return result + + +def api_for_path(client, path): + root = client.api.services + for part in path.split("."): + if hasattr(root, part): + root = getattr(root, part) + else: + print("cant find part", part, path) + return root + + +EVENT_USERS_CAN_QUERY_MOCK = "users_can_query_mock" @pytest.mark.asyncio async def test_create_dataset_and_read_mock(request): + ensure_package_installed("google-cloud-bigquery", "google.cloud.bigquery") + ensure_package_installed("db-dtypes", "db_dtypes") + + scenario = Scenario( + name="test_create_dataset_and_read_mock", + events=[ + EVENT_USER_ADMIN_CREATED, + EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED, + EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED, + EVENT_WORKER_POOL_CREATED, + EVENT_ALLOW_GUEST_SIGNUP_DISABLED, + EVENT_USERS_CREATED, + EVENT_USERS_CREATED_CHECKED, + EVENT_QUERY_ENDPOINT_CREATED, + EVENT_QUERY_ENDPOINT_CONFIGURED, + EVENT_USERS_CAN_QUERY_MOCK, + ], + ) + events = EventManager() + events.add_scenario(scenario) + events.monitor() + server = make_server(request) admin = make_admin() events.register(EVENT_USER_ADMIN_CREATED) - await events.wait_for(event_name=EVENT_USER_ADMIN_CREATED) + + await events.await_for(event_name=EVENT_USER_ADMIN_CREATED) assert events.happened(EVENT_USER_ADMIN_CREATED) root_client = admin.client(server) worker_pool_name = "bigquery-pool" worker_docker_tag = f"openmined/bigquery:{sy.__version__}" - print("running create_prebuilt_worker_image") - asyncio.create_task( - create_prebuilt_worker_image( - events, - root_client, - worker_docker_tag, - EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED, - ) - ) - print("finished queing create_prebuilt_worker_image") - - print("waiting...") - # await events.wait_for( - # event_name=EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED, timeout=30 - # ) - # assert await events.happened(EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED) - - # worker_image_result = get_prebuilt_worker_image( - # events, - # root_client, - # worker_docker_tag, - # EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED, - # ) - - # await events.wait_for(event_name=EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED) - # assert events.happened(EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED) - - # add_external_registry(events, root_client, EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED) - # await events.wait_for(event_name=EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED) - # assert events.happened(EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED) - - # create_worker_pool( - # events, - # root_client, - # worker_pool_name, - # worker_image_result, - # EVENT_WORKER_POOL_CREATED, - # ) - - # await events.wait_for(event_name=EVENT_WORKER_POOL_CREATED) - # assert events.happened(EVENT_WORKER_POOL_CREATED) - - # check_worker_pool_exists( - # events, root_client, worker_pool_name, EVENT_WORKER_POOL_CREATED - # ) - - # await events.wait_for(event_name=EVENT_WORKER_POOL_CREATED) - # assert events.happened(EVENT_WORKER_POOL_CREATED) + create_prebuilt_worker_image( + events, + root_client, + worker_docker_tag, + EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED, + ) - # set_settings_allow_guest_signup( - # events, root_client, False, EVENT_ALLOW_GUEST_SIGNUP_DISABLED - # ) + worker_image_result = get_prebuilt_worker_image( + events, + root_client, + worker_docker_tag, + EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED, + ) - # await events.wait_for(event_name=EVENT_ALLOW_GUEST_SIGNUP_DISABLED) - # assert events.happened(EVENT_ALLOW_GUEST_SIGNUP_DISABLED) + add_external_registry(events, root_client, EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED) - # users = [make_user() for i in range(2)] - # # user = users[0] + create_worker_pool( + events, + root_client, + worker_pool_name, + worker_image_result, + EVENT_WORKER_POOL_CREATED, + ) - # create_users(root_client, events, users, EVENT_USERS_CREATED) + check_worker_pool_exists( + events, root_client, worker_pool_name, EVENT_WORKER_POOL_CREATED + ) - # await events.wait_for(event_name=EVENT_USERS_CREATED) - # assert events.happened(EVENT_USERS_CREATED) + set_settings_allow_guest_signup( + events, root_client, False, EVENT_ALLOW_GUEST_SIGNUP_DISABLED + ) - # check_users_created( - # events, root_client, users, EVENT_USERS_CREATED, EVENT_USERS_CREATED_CHECKED - # ) + users = [make_user() for i in range(2)] - # await events.wait_for(event_name=EVENT_USERS_CREATED_CHECKED) - # assert events.happened(EVENT_USERS_CREATED_CHECKED) + create_users(root_client, events, users, EVENT_USERS_CREATED) - # check users are created - # high_client.api.services.user.get_all() + check_users_created( + events, root_client, users, EVENT_USERS_CREATED, EVENT_USERS_CREATED_CHECKED + ) - # # check_cant_sign_up + create_endpoints_query( + events, + root_client, + worker_pool_name=worker_pool_name, + register=EVENT_QUERY_ENDPOINT_CREATED, + ) - # # create users + test_query_path = "bigquery.test_query" + set_endpoint_settings( + events, + root_client, + path=test_query_path, + after=EVENT_QUERY_ENDPOINT_CREATED, + register=EVENT_QUERY_ENDPOINT_CONFIGURED, + ) - # # create api endpoints - # # check they respond + await result_is( + events, + lambda: len(make_test_query(users[0].client(server), test_query_path)) == 10000, + matches=True, + after=[EVENT_QUERY_ENDPOINT_CONFIGURED, EVENT_USERS_CREATED_CHECKED], + register=EVENT_USERS_CAN_QUERY_MOCK, + ) - # # login as user - # # test queries # # submit code via api # # verify its not accessible yet @@ -262,11 +353,16 @@ async def test_create_dataset_and_read_mock(request): # # continuously checking # # assert he random number of rows is there - # # await has( - # # lambda: dataset_get_all() == 1, - # # "1 Dataset", - # # timeout=15, - # # retry=1, - # # ) + res = await result_is( + events, + lambda: guest_register(root_client, make_user()), + matches=sy.SyftError( + message="*You don't have permission to create an account*" + ), + after=EVENT_ALLOW_GUEST_SIGNUP_DISABLED, + ) + + assert res is True - # await events.wait_for(event_name=EVENT_USERS_CREATED_CHECKED, timeout=60) + await events.await_scenario(scenario_name="test_create_dataset_and_read_mock") + assert events.scenario_completed("test_create_dataset_and_read_mock") diff --git a/tests/scenarios/bigquery/a.py b/tests/scenarios/bigquery/a.py deleted file mode 100644 index e8e42185a2c..00000000000 --- a/tests/scenarios/bigquery/a.py +++ /dev/null @@ -1,22 +0,0 @@ -import inspect -import os - - -def get_caller_file_path() -> str | None: - stack = inspect.stack() - print("stack", stack) - - for frame_info in stack: - if "from syft import test_settings" in str(frame_info.code_context): - print(f"File: {frame_info.filename}") - print(f"Line: {frame_info.lineno}") - print(f"Code: {frame_info.code_context[0].strip()}") - caller_file_path = os.path.dirname(os.path.abspath(frame_info.filename)) - print("possible path", caller_file_path) - return caller_file_path - - return None - - -result = get_caller_file_path() -print(result) diff --git a/tests/scenarios/bigquery/asserts.py b/tests/scenarios/bigquery/asserts.py index f71549352ea..ad8a97b1eaf 100644 --- a/tests/scenarios/bigquery/asserts.py +++ b/tests/scenarios/bigquery/asserts.py @@ -1,5 +1,8 @@ # stdlib +import importlib.util import inspect +import subprocess +import sys # third party import anyio @@ -19,3 +22,30 @@ async def has(expr, expects="", timeout=10, retry=1): except TimeoutError: lambda_source = inspect.getsource(expr) raise FailedAssert(f"{lambda_source} {expects}") + + +def check_import_exists(module_name: str): + # can pass . paths like google.cloud.bigquery + spec = importlib.util.find_spec(module_name) + return spec is not None + + +def install_package(package_name: str): + try: + subprocess.check_call([sys.executable, "-m", "pip", "install", package_name]) + except subprocess.CalledProcessError: + print(f"pip failed to install {package_name}. Trying uv pip...") + try: + subprocess.check_call(["uv", "pip", "install", package_name]) + except subprocess.CalledProcessError as e: + print( + f"An error occurred while trying to install {package_name} with uv pip: {e}" + ) + + +def ensure_package_installed(package_name, module_name): + if not check_import_exists(module_name): + print(f"{module_name} not found. Installing...") + install_package(package_name) + else: + print(f"{module_name} is already installed.") diff --git a/tests/scenarios/bigquery/events.py b/tests/scenarios/bigquery/events.py index 3f8a550c4f6..b17e5a91b6b 100644 --- a/tests/scenarios/bigquery/events.py +++ b/tests/scenarios/bigquery/events.py @@ -1,44 +1,237 @@ +# stdlib +import asyncio +from dataclasses import dataclass +import inspect +import json +import os +from threading import Lock +import time + # third party -import anyio +from unsync import unsync EVENT_USER_ADMIN_CREATED = "user_admin_created" EVENT_USERS_CREATED = "users_created" EVENT_DATASET_UPLOADED = "dataset_uploaded" EVENT_DATASET_MOCK_READABLE = "dataset_mock_readable" +EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED = "prebuilt_worker_image_bigquery_created" +EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED = "external_registry_bigquery_created" +EVENT_WORKER_POOL_CREATED = "worker_pool_created" +EVENT_ALLOW_GUEST_SIGNUP_DISABLED = "allow_guest_signup_disabled" +EVENT_USERS_CREATED_CHECKED = "users_created_checked" + + +@dataclass +class Scenario: + name: str + events: list[str] + + def add_event(self, event: str): + self.events.append(event) class EventManager: - def __init__(self): - self.events = {} - self.event_waiters = {} + def __init__( + self, + test_name: str | None = None, + test_dir: str | None = None, + reset: bool = True, + ): + self.start_time = time.time() + self.event_file = self._get_event_file(test_name, test_dir) + self.lock = Lock() + self._ensure_file_exists() + self.scenarios = {} + if reset: + self.clear_events() + + def add_scenario(self, scenario: Scenario): + with self.lock: + with open(self.event_file, "r+") as f: + events = json.load(f) + for event in scenario.events: + if event not in events: + events[event] = None + self.scenarios[scenario.name] = scenario.events + f.seek(0) + json.dump(events, f) + f.truncate() + + def wait_scenario( + self, scenario_name: str, timeout: float = 15.0, show: bool = True + ) -> bool: + start_time = time.time() + while time.time() - start_time < timeout: + if self.scenario_completed(scenario_name): + return True + if show: + time_left = timeout - (time.time() - start_time) + print(f"wait_for_scenario: {scenario_name}. Time left: {time_left}") + + time.sleep(1) + return False + + async def await_scenario( + self, scenario_name: str, timeout: float = 15.0, show: bool = True + ) -> bool: + start_time = time.time() + while time.time() - start_time < timeout: + if self.scenario_completed(scenario_name): + return True + if show: + time_left = timeout - (time.time() - start_time) + print( + f"async await_for_scenario: {scenario_name}. Time left: {time_left}" + ) + await asyncio.sleep(1) + return False + + def scenario_completed(self, scenario_name: str) -> bool: + with self.lock: + with open(self.event_file) as f: + events = json.load(f) + scenario_events = self.scenarios.get(scenario_name, []) + incomplete_events = [ + event for event in scenario_events if events.get(event) is None + ] + + if incomplete_events: + print( + f"Scenario '{scenario_name}' is incomplete. Missing events: {incomplete_events}" + ) + return False + return True + + def _get_event_file( + self, test_name: str | None = None, test_dir: str | None = None + ): + # Get the calling test function's name + if not test_name: + current_frame = inspect.currentframe() + caller_frame = current_frame.f_back + while caller_frame: + if caller_frame.f_code.co_name.startswith("test_"): + test_name = caller_frame.f_code.co_name + break + caller_frame = caller_frame.f_back + else: + test_name = "unknown_test" + + # Get the directory of the calling test file + if not test_dir: + current_frame = inspect.currentframe() + caller_frame = current_frame.f_back + caller_file = inspect.getfile(caller_frame) + test_dir = os.path.dirname(os.path.abspath(caller_file)) + + # Create a unique filename for this test + return os.path.join(test_dir, f"{test_name}_events.json.events") + + def _ensure_file_exists(self): + if not os.path.exists(self.event_file): + with open(self.event_file, "w") as f: + json.dump({}, f) def register(self, event_name: str): - self.events[event_name] = anyio.Event() - # change to two step process to better track event outcomes - self.events[event_name].set() + with self.lock: + with open(self.event_file, "r+") as f: + now = time.time() + events = json.load(f) + events[event_name] = now + f.seek(0) + json.dump(events, f) + f.truncate() + print(f"> Event: {event_name} occured at: {now}") - waiters = self.event_waiters.get(event_name, []) - for waiter in waiters: - waiter.set() + def wait_for( + self, + event_name: str | list[str] | tuple[str], + timeout: float = 15.0, + show: bool = True, + ) -> bool: + event_names = event_name + if isinstance(event_names, str): + event_names = [event_names] + + start_time = time.time() + while time.time() - start_time < timeout: + if all(self.happened(event_name) for event_name in event_names): + return True + if show: + time_left = timeout - (time.time() - start_time) + print(f"wait_for: {event_names}. Time left: {time_left}") + + time.sleep(1) + return False - async def wait_for(self, event_name: str, timeout: float = 15.0): - if event_name in self.events: - return self.events[event_name] + async def await_for( + self, + event_name: str | list[str] | tuple[str], + timeout: float = 15.0, + show: bool = True, + ) -> bool: + event_names = event_name + if isinstance(event_names, str): + event_names = [event_names] - waiter = anyio.Event() - self.event_waiters.setdefault(event_name, []).append(waiter) + start_time = time.time() + while time.time() - start_time < timeout: + if all(self.happened(event_name) for event_name in event_names): + return True + if show: + time_left = timeout - (time.time() - start_time) + print(f"async await_for: {event_names}. Time left: {time_left}") + await asyncio.sleep(1) + return False + def happened(self, event_name: str) -> bool: try: - with anyio.move_on_after(timeout) as cancel_scope: - await waiter.wait() - if cancel_scope.cancel_called: - raise TimeoutError(f"Timeout waiting for event: {event_name}") - return self.events[event_name] - finally: - self.event_waiters[event_name].remove(waiter) - - async def happened(self, event_name: str) -> bool: - if event_name in self.events: - event = self.events[event_name] - return event.is_set() + with self.lock: + with open(self.event_file) as f: + events = json.load(f) + if event_name in events: + return events[event_name] + except Exception as e: + print("e", e) return False + + def get_event_time(self, event_name: str) -> float | None: + with self.lock: + with open(self.event_file) as f: + events = json.load(f) + return events.get(event_name) + + def clear_events(self): + with self.lock: + with open(self.event_file, "w") as f: + json.dump({}, f) + + @unsync + async def monitor(self, period: float = 2): + while True: + await asyncio.sleep(period) + self.flush_monitor() + + def flush_monitor(self): + with self.lock: + with open(self.event_file) as f: + events = json.load(f) + if not events: + return + for event, timestamp in sorted(events.items(), key=lambda x: x[1]): + if timestamp: + now = time.time() + time_since_start = now - timestamp + print( + f"Event: {event} happened {time_since_start:.2f} seconds ago" + ) + else: + print( + f"Event: {event} is registered but has not happened yet. Pending..." + ) + + def __del__(self): + # Clean up the file when the EventManager is destroyed + # if os.path.exists(self.event_file): + # os.remove(self.event_file) + pass diff --git a/tests/scenarios/bigquery/fixtures_sync.py b/tests/scenarios/bigquery/fixtures_sync.py index 2f25eeb980f..4f4a0475609 100644 --- a/tests/scenarios/bigquery/fixtures_sync.py +++ b/tests/scenarios/bigquery/fixtures_sync.py @@ -119,15 +119,23 @@ def create_dataset(name: str): return dataset -def make_server(request) -> Any: +def make_server(request: Any | None = None) -> Any: print("making server") server = sy.orchestra.launch( - name="test-datasite-1", port="auto", dev_mode=True, reset=True + name="test-datasite-1", + port="auto", + dev_mode=True, + reset=True, + n_consumers=1, + create_producer=True, ) def cleanup(): print("landing server") server.land() - request.addfinalizer(cleanup) + if not request: + print("WARNING: No pytest request supplied, no finalizer added") + else: + request.addfinalizer(cleanup) return server diff --git a/tests/scenarios/bigquery/make.py b/tests/scenarios/bigquery/make.py index b05b10af7da..126523f12f5 100644 --- a/tests/scenarios/bigquery/make.py +++ b/tests/scenarios/bigquery/make.py @@ -2,9 +2,223 @@ from fixtures_sync import create_user from unsync import unsync +# syft absolute +import syft as sy +from syft import test_settings + @unsync async def create_users(root_client, events, users, event_name): for test_user in users: create_user(root_client, test_user) events.register(event_name) + + +@unsync +async def create_endpoints_query(events, client, worker_pool_name: str, register: str): + @sy.api_endpoint_method( + settings={ + "credentials": test_settings.gce_service_account.to_dict(), + "region": test_settings.gce_region, + "project_id": test_settings.gce_project_id, + } + ) + def private_query_function( + context, + sql_query: str, + ) -> str: + # third party + from google.cloud import bigquery # noqa: F811 + from google.oauth2 import service_account + + # syft absolute + from syft.service.response import SyftError + + # Auth for Bigquer based on the workload identity + credentials = service_account.Credentials.from_service_account_info( + context.settings["credentials"] + ) + scoped_credentials = credentials.with_scopes( + ["https://www.googleapis.com/auth/cloud-platform"] + ) + + client = bigquery.Client( + credentials=scoped_credentials, + location=context.settings["region"], + ) + + try: + rows = client.query_and_wait( + sql_query, + project=context.settings["project_id"], + ) + + if rows.total_rows > 1_000_000: + return SyftError( + message="Please only write queries that gather aggregate statistics" + ) + + return rows.to_dataframe() + except Exception as e: + # We MUST handle the errors that we want to be visible to the data owners. + # Any exception not catched is visible only to the data owner. + # not a bigquery exception + if not hasattr(e, "_errors"): + output = f"got exception e: {type(e)} {str(e)}" + return SyftError( + message=f"An error occured executing the API call {output}" + ) + # return SyftError(message="An error occured executing the API call, please contact the domain owner.") + + if e._errors[0]["reason"] in [ + "badRequest", + "blocked", + "duplicate", + "invalidQuery", + "invalid", + "jobBackendError", + "jobInternalError", + "notFound", + "notImplemented", + "rateLimitExceeded", + "resourceInUse", + "resourcesExceeded", + "tableUnavailable", + "timeout", + ]: + return SyftError( + message="Error occured during the call: " + e._errors[0]["message"] + ) + else: + return SyftError( + message="An error occured executing the API call, please contact the domain owner." + ) + + # Define any helper methods for our rate limiter + def is_within_rate_limit(context): + """Rate limiter for custom API calls made by users.""" + # stdlib + import datetime + + state = context.state + settings = context.settings + email = context.user.email + + current_time = datetime.datetime.now() + calls_last_min = [ + 1 if (current_time - call_time).seconds < 60 else 0 + for call_time in state[email] + ] + + return sum(calls_last_min) < settings["CALLS_PER_MIN"] + + # Define a mock endpoint that the researchers can use for testing + @sy.api_endpoint_method( + settings={ + "credentials": test_settings.gce_service_account.to_dict(), + "region": test_settings.gce_region, + "project_id": test_settings.gce_project_id, + "CALLS_PER_MIN": 10, + }, + helper_functions=[is_within_rate_limit], + ) + def mock_query_function( + context, + sql_query: str, + ) -> str: + # stdlib + import datetime + + # third party + from google.cloud import bigquery # noqa: F811 + from google.oauth2 import service_account + + # syft absolute + from syft.service.response import SyftError + + # Auth for Bigquer based on the workload identity + credentials = service_account.Credentials.from_service_account_info( + context.settings["credentials"] + ) + scoped_credentials = credentials.with_scopes( + ["https://www.googleapis.com/auth/cloud-platform"] + ) + + client = bigquery.Client( + credentials=scoped_credentials, + location=context.settings["region"], + ) + + # Store a dict with the calltimes for each user, via the email. + if context.user.email not in context.state.keys(): + context.state[context.user.email] = [] + + if not context.code.is_within_rate_limit(context): + return SyftError(message="Rate limit of calls per minute has been reached.") + + try: + context.state[context.user.email].append(datetime.datetime.now()) + + rows = client.query_and_wait( + sql_query, + project=context.settings["project_id"], + ) + + if rows.total_rows > 1_000_000: + return SyftError( + message="Please only write queries that gather aggregate statistics" + ) + + return rows.to_dataframe() + + except Exception as e: + # not a bigquery exception + if not hasattr(e, "_errors"): + output = f"got exception e: {type(e)} {str(e)}" + return SyftError( + message=f"An error occured executing the API call {output}" + ) + # return SyftError(message="An error occured executing the API call, please contact the domain owner.") + + # Treat all errors that we would like to be forwarded to the data scientists + # By default, any exception is only visible to the data owner. + + if e._errors[0]["reason"] in [ + "badRequest", + "blocked", + "duplicate", + "invalidQuery", + "invalid", + "jobBackendError", + "jobInternalError", + "notFound", + "notImplemented", + "rateLimitExceeded", + "resourceInUse", + "resourcesExceeded", + "tableUnavailable", + "timeout", + ]: + return SyftError( + message="Error occured during the call: " + e._errors[0]["message"] + ) + else: + return SyftError( + message="An error occured executing the API call, please contact the domain owner." + ) + + new_endpoint = sy.TwinAPIEndpoint( + path="bigquery.test_query", + description="This endpoint allows to query Bigquery storage via SQL queries.", + private_function=private_query_function, + mock_function=mock_query_function, + worker_pool=worker_pool_name, + ) + + result = client.custom_api.add(endpoint=new_endpoint) + + if register: + if isinstance(result, sy.SyftSuccess): + events.register(register) + else: + print("Failed to add api endpoint") diff --git a/tests/scenarios/bigquery/prototype.ipynb b/tests/scenarios/bigquery/prototype.ipynb deleted file mode 100644 index d68a8cb4f39..00000000000 --- a/tests/scenarios/bigquery/prototype.ipynb +++ /dev/null @@ -1,277 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [], - "source": [ - "# third party\n", - "from faker import Faker\n", - "from fixtures import *" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "manager = TestEventManager()\n", - "manager.reset_test_state()" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "USERS_CREATED = \"users_created\"\n", - "MOCK_READABLE = \"mock_readable\"" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [], - "source": [ - "fake = Faker()" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "TestUser(name='Eduardo Edwards', email='info@openmined.org', password='changethis', role=, server_cache=None)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "admin = make_admin()\n", - "admin" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Test Complete\n", - "None\n" - ] - }, - { - "ename": "AssertionError", - "evalue": "", - "output_type": "error", - "traceback": [ - "\u001b[0;31m--------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAssertionError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[7], line 5\u001b[0m\n\u001b[1;32m 3\u001b[0m result \u001b[38;5;241m=\u001b[39m manager\u001b[38;5;241m.\u001b[39mget_event(MOCK_READABLE)\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28mprint\u001b[39m(result)\n\u001b[0;32m----> 5\u001b[0m \u001b[38;5;28;01massert\u001b[39;00m result\n", - "\u001b[0;31mAssertionError\u001b[0m: " - ] - } - ], - "source": [ - "async with AsyncWaitForEvent(manager, MOCK_READABLE, retry_secs=1, timeout_secs=10):\n", - " print(\"Test Complete\")\n", - " result = manager.get_event(MOCK_READABLE)\n", - " print(result)\n", - " assert result" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "server = make_server(admin)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "root_client = admin.client(server)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "dataset_name = fake.name()\n", - "dataset = create_dataset(name=dataset_name)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "upload_dataset(root_client, dataset)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "users = [make_user() for i in range(2)]\n", - "users" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def create_users(root_client, manager, users):\n", - " for test_user in users:\n", - " create_user(root_client, test_user)\n", - " manager.register_event(USERS_CREATED)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "def user_can_read_mock_dataset(server, manager, user, dataset_name):\n", - " print(\"waiting \", USERS_CREATED)\n", - " with WaitForEvent(manager, USERS_CREATED, retry_secs=1):\n", - " print(\"logging in user\")\n", - " user_client = user.client(server)\n", - " print(\"getting dataset\", dataset_name)\n", - " mock = user_client.api.services.dataset[dataset_name].assets[0].mock\n", - " df = trade_flow_df_mock(trade_flow_df())\n", - " assert df.equals(mock)\n", - " manager.register_event(MOCK_READABLE)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "user = users[0]\n", - "user" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "asyncit(\n", - " user_can_read_mock_dataset,\n", - " server=server,\n", - " manager=manager,\n", - " user=user,\n", - " dataset_name=dataset_name,\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "asyncit(create_users, root_client=root_client, manager=manager, users=users)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "with WaitForEvent(manager, MOCK_READABLE, retry_secs=1, timeout_secs=10):\n", - " print(\"Test Complete\")\n", - " result = manager.get_event(MOCK_READABLE)\n", - " print(result)\n", - " assert result" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.2" - }, - "toc": { - "base_numbering": 1, - "nav_menu": {}, - "number_sections": true, - "sideBar": true, - "skip_h1_title": false, - "title_cell": "Table of Contents", - "title_sidebar": "Contents", - "toc_cell": false, - "toc_position": {}, - "toc_section_display": true, - "toc_window_display": true - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} From 42ba5e2d0d1b17add4a46414b0d210e075b42018 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Tue, 27 Aug 2024 14:57:49 +1000 Subject: [PATCH 06/13] More progress on scenario --- tests/scenarios/bigquery/events.py | 5 + ...{1_setup_test.py => level_2_basic_test.py} | 114 ++++++++-- tests/scenarios/bigquery/make.py | 195 ++++++++++++++++-- 3 files changed, 281 insertions(+), 33 deletions(-) rename tests/scenarios/bigquery/{1_setup_test.py => level_2_basic_test.py} (75%) diff --git a/tests/scenarios/bigquery/events.py b/tests/scenarios/bigquery/events.py index b17e5a91b6b..5b9bd0a8358 100644 --- a/tests/scenarios/bigquery/events.py +++ b/tests/scenarios/bigquery/events.py @@ -19,6 +19,11 @@ EVENT_WORKER_POOL_CREATED = "worker_pool_created" EVENT_ALLOW_GUEST_SIGNUP_DISABLED = "allow_guest_signup_disabled" EVENT_USERS_CREATED_CHECKED = "users_created_checked" +EVENT_SCHEMA_ENDPOINT_CREATED = "schema_endpoint_created" +EVENT_SUBMIT_QUERY_ENDPOINT_CREATED = "submit_query_endpoint_created" +EVENT_SUBMIT_QUERY_ENDPOINT_CONFIGURED = "submit_query_endpoint_configured" +EVENT_USERS_CAN_QUERY_MOCK = "users_can_query_mock" +EVENT_USERS_CAN_SUBMIT_QUERY = "users_can_submit_query" @dataclass diff --git a/tests/scenarios/bigquery/1_setup_test.py b/tests/scenarios/bigquery/level_2_basic_test.py similarity index 75% rename from tests/scenarios/bigquery/1_setup_test.py rename to tests/scenarios/bigquery/level_2_basic_test.py index 3f0927b46ab..eaf053ea6c6 100644 --- a/tests/scenarios/bigquery/1_setup_test.py +++ b/tests/scenarios/bigquery/level_2_basic_test.py @@ -7,6 +7,11 @@ from events import EVENT_ALLOW_GUEST_SIGNUP_DISABLED from events import EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED from events import EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED +from events import EVENT_SCHEMA_ENDPOINT_CREATED +from events import EVENT_SUBMIT_QUERY_ENDPOINT_CONFIGURED +from events import EVENT_SUBMIT_QUERY_ENDPOINT_CREATED +from events import EVENT_USERS_CAN_QUERY_MOCK +from events import EVENT_USERS_CAN_SUBMIT_QUERY from events import EVENT_USERS_CREATED from events import EVENT_USERS_CREATED_CHECKED from events import EVENT_USER_ADMIN_CREATED @@ -18,6 +23,8 @@ from fixtures_sync import make_server from fixtures_sync import make_user from make import create_endpoints_query +from make import create_endpoints_schema +from make import create_endpoints_submit_query from make import create_users import pytest from unsync import unsync @@ -167,7 +174,7 @@ def guest_register(client, test_user): async def result_is( events, expr, - matches: bool | type | object, + matches: bool | str | type | object, after: str | None = None, register: str | None = None, ): @@ -177,10 +184,14 @@ async def result_is( lambda_source = inspect.getsource(expr) try: result = expr() + assertion = False if isinstance(matches, bool): assertion = result == matches - if isinstance(matches, type): + elif isinstance(matches, type): assertion = isinstance(result, matches) + elif isinstance(matches, str): + message = matches.replace("*", "") + assertion = message in str(result) else: if hasattr(result, "message"): message = result.message.replace("*", "") @@ -196,16 +207,15 @@ async def result_is( @unsync -async def set_endpoint_settings(events, client, path, after: str, register: str): +async def set_endpoint_settings( + events, client, path, kwargs, after: str, register: str +): if after: await events.await_for(event_name=after) # Here, we update the endpoint to timeout after 100s (rather the default of 60s) - result1 = client.api.services.api.update(endpoint_path=path, endpoint_timeout=120) - result2 = client.api.services.api.update( - endpoint_path=path, hide_mock_definition=True - ) - if isinstance(result1, sy.SyftSuccess) and isinstance(result2, sy.SyftSuccess): + result = client.api.services.api.update(endpoint_path=path, **kwargs) + if isinstance(result, sy.SyftSuccess): events.register(register) else: print(f"Failed to update api endpoint. {path}") @@ -215,14 +225,32 @@ async def set_endpoint_settings(events, client, path, after: str, register: str) EVENT_QUERY_ENDPOINT_CONFIGURED = "query_endpoint_configured" -def make_test_query(client, path): +def query_sql(): query = f"SELECT {test_settings.table_2_col_id}, AVG({test_settings.table_2_col_score}) AS average_score \ FROM {test_settings.dataset_2}.{test_settings.table_2} \ GROUP BY {test_settings.table_2_col_id} \ LIMIT 10000" + return query + + +def run_code(client, method_name, **kwargs): + service_func_name = method_name + if "*" in method_name: + matcher = method_name.replace("*", "") + all_code = client.api.services.code.get_all() + for code in all_code: + if matcher in code.service_func_name: + service_func_name = code.service_func_name + break + api_method = api_for_path(client, path=f"code.{service_func_name}") + result = api_method(**kwargs) + return result + + +def run_api_path(client, path, **kwargs): api_method = api_for_path(client, path) - result = api_method(sql_query=query) + result = api_method(**kwargs) return result @@ -236,11 +264,11 @@ def api_for_path(client, path): return root -EVENT_USERS_CAN_QUERY_MOCK = "users_can_query_mock" +EVENT_USERS_QUERY_NOT_READY = "users_query_not_ready" @pytest.mark.asyncio -async def test_create_dataset_and_read_mock(request): +async def test_level_2_basic_scenario(request): ensure_package_installed("google-cloud-bigquery", "google.cloud.bigquery") ensure_package_installed("db-dtypes", "db_dtypes") @@ -256,7 +284,12 @@ async def test_create_dataset_and_read_mock(request): EVENT_USERS_CREATED_CHECKED, EVENT_QUERY_ENDPOINT_CREATED, EVENT_QUERY_ENDPOINT_CONFIGURED, + EVENT_SCHEMA_ENDPOINT_CREATED, + EVENT_SUBMIT_QUERY_ENDPOINT_CREATED, + EVENT_SUBMIT_QUERY_ENDPOINT_CONFIGURED, EVENT_USERS_CAN_QUERY_MOCK, + EVENT_USERS_CAN_SUBMIT_QUERY, + EVENT_USERS_QUERY_NOT_READY, ], ) @@ -329,20 +362,71 @@ async def test_create_dataset_and_read_mock(request): events, root_client, path=test_query_path, + kwargs={"endpoint_timeout": 120, "hide_mock_definition": True}, after=EVENT_QUERY_ENDPOINT_CREATED, register=EVENT_QUERY_ENDPOINT_CONFIGURED, ) + print("calling create endpoints schema") + create_endpoints_schema( + events, + root_client, + worker_pool_name=worker_pool_name, + register=EVENT_SCHEMA_ENDPOINT_CREATED, + ) + + create_endpoints_submit_query( + events, + root_client, + worker_pool_name=worker_pool_name, + register=EVENT_SUBMIT_QUERY_ENDPOINT_CREATED, + ) + + submit_query_path = "bigquery.submit_query" + set_endpoint_settings( + events, + root_client, + path=submit_query_path, + kwargs={"hide_mock_definition": True}, + after=EVENT_SUBMIT_QUERY_ENDPOINT_CREATED, + register=EVENT_SUBMIT_QUERY_ENDPOINT_CONFIGURED, + ) + await result_is( events, - lambda: len(make_test_query(users[0].client(server), test_query_path)) == 10000, + lambda: len( + run_api_path( + users[0].client(server), test_query_path, sql_query=query_sql() + ) + ) + == 10000, matches=True, after=[EVENT_QUERY_ENDPOINT_CONFIGURED, EVENT_USERS_CREATED_CHECKED], register=EVENT_USERS_CAN_QUERY_MOCK, ) - # # submit code via api - # # verify its not accessible yet + func_name = "test_func" + + await result_is( + events, + lambda: run_api_path( + users[0].client(server), + submit_query_path, + func_name=func_name, + query=query_sql(), + ), + matches="*Query submitted*", + after=[EVENT_SUBMIT_QUERY_ENDPOINT_CONFIGURED, EVENT_USERS_CREATED_CHECKED], + register=EVENT_USERS_CAN_SUBMIT_QUERY, + ) + + await result_is( + events, + lambda: run_code(users[0].client(server), method_name=f"{func_name}*"), + matches=sy.SyftError(message="*Your code is waiting for approval*"), + after=[EVENT_USERS_CAN_SUBMIT_QUERY], + register=EVENT_USERS_QUERY_NOT_READY, + ) # # continuously checking for # # new untriaged requests diff --git a/tests/scenarios/bigquery/make.py b/tests/scenarios/bigquery/make.py index 126523f12f5..dc8a539ed19 100644 --- a/tests/scenarios/bigquery/make.py +++ b/tests/scenarios/bigquery/make.py @@ -7,6 +7,25 @@ from syft import test_settings +# Define any helper methods for our rate limiter +def is_within_rate_limit(context): + """Rate limiter for custom API calls made by users.""" + # stdlib + import datetime + + state = context.state + settings = context.settings + email = context.user.email + + current_time = datetime.datetime.now() + calls_last_min = [ + 1 if (current_time - call_time).seconds < 60 else 0 + for call_time in state[email] + ] + + return sum(calls_last_min) < settings["CALLS_PER_MIN"] + + @unsync async def create_users(root_client, events, users, event_name): for test_user in users: @@ -94,24 +113,6 @@ def private_query_function( message="An error occured executing the API call, please contact the domain owner." ) - # Define any helper methods for our rate limiter - def is_within_rate_limit(context): - """Rate limiter for custom API calls made by users.""" - # stdlib - import datetime - - state = context.state - settings = context.settings - email = context.user.email - - current_time = datetime.datetime.now() - calls_last_min = [ - 1 if (current_time - call_time).seconds < 60 else 0 - for call_time in state[email] - ] - - return sum(calls_last_min) < settings["CALLS_PER_MIN"] - # Define a mock endpoint that the researchers can use for testing @sy.api_endpoint_method( settings={ @@ -222,3 +223,161 @@ def mock_query_function( events.register(register) else: print("Failed to add api endpoint") + + +@unsync +async def create_endpoints_schema(events, client, worker_pool_name: str, register: str): + @sy.api_endpoint( + path="bigquery.schema", + description="This endpoint allows for visualising the metadata of tables available in BigQuery.", + settings={ + "credentials": test_settings.gce_service_account.to_dict(), + "region": test_settings.gce_region, + "project_id": test_settings.gce_project_id, + "dataset_1": test_settings.dataset_1, + "table_1": test_settings.table_1, + "table_2": test_settings.table_2, + "CALLS_PER_MIN": 5, + }, + helper_functions=[ + is_within_rate_limit + ], # Adds ratelimit as this is also a method available to data scientists + worker_pool=worker_pool_name, + ) + def schema_function( + context, + ) -> str: + # stdlib + import datetime + + # third party + from google.cloud import bigquery # noqa: F811 + from google.oauth2 import service_account + import pandas as pd + + # syft absolute + from syft.service.response import SyftError + + # Auth for Bigquer based on the workload identity + credentials = service_account.Credentials.from_service_account_info( + context.settings["credentials"] + ) + scoped_credentials = credentials.with_scopes( + ["https://www.googleapis.com/auth/cloud-platform"] + ) + + client = bigquery.Client( + credentials=scoped_credentials, + location=context.settings["region"], + ) + + if context.user.email not in context.state.keys(): + context.state[context.user.email] = [] + + if not context.code.is_within_rate_limit(context): + return SyftError(message="Rate limit of calls per minute has been reached.") + + try: + context.state[context.user.email].append(datetime.datetime.now()) + + # Formats the data schema in a data frame format + # Warning: the only supported format types are primitives, np.ndarrays and pd.DataFrames + + data_schema = [] + for table_id in [ + f"{context.settings["dataset_1"]}.{context.settings["table_1"]}", + f"{context.settings["dataset_1"]}.{context.settings["table_2"]}", + ]: + table = client.get_table(table_id) + for schema in table.schema: + data_schema.append( + { + "project": str(table.project), + "dataset_id": str(table.dataset_id), + "table_id": str(table.table_id), + "schema_name": str(schema.name), + "schema_field": str(schema.field_type), + "description": str(table.description), + "num_rows": str(table.num_rows), + } + ) + return pd.DataFrame(data_schema) + + except Exception as e: + # not a bigquery exception + if not hasattr(e, "_errors"): + output = f"got exception e: {type(e)} {str(e)}" + return SyftError( + message=f"An error occured executing the API call {output}" + ) + # return SyftError(message="An error occured executing the API call, please contact the domain owner.") + + # Should add appropriate error handling for what should be exposed to the data scientists. + return SyftError( + message="An error occured executing the API call, please contact the domain owner." + ) + + result = client.custom_api.add(endpoint=schema_function) + + if register: + if isinstance(result, sy.SyftSuccess): + events.register(register) + else: + print("Failed to add schema_function") + + +@unsync +async def create_endpoints_submit_query( + events, client, worker_pool_name: str, register: str +): + @sy.api_endpoint( + path="bigquery.submit_query", + description="API endpoint that allows you to submit SQL queries to run on the private data.", + worker_pool=worker_pool_name, + settings={"worker": worker_pool_name}, + ) + def submit_query( + context, + func_name: str, + query: str, + ) -> str: + # stdlib + import hashlib + + # syft absolute + import syft as sy + + hash_object = hashlib.new("sha256") + + hash_object.update(context.user.email.encode("utf-8")) + func_name = func_name + "_" + hash_object.hexdigest()[:6] + + @sy.syft_function( + name=func_name, + input_policy=sy.MixedInputPolicy( + endpoint=sy.Constant( + val=context.admin_client.api.services.bigquery.test_query + ), + query=sy.Constant(val=query), + client=context.admin_client, + ), + worker_pool_name=context.settings["worker"], + ) + def execute_query(query: str, endpoint): + res = endpoint(sql_query=query) + return res + + request = context.user_client.code.request_code_execution(execute_query) + if isinstance(request, sy.SyftError): + return request + context.admin_client.requests.set_tags(request, ["autosync"]) + + return f"Query submitted {request}. Use `client.code.{func_name}()` to run your query" + + result = client.custom_api.add(endpoint=submit_query) + + if register: + if isinstance(result, sy.SyftSuccess): + events.register(register) + else: + print("Failed to add api endpoint") From c53d013fa4a40b21a40ee2b6c9a9bad473b1469f Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Tue, 27 Aug 2024 15:45:29 +1000 Subject: [PATCH 07/13] Updated for new dev Exception handling code --- .../scenarios/bigquery/level_2_basic_test.py | 25 ++++++++++++++----- 1 file changed, 19 insertions(+), 6 deletions(-) diff --git a/tests/scenarios/bigquery/level_2_basic_test.py b/tests/scenarios/bigquery/level_2_basic_test.py index eaf053ea6c6..42bc93efbec 100644 --- a/tests/scenarios/bigquery/level_2_basic_test.py +++ b/tests/scenarios/bigquery/level_2_basic_test.py @@ -183,7 +183,14 @@ async def result_is( lambda_source = inspect.getsource(expr) try: - result = expr() + try: + result = expr() + except Exception as e: + if isinstance(e, sy.SyftException): + result = e + else: + raise e + assertion = False if isinstance(matches, bool): assertion = result == matches @@ -193,9 +200,15 @@ async def result_is( message = matches.replace("*", "") assertion = message in str(result) else: - if hasattr(result, "message"): - message = result.message.replace("*", "") + if isinstance(result, sy.service.response.SyftResponseMessage): + message = matches.message.replace("*", "") assertion = isinstance(result, type(matches)) and message in str(result) + elif isinstance(result, sy.SyftException): + message = matches.public_message.replace("*", "") + assertion = ( + isinstance(result, type(matches)) + and message in result.public_message + ) if assertion and register: events.register(event_name=register) @@ -423,7 +436,7 @@ async def test_level_2_basic_scenario(request): await result_is( events, lambda: run_code(users[0].client(server), method_name=f"{func_name}*"), - matches=sy.SyftError(message="*Your code is waiting for approval*"), + matches=sy.SyftException(public_message="*Your code is waiting for approval*"), after=[EVENT_USERS_CAN_SUBMIT_QUERY], register=EVENT_USERS_QUERY_NOT_READY, ) @@ -440,8 +453,8 @@ async def test_level_2_basic_scenario(request): res = await result_is( events, lambda: guest_register(root_client, make_user()), - matches=sy.SyftError( - message="*You don't have permission to create an account*" + matches=sy.SyftException( + public_message="*You have no permission to create an account*" ), after=EVENT_ALLOW_GUEST_SIGNUP_DISABLED, ) From cfa158ea0376c1faed97593c8247c8a659cb4ccd Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Wed, 28 Aug 2024 10:53:14 +1000 Subject: [PATCH 08/13] Added admin and data scientist loops --- packages/syft/src/syft/util/autoreload.py | 8 +- tests/scenarios/bigquery/events.py | 2 + .../scenarios/bigquery/level_2_basic_test.py | 118 ++++++++++++++++-- 3 files changed, 115 insertions(+), 13 deletions(-) diff --git a/packages/syft/src/syft/util/autoreload.py b/packages/syft/src/syft/util/autoreload.py index e1f68e45555..b3230c6dc6d 100644 --- a/packages/syft/src/syft/util/autoreload.py +++ b/packages/syft/src/syft/util/autoreload.py @@ -8,8 +8,9 @@ def enable_autoreload() -> None: from IPython import get_ipython ipython = get_ipython() # noqa: F821 - ipython.run_line_magic("load_ext", "autoreload") - ipython.run_line_magic("autoreload", "2") + if hasattr(ipython, "run_line_magic"): + ipython.run_line_magic("load_ext", "autoreload") + ipython.run_line_magic("autoreload", "2") AUTORELOAD_ENABLED = True print("Autoreload enabled") except Exception as e: @@ -24,7 +25,8 @@ def disable_autoreload() -> None: from IPython import get_ipython ipython = get_ipython() # noqa: F821 - ipython.run_line_magic("autoreload", "0") + if hasattr(ipython, "run_line_magic"): + ipython.run_line_magic("autoreload", "0") AUTORELOAD_ENABLED = False print("Autoreload disabled.") except Exception as e: diff --git a/tests/scenarios/bigquery/events.py b/tests/scenarios/bigquery/events.py index 5b9bd0a8358..4b1bcf3471b 100644 --- a/tests/scenarios/bigquery/events.py +++ b/tests/scenarios/bigquery/events.py @@ -24,6 +24,8 @@ EVENT_SUBMIT_QUERY_ENDPOINT_CONFIGURED = "submit_query_endpoint_configured" EVENT_USERS_CAN_QUERY_MOCK = "users_can_query_mock" EVENT_USERS_CAN_SUBMIT_QUERY = "users_can_submit_query" +EVENT_ADMIN_APPROVED_FIRST_REQUEST = "admin_approved_first_request" +EVENT_USERS_CAN_GET_APPROVED_RESULT = "users_can_get_approved_result" @dataclass diff --git a/tests/scenarios/bigquery/level_2_basic_test.py b/tests/scenarios/bigquery/level_2_basic_test.py index 42bc93efbec..d02e083b5f6 100644 --- a/tests/scenarios/bigquery/level_2_basic_test.py +++ b/tests/scenarios/bigquery/level_2_basic_test.py @@ -4,12 +4,14 @@ # third party from asserts import ensure_package_installed +from events import EVENT_ADMIN_APPROVED_FIRST_REQUEST from events import EVENT_ALLOW_GUEST_SIGNUP_DISABLED from events import EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED from events import EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED from events import EVENT_SCHEMA_ENDPOINT_CREATED from events import EVENT_SUBMIT_QUERY_ENDPOINT_CONFIGURED from events import EVENT_SUBMIT_QUERY_ENDPOINT_CREATED +from events import EVENT_USERS_CAN_GET_APPROVED_RESULT from events import EVENT_USERS_CAN_QUERY_MOCK from events import EVENT_USERS_CAN_SUBMIT_QUERY from events import EVENT_USERS_CREATED @@ -32,6 +34,8 @@ # syft absolute import syft as sy from syft import test_settings +from syft.service.code.user_code import UserCode +from syft.service.job.job_stash import Job # dataset stuff # """ @@ -257,7 +261,10 @@ def run_code(client, method_name, **kwargs): break api_method = api_for_path(client, path=f"code.{service_func_name}") - result = api_method(**kwargs) + try: + result = api_method(**kwargs) + except Exception as e: + print(">> got an exception while trying to run code", e) return result @@ -280,6 +287,87 @@ def api_for_path(client, path): EVENT_USERS_QUERY_NOT_READY = "users_query_not_ready" +def get_pending(client): + results = [] + for request in client.requests: + if str(request.status) == "RequestStatus.PENDING": + results.append(request) + print( + f"Found pending request: {request.code.constants["query"].val}: {request.id}" + ) + return results + + +def approve_and_deposit(client, request_id): + request = client.requests.get_by_uid(uid=request_id) + code = request.code + + if not isinstance(code, UserCode): + print("NOT A USER CODE???") + + func_name = request.code.service_func_name + job = run_code(client, func_name, blocking=False) + if not isinstance(job, Job): + print("NOT A JOB??") + + job.wait() + job_info = job.info(result=True) + result = request.deposit_result(job_info, approve=True) + print("got result from approving?", result) + return result + + +@unsync +async def triage_requests(events, client, after, register): + print("Waiting for admin account to be created") + if after: + await events.await_for(event_name=after) + while True: + await asyncio.sleep(1) + print("> Admin checking for requests") + requests = get_pending(client) + for request in requests: + print("> Admin approving request", request.id) + result = approve_and_deposit(client, request.id) + print("got result from approving reuwest", result) + events.register(event_name=register) + + +def get_approved(client): + results = [] + for request in client.requests: + if str(request.status) == "RequestStatus.APPROVED": + results.append(request) + print( + f"Found approved request: {request.code.constants["query"].val}: {request.id}" + ) + return results + + +@unsync +async def get_results(events, client, method_name, after, register): + method_name = method_name.replace("*", "") + print("Waiting for admin approve or deny") + if after: + await events.await_for(event_name=after) + while True: + await asyncio.sleep(1) + print("> Data Scientist checking for approval") + requests = get_approved(client) + for request in requests: + if method_name in request.code.service_func_name: + print( + f"> Found approved request: {method_name} at {request.code.service_func_name}" + ) + print("> Running and getting result") + result = run_code(client, request.code.service_func_name) + print("> got result", result) + if hasattr(result, "__len__") and len(result) == 10000: + events.register(event_name=register) + else: + print("no match with expected") + + @pytest.mark.asyncio async def test_level_2_basic_scenario(request): ensure_package_installed("google-cloud-bigquery", "google.cloud.bigquery") @@ -303,6 +391,8 @@ async def test_level_2_basic_scenario(request): EVENT_USERS_CAN_QUERY_MOCK, EVENT_USERS_CAN_SUBMIT_QUERY, EVENT_USERS_QUERY_NOT_READY, + EVENT_ADMIN_APPROVED_FIRST_REQUEST, + EVENT_USERS_CAN_GET_APPROVED_RESULT, ], ) @@ -319,6 +409,13 @@ async def test_level_2_basic_scenario(request): assert events.happened(EVENT_USER_ADMIN_CREATED) root_client = admin.client(server) + triage_requests( + events, + root_client, + after=EVENT_USER_ADMIN_CREATED, + register=EVENT_ADMIN_APPROVED_FIRST_REQUEST, + ) + worker_pool_name = "bigquery-pool" worker_docker_tag = f"openmined/bigquery:{sy.__version__}" @@ -441,14 +538,13 @@ async def test_level_2_basic_scenario(request): register=EVENT_USERS_QUERY_NOT_READY, ) - # # continuously checking for - # # new untriaged requests - # # executing them locally - # # submitting the results - - # # users get the results - # # continuously checking - # # assert he random number of rows is there + get_results( + events, + users[0].client(server), + method_name=f"{func_name}*", + after=EVENT_USERS_QUERY_NOT_READY, + register=EVENT_USERS_CAN_GET_APPROVED_RESULT, + ) res = await result_is( events, @@ -461,5 +557,7 @@ async def test_level_2_basic_scenario(request): assert res is True - await events.await_scenario(scenario_name="test_create_dataset_and_read_mock") + await events.await_scenario( + scenario_name="test_create_dataset_and_read_mock", timeout=30 + ) assert events.scenario_completed("test_create_dataset_and_read_mock") From 01038205001bdcdd4cca3773f31c34dd7b68d86f Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Wed, 28 Aug 2024 11:39:47 +1000 Subject: [PATCH 09/13] Fixed issues with exceptions versus errors in asserts --- packages/syft/src/syft/service/response.py | 7 + packages/syft/src/syft/types/errors.py | 5 +- tests/scenarios/bigquery/events.py | 3 + tests/scenarios/bigquery/fixtures_sync.py | 7 +- .../scenarios/bigquery/level_2_basic_test.py | 123 +++++------------- tests/scenarios/bigquery/partials.py | 23 ---- tests/scenarios/bigquery/story.py | 20 --- 7 files changed, 49 insertions(+), 139 deletions(-) delete mode 100644 tests/scenarios/bigquery/partials.py delete mode 100644 tests/scenarios/bigquery/story.py diff --git a/packages/syft/src/syft/service/response.py b/packages/syft/src/syft/service/response.py index 4484ab93687..90058d4d8a2 100644 --- a/packages/syft/src/syft/service/response.py +++ b/packages/syft/src/syft/service/response.py @@ -106,6 +106,13 @@ def is_err(self) -> bool: def is_ok(self) -> bool: return False + @classmethod + def from_public_exception( + cls, + exc: Exception, + ) -> Self: + return cls(message=exc.public_message) + @classmethod def from_exception( cls, diff --git a/packages/syft/src/syft/types/errors.py b/packages/syft/src/syft/types/errors.py index ada21e720a2..09d70be9c56 100644 --- a/packages/syft/src/syft/types/errors.py +++ b/packages/syft/src/syft/types/errors.py @@ -197,9 +197,10 @@ def __enter__(self): # type: ignore def __exit__(self, exc_type, exc_value, traceback): # type: ignore message = None expected_exception_type = self.expected_exception - if isinstance(expected_exception_type, SyftException): - message = self.expected_exception.public_message.replace("*", "") + if not isinstance(expected_exception_type, type): expected_exception_type = type(self.expected_exception) + if hasattr(self.expected_exception, "public_message"): + message = self.expected_exception.public_message.replace("*", "") # After block of code if exc_type is None: diff --git a/tests/scenarios/bigquery/events.py b/tests/scenarios/bigquery/events.py index 4b1bcf3471b..b666f7b98de 100644 --- a/tests/scenarios/bigquery/events.py +++ b/tests/scenarios/bigquery/events.py @@ -26,6 +26,9 @@ EVENT_USERS_CAN_SUBMIT_QUERY = "users_can_submit_query" EVENT_ADMIN_APPROVED_FIRST_REQUEST = "admin_approved_first_request" EVENT_USERS_CAN_GET_APPROVED_RESULT = "users_can_get_approved_result" +EVENT_USERS_QUERY_NOT_READY = "users_query_not_ready" +EVENT_QUERY_ENDPOINT_CREATED = "query_endpoint_created" +EVENT_QUERY_ENDPOINT_CONFIGURED = "query_endpoint_configured" @dataclass diff --git a/tests/scenarios/bigquery/fixtures_sync.py b/tests/scenarios/bigquery/fixtures_sync.py index 4f4a0475609..c214e9612a9 100644 --- a/tests/scenarios/bigquery/fixtures_sync.py +++ b/tests/scenarios/bigquery/fixtures_sync.py @@ -119,10 +119,13 @@ def create_dataset(name: str): return dataset -def make_server(request: Any | None = None) -> Any: +def make_server(request: Any | None = None, server_name: str | None = None) -> Any: print("making server") + if server_name is None: + faker = Faker() + server_name = faker.name() server = sy.orchestra.launch( - name="test-datasite-1", + name=server_name, port="auto", dev_mode=True, reset=True, diff --git a/tests/scenarios/bigquery/level_2_basic_test.py b/tests/scenarios/bigquery/level_2_basic_test.py index d02e083b5f6..3f7b4790ba2 100644 --- a/tests/scenarios/bigquery/level_2_basic_test.py +++ b/tests/scenarios/bigquery/level_2_basic_test.py @@ -8,6 +8,8 @@ from events import EVENT_ALLOW_GUEST_SIGNUP_DISABLED from events import EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED from events import EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED +from events import EVENT_QUERY_ENDPOINT_CONFIGURED +from events import EVENT_QUERY_ENDPOINT_CREATED from events import EVENT_SCHEMA_ENDPOINT_CREATED from events import EVENT_SUBMIT_QUERY_ENDPOINT_CONFIGURED from events import EVENT_SUBMIT_QUERY_ENDPOINT_CREATED @@ -16,6 +18,7 @@ from events import EVENT_USERS_CAN_SUBMIT_QUERY from events import EVENT_USERS_CREATED from events import EVENT_USERS_CREATED_CHECKED +from events import EVENT_USERS_QUERY_NOT_READY from events import EVENT_USER_ADMIN_CREATED from events import EVENT_WORKER_POOL_CREATED from events import EventManager @@ -37,30 +40,6 @@ from syft.service.code.user_code import UserCode from syft.service.job.job_stash import Job -# dataset stuff -# """ -# dataset_get_all = with_client(get_datasets, server) - -# assert dataset_get_all() == 0 - -# dataset_name = fake.name() -# dataset = create_dataset(name=dataset_name) - -# upload_dataset(root_client, dataset) - -# events.register(EVENT_DATASET_UPLOADED) - -# user_can_read_mock_dataset(server, events, user, dataset_name) - -# await has( -# lambda: dataset_get_all() == 1, -# "1 Dataset", -# timeout=15, -# retry=1, -# ) - -# """ - @unsync async def get_prebuilt_worker_image(events, client, expected_tag, event_name): @@ -70,8 +49,6 @@ async def get_prebuilt_worker_image(events, client, expected_tag, event_name): if expected_tag in str(worker_image.image_identifier): assert expected_tag in str(worker_image.image_identifier) return worker_image - print(f"get_prebuilt_worker_image cannot find {expected_tag}") - # raise FailedAssert() @unsync @@ -109,8 +86,6 @@ async def create_worker_pool( result[0], sy.service.worker.worker_pool.ContainerSpawnStatus ): events.register(event_name) - else: - print("bad result", result) @unsync @@ -123,11 +98,6 @@ async def check_worker_pool_exists(events, client, worker_pool_name, event_name) assert worker_pool_name == pool.name return worker_pool_name == pool.name - print(f"check_worker_pool_exists cannot find worker_pool_name {worker_pool_name}") - # raise FailedAssert( - - # ) - @unsync async def set_settings_allow_guest_signup( @@ -137,8 +107,6 @@ async def set_settings_allow_guest_signup( if event_name: if isinstance(result, sy.SyftSuccess): events.register(event_name) - else: - print("cant set settings alow guest signup") @unsync @@ -153,12 +121,6 @@ async def check_users_created(events, client, users, event_name, event_set): if len(found_emails) == len(expected_emails): events.register(event_set) - else: - print( - f"check_users_created only found {len(found_emails)} of {len(expected_emails)} " - f"emails: {found_emails}, {expected_emails}" - ) - # raise FailedAssert() def guest_register(client, test_user): @@ -187,6 +149,7 @@ async def result_is( lambda_source = inspect.getsource(expr) try: + result = None try: result = expr() except Exception as e: @@ -204,16 +167,22 @@ async def result_is( message = matches.replace("*", "") assertion = message in str(result) else: - if isinstance(result, sy.service.response.SyftResponseMessage): + type_matches = isinstance(result, type(matches)) + message_matches = True + + message = None + if isinstance(matches, sy.service.response.SyftResponseMessage): message = matches.message.replace("*", "") - assertion = isinstance(result, type(matches)) and message in str(result) elif isinstance(result, sy.SyftException): message = matches.public_message.replace("*", "") - assertion = ( - isinstance(result, type(matches)) - and message in result.public_message - ) + if message: + if isinstance(result, sy.service.response.SyftResponseMessage): + message_matches = message in str(result) + elif isinstance(result, sy.SyftException): + message_matches = message in result.public_message + + assertion = type_matches and message_matches if assertion and register: events.register(event_name=register) return assertion @@ -234,12 +203,6 @@ async def set_endpoint_settings( result = client.api.services.api.update(endpoint_path=path, **kwargs) if isinstance(result, sy.SyftSuccess): events.register(register) - else: - print(f"Failed to update api endpoint. {path}") - - -EVENT_QUERY_ENDPOINT_CREATED = "query_endpoint_created" -EVENT_QUERY_ENDPOINT_CONFIGURED = "query_endpoint_configured" def query_sql(): @@ -261,10 +224,8 @@ def run_code(client, method_name, **kwargs): break api_method = api_for_path(client, path=f"code.{service_func_name}") - try: - result = api_method(**kwargs) - except Exception as e: - print(">> got an exception while trying to run code", e) + # can raise + result = api_method(**kwargs) return result @@ -279,22 +240,14 @@ def api_for_path(client, path): for part in path.split("."): if hasattr(root, part): root = getattr(root, part) - else: - print("cant find part", part, path) return root -EVENT_USERS_QUERY_NOT_READY = "users_query_not_ready" - - def get_pending(client): results = [] for request in client.requests: if str(request.status) == "RequestStatus.PENDING": results.append(request) - print( - f"Found pending request: {request.code.constants["query"].val}: {request.id}" - ) return results @@ -303,33 +256,28 @@ def approve_and_deposit(client, request_id): code = request.code if not isinstance(code, UserCode): - print("NOT A USER CODE???") + return func_name = request.code.service_func_name job = run_code(client, func_name, blocking=False) if not isinstance(job, Job): - print("NOT A JOB??") + return None job.wait() job_info = job.info(result=True) result = request.deposit_result(job_info, approve=True) - print("got result from approving?", result) return result @unsync async def triage_requests(events, client, after, register): - print("Waiting for admin account to be created") if after: await events.await_for(event_name=after) while True: - await asyncio.sleep(1) - print("> Admin checking for requests") + await asyncio.sleep(2) requests = get_pending(client) for request in requests: - print("> Admin approving request", request.id) - result = approve_and_deposit(client, request.id) - print("got result from approving reuwest", result) + approve_and_deposit(client, request.id) events.register(event_name=register) @@ -338,34 +286,26 @@ def get_approved(client): for request in client.requests: if str(request.status) == "RequestStatus.APPROVED": results.append(request) - print( - f"Found approved request: {request.code.constants["query"].val}: {request.id}" - ) return results @unsync async def get_results(events, client, method_name, after, register): method_name = method_name.replace("*", "") - print("Waiting for admin approve or deny") if after: await events.await_for(event_name=after) while True: await asyncio.sleep(1) - print("> Data Scientist checking for approval") requests = get_approved(client) for request in requests: if method_name in request.code.service_func_name: - print( - f"> Found approved request: {method_name} at {request.code.service_func_name}" - ) - print("> Running and getting result") - result = run_code(client, request.code.service_func_name) - print("> got result", result) - if hasattr(result, "__len__") and len(result) == 10000: - events.register(event_name=register) + job = run_code(client, request.code.service_func_name, blocking=False) + if not isinstance(job, Job): + continue else: - print("no match with expected") + result = job.wait().get() + if hasattr(result, "__len__") and len(result) == 10000: + events.register(event_name=register) @pytest.mark.asyncio @@ -374,7 +314,7 @@ async def test_level_2_basic_scenario(request): ensure_package_installed("db-dtypes", "db_dtypes") scenario = Scenario( - name="test_create_dataset_and_read_mock", + name="test_create_apis_and_triage_requests", events=[ EVENT_USER_ADMIN_CREATED, EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED, @@ -477,7 +417,6 @@ async def test_level_2_basic_scenario(request): register=EVENT_QUERY_ENDPOINT_CONFIGURED, ) - print("calling create endpoints schema") create_endpoints_schema( events, root_client, @@ -558,6 +497,6 @@ async def test_level_2_basic_scenario(request): assert res is True await events.await_scenario( - scenario_name="test_create_dataset_and_read_mock", timeout=30 + scenario_name="test_create_apis_and_triage_requests", timeout=30 ) - assert events.scenario_completed("test_create_dataset_and_read_mock") + assert events.scenario_completed("test_create_apis_and_triage_requests") diff --git a/tests/scenarios/bigquery/partials.py b/tests/scenarios/bigquery/partials.py deleted file mode 100644 index e62b4fbd04b..00000000000 --- a/tests/scenarios/bigquery/partials.py +++ /dev/null @@ -1,23 +0,0 @@ -# stdlib -from collections.abc import Callable - -# third party -from unsync import Unfuture -from unsync import unsync - -# syft absolute -from syft.client.datasite_client import DatasiteClient -from syft.orchestra import ServerHandle - - -def with_client(func, client: unsync | DatasiteClient | ServerHandle) -> Callable: - if isinstance(client, ServerHandle): - client = client.client - - def with_func(): - result = func(client) - if isinstance(result, Unfuture): - result = result.result() - return result - - return with_func diff --git a/tests/scenarios/bigquery/story.py b/tests/scenarios/bigquery/story.py deleted file mode 100644 index 0817aee8293..00000000000 --- a/tests/scenarios/bigquery/story.py +++ /dev/null @@ -1,20 +0,0 @@ -# third party -from events import EVENT_DATASET_MOCK_READABLE -from events import EVENT_USERS_CREATED -from fixtures_sync import trade_flow_df -from fixtures_sync import trade_flow_df_mock -from unsync import unsync - - -@unsync -async def user_can_read_mock_dataset(server, events, user, dataset_name): - print("waiting ", EVENT_USERS_CREATED) - await events.wait_for(event_name=EVENT_USERS_CREATED) - user_client = user.client(server) - print("getting dataset", dataset_name) - mock = user_client.api.services.dataset[dataset_name].assets[0].mock - df = trade_flow_df_mock(trade_flow_df()) - print("Are we here?") - if df.equals(mock): - print("REGISTERING EVENT", EVENT_DATASET_MOCK_READABLE) - events.register(EVENT_DATASET_MOCK_READABLE) From 5f3452fa3583b827a4c9000b9a4de0009f6ca5f8 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Wed, 28 Aug 2024 12:03:32 +1000 Subject: [PATCH 10/13] Moved helpers into a subdir added ignore to precommit --- .pre-commit-config.yaml | 2 +- tests/scenarios/bigquery/api.py | 10 ---- .../bigquery/{ => helpers}/asserts.py | 0 .../bigquery/{ => helpers}/events.py | 0 .../bigquery/{ => helpers}/fixtures_sync.py | 2 +- .../scenarios/bigquery/{ => helpers}/make.py | 2 +- .../scenarios/bigquery/{ => helpers}/users.py | 0 .../scenarios/bigquery/level_2_basic_test.py | 54 +++++++++---------- 8 files changed, 30 insertions(+), 40 deletions(-) delete mode 100644 tests/scenarios/bigquery/api.py rename tests/scenarios/bigquery/{ => helpers}/asserts.py (100%) rename tests/scenarios/bigquery/{ => helpers}/events.py (100%) rename tests/scenarios/bigquery/{ => helpers}/fixtures_sync.py (99%) rename tests/scenarios/bigquery/{ => helpers}/make.py (99%) rename tests/scenarios/bigquery/{ => helpers}/users.py (100%) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 9dcd417cccc..3487d8d0915 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -31,7 +31,7 @@ repos: exclude: ^(packages/syft/tests/mongomock) - id: name-tests-test always_run: true - exclude: ^(.*/tests/utils/)|^(.*fixtures.py|packages/syft/tests/mongomock) + exclude: ^(.*/tests/utils/)|^(.*fixtures.py|packages/syft/tests/mongomock)|^(tests/scenarios/bigquery/helpers) - id: requirements-txt-fixer always_run: true exclude: "packages/syft/tests/mongomock" diff --git a/tests/scenarios/bigquery/api.py b/tests/scenarios/bigquery/api.py deleted file mode 100644 index 4343a603ff1..00000000000 --- a/tests/scenarios/bigquery/api.py +++ /dev/null @@ -1,10 +0,0 @@ -# third party -from unsync import unsync - - -@unsync -def get_datasets(client): - print("Checking datasets") - num_datasets = len(client.api.services.dataset.get_all()) - print(">>> num datasets", num_datasets) - return num_datasets diff --git a/tests/scenarios/bigquery/asserts.py b/tests/scenarios/bigquery/helpers/asserts.py similarity index 100% rename from tests/scenarios/bigquery/asserts.py rename to tests/scenarios/bigquery/helpers/asserts.py diff --git a/tests/scenarios/bigquery/events.py b/tests/scenarios/bigquery/helpers/events.py similarity index 100% rename from tests/scenarios/bigquery/events.py rename to tests/scenarios/bigquery/helpers/events.py diff --git a/tests/scenarios/bigquery/fixtures_sync.py b/tests/scenarios/bigquery/helpers/fixtures_sync.py similarity index 99% rename from tests/scenarios/bigquery/fixtures_sync.py rename to tests/scenarios/bigquery/helpers/fixtures_sync.py index c214e9612a9..266f27128dd 100644 --- a/tests/scenarios/bigquery/fixtures_sync.py +++ b/tests/scenarios/bigquery/helpers/fixtures_sync.py @@ -3,8 +3,8 @@ # third party from faker import Faker +from helpers.users import TestUser import pandas as pd -from users import TestUser # syft absolute import syft as sy diff --git a/tests/scenarios/bigquery/make.py b/tests/scenarios/bigquery/helpers/make.py similarity index 99% rename from tests/scenarios/bigquery/make.py rename to tests/scenarios/bigquery/helpers/make.py index dc8a539ed19..2ac90833cec 100644 --- a/tests/scenarios/bigquery/make.py +++ b/tests/scenarios/bigquery/helpers/make.py @@ -1,5 +1,5 @@ # third party -from fixtures_sync import create_user +from helpers.fixtures_sync import create_user from unsync import unsync # syft absolute diff --git a/tests/scenarios/bigquery/users.py b/tests/scenarios/bigquery/helpers/users.py similarity index 100% rename from tests/scenarios/bigquery/users.py rename to tests/scenarios/bigquery/helpers/users.py diff --git a/tests/scenarios/bigquery/level_2_basic_test.py b/tests/scenarios/bigquery/level_2_basic_test.py index 3f7b4790ba2..6f4f4372ad7 100644 --- a/tests/scenarios/bigquery/level_2_basic_test.py +++ b/tests/scenarios/bigquery/level_2_basic_test.py @@ -3,34 +3,34 @@ import inspect # third party -from asserts import ensure_package_installed -from events import EVENT_ADMIN_APPROVED_FIRST_REQUEST -from events import EVENT_ALLOW_GUEST_SIGNUP_DISABLED -from events import EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED -from events import EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED -from events import EVENT_QUERY_ENDPOINT_CONFIGURED -from events import EVENT_QUERY_ENDPOINT_CREATED -from events import EVENT_SCHEMA_ENDPOINT_CREATED -from events import EVENT_SUBMIT_QUERY_ENDPOINT_CONFIGURED -from events import EVENT_SUBMIT_QUERY_ENDPOINT_CREATED -from events import EVENT_USERS_CAN_GET_APPROVED_RESULT -from events import EVENT_USERS_CAN_QUERY_MOCK -from events import EVENT_USERS_CAN_SUBMIT_QUERY -from events import EVENT_USERS_CREATED -from events import EVENT_USERS_CREATED_CHECKED -from events import EVENT_USERS_QUERY_NOT_READY -from events import EVENT_USER_ADMIN_CREATED -from events import EVENT_WORKER_POOL_CREATED -from events import EventManager -from events import Scenario from faker import Faker -from fixtures_sync import make_admin -from fixtures_sync import make_server -from fixtures_sync import make_user -from make import create_endpoints_query -from make import create_endpoints_schema -from make import create_endpoints_submit_query -from make import create_users +from helpers.asserts import ensure_package_installed +from helpers.events import EVENT_ADMIN_APPROVED_FIRST_REQUEST +from helpers.events import EVENT_ALLOW_GUEST_SIGNUP_DISABLED +from helpers.events import EVENT_EXTERNAL_REGISTRY_BIGQUERY_CREATED +from helpers.events import EVENT_PREBUILT_WORKER_IMAGE_BIGQUERY_CREATED +from helpers.events import EVENT_QUERY_ENDPOINT_CONFIGURED +from helpers.events import EVENT_QUERY_ENDPOINT_CREATED +from helpers.events import EVENT_SCHEMA_ENDPOINT_CREATED +from helpers.events import EVENT_SUBMIT_QUERY_ENDPOINT_CONFIGURED +from helpers.events import EVENT_SUBMIT_QUERY_ENDPOINT_CREATED +from helpers.events import EVENT_USERS_CAN_GET_APPROVED_RESULT +from helpers.events import EVENT_USERS_CAN_QUERY_MOCK +from helpers.events import EVENT_USERS_CAN_SUBMIT_QUERY +from helpers.events import EVENT_USERS_CREATED +from helpers.events import EVENT_USERS_CREATED_CHECKED +from helpers.events import EVENT_USERS_QUERY_NOT_READY +from helpers.events import EVENT_USER_ADMIN_CREATED +from helpers.events import EVENT_WORKER_POOL_CREATED +from helpers.events import EventManager +from helpers.events import Scenario +from helpers.fixtures_sync import make_admin +from helpers.fixtures_sync import make_server +from helpers.fixtures_sync import make_user +from helpers.make import create_endpoints_query +from helpers.make import create_endpoints_schema +from helpers.make import create_endpoints_submit_query +from helpers.make import create_users import pytest from unsync import unsync From cf0a030cddef9cee853fef50a0d84e41c531c1a9 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Wed, 28 Aug 2024 13:34:10 +1000 Subject: [PATCH 11/13] Moved api code into files --- .../scenarios/bigquery/02-configure-api.ipynb | 369 +++--------------- .../bigquery/03-ds-submit-request.ipynb | 7 + .../bigquery/04-do-review-requests.ipynb | 7 + .../bigquery/apis/bigquery/__init__.py | 0 .../bigquery/apis/bigquery/helpers.py | 16 + .../bigquery/apis/bigquery/schema.py | 92 +++++ .../bigquery/apis/bigquery/submit_query.py | 50 +++ .../bigquery/apis/bigquery/test_query.py | 176 +++++++++ packages/syft/setup.cfg | 8 +- packages/syft/src/syft/util/util.py | 45 --- 10 files changed, 397 insertions(+), 373 deletions(-) create mode 100644 notebooks/scenarios/bigquery/apis/bigquery/__init__.py create mode 100644 notebooks/scenarios/bigquery/apis/bigquery/helpers.py create mode 100644 notebooks/scenarios/bigquery/apis/bigquery/schema.py create mode 100644 notebooks/scenarios/bigquery/apis/bigquery/submit_query.py create mode 100644 notebooks/scenarios/bigquery/apis/bigquery/test_query.py diff --git a/notebooks/scenarios/bigquery/02-configure-api.ipynb b/notebooks/scenarios/bigquery/02-configure-api.ipynb index d5d4bf9f716..00ff1677110 100644 --- a/notebooks/scenarios/bigquery/02-configure-api.ipynb +++ b/notebooks/scenarios/bigquery/02-configure-api.ipynb @@ -80,83 +80,9 @@ "metadata": {}, "outputs": [], "source": [ - "@sy.api_endpoint_method(\n", - " settings={\n", - " \"credentials\": test_settings.gce_service_account.to_dict(),\n", - " \"region\": test_settings.gce_region,\n", - " \"project_id\": test_settings.gce_project_id,\n", - " }\n", - ")\n", - "def private_query_function(\n", - " context,\n", - " sql_query: str,\n", - ") -> str:\n", - " # third party\n", - " from google.cloud import bigquery # noqa: F811\n", - " from google.oauth2 import service_account\n", - "\n", - " # syft absolute\n", - " from syft import SyftException\n", - "\n", - " # Auth for Bigquer based on the workload identity\n", - " credentials = service_account.Credentials.from_service_account_info(\n", - " context.settings[\"credentials\"]\n", - " )\n", - " scoped_credentials = credentials.with_scopes(\n", - " [\"https://www.googleapis.com/auth/cloud-platform\"]\n", - " )\n", - "\n", - " client = bigquery.Client(\n", - " credentials=scoped_credentials,\n", - " location=context.settings[\"region\"],\n", - " )\n", - "\n", - " try:\n", - " rows = client.query_and_wait(\n", - " sql_query,\n", - " project=context.settings[\"project_id\"],\n", - " )\n", - "\n", - " if rows.total_rows > 1_000_000:\n", - " raise SyftException(\n", - " public_message=\"Please only write queries that gather aggregate statistics\"\n", - " )\n", - "\n", - " return rows.to_dataframe()\n", - " except Exception as e:\n", - " # We MUST handle the errors that we want to be visible to the data owners.\n", - " # Any exception not catched is visible only to the data owner.\n", - " # not a bigquery exception\n", - " if not hasattr(e, \"_errors\"):\n", - " output = f\"got exception e: {type(e)} {str(e)}\"\n", - " raise SyftException(\n", - " public_message=f\"An error occured executing the API call {output}\"\n", - " )\n", - "\n", - " if e._errors[0][\"reason\"] in [\n", - " \"badRequest\",\n", - " \"blocked\",\n", - " \"duplicate\",\n", - " \"invalidQuery\",\n", - " \"invalid\",\n", - " \"jobBackendError\",\n", - " \"jobInternalError\",\n", - " \"notFound\",\n", - " \"notImplemented\",\n", - " \"rateLimitExceeded\",\n", - " \"resourceInUse\",\n", - " \"resourcesExceeded\",\n", - " \"tableUnavailable\",\n", - " \"timeout\",\n", - " ]:\n", - " raise SyftException(\n", - " public_message=\"Error occured during the call: \"\n", - " + e._errors[0][\"message\"]\n", - " )\n", - " else:\n", - " raise SyftException(\n", - " public_message=\"An error occured executing the API call, please contact the domain owner.\"\n", - " )" + "# Look up the worker pools and identify the name of the one that has the required packages\n", + "# After, bind the endpoint to that workerpool\n", + "high_client.worker_pools" ] }, { @@ -165,125 +91,17 @@ "metadata": {}, "outputs": [], "source": [ - "# Define any helper methods for our rate limiter\n", - "\n", - "\n", - "def is_within_rate_limit(context):\n", - " \"\"\"Rate limiter for custom API calls made by users.\"\"\"\n", - " # stdlib\n", - " import datetime\n", - "\n", - " state = context.state\n", - " settings = context.settings\n", - " email = context.user.email\n", - "\n", - " current_time = datetime.datetime.now()\n", - " calls_last_min = [\n", - " 1 if (current_time - call_time).seconds < 60 else 0\n", - " for call_time in state[email]\n", - " ]\n", - "\n", - " return sum(calls_last_min) < settings[\"CALLS_PER_MIN\"]\n", - "\n", - "\n", - "# Define a mock endpoint that the researchers can use for testing\n", + "# third party\n", + "from apis.bigquery import test_query\n", "\n", - "\n", - "@sy.api_endpoint_method(\n", + "mock_func = test_query.make_mock(\n", " settings={\n", " \"credentials\": test_settings.gce_service_account.to_dict(),\n", " \"region\": test_settings.gce_region,\n", " \"project_id\": test_settings.gce_project_id,\n", " \"CALLS_PER_MIN\": 10,\n", - " },\n", - " helper_functions=[is_within_rate_limit],\n", - ")\n", - "def mock_query_function(\n", - " context,\n", - " sql_query: str,\n", - ") -> str:\n", - " # stdlib\n", - " import datetime\n", - "\n", - " # third party\n", - " from google.cloud import bigquery # noqa: F811\n", - " from google.oauth2 import service_account\n", - "\n", - " # syft absolute\n", - " from syft import SyftException\n", - "\n", - " # Auth for Bigquer based on the workload identity\n", - " credentials = service_account.Credentials.from_service_account_info(\n", - " context.settings[\"credentials\"]\n", - " )\n", - " scoped_credentials = credentials.with_scopes(\n", - " [\"https://www.googleapis.com/auth/cloud-platform\"]\n", - " )\n", - "\n", - " client = bigquery.Client(\n", - " credentials=scoped_credentials,\n", - " location=context.settings[\"region\"],\n", - " )\n", - "\n", - " # Store a dict with the calltimes for each user, via the email.\n", - " if context.user.email not in context.state.keys():\n", - " context.state[context.user.email] = []\n", - "\n", - " if not context.code.is_within_rate_limit(context):\n", - " raise SyftException(\n", - " public_message=\"Rate limit of calls per minute has been reached.\"\n", - " )\n", - "\n", - " try:\n", - " context.state[context.user.email].append(datetime.datetime.now())\n", - "\n", - " rows = client.query_and_wait(\n", - " sql_query,\n", - " project=context.settings[\"project_id\"],\n", - " )\n", - "\n", - " if rows.total_rows > 1_000_000:\n", - " raise SyftException(\n", - " public_message=\"Please only write queries that gather aggregate statistics\"\n", - " )\n", - "\n", - " return rows.to_dataframe()\n", - "\n", - " except Exception as e:\n", - " # not a bigquery exception\n", - " if not hasattr(e, \"_errors\"):\n", - " output = f\"got exception e: {type(e)} {str(e)}\"\n", - " raise SyftException(\n", - " public_message=f\"An error occured executing the API call {output}\"\n", - " )\n", - "\n", - " # Treat all errors that we would like to be forwarded to the data scientists\n", - " # By default, any exception is only visible to the data owner.\n", - "\n", - " if e._errors[0][\"reason\"] in [\n", - " \"badRequest\",\n", - " \"blocked\",\n", - " \"duplicate\",\n", - " \"invalidQuery\",\n", - " \"invalid\",\n", - " \"jobBackendError\",\n", - " \"jobInternalError\",\n", - " \"notFound\",\n", - " \"notImplemented\",\n", - " \"rateLimitExceeded\",\n", - " \"resourceInUse\",\n", - " \"resourcesExceeded\",\n", - " \"tableUnavailable\",\n", - " \"timeout\",\n", - " ]:\n", - " raise SyftException(\n", - " public_message=\"Error occured during the call: \"\n", - " + e._errors[0][\"message\"]\n", - " )\n", - " else:\n", - " raise SyftException(\n", - " public_message=\"An error occured executing the API call, please contact the domain owner.\"\n", - " )" + " }\n", + ")" ] }, { @@ -292,9 +110,13 @@ "metadata": {}, "outputs": [], "source": [ - "# Look up the worker pools and identify the name of the one that has the required packages\n", - "# After, bind the endpoint to that workerpool\n", - "high_client.worker_pools" + "private_func = test_query.make_private(\n", + " settings={\n", + " \"credentials\": test_settings.gce_service_account.to_dict(),\n", + " \"region\": test_settings.gce_region,\n", + " \"project_id\": test_settings.gce_project_id,\n", + " }\n", + ")" ] }, { @@ -306,8 +128,8 @@ "new_endpoint = sy.TwinAPIEndpoint(\n", " path=\"bigquery.test_query\",\n", " description=\"This endpoint allows to query Bigquery storage via SQL queries.\",\n", - " private_function=private_query_function,\n", - " mock_function=mock_query_function,\n", + " private_function=private_func,\n", + " mock_function=mock_func,\n", " worker_pool=this_worker_pool_name,\n", ")\n", "\n", @@ -374,9 +196,10 @@ "metadata": {}, "outputs": [], "source": [ - "@sy.api_endpoint(\n", - " path=\"bigquery.schema\",\n", - " description=\"This endpoint allows for visualising the metadata of tables available in BigQuery.\",\n", + "# third party\n", + "from apis.bigquery import schema\n", + "\n", + "schema_function = schema.make_schema(\n", " settings={\n", " \"credentials\": test_settings.gce_service_account.to_dict(),\n", " \"region\": test_settings.gce_region,\n", @@ -386,86 +209,16 @@ " \"table_2\": test_settings.table_2,\n", " \"CALLS_PER_MIN\": 5,\n", " },\n", - " helper_functions=[\n", - " is_within_rate_limit\n", - " ], # Adds ratelimit as this is also a method available to data scientists\n", " worker_pool=this_worker_pool_name,\n", - ")\n", - "def schema_function(\n", - " context,\n", - ") -> str:\n", - " # stdlib\n", - " import datetime\n", - "\n", - " # third party\n", - " from google.cloud import bigquery # noqa: F811\n", - " from google.oauth2 import service_account\n", - " import pandas as pd\n", - "\n", - " # syft absolute\n", - " from syft import SyftException\n", - "\n", - " # Auth for Bigquer based on the workload identity\n", - " credentials = service_account.Credentials.from_service_account_info(\n", - " context.settings[\"credentials\"]\n", - " )\n", - " scoped_credentials = credentials.with_scopes(\n", - " [\"https://www.googleapis.com/auth/cloud-platform\"]\n", - " )\n", - "\n", - " client = bigquery.Client(\n", - " credentials=scoped_credentials,\n", - " location=context.settings[\"region\"],\n", - " )\n", - "\n", - " if context.user.email not in context.state.keys():\n", - " context.state[context.user.email] = []\n", - "\n", - " if not context.code.is_within_rate_limit(context):\n", - " raise SyftException(\n", - " public_message=\"Rate limit of calls per minute has been reached.\"\n", - " )\n", - "\n", - " try:\n", - " context.state[context.user.email].append(datetime.datetime.now())\n", - "\n", - " # Formats the data schema in a data frame format\n", - " # Warning: the only supported format types are primitives, np.ndarrays and pd.DataFrames\n", - "\n", - " data_schema = []\n", - " for table_id in [\n", - " f\"{context.settings[\"dataset_1\"]}.{context.settings[\"table_1\"]}\",\n", - " f\"{context.settings[\"dataset_1\"]}.{context.settings[\"table_2\"]}\",\n", - " ]:\n", - " table = client.get_table(table_id)\n", - " for schema in table.schema:\n", - " data_schema.append(\n", - " {\n", - " \"project\": str(table.project),\n", - " \"dataset_id\": str(table.dataset_id),\n", - " \"table_id\": str(table.table_id),\n", - " \"schema_name\": str(schema.name),\n", - " \"schema_field\": str(schema.field_type),\n", - " \"description\": str(table.description),\n", - " \"num_rows\": str(table.num_rows),\n", - " }\n", - " )\n", - " return pd.DataFrame(data_schema)\n", - "\n", - " except Exception as e:\n", - " # not a bigquery exception\n", - " if not hasattr(e, \"_errors\"):\n", - " output = f\"got exception e: {type(e)} {str(e)}\"\n", - " raise SyftException(\n", - " public_message=f\"An error occured executing the API call {output}\"\n", - " )\n", - "\n", - " # Should add appropriate error handling for what should be exposed to the data scientists.\n", - " raise SyftException(\n", - " public_message=\"An error occured executing the API call, please contact the domain owner.\"\n", - " )\n", - "\n", - "\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ "high_client.custom_api.add(endpoint=schema_function)\n", "high_client.refresh()" ] @@ -476,49 +229,12 @@ "metadata": {}, "outputs": [], "source": [ - "@sy.api_endpoint(\n", - " path=\"bigquery.submit_query\",\n", - " description=\"API endpoint that allows you to submit SQL queries to run on the private data.\",\n", - " worker_pool=this_worker_pool_name,\n", - " settings={\"worker\": this_worker_pool_name},\n", - ")\n", - "def submit_query(\n", - " context,\n", - " func_name: str,\n", - " query: str,\n", - ") -> str:\n", - " # stdlib\n", - " import hashlib\n", - "\n", - " # syft absolute\n", - " import syft as sy\n", - "\n", - " hash_object = hashlib.new(\"sha256\")\n", + "# third party\n", + "from apis.bigquery import submit_query\n", "\n", - " hash_object.update(context.user.email.encode(\"utf-8\"))\n", - " func_name = func_name + \"_\" + hash_object.hexdigest()[:6]\n", - "\n", - " @sy.syft_function(\n", - " name=func_name,\n", - " input_policy=sy.MixedInputPolicy(\n", - " endpoint=sy.Constant(\n", - " val=context.admin_client.api.services.bigquery.test_query\n", - " ),\n", - " query=sy.Constant(val=query),\n", - " client=context.admin_client,\n", - " ),\n", - " worker_pool_name=context.settings[\"worker\"],\n", - " )\n", - " def execute_query(query: str, endpoint):\n", - " res = endpoint(sql_query=query)\n", - " return res\n", - "\n", - " request = context.user_client.code.request_code_execution(execute_query)\n", - " context.admin_client.requests.set_tags(request, [\"autosync\"])\n", - "\n", - " return (\n", - " f\"Query submitted {request}. Use `client.code.{func_name}()` to run your query\"\n", - " )" + "submit_query_function = submit_query.make_submit_query(\n", + " settings={}, worker_pool=this_worker_pool_name\n", + ")" ] }, { @@ -527,7 +243,7 @@ "metadata": {}, "outputs": [], "source": [ - "high_client.custom_api.add(endpoint=submit_query)" + "high_client.custom_api.add(endpoint=submit_query_function)" ] }, { @@ -637,9 +353,7 @@ "outputs": [], "source": [ "# Test mock version for wrong queries\n", - "with sy.raises(\n", - " sy.SyftException(public_message=\"*must be qualified with a dataset*\"), show=True\n", - "):\n", + "with sy.raises(sy.SyftException(public_message=\"*must be qualified with a dataset*\")):\n", " _ = high_client.api.services.bigquery.test_query.mock(\n", " sql_query=\"SELECT * FROM invalid_table LIMIT 1\"\n", " )" @@ -693,7 +407,7 @@ "metadata": {}, "outputs": [], "source": [ - "assert len(state[\"info@openmined.org\"]) == 3" + "assert len(state[\"info@openmined.org\"]) >= 3" ] }, { @@ -755,6 +469,13 @@ "source": [ "server.land()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/notebooks/scenarios/bigquery/03-ds-submit-request.ipynb b/notebooks/scenarios/bigquery/03-ds-submit-request.ipynb index 62f7d511dc5..81b0d3b0d4a 100644 --- a/notebooks/scenarios/bigquery/03-ds-submit-request.ipynb +++ b/notebooks/scenarios/bigquery/03-ds-submit-request.ipynb @@ -223,6 +223,13 @@ "source": [ "server.land()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/notebooks/scenarios/bigquery/04-do-review-requests.ipynb b/notebooks/scenarios/bigquery/04-do-review-requests.ipynb index b1df0a88d27..bf45858b875 100644 --- a/notebooks/scenarios/bigquery/04-do-review-requests.ipynb +++ b/notebooks/scenarios/bigquery/04-do-review-requests.ipynb @@ -243,6 +243,13 @@ "source": [ "server.land()" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/notebooks/scenarios/bigquery/apis/bigquery/__init__.py b/notebooks/scenarios/bigquery/apis/bigquery/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/notebooks/scenarios/bigquery/apis/bigquery/helpers.py b/notebooks/scenarios/bigquery/apis/bigquery/helpers.py new file mode 100644 index 00000000000..c22b6f507a4 --- /dev/null +++ b/notebooks/scenarios/bigquery/apis/bigquery/helpers.py @@ -0,0 +1,16 @@ +def is_within_rate_limit(context): + """Rate limiter for custom API calls made by users.""" + # stdlib + import datetime + + state = context.state + settings = context.settings + email = context.user.email + + current_time = datetime.datetime.now() + calls_last_min = [ + 1 if (current_time - call_time).seconds < 60 else 0 + for call_time in state[email] + ] + + return sum(calls_last_min) < settings["CALLS_PER_MIN"] diff --git a/notebooks/scenarios/bigquery/apis/bigquery/schema.py b/notebooks/scenarios/bigquery/apis/bigquery/schema.py new file mode 100644 index 00000000000..7958244b38c --- /dev/null +++ b/notebooks/scenarios/bigquery/apis/bigquery/schema.py @@ -0,0 +1,92 @@ +# syft absolute +import syft as sy + +# relative +from .helpers import is_within_rate_limit + + +def make_schema(settings, worker_pool): + @sy.api_endpoint( + path="bigquery.schema", + description="This endpoint allows for visualising the metadata of tables available in BigQuery.", + settings=settings, + helper_functions=[ + is_within_rate_limit + ], # Adds ratelimit as this is also a method available to data scientists + worker_pool=worker_pool, + ) + def schema( + context, + ) -> str: + # stdlib + import datetime + + # third party + from google.cloud import bigquery # noqa: F811 + from google.oauth2 import service_account + import pandas as pd + + # syft absolute + from syft import SyftException + + # Auth for Bigquer based on the workload identity + credentials = service_account.Credentials.from_service_account_info( + context.settings["credentials"] + ) + scoped_credentials = credentials.with_scopes( + ["https://www.googleapis.com/auth/cloud-platform"] + ) + + client = bigquery.Client( + credentials=scoped_credentials, + location=context.settings["region"], + ) + + if context.user.email not in context.state.keys(): + context.state[context.user.email] = [] + + if not context.code.is_within_rate_limit(context): + raise SyftException( + public_message="Rate limit of calls per minute has been reached." + ) + + try: + context.state[context.user.email].append(datetime.datetime.now()) + + # Formats the data schema in a data frame format + # Warning: the only supported format types are primitives, np.ndarrays and pd.DataFrames + + data_schema = [] + for table_id in [ + f"{context.settings["dataset_1"]}.{context.settings["table_1"]}", + f"{context.settings["dataset_1"]}.{context.settings["table_2"]}", + ]: + table = client.get_table(table_id) + for schema in table.schema: + data_schema.append( + { + "project": str(table.project), + "dataset_id": str(table.dataset_id), + "table_id": str(table.table_id), + "schema_name": str(schema.name), + "schema_field": str(schema.field_type), + "description": str(table.description), + "num_rows": str(table.num_rows), + } + ) + return pd.DataFrame(data_schema) + + except Exception as e: + # not a bigquery exception + if not hasattr(e, "_errors"): + output = f"got exception e: {type(e)} {str(e)}" + raise SyftException( + public_message=f"An error occured executing the API call {output}" + ) + + # Should add appropriate error handling for what should be exposed to the data scientists. + raise SyftException( + public_message="An error occured executing the API call, please contact the domain owner." + ) + + return schema diff --git a/notebooks/scenarios/bigquery/apis/bigquery/submit_query.py b/notebooks/scenarios/bigquery/apis/bigquery/submit_query.py new file mode 100644 index 00000000000..6d88cb9f7ef --- /dev/null +++ b/notebooks/scenarios/bigquery/apis/bigquery/submit_query.py @@ -0,0 +1,50 @@ +# syft absolute +import syft as sy + + +def make_submit_query(settings, worker_pool): + updated_settings = {"user_code_worker": worker_pool} | settings + + @sy.api_endpoint( + path="bigquery.submit_query", + description="API endpoint that allows you to submit SQL queries to run on the private data.", + worker_pool=worker_pool, + settings=updated_settings, + ) + def submit_query( + context, + func_name: str, + query: str, + ) -> str: + # stdlib + import hashlib + + # syft absolute + import syft as sy + + hash_object = hashlib.new("sha256") + + hash_object.update(context.user.email.encode("utf-8")) + func_name = func_name + "_" + hash_object.hexdigest()[:6] + + @sy.syft_function( + name=func_name, + input_policy=sy.MixedInputPolicy( + endpoint=sy.Constant( + val=context.admin_client.api.services.bigquery.test_query + ), + query=sy.Constant(val=query), + client=context.admin_client, + ), + worker_pool_name=context.settings["user_code_worker"], + ) + def execute_query(query: str, endpoint): + res = endpoint(sql_query=query) + return res + + request = context.user_client.code.request_code_execution(execute_query) + context.admin_client.requests.set_tags(request, ["autosync"]) + + return f"Query submitted {request}. Use `client.code.{func_name}()` to run your query" + + return submit_query diff --git a/notebooks/scenarios/bigquery/apis/bigquery/test_query.py b/notebooks/scenarios/bigquery/apis/bigquery/test_query.py new file mode 100644 index 00000000000..4ce3b49d107 --- /dev/null +++ b/notebooks/scenarios/bigquery/apis/bigquery/test_query.py @@ -0,0 +1,176 @@ +# syft absolute +import syft as sy + +# relative +from .helpers import is_within_rate_limit + + +def make_mock(settings): + @sy.api_endpoint_method( + settings=settings, + helper_functions=[is_within_rate_limit], + ) + def mock( + context, + sql_query: str, + ) -> str: + # stdlib + import datetime + + # third party + from google.cloud import bigquery # noqa: F811 + from google.oauth2 import service_account + + # syft absolute + from syft import SyftException + + # Auth for Bigquer based on the workload identity + credentials = service_account.Credentials.from_service_account_info( + context.settings["credentials"] + ) + scoped_credentials = credentials.with_scopes( + ["https://www.googleapis.com/auth/cloud-platform"] + ) + + client = bigquery.Client( + credentials=scoped_credentials, + location=context.settings["region"], + ) + + # Store a dict with the calltimes for each user, via the email. + if context.user.email not in context.state.keys(): + context.state[context.user.email] = [] + + if not context.code.is_within_rate_limit(context): + raise SyftException( + public_message="Rate limit of calls per minute has been reached." + ) + + try: + context.state[context.user.email].append(datetime.datetime.now()) + + rows = client.query_and_wait( + sql_query, + project=context.settings["project_id"], + ) + + if rows.total_rows > 1_000_000: + raise SyftException( + public_message="Please only write queries that gather aggregate statistics" + ) + + return rows.to_dataframe() + + except Exception as e: + # not a bigquery exception + if not hasattr(e, "_errors"): + output = f"got exception e: {type(e)} {str(e)}" + raise SyftException( + public_message=f"An error occured executing the API call {output}" + ) + + # Treat all errors that we would like to be forwarded to the data scientists + # By default, any exception is only visible to the data owner. + + if e._errors[0]["reason"] in [ + "badRequest", + "blocked", + "duplicate", + "invalidQuery", + "invalid", + "jobBackendError", + "jobInternalError", + "notFound", + "notImplemented", + "rateLimitExceeded", + "resourceInUse", + "resourcesExceeded", + "tableUnavailable", + "timeout", + ]: + raise SyftException( + public_message="Error occured during the call: " + + e._errors[0]["message"] + ) + else: + raise SyftException( + public_message="An error occured executing the API call, please contact the domain owner." + ) + + return mock + + +def make_private(settings): + @sy.api_endpoint_method(settings=settings) + def private( + context, + sql_query: str, + ) -> str: + # third party + from google.cloud import bigquery # noqa: F811 + from google.oauth2 import service_account + + # syft absolute + from syft import SyftException + + # Auth for Bigquer based on the workload identity + credentials = service_account.Credentials.from_service_account_info( + context.settings["credentials"] + ) + scoped_credentials = credentials.with_scopes( + ["https://www.googleapis.com/auth/cloud-platform"] + ) + + client = bigquery.Client( + credentials=scoped_credentials, + location=context.settings["region"], + ) + + try: + rows = client.query_and_wait( + sql_query, + project=context.settings["project_id"], + ) + + if rows.total_rows > 1_000_000: + raise SyftException( + public_message="Please only write queries that gather aggregate statistics" + ) + + return rows.to_dataframe() + except Exception as e: + # We MUST handle the errors that we want to be visible to the data owners. + # Any exception not catched is visible only to the data owner. + # not a bigquery exception + if not hasattr(e, "_errors"): + output = f"got exception e: {type(e)} {str(e)}" + raise SyftException( + public_message=f"An error occured executing the API call {output}" + ) + + if e._errors[0]["reason"] in [ + "badRequest", + "blocked", + "duplicate", + "invalidQuery", + "invalid", + "jobBackendError", + "jobInternalError", + "notFound", + "notImplemented", + "rateLimitExceeded", + "resourceInUse", + "resourcesExceeded", + "tableUnavailable", + "timeout", + ]: + raise SyftException( + public_message="Error occured during the call: " + + e._errors[0]["message"] + ) + else: + raise SyftException( + public_message="An error occured executing the API call, please contact the domain owner." + ) + + return private diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index e9cf56f12a6..3e4797c6bee 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -121,14 +121,14 @@ test_plugins = pytest-sugar pytest-lazy-fixture pytest-rerunfailures - pytest-asyncio - pytest-timeout - anyio - unsync coverage faker distro dynaconf + ; pytest-asyncio + ; pytest-timeout + ; anyio + ; unsync [options.entry_points] console_scripts = diff --git a/packages/syft/src/syft/util/util.py b/packages/syft/src/syft/util/util.py index 8fe4132ee0f..bc592c00f1a 100644 --- a/packages/syft/src/syft/util/util.py +++ b/packages/syft/src/syft/util/util.py @@ -39,7 +39,6 @@ import types from types import ModuleType from typing import Any -from unittest.mock import Mock # third party from IPython.display import display @@ -1169,47 +1168,3 @@ def repr_truncation(obj: Any, max_elements: int = 10) -> str: r.maxother = 100 # For other objects return r.repr(obj) - - -class MockBigQueryError(Exception): - def __init__(self, errors: list) -> None: - self._errors = errors - super().__init__(self._errors[0]["message"]) - - -class MockBigQueryClient: - def __init__(self, credentials: dict, location: str | None = None) -> None: - self.credentials = credentials - self.location = location - - def query_and_wait(self, sql_query: str, project: str | None = None) -> Mock | None: - if self.credentials["mock_result"] == "timeout": - raise TimeoutError("Simulated query timeout.") - - if self.credentials["mock_result"] == "success": - # Simulate a successful response - rows = Mock() - rows.total_rows = 1 # or any number within acceptable limits - rows.to_dataframe = Mock(return_value="Simulated DataFrame") - return rows - - if self.credentials["mock_result"] == "bigquery_error": - errors = [ - { - "reason": "Simulated BigQuery error.", - "message": "Simulated BigQuery error.", - } - ] - raise MockBigQueryError(errors) - - raise Exception("Simulated non-BigQuery exception.") - - -class MockBigQuery: - @staticmethod - def mock_credentials(mock_result: str = "success") -> dict: - return {"mocked": "credentials", "mock_result": mock_result} - - @staticmethod - def Client(credentials: dict, location: str | None = None) -> MockBigQueryClient: - return MockBigQueryClient(credentials, location) From eba7205bb213014154a49576d100259ba1096ba0 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Wed, 28 Aug 2024 16:31:00 +1000 Subject: [PATCH 12/13] Made bigquery scenario notebooks run without live server --- .github/file-filters.yml | 14 + .github/workflows/pr-tests-syft.yml | 87 ++++++ .../bigquery/01-setup-datasite.ipynb | 9 +- .../scenarios/bigquery/02-configure-api.ipynb | 89 +++--- .../bigquery/03-ds-submit-request.ipynb | 24 +- .../bigquery/{apis/bigquery => }/__init__.py | 0 notebooks/scenarios/bigquery/apis/__init__.py | 23 ++ .../bigquery/apis/bigquery/test_query.py | 176 ------------ .../scenarios/bigquery/apis/live/__init__.py | 0 .../apis/{bigquery => live}/schema.py | 42 ++- .../bigquery/apis/live/test_query.py | 113 ++++++++ .../scenarios/bigquery/apis/mock/__init__.py | 0 .../scenarios/bigquery/apis/mock/data.py | 268 ++++++++++++++++++ .../scenarios/bigquery/apis/mock/schema.py | 52 ++++ .../bigquery/apis/mock/test_query.py | 138 +++++++++ .../{bigquery/helpers.py => rate_limiter.py} | 4 +- .../apis/{bigquery => }/submit_query.py | 0 17 files changed, 788 insertions(+), 251 deletions(-) rename notebooks/scenarios/bigquery/{apis/bigquery => }/__init__.py (100%) create mode 100644 notebooks/scenarios/bigquery/apis/__init__.py delete mode 100644 notebooks/scenarios/bigquery/apis/bigquery/test_query.py create mode 100644 notebooks/scenarios/bigquery/apis/live/__init__.py rename notebooks/scenarios/bigquery/apis/{bigquery => live}/schema.py (72%) create mode 100644 notebooks/scenarios/bigquery/apis/live/test_query.py create mode 100644 notebooks/scenarios/bigquery/apis/mock/__init__.py create mode 100644 notebooks/scenarios/bigquery/apis/mock/data.py create mode 100644 notebooks/scenarios/bigquery/apis/mock/schema.py create mode 100644 notebooks/scenarios/bigquery/apis/mock/test_query.py rename notebooks/scenarios/bigquery/apis/{bigquery/helpers.py => rate_limiter.py} (76%) rename notebooks/scenarios/bigquery/apis/{bigquery => }/submit_query.py (100%) diff --git a/.github/file-filters.yml b/.github/file-filters.yml index be000a84640..0d677dcb4e5 100644 --- a/.github/file-filters.yml +++ b/.github/file-filters.yml @@ -76,3 +76,17 @@ notebooks: - packages/syft/**/*.ini - packages/syft/**/*.sh - packages/syft/**/*.mako + +notebooks_scenario: + - .github/workflows/pr-tests-syft.yml + - notebooks/scenarios/**/*.ipynb + - packages/syft/**/*.py + - packages/syft/**/*.capnp + - packages/syft/**/*.yml + - packages/syft/**/*.cfg + - packages/syft/**/*.dockerfile + - packages/syft/**/*.toml + - packages/syft/**/*.txt + - packages/syft/**/*.ini + - packages/syft/**/*.sh + - packages/syft/**/*.mako diff --git a/.github/workflows/pr-tests-syft.yml b/.github/workflows/pr-tests-syft.yml index e10385d95d1..ec97ecbbd04 100644 --- a/.github/workflows/pr-tests-syft.yml +++ b/.github/workflows/pr-tests-syft.yml @@ -188,6 +188,93 @@ jobs: max_attempts: 3 command: tox -e syft.test.notebook + pr-tests-syft-notebook-scenario: + strategy: + max-parallel: 99 + matrix: + # Disable on windows until its flakyness is reduced. + # os: [ubuntu-latest, macos-latest, windows-latest] + os: [ubuntu-latest, macos-latest] + python-version: ["3.12"] + deployment-type: ["python"] + bump-version: ["False"] + include: + - python-version: "3.11" + os: "ubuntu-latest" + deployment-type: "python" + - python-version: "3.10" + os: "ubuntu-latest" + deployment-type: "python" + - python-version: "3.12" + os: "ubuntu-latest" + deployment-type: "python" + bump-version: "True" + + runs-on: ${{ matrix.os }} + steps: + # - name: Permission to home directory + # if: matrix.os == 'ubuntu-latest' + # run: | + # sudo chown -R $USER:$USER $HOME + - name: "clean .git/config" + if: matrix.os == 'windows-latest' + continue-on-error: true + shell: bash + run: | + echo "deleting ${GITHUB_WORKSPACE}/.git/config" + rm ${GITHUB_WORKSPACE}/.git/config + + - uses: actions/checkout@v4 + + - name: Check for file changes + uses: dorny/paths-filter@v3 + id: changes + with: + base: ${{ github.ref }} + token: ${{ github.token }} + filters: .github/file-filters.yml + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + if: steps.changes.outputs.syft == 'true' || steps.changes.outputs.notebooks_scenario == 'true' + with: + python-version: ${{ matrix.python-version }} + + - name: Install pip packages + if: steps.changes.outputs.syft == 'true' || steps.changes.outputs.notebooks_scenario == 'true' + run: | + python -m pip install --upgrade pip + pip install uv==0.2.17 tox==4.16.0 tox-uv==1.9.0 + uv --version + + - name: Get uv cache dir + id: pip-cache + if: steps.changes.outputs.syft == 'true' || steps.changes.outputs.notebooks_scenario == 'true' + shell: bash + run: | + echo "dir=$(uv cache dir)" >> $GITHUB_OUTPUT + + - name: Load github cache + uses: actions/cache@v4 + if: steps.changes.outputs.syft == 'true' || steps.changes.outputs.notebooks_scenario == 'true' + with: + path: ${{ steps.pip-cache.outputs.dir }} + key: ${{ runner.os }}-uv-py${{ matrix.python-version }}-${{ hashFiles('setup.cfg') }} + restore-keys: | + ${{ runner.os }}-uv-py${{ matrix.python-version }}- + + - name: Run notebook scenario tests + uses: nick-fields/retry@v3 + if: steps.changes.outputs.syft == 'true' || steps.changes.outputs.notebooks_scenario == 'true' + env: + ORCHESTRA_DEPLOYMENT_TYPE: "${{ matrix.deployment-type }}" + TEST_NOTEBOOK_PATHS: "${{ matrix.notebook-paths }}" + BUMP_VERSION: "${{ matrix.bump-version }}" + with: + timeout_seconds: 2400 + max_attempts: 3 + command: tox -e syft.test.notebook.scenario + pr-tests-syft-notebook-single-container: strategy: max-parallel: 99 diff --git a/notebooks/scenarios/bigquery/01-setup-datasite.ipynb b/notebooks/scenarios/bigquery/01-setup-datasite.ipynb index d2feba3536b..9741795e0a8 100644 --- a/notebooks/scenarios/bigquery/01-setup-datasite.ipynb +++ b/notebooks/scenarios/bigquery/01-setup-datasite.ipynb @@ -6,8 +6,6 @@ "metadata": {}, "outputs": [], "source": [ - "# third party\n", - "\n", "# syft absolute\n", "import syft as sy\n", "from syft import test_settings" @@ -256,6 +254,13 @@ "source": [ "# !docker image ls | grep bigquery" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/notebooks/scenarios/bigquery/02-configure-api.ipynb b/notebooks/scenarios/bigquery/02-configure-api.ipynb index 00ff1677110..e32c534465c 100644 --- a/notebooks/scenarios/bigquery/02-configure-api.ipynb +++ b/notebooks/scenarios/bigquery/02-configure-api.ipynb @@ -1,5 +1,20 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# set to use the live APIs\n", + "# import os\n", + "# os.environ[\"TEST_BIGQUERY_APIS_LIVE\"] = \"True\"\n", + "# third party\n", + "from apis import make_schema\n", + "from apis import make_submit_query\n", + "from apis import make_test_query" + ] + }, { "cell_type": "code", "execution_count": null, @@ -91,15 +106,10 @@ "metadata": {}, "outputs": [], "source": [ - "# third party\n", - "from apis.bigquery import test_query\n", - "\n", - "mock_func = test_query.make_mock(\n", + "mock_func = make_test_query(\n", " settings={\n", - " \"credentials\": test_settings.gce_service_account.to_dict(),\n", - " \"region\": test_settings.gce_region,\n", - " \"project_id\": test_settings.gce_project_id,\n", - " \"CALLS_PER_MIN\": 10,\n", + " \"rate_limiter_enabled\": True,\n", + " \"calls_per_min\": 10,\n", " }\n", ")" ] @@ -110,11 +120,9 @@ "metadata": {}, "outputs": [], "source": [ - "private_func = test_query.make_private(\n", + "private_func = make_test_query(\n", " settings={\n", - " \"credentials\": test_settings.gce_service_account.to_dict(),\n", - " \"region\": test_settings.gce_region,\n", - " \"project_id\": test_settings.gce_project_id,\n", + " \"rate_limiter_enabled\": False,\n", " }\n", ")" ] @@ -165,7 +173,12 @@ "metadata": {}, "outputs": [], "source": [ - "# it currently hangs here because the reloaded server in this notebook doesnt recreate the worker consumers" + "dataset_1 = test_settings.get(\"dataset_1\", default=\"dataset_1\")\n", + "dataset_2 = test_settings.get(\"dataset_2\", default=\"dataset_2\")\n", + "table_1 = test_settings.get(\"table_1\", default=\"table_1\")\n", + "table_2 = test_settings.get(\"table_2\", default=\"table_2\")\n", + "table_2_col_id = test_settings.get(\"table_2_col_id\", default=\"table_id\")\n", + "table_2_col_score = test_settings.get(\"table_2_col_score\", default=\"colname\")" ] }, { @@ -176,7 +189,7 @@ "source": [ "# Test mock version\n", "result = high_client.api.services.bigquery.test_query.mock(\n", - " sql_query=f\"SELECT * FROM {test_settings.dataset_1}.{test_settings.table_1} LIMIT 10\"\n", + " sql_query=f\"SELECT * FROM {dataset_1}.{table_1} LIMIT 10\"\n", ")\n", "result" ] @@ -196,18 +209,9 @@ "metadata": {}, "outputs": [], "source": [ - "# third party\n", - "from apis.bigquery import schema\n", - "\n", - "schema_function = schema.make_schema(\n", + "schema_function = make_schema(\n", " settings={\n", - " \"credentials\": test_settings.gce_service_account.to_dict(),\n", - " \"region\": test_settings.gce_region,\n", - " \"project_id\": test_settings.gce_project_id,\n", - " \"dataset_1\": test_settings.dataset_1,\n", - " \"table_1\": test_settings.table_1,\n", - " \"table_2\": test_settings.table_2,\n", - " \"CALLS_PER_MIN\": 5,\n", + " \"calls_per_min\": 5,\n", " },\n", " worker_pool=this_worker_pool_name,\n", ")" @@ -229,10 +233,7 @@ "metadata": {}, "outputs": [], "source": [ - "# third party\n", - "from apis.bigquery import submit_query\n", - "\n", - "submit_query_function = submit_query.make_submit_query(\n", + "submit_query_function = make_submit_query(\n", " settings={}, worker_pool=this_worker_pool_name\n", ")" ] @@ -301,20 +302,11 @@ "source": [ "# Test mock version\n", "result = high_client.api.services.bigquery.test_query.mock(\n", - " sql_query=f\"SELECT * FROM {test_settings.dataset_1}.{test_settings.table_1} LIMIT 10\"\n", + " sql_query=f\"SELECT * FROM {dataset_1}.{table_1} LIMIT 10\"\n", ")\n", "result" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "assert len(result) == 10" - ] - }, { "cell_type": "code", "execution_count": null, @@ -323,7 +315,7 @@ "source": [ "# Test private version\n", "result = high_client.api.services.bigquery.test_query.private(\n", - " sql_query=f\"SELECT * FROM {test_settings.dataset_1}.{test_settings.table_1} LIMIT 10\"\n", + " sql_query=f\"SELECT * FROM {dataset_1}.{table_1} LIMIT 10\"\n", ")\n", "result" ] @@ -337,15 +329,6 @@ "assert len(result) == 10" ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# todo can we clean up the duplicate exception messages?" - ] - }, { "cell_type": "code", "execution_count": null, @@ -353,7 +336,9 @@ "outputs": [], "source": [ "# Test mock version for wrong queries\n", - "with sy.raises(sy.SyftException(public_message=\"*must be qualified with a dataset*\")):\n", + "with sy.raises(\n", + " sy.SyftException(public_message=\"*must be qualified with a dataset*\"), show=True\n", + "):\n", " _ = high_client.api.services.bigquery.test_query.mock(\n", " sql_query=\"SELECT * FROM invalid_table LIMIT 1\"\n", " )" @@ -367,7 +352,7 @@ "source": [ "# Test private version\n", "result = high_client.api.services.bigquery.test_query.private(\n", - " sql_query=f\"SELECT * FROM {test_settings.dataset_1}.{test_settings.table_1} LIMIT 1\"\n", + " sql_query=f\"SELECT * FROM {dataset_1}.{table_1} LIMIT 1\"\n", ")\n", "result" ] @@ -439,7 +424,7 @@ "# Testing submit query\n", "result = high_client.api.services.bigquery.submit_query(\n", " func_name=\"my_func\",\n", - " query=f\"SELECT * FROM {test_settings.dataset_1}.{test_settings.table_1} LIMIT 1\",\n", + " query=f\"SELECT * FROM {dataset_1}.{table_1} LIMIT 1\",\n", ")" ] }, diff --git a/notebooks/scenarios/bigquery/03-ds-submit-request.ipynb b/notebooks/scenarios/bigquery/03-ds-submit-request.ipynb index 81b0d3b0d4a..3918f353dba 100644 --- a/notebooks/scenarios/bigquery/03-ds-submit-request.ipynb +++ b/notebooks/scenarios/bigquery/03-ds-submit-request.ipynb @@ -58,6 +58,20 @@ "high_client.api.services.bigquery.test_query" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "dataset_1 = test_settings.get(\"dataset_1\", default=\"dataset_1\")\n", + "dataset_2 = test_settings.get(\"dataset_2\", default=\"dataset_2\")\n", + "table_1 = test_settings.get(\"table_1\", default=\"table_1\")\n", + "table_2 = test_settings.get(\"table_2\", default=\"table_2\")\n", + "table_2_col_id = test_settings.get(\"table_2_col_id\", default=\"table_id\")\n", + "table_2_col_score = test_settings.get(\"table_2_col_score\", default=\"colname\")" + ] + }, { "cell_type": "code", "execution_count": null, @@ -65,9 +79,9 @@ "outputs": [], "source": [ "FUNC_NAME = \"popular\"\n", - "QUERY = f\"SELECT {test_settings.table_2_col_id}, AVG({test_settings.table_2_col_score}) AS average_score \\\n", - " FROM {test_settings.dataset_2}.{test_settings.table_2} \\\n", - " GROUP BY {test_settings.table_2_col_id} \\\n", + "QUERY = f\"SELECT {table_2_col_id}, AVG({table_2_col_score}) AS average_score \\\n", + " FROM {dataset_2}.{table_2} \\\n", + " GROUP BY {table_2_col_id} \\\n", " LIMIT 10000\"\n", "\n", "result = high_client.api.services.bigquery.test_query(sql_query=QUERY)" @@ -158,9 +172,7 @@ "outputs": [], "source": [ "FUNC_NAME = \"large_sample\"\n", - "LARGE_SAMPLE_QUERY = (\n", - " f\"SELECT * FROM {test_settings.dataset_2}.{test_settings.table_2} LIMIT 10000\"\n", - ")" + "LARGE_SAMPLE_QUERY = f\"SELECT * FROM {dataset_2}.{table_2} LIMIT 10000\"" ] }, { diff --git a/notebooks/scenarios/bigquery/apis/bigquery/__init__.py b/notebooks/scenarios/bigquery/__init__.py similarity index 100% rename from notebooks/scenarios/bigquery/apis/bigquery/__init__.py rename to notebooks/scenarios/bigquery/__init__.py diff --git a/notebooks/scenarios/bigquery/apis/__init__.py b/notebooks/scenarios/bigquery/apis/__init__.py new file mode 100644 index 00000000000..7231b580696 --- /dev/null +++ b/notebooks/scenarios/bigquery/apis/__init__.py @@ -0,0 +1,23 @@ +# stdlib +import os + +# syft absolute +from syft.util.util import str_to_bool + +# relative +from .submit_query import make_submit_query + +env_var = "TEST_BIGQUERY_APIS_LIVE" +use_live = str_to_bool(str(os.environ.get(env_var, "False"))) +env_name = "Live" if use_live else "Mock" +print(f"Using {env_name} API Code, this will query BigQuery. ${env_var}=={use_live}") + + +if use_live: + # relative + from .live.schema import make_schema + from .live.test_query import make_test_query +else: + # relative + from .mock.schema import make_schema + from .mock.test_query import make_test_query diff --git a/notebooks/scenarios/bigquery/apis/bigquery/test_query.py b/notebooks/scenarios/bigquery/apis/bigquery/test_query.py deleted file mode 100644 index 4ce3b49d107..00000000000 --- a/notebooks/scenarios/bigquery/apis/bigquery/test_query.py +++ /dev/null @@ -1,176 +0,0 @@ -# syft absolute -import syft as sy - -# relative -from .helpers import is_within_rate_limit - - -def make_mock(settings): - @sy.api_endpoint_method( - settings=settings, - helper_functions=[is_within_rate_limit], - ) - def mock( - context, - sql_query: str, - ) -> str: - # stdlib - import datetime - - # third party - from google.cloud import bigquery # noqa: F811 - from google.oauth2 import service_account - - # syft absolute - from syft import SyftException - - # Auth for Bigquer based on the workload identity - credentials = service_account.Credentials.from_service_account_info( - context.settings["credentials"] - ) - scoped_credentials = credentials.with_scopes( - ["https://www.googleapis.com/auth/cloud-platform"] - ) - - client = bigquery.Client( - credentials=scoped_credentials, - location=context.settings["region"], - ) - - # Store a dict with the calltimes for each user, via the email. - if context.user.email not in context.state.keys(): - context.state[context.user.email] = [] - - if not context.code.is_within_rate_limit(context): - raise SyftException( - public_message="Rate limit of calls per minute has been reached." - ) - - try: - context.state[context.user.email].append(datetime.datetime.now()) - - rows = client.query_and_wait( - sql_query, - project=context.settings["project_id"], - ) - - if rows.total_rows > 1_000_000: - raise SyftException( - public_message="Please only write queries that gather aggregate statistics" - ) - - return rows.to_dataframe() - - except Exception as e: - # not a bigquery exception - if not hasattr(e, "_errors"): - output = f"got exception e: {type(e)} {str(e)}" - raise SyftException( - public_message=f"An error occured executing the API call {output}" - ) - - # Treat all errors that we would like to be forwarded to the data scientists - # By default, any exception is only visible to the data owner. - - if e._errors[0]["reason"] in [ - "badRequest", - "blocked", - "duplicate", - "invalidQuery", - "invalid", - "jobBackendError", - "jobInternalError", - "notFound", - "notImplemented", - "rateLimitExceeded", - "resourceInUse", - "resourcesExceeded", - "tableUnavailable", - "timeout", - ]: - raise SyftException( - public_message="Error occured during the call: " - + e._errors[0]["message"] - ) - else: - raise SyftException( - public_message="An error occured executing the API call, please contact the domain owner." - ) - - return mock - - -def make_private(settings): - @sy.api_endpoint_method(settings=settings) - def private( - context, - sql_query: str, - ) -> str: - # third party - from google.cloud import bigquery # noqa: F811 - from google.oauth2 import service_account - - # syft absolute - from syft import SyftException - - # Auth for Bigquer based on the workload identity - credentials = service_account.Credentials.from_service_account_info( - context.settings["credentials"] - ) - scoped_credentials = credentials.with_scopes( - ["https://www.googleapis.com/auth/cloud-platform"] - ) - - client = bigquery.Client( - credentials=scoped_credentials, - location=context.settings["region"], - ) - - try: - rows = client.query_and_wait( - sql_query, - project=context.settings["project_id"], - ) - - if rows.total_rows > 1_000_000: - raise SyftException( - public_message="Please only write queries that gather aggregate statistics" - ) - - return rows.to_dataframe() - except Exception as e: - # We MUST handle the errors that we want to be visible to the data owners. - # Any exception not catched is visible only to the data owner. - # not a bigquery exception - if not hasattr(e, "_errors"): - output = f"got exception e: {type(e)} {str(e)}" - raise SyftException( - public_message=f"An error occured executing the API call {output}" - ) - - if e._errors[0]["reason"] in [ - "badRequest", - "blocked", - "duplicate", - "invalidQuery", - "invalid", - "jobBackendError", - "jobInternalError", - "notFound", - "notImplemented", - "rateLimitExceeded", - "resourceInUse", - "resourcesExceeded", - "tableUnavailable", - "timeout", - ]: - raise SyftException( - public_message="Error occured during the call: " - + e._errors[0]["message"] - ) - else: - raise SyftException( - public_message="An error occured executing the API call, please contact the domain owner." - ) - - return private diff --git a/notebooks/scenarios/bigquery/apis/live/__init__.py b/notebooks/scenarios/bigquery/apis/live/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/notebooks/scenarios/bigquery/apis/bigquery/schema.py b/notebooks/scenarios/bigquery/apis/live/schema.py similarity index 72% rename from notebooks/scenarios/bigquery/apis/bigquery/schema.py rename to notebooks/scenarios/bigquery/apis/live/schema.py index 7958244b38c..17223a9f2bb 100644 --- a/notebooks/scenarios/bigquery/apis/bigquery/schema.py +++ b/notebooks/scenarios/bigquery/apis/live/schema.py @@ -1,21 +1,36 @@ +# stdlib +from collections.abc import Callable + # syft absolute import syft as sy +from syft import test_settings # relative -from .helpers import is_within_rate_limit +from ..rate_limiter import is_within_rate_limit + +def make_schema(settings: dict, worker_pool: str) -> Callable: + updated_settings = { + "calls_per_min": 5, + "rate_limiter_enabled": True, + "credentials": test_settings.gce_service_account.to_dict(), + "region": test_settings.gce_region, + "project_id": test_settings.gce_project_id, + "dataset_1": test_settings.dataset_1, + "table_1": test_settings.table_1, + "table_2": test_settings.table_2, + } | settings -def make_schema(settings, worker_pool): @sy.api_endpoint( path="bigquery.schema", description="This endpoint allows for visualising the metadata of tables available in BigQuery.", - settings=settings, + settings=updated_settings, helper_functions=[ is_within_rate_limit ], # Adds ratelimit as this is also a method available to data scientists worker_pool=worker_pool, ) - def schema( + def live_schema( context, ) -> str: # stdlib @@ -42,17 +57,18 @@ def schema( location=context.settings["region"], ) - if context.user.email not in context.state.keys(): - context.state[context.user.email] = [] - - if not context.code.is_within_rate_limit(context): - raise SyftException( - public_message="Rate limit of calls per minute has been reached." - ) + # Store a dict with the calltimes for each user, via the email. + if context.settings["rate_limiter_enabled"]: + if context.user.email not in context.state.keys(): + context.state[context.user.email] = [] - try: + if not context.code.is_within_rate_limit(context): + raise SyftException( + public_message="Rate limit of calls per minute has been reached." + ) context.state[context.user.email].append(datetime.datetime.now()) + try: # Formats the data schema in a data frame format # Warning: the only supported format types are primitives, np.ndarrays and pd.DataFrames @@ -89,4 +105,4 @@ def schema( public_message="An error occured executing the API call, please contact the domain owner." ) - return schema + return live_schema diff --git a/notebooks/scenarios/bigquery/apis/live/test_query.py b/notebooks/scenarios/bigquery/apis/live/test_query.py new file mode 100644 index 00000000000..344879dcb62 --- /dev/null +++ b/notebooks/scenarios/bigquery/apis/live/test_query.py @@ -0,0 +1,113 @@ +# stdlib +from collections.abc import Callable + +# syft absolute +import syft as sy +from syft import test_settings + +# relative +from ..rate_limiter import is_within_rate_limit + + +def make_test_query(settings) -> Callable: + updated_settings = { + "calls_per_min": 10, + "rate_limiter_enabled": True, + "credentials": test_settings.gce_service_account.to_dict(), + "region": test_settings.gce_region, + "project_id": test_settings.gce_project_id, + } | settings + + # these are the same if you allow the rate limiter to be turned on and off + @sy.api_endpoint_method( + settings=updated_settings, + helper_functions=[is_within_rate_limit], + ) + def live_test_query( + context, + sql_query: str, + ) -> str: + # stdlib + import datetime + + # third party + from google.cloud import bigquery # noqa: F811 + from google.oauth2 import service_account + + # syft absolute + from syft import SyftException + + # Auth for Bigquer based on the workload identity + credentials = service_account.Credentials.from_service_account_info( + context.settings["credentials"] + ) + scoped_credentials = credentials.with_scopes( + ["https://www.googleapis.com/auth/cloud-platform"] + ) + + client = bigquery.Client( + credentials=scoped_credentials, + location=context.settings["region"], + ) + + # Store a dict with the calltimes for each user, via the email. + if context.settings["rate_limiter_enabled"]: + if context.user.email not in context.state.keys(): + context.state[context.user.email] = [] + + if not context.code.is_within_rate_limit(context): + raise SyftException( + public_message="Rate limit of calls per minute has been reached." + ) + context.state[context.user.email].append(datetime.datetime.now()) + + try: + rows = client.query_and_wait( + sql_query, + project=context.settings["project_id"], + ) + + if rows.total_rows > 1_000_000: + raise SyftException( + public_message="Please only write queries that gather aggregate statistics" + ) + + return rows.to_dataframe() + + except Exception as e: + # not a bigquery exception + if not hasattr(e, "_errors"): + output = f"got exception e: {type(e)} {str(e)}" + raise SyftException( + public_message=f"An error occured executing the API call {output}" + ) + + # Treat all errors that we would like to be forwarded to the data scientists + # By default, any exception is only visible to the data owner. + + if e._errors[0]["reason"] in [ + "badRequest", + "blocked", + "duplicate", + "invalidQuery", + "invalid", + "jobBackendError", + "jobInternalError", + "notFound", + "notImplemented", + "rateLimitExceeded", + "resourceInUse", + "resourcesExceeded", + "tableUnavailable", + "timeout", + ]: + raise SyftException( + public_message="Error occured during the call: " + + e._errors[0]["message"] + ) + else: + raise SyftException( + public_message="An error occured executing the API call, please contact the domain owner." + ) + + return live_test_query diff --git a/notebooks/scenarios/bigquery/apis/mock/__init__.py b/notebooks/scenarios/bigquery/apis/mock/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/notebooks/scenarios/bigquery/apis/mock/data.py b/notebooks/scenarios/bigquery/apis/mock/data.py new file mode 100644 index 00000000000..61229ff8a7c --- /dev/null +++ b/notebooks/scenarios/bigquery/apis/mock/data.py @@ -0,0 +1,268 @@ +# stdlib +from math import nan + +schema_dict = { + "project": { + 0: "reddit-testing-415005", + 1: "reddit-testing-415005", + 2: "reddit-testing-415005", + 3: "reddit-testing-415005", + 4: "reddit-testing-415005", + 5: "reddit-testing-415005", + 6: "reddit-testing-415005", + 7: "reddit-testing-415005", + 8: "reddit-testing-415005", + 9: "reddit-testing-415005", + 10: "reddit-testing-415005", + 11: "reddit-testing-415005", + 12: "reddit-testing-415005", + 13: "reddit-testing-415005", + 14: "reddit-testing-415005", + 15: "reddit-testing-415005", + 16: "reddit-testing-415005", + 17: "reddit-testing-415005", + 18: "reddit-testing-415005", + 19: "reddit-testing-415005", + 20: "reddit-testing-415005", + 21: "reddit-testing-415005", + 22: "reddit-testing-415005", + }, + "dataset_id": { + 0: "test_1gb", + 1: "test_1gb", + 2: "test_1gb", + 3: "test_1gb", + 4: "test_1gb", + 5: "test_1gb", + 6: "test_1gb", + 7: "test_1gb", + 8: "test_1gb", + 9: "test_1gb", + 10: "test_1gb", + 11: "test_1gb", + 12: "test_1gb", + 13: "test_1gb", + 14: "test_1gb", + 15: "test_1gb", + 16: "test_1gb", + 17: "test_1gb", + 18: "test_1gb", + 19: "test_1gb", + 20: "test_1gb", + 21: "test_1gb", + 22: "test_1gb", + }, + "table_id": { + 0: "subreddits", + 1: "subreddits", + 2: "subreddits", + 3: "subreddits", + 4: "subreddits", + 5: "subreddits", + 6: "subreddits", + 7: "comments", + 8: "comments", + 9: "comments", + 10: "comments", + 11: "comments", + 12: "comments", + 13: "comments", + 14: "comments", + 15: "comments", + 16: "comments", + 17: "comments", + 18: "comments", + 19: "comments", + 20: "comments", + 21: "comments", + 22: "comments", + }, + "schema_name": { + 0: "int64_field_0", + 1: "id", + 2: "name", + 3: "subscribers_count", + 4: "permalink", + 5: "nsfw", + 6: "spam", + 7: "int64_field_0", + 8: "id", + 9: "body", + 10: "parent_id", + 11: "created_at", + 12: "last_modified_at", + 13: "gilded", + 14: "permalink", + 15: "score", + 16: "subreddit_id", + 17: "post_id", + 18: "author_id", + 19: "spam", + 20: "deleted", + 21: "upvote_raio", + 22: "collapsed_in_crowd_control", + }, + "schema_field": { + 0: "INTEGER", + 1: "STRING", + 2: "STRING", + 3: "INTEGER", + 4: "STRING", + 5: "FLOAT", + 6: "BOOLEAN", + 7: "INTEGER", + 8: "STRING", + 9: "STRING", + 10: "STRING", + 11: "INTEGER", + 12: "INTEGER", + 13: "BOOLEAN", + 14: "STRING", + 15: "INTEGER", + 16: "STRING", + 17: "STRING", + 18: "STRING", + 19: "BOOLEAN", + 20: "BOOLEAN", + 21: "FLOAT", + 22: "BOOLEAN", + }, + "description": { + 0: "None", + 1: "None", + 2: "None", + 3: "None", + 4: "None", + 5: "None", + 6: "None", + 7: "None", + 8: "None", + 9: "None", + 10: "None", + 11: "None", + 12: "None", + 13: "None", + 14: "None", + 15: "None", + 16: "None", + 17: "None", + 18: "None", + 19: "None", + 20: "None", + 21: "None", + 22: "None", + }, + "num_rows": { + 0: "2000000", + 1: "2000000", + 2: "2000000", + 3: "2000000", + 4: "2000000", + 5: "2000000", + 6: "2000000", + 7: "2000000", + 8: "2000000", + 9: "2000000", + 10: "2000000", + 11: "2000000", + 12: "2000000", + 13: "2000000", + 14: "2000000", + 15: "2000000", + 16: "2000000", + 17: "2000000", + 18: "2000000", + 19: "2000000", + 20: "2000000", + 21: "2000000", + 22: "2000000", + }, +} + + +query_dict = { + "int64_field_0": { + 0: 4, + 1: 5, + 2: 10, + 3: 16, + 4: 17, + 5: 23, + 6: 24, + 7: 25, + 8: 27, + 9: 40, + }, + "id": { + 0: "t5_via1x", + 1: "t5_cv9gn", + 2: "t5_8p2tq", + 3: "t5_8fcro", + 4: "t5_td5of", + 5: "t5_z01fv", + 6: "t5_hmqjk", + 7: "t5_1flyj", + 8: "t5_5rwej", + 9: "t5_uurcv", + }, + "name": { + 0: "/r/mylittlepony", + 1: "/r/polyamory", + 2: "/r/Catholicism", + 3: "/r/cordcutters", + 4: "/r/stevenuniverse", + 5: "/r/entitledbitch", + 6: "/r/engineering", + 7: "/r/nottheonion", + 8: "/r/FoodPorn", + 9: "/r/puppysmiles", + }, + "subscribers_count": { + 0: 4323081, + 1: 2425929, + 2: 4062607, + 3: 7543226, + 4: 2692168, + 5: 2709080, + 6: 8766144, + 7: 2580984, + 8: 7784809, + 9: 3715991, + }, + "permalink": { + 0: "/r//r/mylittlepony", + 1: "/r//r/polyamory", + 2: "/r//r/Catholicism", + 3: "/r//r/cordcutters", + 4: "/r//r/stevenuniverse", + 5: "/r//r/entitledbitch", + 6: "/r//r/engineering", + 7: "/r//r/nottheonion", + 8: "/r//r/FoodPorn", + 9: "/r//r/puppysmiles", + }, + "nsfw": { + 0: nan, + 1: nan, + 2: nan, + 3: nan, + 4: nan, + 5: nan, + 6: nan, + 7: nan, + 8: nan, + 9: nan, + }, + "spam": { + 0: False, + 1: False, + 2: False, + 3: False, + 4: False, + 5: False, + 6: False, + 7: False, + 8: False, + 9: False, + }, +} diff --git a/notebooks/scenarios/bigquery/apis/mock/schema.py b/notebooks/scenarios/bigquery/apis/mock/schema.py new file mode 100644 index 00000000000..a95e04f2f1d --- /dev/null +++ b/notebooks/scenarios/bigquery/apis/mock/schema.py @@ -0,0 +1,52 @@ +# stdlib +from collections.abc import Callable + +# syft absolute +import syft as sy + +# relative +from ..rate_limiter import is_within_rate_limit +from .data import schema_dict + + +def make_schema(settings, worker_pool) -> Callable: + updated_settings = { + "calls_per_min": 5, + "rate_limiter_enabled": True, + "schema_dict": schema_dict, + } | settings + + @sy.api_endpoint( + path="bigquery.schema", + description="This endpoint allows for visualising the metadata of tables available in BigQuery.", + settings=updated_settings, + helper_functions=[is_within_rate_limit], + worker_pool=worker_pool, + ) + def mock_schema( + context, + ) -> str: + # syft absolute + from syft import SyftException + + # Store a dict with the calltimes for each user, via the email. + if context.settings["rate_limiter_enabled"]: + # stdlib + import datetime + + if context.user.email not in context.state.keys(): + context.state[context.user.email] = [] + + if not context.code.is_within_rate_limit(context): + raise SyftException( + public_message="Rate limit of calls per minute has been reached." + ) + context.state[context.user.email].append(datetime.datetime.now()) + + # third party + import pandas as pd + + df = pd.DataFrame(context.settings["schema_dict"]) + return df + + return mock_schema diff --git a/notebooks/scenarios/bigquery/apis/mock/test_query.py b/notebooks/scenarios/bigquery/apis/mock/test_query.py new file mode 100644 index 00000000000..ae028a8cf36 --- /dev/null +++ b/notebooks/scenarios/bigquery/apis/mock/test_query.py @@ -0,0 +1,138 @@ +# stdlib +from collections.abc import Callable + +# syft absolute +import syft as sy + +# relative +from ..rate_limiter import is_within_rate_limit +from .data import query_dict + + +def extract_limit_value(sql_query: str) -> int: + # stdlib + import re + + limit_pattern = re.compile(r"\bLIMIT\s+(\d+)\b", re.IGNORECASE) + match = limit_pattern.search(sql_query) + if match: + return int(match.group(1)) + return None + + +def is_valid_sql(query: str) -> bool: + # stdlib + import sqlite3 + + # Prepare an in-memory SQLite database + conn = sqlite3.connect(":memory:") + cursor = conn.cursor() + + try: + # Use the EXPLAIN QUERY PLAN command to get the query plan + cursor.execute(f"EXPLAIN QUERY PLAN {query}") + except sqlite3.Error as e: + if "no such table" in str(e).lower(): + return True + return False + finally: + conn.close() + + +def adjust_dataframe_rows(df, target_rows: int): + # third party + import pandas as pd + + current_rows = len(df) + + if target_rows > current_rows: + # Repeat rows to match target_rows + repeat_times = (target_rows + current_rows - 1) // current_rows + df_expanded = pd.concat([df] * repeat_times, ignore_index=True).head( + target_rows + ) + else: + # Truncate rows to match target_rows + df_expanded = df.head(target_rows) + + return df_expanded + + +def make_test_query(settings: dict) -> Callable: + updated_settings = { + "calls_per_min": 10, + "rate_limiter_enabled": True, + "query_dict": query_dict, + } | settings + + # these are the same if you allow the rate limiter to be turned on and off + @sy.api_endpoint_method( + settings=updated_settings, + helper_functions=[ + is_within_rate_limit, + extract_limit_value, + is_valid_sql, + adjust_dataframe_rows, + ], + ) + def mock_test_query( + context, + sql_query: str, + ) -> str: + # stdlib + import datetime + + # third party + from google.api_core.exceptions import BadRequest + + # syft absolute + from syft import SyftException + + # Store a dict with the calltimes for each user, via the email. + if context.settings["rate_limiter_enabled"]: + if context.user.email not in context.state.keys(): + context.state[context.user.email] = [] + + if not context.code.is_within_rate_limit(context): + raise SyftException( + public_message="Rate limit of calls per minute has been reached." + ) + context.state[context.user.email].append(datetime.datetime.now()) + + bad_table = "invalid_table" + bad_post = ( + "BadRequest: 400 POST " + "https://bigquery.googleapis.com/bigquery/v2/projects/project-id/" + "queries?prettyPrint=false: " + ) + if bad_table in sql_query: + try: + raise BadRequest( + f'{bad_post} Table "{bad_table}" must be qualified ' + "with a dataset (e.g. dataset.table)." + ) + except Exception as e: + raise SyftException( + public_message=f"*must be qualified with a dataset*. {e}" + ) + + if not context.code.is_valid_sql(sql_query): + raise BadRequest( + f'{bad_post} Syntax error: Unexpected identifier "{sql_query}" at [1:1]' + ) + + # third party + import pandas as pd + + limit = context.code.extract_limit_value(sql_query) + if limit > 1_000_000: + raise SyftException( + public_message="Please only write queries that gather aggregate statistics" + ) + + base_df = pd.DataFrame(context.settings["query_dict"]) + + df = context.code.adjust_dataframe_rows(base_df, limit) + return df + + return mock_test_query diff --git a/notebooks/scenarios/bigquery/apis/bigquery/helpers.py b/notebooks/scenarios/bigquery/apis/rate_limiter.py similarity index 76% rename from notebooks/scenarios/bigquery/apis/bigquery/helpers.py rename to notebooks/scenarios/bigquery/apis/rate_limiter.py index c22b6f507a4..8ce319b61f4 100644 --- a/notebooks/scenarios/bigquery/apis/bigquery/helpers.py +++ b/notebooks/scenarios/bigquery/apis/rate_limiter.py @@ -1,4 +1,4 @@ -def is_within_rate_limit(context): +def is_within_rate_limit(context) -> bool: """Rate limiter for custom API calls made by users.""" # stdlib import datetime @@ -13,4 +13,4 @@ def is_within_rate_limit(context): for call_time in state[email] ] - return sum(calls_last_min) < settings["CALLS_PER_MIN"] + return sum(calls_last_min) < settings.get("calls_per_min", 5) diff --git a/notebooks/scenarios/bigquery/apis/bigquery/submit_query.py b/notebooks/scenarios/bigquery/apis/submit_query.py similarity index 100% rename from notebooks/scenarios/bigquery/apis/bigquery/submit_query.py rename to notebooks/scenarios/bigquery/apis/submit_query.py From 213204b9881078bed090e7bb20c823e777fc58c9 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Wed, 28 Aug 2024 16:39:54 +1000 Subject: [PATCH 13/13] fixing running scenario notebooks --- .github/workflows/pr-tests-syft.yml | 1 - 1 file changed, 1 deletion(-) diff --git a/.github/workflows/pr-tests-syft.yml b/.github/workflows/pr-tests-syft.yml index ec97ecbbd04..3c56d9acc43 100644 --- a/.github/workflows/pr-tests-syft.yml +++ b/.github/workflows/pr-tests-syft.yml @@ -268,7 +268,6 @@ jobs: if: steps.changes.outputs.syft == 'true' || steps.changes.outputs.notebooks_scenario == 'true' env: ORCHESTRA_DEPLOYMENT_TYPE: "${{ matrix.deployment-type }}" - TEST_NOTEBOOK_PATHS: "${{ matrix.notebook-paths }}" BUMP_VERSION: "${{ matrix.bump-version }}" with: timeout_seconds: 2400