From 67775cef9d653baa7eba5bb073fe88ebd4fcad84 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 14 Aug 2024 10:11:54 +0200 Subject: [PATCH 001/197] add base stash and update job stash --- .../syft/src/syft/service/job/base_stash.py | 305 ++++++++++++++++++ .../syft/src/syft/service/job/job_stash.py | 80 +---- packages/syft/src/syft/types/uid.py | 4 + 3 files changed, 321 insertions(+), 68 deletions(-) create mode 100644 packages/syft/src/syft/service/job/base_stash.py diff --git a/packages/syft/src/syft/service/job/base_stash.py b/packages/syft/src/syft/service/job/base_stash.py new file mode 100644 index 00000000000..9d72d543330 --- /dev/null +++ b/packages/syft/src/syft/service/job/base_stash.py @@ -0,0 +1,305 @@ +# stdlib + +# stdlib +import base64 +import json +import threading +import typing +from typing import Any, ClassVar +from typing import Generic +import uuid +from uuid import UUID + +# third party +import pydantic +from result import Ok +from result import Result +import sqlalchemy as sa +from sqlalchemy import Column +from sqlalchemy import Row +from sqlalchemy import Table +from sqlalchemy import TypeDecorator +from sqlalchemy import create_engine +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import Session +from sqlalchemy.orm import mapped_column +from sqlalchemy.orm import sessionmaker +from sqlalchemy.types import JSON +from typing_extensions import TypeVar + +# syft absolute +import syft as sy + +# relative +from ...server.credentials import SyftSigningKey +from ...server.credentials import SyftVerifyKey +from ...store.document_store import DocumentStore +from ...store.document_store import PartitionKey +from ...store.linked_obj import LinkedObject +from ...types.datetime import DateTime +from ...types.syft_object import SyftObject +from ...types.uid import UID +from ..action.action_object import Action +from ..action.action_permissions import ActionObjectPermission +from ..response import SyftSuccess + + +class Base(DeclarativeBase): + pass + + +class UIDTypeDecorator(TypeDecorator): + """Converts between Syft UID and UUID.""" + + impl = sa.UUID + cache_ok = True + + def process_bind_param(self, value, dialect): # type: ignore + if value is not None: + return value + + def process_result_value(self, value, dialect): # type: ignore + if value is not None: + return UID(value) + + +class CommonMixin: + id: Mapped[UID] = mapped_column( + default=uuid.uuid4, + primary_key=True, + ) + created_at: Mapped[DateTime] = mapped_column(server_default=sa.func.now()) + + updated_at: Mapped[DateTime] = mapped_column( + server_default=sa.func.now(), + server_onupdate=sa.func.now(), + ) + json_document: Mapped[dict] = mapped_column(JSON, default={}) + + +def model_dump(obj: pydantic.BaseModel) -> dict: + obj_dict = obj.model_dump() + for key, type_ in obj.model_fields.items(): + if type_.annotation is UID: + obj_dict[key] = obj_dict[key].no_dash + elif type_.annotation is SyftVerifyKey: + obj_dict[key] = str(getattr(obj, key)) + elif type_.annotation is SyftSigningKey: + obj_dict[key] = str(getattr(obj, key)) + elif ( + type_.annotation is LinkedObject + or type_.annotation == Any | None + or type_.annotation == Action | None + ): + # not very efficient as it serializes the object twice + data = sy.serialize(getattr(obj, key), to_bytes=True) + base64_data = base64.b64encode(data).decode("utf-8") + obj_dict[key] = base64_data + return obj_dict + + +T = TypeVar("T", bound=pydantic.BaseModel) + + +def model_validate(obj_type: type[T], obj_dict: dict) -> T: + for key, type_ in obj_type.model_fields.items(): + if key not in obj_dict: + continue + # FIXME + if type_.annotation is UID or type_.annotation == UID | None: + obj_dict[key] = UID(obj_dict[key]) + elif type_.annotation is SyftVerifyKey: + obj_dict[key] = SyftVerifyKey.from_string(obj_dict[key]) + elif type_.annotation is SyftSigningKey: + obj_dict[key] = SyftSigningKey.from_string(obj_dict[key]) + elif ( + type_.annotation is LinkedObject + or type_.annotation == Any | None + or type_.annotation == Action | None + ): + data = base64.b64decode(obj_dict[key]) + obj_dict[key] = sy.deserialize(data, from_bytes=True) + + return obj_type.model_validate(obj_dict) + + +def _default_dumps(val): # type: ignore + if isinstance(val, UID): + return str(val.no_dash) + elif isinstance(val, UUID): + return val.hex + # raise TypeError(f"Can't serialize {val}, type {type(val)}") + + +def _default_loads(val): # type: ignore + if "UID" in val: + return UID(val) + return val + + +def dumps(d): + return json.dumps(d, default=_default_dumps) + + +def loads(d): + return json.loads(d, object_hook=_default_loads) + + +class SQLiteDBManager: + def __init__(self, server_uid: str) -> None: + self.server_uid = server_uid + self.path = f"sqlite:////tmp/{server_uid}.db" + self.engine = create_engine( + self.path, json_serializer=dumps, json_deserializer=loads + ) + print(f"Connecting to {self.path}") + self.SessionFactory = sessionmaker(bind=self.engine) + self.thread_local = threading.local() + + Base.metadata.create_all(self.engine) + + def get_session(self) -> Session: + if not hasattr(self.thread_local, "session"): + self.thread_local.session = self.SessionFactory() + return self.thread_local.session + + @property + def session(self) -> Session: + return self.get_session() + + +SyftT = TypeVar("SyftT", bound=SyftObject) + + +class ObjectStash(Generic[SyftT]): + object_type: ClassVar[type[SyftT]] + + def __init__(self, store: DocumentStore) -> None: + self.server_uid = store.server_uid + self.verify_key = store.root_verify_key + # is there a better way to init the table + _ = self.table + self.db = SQLiteDBManager(self.server_uid) + + @property + def session(self) -> Session: + return self.db.session + + @property + def table(self) -> Table: + # need to call Base.metadata.create_all(engine) to create the table + table_name = self.object_type.__canonical_name__ + if table_name not in Base.metadata.tables: + Table( + self.object_type.__canonical_name__, + Base.metadata, + Column("id", UIDTypeDecorator, primary_key=True, default=uuid.uuid4), + Column("created_at", sa.DateTime, server_default=sa.func.now()), + Column( + "updated_at", + sa.DateTime, + server_default=sa.func.now(), + server_onupdate=sa.func.now(), + ), + Column("fields", JSON, default={}), + ) + return Base.metadata.tables[table_name] + + def get_by_uid( + self, credentials: SyftVerifyKey, uid: UID + ) -> Result[SyftT | None, str]: + result = self.session.execute( + self.table.select().where(self.table.c.id == uid) + ).first() + if result is None: + return Ok(None) + return Ok(self.row_as_obj(result)) + + def get_one_by_field( + self, credentials: SyftVerifyKey, field_name: str, field_value: str + ) -> Result[SyftT | None, str]: + result = self.session.execute( + self.table.select().where(self.table.c.fields[field_name] == field_value) + ).first() + if result is None: + return Ok(None) + return Ok(self.row_as_obj(result)) + + def get_all_by_field( + self, credentials: SyftVerifyKey, field_name: str, field_value: str + ) -> Result[list[SyftT], str]: + result = self.session.execute( + self.table.select().where(self.table.c.fields[field_name] == field_value) + ).all() + objs = [self.row_as_obj(row) for row in result] + return Ok(objs) + + def row_as_obj(self, row: Row): + return model_validate(self.object_type, row.fields) + + def get_all( + self, + credentials: SyftVerifyKey, + order_by: PartitionKey | None = None, + has_permission: bool = False, + ) -> Result[list[SyftT], str]: + stmt = self.table.select() + result = self.session.execute(stmt).all() + + objs = [self.row_as_obj(row) for row in result] + return Ok(objs) + + def update( + self, + credentials: SyftVerifyKey, + obj: SyftT, + has_permission: bool = False, + ) -> Result[SyftT, str]: + stmt = ( + self.table.update() + .where(self.table.c.id == obj.id) + .values(fields=model_dump(obj)) + ) + self.session.execute(stmt) + self.session.commit() + return Ok(obj) + + def set( + self, + credentials: SyftVerifyKey, + obj: SyftT, + add_permissions: list[ActionObjectPermission] | None = None, + add_storage_permission: bool = True, + ignore_duplicates: bool = False, + ) -> Result[SyftT, str]: + stmt = self.table.insert().values( + id=obj.id, + fields=model_dump(obj), + ) + self.session.execute(stmt) + self.session.commit() + return Ok(obj) + + def delete_by_uid( + self, credentials: SyftVerifyKey, uid: UID + ) -> Result[SyftSuccess, str]: + stmt = self.table.delete().where(self.table.c.id == uid) + self.session.execute(stmt) + self.session.commit() + return Ok(SyftSuccess()) + + def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: + pass + + def add_permission(self, permission: ActionObjectPermission) -> None: + pass + + def remove_permission(self, permission: ActionObjectPermission) -> None: + pass + + def has_permission(self, permission: ActionObjectPermission) -> bool: + return True + + def has_storage_permission(self, permission) -> bool: + return True diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 7b32b6281cf..cbcb278492e 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -13,7 +13,6 @@ from pydantic import Field from pydantic import model_validator from result import Err -from result import Ok from result import Result from typing_extensions import Self @@ -24,12 +23,8 @@ from ...server.credentials import SyftVerifyKey from ...service.context import AuthedServiceContext from ...service.worker.worker_pool import SyftWorker -from ...store.document_store import BaseUIDStoreStash from ...store.document_store import DocumentStore -from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from ...store.document_store import UIDPartitionKey from ...types.datetime import DateTime from ...types.datetime import format_timedelta from ...types.syft_migration import migrate @@ -50,6 +45,7 @@ from ..response import SyftNotReady from ..response import SyftSuccess from ..user.user import UserView +from .base_stash import ObjectStash from .html_template import job_repr_template @@ -842,15 +838,15 @@ def from_job( @instrument -@serializable(canonical_name="JobStash", version=1) -class JobStash(BaseUIDStoreStash): +@serializable(canonical_name="JobStashSQL", version=1) +class JobStash(ObjectStash[Job]): object_type = Job settings: PartitionSettings = PartitionSettings( name=Job.__canonical_name__, object_type=Job ) def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) + super().__init__(store) def set_result( self, @@ -858,85 +854,33 @@ def set_result( item: Job, add_permissions: list[ActionObjectPermission] | None = None, ) -> Result[Job | None, str]: - valid = self.check_type(item, self.object_type) - if valid.is_err(): - return SyftError(message=valid.err()) - - # Ensure we never save cached result data in the database, - # as they can be arbitrarily large if ( isinstance(item.result, ActionObject) and item.result.syft_blob_storage_entry_id is not None ): item.result._clear_cache() - return super().update(credentials, item, add_permissions) + return self.update(credentials, item, add_permissions) - def get_by_result_id( - self, - credentials: SyftVerifyKey, - result_id: UID, - ) -> Result[Job | None, str]: - qks = QueryKeys( - qks=[PartitionKey(key="result_id", type_=UID).with_obj(result_id)] + def get_active(self, credentials: SyftVerifyKey) -> Result[list[Job], str]: + return self.get_all_by_field( + credentials=credentials, field_name="status", field_value=JobStatus.CREATED ) - res = self.query_all(credentials=credentials, qks=qks) - if res.is_err(): - return res - - res = res.ok() - if len(res) == 0: - return Ok(None) - elif len(res) > 1: - return Err("multiple Jobs found") - else: - return Ok(res[0]) - - def get_by_parent_id( - self, credentials: SyftVerifyKey, uid: UID - ) -> Result[Job | None, str]: - qks = QueryKeys( - qks=[PartitionKey(key="parent_job_id", type_=UID).with_obj(uid)] - ) - item = self.query_all(credentials=credentials, qks=qks) - return item - - def delete_by_uid( - self, credentials: SyftVerifyKey, uid: UID - ) -> Result[SyftSuccess, str]: - qk = UIDPartitionKey.with_obj(uid) - result = super().delete(credentials=credentials, qk=qk) - if result.is_ok(): - return Ok(SyftSuccess(message=f"ID: {uid} deleted")) - return result - - def get_active(self, credentials: SyftVerifyKey) -> Result[SyftSuccess, str]: - qks = QueryKeys( - qks=[ - PartitionKey(key="status", type_=JobStatus).with_obj( - JobStatus.PROCESSING - ) - ] - ) - return self.query_all(credentials=credentials, qks=qks) def get_by_worker( self, credentials: SyftVerifyKey, worker_id: str ) -> Result[list[Job], str]: - qks = QueryKeys( - qks=[PartitionKey(key="job_worker_id", type_=str).with_obj(worker_id)] + return self.get_all_by_field( + credentials=credentials, field_name="worker_id", field_value=worker_id ) - return self.query_all(credentials=credentials, qks=qks) def get_by_user_code_id( self, credentials: SyftVerifyKey, user_code_id: UID ) -> Result[list[Job], str]: - qks = QueryKeys( - qks=[PartitionKey(key="user_code_id", type_=UID).with_obj(user_code_id)] + return self.get_all_by_field( + credentials=credentials, field_name="user_code_id", field_value=user_code_id ) - return self.query_all(credentials=credentials, qks=qks) - @serializable() class JobV1(SyncableSyftObject): diff --git a/packages/syft/src/syft/types/uid.py b/packages/syft/src/syft/types/uid.py index 96bc5af31b8..de364d7b10a 100644 --- a/packages/syft/src/syft/types/uid.py +++ b/packages/syft/src/syft/types/uid.py @@ -154,6 +154,10 @@ def is_valid_uuid(value: Any) -> bool: def no_dash(self) -> str: return str(self.value).replace("-", "") + @property + def hex(self) -> str: + return self.value.hex + def __repr__(self) -> str: """Returns a human-readable version of the ID From b32b49cff348520c018402f901ee897025a15c0f Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 14 Aug 2024 10:19:45 +0200 Subject: [PATCH 002/197] update output stash --- .../syft/src/syft/service/job/base_stash.py | 8 ++-- .../src/syft/service/output/output_service.py | 48 ++++++------------- 2 files changed, 19 insertions(+), 37 deletions(-) diff --git a/packages/syft/src/syft/service/job/base_stash.py b/packages/syft/src/syft/service/job/base_stash.py index 9d72d543330..59146a71fdb 100644 --- a/packages/syft/src/syft/service/job/base_stash.py +++ b/packages/syft/src/syft/service/job/base_stash.py @@ -4,8 +4,8 @@ import base64 import json import threading -import typing -from typing import Any, ClassVar +from typing import Any +from typing import ClassVar from typing import Generic import uuid from uuid import UUID @@ -138,11 +138,11 @@ def _default_loads(val): # type: ignore return val -def dumps(d): +def dumps(d: dict) -> str: return json.dumps(d, default=_default_dumps) -def loads(d): +def loads(d: str) -> dict: return json.loads(d, object_hook=_default_loads) diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index 386aaa926a9..d8a2c75f7ad 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -3,19 +3,15 @@ # third party from pydantic import model_validator -from result import Err -from result import Ok from result import Result # relative from ...client.api import APIRegistry from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime from ...types.syft_object import SYFT_OBJECT_VERSION_1 @@ -25,6 +21,7 @@ from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectREAD from ..context import AuthedServiceContext +from ..job.base_stash import ObjectStash from ..response import SyftError from ..service import AbstractService from ..service import TYPE_TO_SERVICE @@ -190,8 +187,8 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: @instrument -@serializable(canonical_name="OutputStash", version=1) -class OutputStash(BaseUIDStoreStash): +@serializable(canonical_name="OutputStashSQL", version=1) +class OutputStash(ObjectStash[ExecutionOutput]): object_type = ExecutionOutput settings: PartitionSettings = PartitionSettings( name=ExecutionOutput.__canonical_name__, object_type=ExecutionOutput @@ -200,47 +197,32 @@ class OutputStash(BaseUIDStoreStash): def __init__(self, store: DocumentStore) -> None: super().__init__(store) self.store = store - self.settings = self.settings - self._object_type = self.object_type def get_by_user_code_id( self, credentials: SyftVerifyKey, user_code_id: UID ) -> Result[list[ExecutionOutput], str]: - qks = QueryKeys( - qks=[UserCodeIdPartitionKey.with_obj(user_code_id)], - ) - return self.query_all( - credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey + return self.get_all_by_field( + credentials=credentials, + field_name="user_code_id", + field_value=user_code_id, ) def get_by_job_id( self, credentials: SyftVerifyKey, user_code_id: UID ) -> Result[ExecutionOutput | None, str]: - qks = QueryKeys( - qks=[JobIdPartitionKey.with_obj(user_code_id)], - ) - res = self.query_all( - credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey + return self.get_one_by_field( + credentials=credentials, + field_name="job_id", + field_value=user_code_id, ) - if res.is_err(): - return res - else: - res = res.ok() - if len(res) == 0: - return Ok(None) - elif len(res) > 1: - return Err(SyftError(message="Too many outputs found")) - else: - return Ok(res[0]) def get_by_output_policy_id( self, credentials: SyftVerifyKey, output_policy_id: UID ) -> Result[list[ExecutionOutput], str]: - qks = QueryKeys( - qks=[OutputPolicyIdPartitionKey.with_obj(output_policy_id)], - ) - return self.query_all( - credentials=credentials, qks=qks, order_by=CreatedAtPartitionKey + return self.get_all_by_field( + credentials=credentials, + field_name="output_policy_id", + field_value=output_policy_id, ) From 1d6a14f0d7ff3e14512955af8195aca365664d9e Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 14 Aug 2024 11:55:08 +0200 Subject: [PATCH 003/197] implement permissions --- .../syft/src/syft/service/job/base_stash.py | 178 +++++++++++++++--- .../syft/src/syft/service/job/job_stash.py | 7 + 2 files changed, 159 insertions(+), 26 deletions(-) diff --git a/packages/syft/src/syft/service/job/base_stash.py b/packages/syft/src/syft/service/job/base_stash.py index 59146a71fdb..079457810d7 100644 --- a/packages/syft/src/syft/service/job/base_stash.py +++ b/packages/syft/src/syft/service/job/base_stash.py @@ -41,7 +41,13 @@ from ...types.syft_object import SyftObject from ...types.uid import UID from ..action.action_object import Action +from ..action.action_permissions import ActionObjectEXECUTE +from ..action.action_permissions import ActionObjectOWNER from ..action.action_permissions import ActionObjectPermission +from ..action.action_permissions import ActionObjectREAD +from ..action.action_permissions import ActionObjectWRITE +from ..action.action_permissions import ActionPermission +from ..action.action_permissions import StoragePermission from ..response import SyftSuccess @@ -89,8 +95,8 @@ def model_dump(obj: pydantic.BaseModel) -> dict: obj_dict[key] = str(getattr(obj, key)) elif ( type_.annotation is LinkedObject - or type_.annotation == Any | None - or type_.annotation == Action | None + or type_.annotation == Any | None # type: ignore + or type_.annotation == Action | None # type: ignore ): # not very efficient as it serializes the object twice data = sy.serialize(getattr(obj, key), to_bytes=True) @@ -195,14 +201,10 @@ def table(self) -> Table: self.object_type.__canonical_name__, Base.metadata, Column("id", UIDTypeDecorator, primary_key=True, default=uuid.uuid4), - Column("created_at", sa.DateTime, server_default=sa.func.now()), - Column( - "updated_at", - sa.DateTime, - server_default=sa.func.now(), - server_onupdate=sa.func.now(), - ), Column("fields", JSON, default={}), + Column("permissions", JSON, default=[]), + Column("created_at", sa.DateTime, server_default=sa.func.now()), + Column("updated_at", sa.DateTime, server_onupdate=sa.func.now()), ) return Base.metadata.tables[table_name] @@ -210,17 +212,33 @@ def get_by_uid( self, credentials: SyftVerifyKey, uid: UID ) -> Result[SyftT | None, str]: result = self.session.execute( - self.table.select().where(self.table.c.id == uid) + self.table.select().where( + sa.and_( + self._get_field_filter("id", uid), + self._get_permission_filter(credentials), + ) + ) ).first() if result is None: return Ok(None) return Ok(self.row_as_obj(result)) + def _get_field_filter( + self, field_name: str, field_value: str + ) -> sa.sql.elements.BinaryExpression: + if field_name == "id": + # use id column directly + return self.table.c.id == field_value + return self.table.c.fields[field_name] == field_value + def get_one_by_field( self, credentials: SyftVerifyKey, field_name: str, field_value: str ) -> Result[SyftT | None, str]: result = self.session.execute( - self.table.select().where(self.table.c.fields[field_name] == field_value) + sa.and_( + self._get_field_filter(field_name, field_value), + self._get_permission_filter(credentials), + ) ).first() if result is None: return Ok(None) @@ -229,24 +247,43 @@ def get_one_by_field( def get_all_by_field( self, credentials: SyftVerifyKey, field_name: str, field_value: str ) -> Result[list[SyftT], str]: - result = self.session.execute( - self.table.select().where(self.table.c.fields[field_name] == field_value) - ).all() + stmt = self.table.select().where( + sa.and_( + self._get_field_filter(field_name, field_value), + self._get_permission_filter(credentials), + ) + ) + result = self.session.execute(stmt).all() objs = [self.row_as_obj(row) for row in result] return Ok(objs) - def row_as_obj(self, row: Row): + def row_as_obj(self, row: Row) -> SyftT: return model_validate(self.object_type, row.fields) + def _get_permission_filter( + self, + credentials: SyftVerifyKey, + permission: ActionPermission = ActionPermission.READ, + ) -> sa.sql.elements.BinaryExpression: + # TODO: handle user.role in (ServiceRole.DATA_OWNER, ServiceRole.ADMIN) + # after user stash is implemented + + return self.table.c.permissions.contains( + ActionObjectREAD( + uid=UID(), # dummy uid, we just need the permission string + credentials=credentials, + ).permission_string + ) + def get_all( self, credentials: SyftVerifyKey, order_by: PartitionKey | None = None, has_permission: bool = False, ) -> Result[list[SyftT], str]: - stmt = self.table.select() + # filter by read permission + stmt = self.table.select().where(self._get_permission_filter(credentials)) result = self.session.execute(stmt).all() - objs = [self.row_as_obj(row) for row in result] return Ok(objs) @@ -258,7 +295,12 @@ def update( ) -> Result[SyftT, str]: stmt = ( self.table.update() - .where(self.table.c.id == obj.id) + .where( + sa.and_( + self._get_field_filter("id", obj.id), + self._get_permission_filter(credentials), + ) + ) .values(fields=model_dump(obj)) ) self.session.execute(stmt) @@ -271,35 +313,119 @@ def set( obj: SyftT, add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, - ignore_duplicates: bool = False, + ignore_duplicates: bool = False, # only used in one place, should use upsert instead ) -> Result[SyftT, str]: + # uid is unique by database constraint + uid = obj.id + + permissions = self.get_ownership_permissions(uid, credentials) + if add_permissions is not None: + add_permission_strings = [p.permission_string for p in add_permissions] + permissions.extend(add_permission_strings) + + storage_permissions = [] + if add_storage_permission: + storage_permissions.append( + StoragePermission( + uid=uid, + server_uid=self.server_uid, + ) + ) + + # TODO: write the storage permissions to the database + + # create the object with the permissions stmt = self.table.insert().values( - id=obj.id, + id=uid, fields=model_dump(obj), + permissions=permissions, + # storage_permissions=storage_permissions, ) self.session.execute(stmt) self.session.commit() return Ok(obj) + def get_ownership_permissions( + self, uid: UID, credentials: SyftVerifyKey + ) -> list[str]: + return [ + ActionObjectOWNER(uid=uid, credentials=credentials).permission_string, + ActionObjectWRITE(uid=uid, credentials=credentials).permission_string, + ActionObjectREAD(uid=uid, credentials=credentials).permission_string, + ActionObjectEXECUTE(uid=uid, credentials=credentials).permission_string, + ] + def delete_by_uid( self, credentials: SyftVerifyKey, uid: UID ) -> Result[SyftSuccess, str]: - stmt = self.table.delete().where(self.table.c.id == uid) + stmt = self.table.delete().where( + sa.and_( + self._get_field_filter("id", uid), + self._get_permission_filter(credentials), + ) + ) self.session.execute(stmt) self.session.commit() return Ok(SyftSuccess()) def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: - pass + # TODO: should do this in a single transaction + for permission in permissions: + self.add_permission(permission) + return None def add_permission(self, permission: ActionObjectPermission) -> None: - pass + stmt = ( + self.table.update() + .values( + permissions=sa.func.array_append( + self.table.c.permissions, permission.permission_string + ) + ) + .where( + sa.and_( + self._get_field_filter("id", permission.uid), + self._get_permission_filter( + permission.credentials, ActionPermission.WRITE + ), + ) + ) + ) + self.session.execute(stmt) + self.session.commit() + return None def remove_permission(self, permission: ActionObjectPermission) -> None: - pass + stmt = ( + self.table.update() + .values( + permissions=sa.func.array_remove(self.table.c.permissions, permission) + ) + .where( + sa.and_( + self._get_field_filter("id", permission.uid), + self._get_permission_filter( + permission.credentials, + # since anyone with write permission can add permissions, + # owner check doesn't make sense, it should be write + ActionPermission.OWNER, + ), + ) + ) + ) + self.session.execute(stmt) + self.session.commit() + return None def has_permission(self, permission: ActionObjectPermission) -> bool: - return True + stmt = self.table.select().where( + sa.and_( + self._get_field_filter("id", permission.uid), + self.table.c.permissions.contains(permission.permission_string), + ) + ) + result = self.session.execute(stmt).first() + return result is not None - def has_storage_permission(self, permission) -> bool: + def has_storage_permission(self, permission: StoragePermission) -> bool: return True diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index cbcb278492e..8d3ec259317 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -881,6 +881,13 @@ def get_by_user_code_id( credentials=credentials, field_name="user_code_id", field_value=user_code_id ) + def get_by_parent_id( + self, credentials: SyftVerifyKey, uid: UID + ) -> Result[list[Job], str]: + return self.get_all_by_field( + credentials=credentials, field_name="parent_job_id", field_value=uid + ) + @serializable() class JobV1(SyncableSyftObject): From 8ea22a29182358070384a6a97d9a8f6c89fb7aa3 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 14 Aug 2024 12:17:09 +0200 Subject: [PATCH 004/197] implement request stash --- .../syft/src/syft/service/job/base_stash.py | 7 ++++++ .../src/syft/service/request/request_stash.py | 24 +++++++++---------- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/packages/syft/src/syft/service/job/base_stash.py b/packages/syft/src/syft/service/job/base_stash.py index 079457810d7..5875be1261b 100644 --- a/packages/syft/src/syft/service/job/base_stash.py +++ b/packages/syft/src/syft/service/job/base_stash.py @@ -85,6 +85,8 @@ class CommonMixin: def model_dump(obj: pydantic.BaseModel) -> dict: + from syft.service.request.request import Change + obj_dict = obj.model_dump() for key, type_ in obj.model_fields.items(): if type_.annotation is UID: @@ -95,6 +97,7 @@ def model_dump(obj: pydantic.BaseModel) -> dict: obj_dict[key] = str(getattr(obj, key)) elif ( type_.annotation is LinkedObject + or type_.annotation == list[Change] or type_.annotation == Any | None # type: ignore or type_.annotation == Action | None # type: ignore ): @@ -102,6 +105,7 @@ def model_dump(obj: pydantic.BaseModel) -> dict: data = sy.serialize(getattr(obj, key), to_bytes=True) base64_data = base64.b64encode(data).decode("utf-8") obj_dict[key] = base64_data + return obj_dict @@ -109,6 +113,8 @@ def model_dump(obj: pydantic.BaseModel) -> dict: def model_validate(obj_type: type[T], obj_dict: dict) -> T: + from syft.service.request.request import Change + for key, type_ in obj_type.model_fields.items(): if key not in obj_dict: continue @@ -121,6 +127,7 @@ def model_validate(obj_type: type[T], obj_dict: dict) -> T: obj_dict[key] = SyftSigningKey.from_string(obj_dict[key]) elif ( type_.annotation is LinkedObject + or type_.annotation == list[Change] or type_.annotation == Any | None or type_.annotation == Action | None ): diff --git a/packages/syft/src/syft/service/request/request_stash.py b/packages/syft/src/syft/service/request/request_stash.py index b56bd6932e1..952fdb6f2d3 100644 --- a/packages/syft/src/syft/service/request/request_stash.py +++ b/packages/syft/src/syft/service/request/request_stash.py @@ -3,6 +3,7 @@ # third party from result import Ok from result import Result +from syft.service.job.base_stash import ObjectStash # relative from ...serde.serializable import serializable @@ -24,8 +25,8 @@ @instrument -@serializable(canonical_name="RequestStash", version=1) -class RequestStash(BaseUIDStoreStash): +@serializable(canonical_name="RequestStashSQL", version=1) +class RequestStash(ObjectStash[Request]): object_type = Request settings: PartitionSettings = PartitionSettings( name=Request.__canonical_name__, object_type=Request @@ -38,20 +39,17 @@ def get_all_for_verify_key( ) -> Result[list[Request], str]: if isinstance(verify_key, str): verify_key = SyftVerifyKey.from_string(verify_key) - qks = QueryKeys(qks=[RequestingUserVerifyKeyPartitionKey.with_obj(verify_key)]) - return self.query_all( + return self.get_all_by_field( credentials=credentials, - qks=qks, - order_by=OrderByRequestTimeStampPartitionKey, + field_name="requesting_user_verify_key", + field_value=verify_key, ) def get_by_usercode_id( self, credentials: SyftVerifyKey, user_code_id: UID ) -> Result[list[Request], str]: - query = self.get_all(credentials=credentials) - if query.is_err(): - return query - - all_requests: list[Request] = query.ok() - results = [r for r in all_requests if r.code_id == user_code_id] - return Ok(results) + return self.get_all_by_field( + credentials=credentials, + field_name="code_id", + field_value=user_code_id, + ) From 2021a79a9303fff0b699e2d3633a50e5b3988aa9 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 14 Aug 2024 14:11:05 +0200 Subject: [PATCH 005/197] implement user and settings stash --- packages/syft/src/syft/server/server.py | 4 +- .../syft/src/syft/service/job/base_stash.py | 124 +++++++++++++----- .../syft/src/syft/service/job/job_stash.py | 8 +- .../src/syft/service/request/request_stash.py | 7 +- .../syft/service/settings/settings_stash.py | 35 +---- .../src/syft/service/user/user_service.py | 14 +- .../syft/src/syft/service/user/user_stash.py | 106 ++++++--------- 7 files changed, 149 insertions(+), 149 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 66bb446fbf5..67b07a67525 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -1622,7 +1622,7 @@ def create_initial_settings(self, admin_email: str) -> ServerSettings | None: notifications_enabled=False, ) result = settings_stash.set( - credentials=self.signing_key.verify_key, settings=new_settings + credentials=self.signing_key.verify_key, obj=new_settings ) if result.is_ok(): return result.ok() @@ -1660,7 +1660,7 @@ def create_admin_new( user.verify_key = user.signing_key.verify_key result = user_stash.set( credentials=server.signing_key.verify_key, - user=user, + obj=user, ignore_duplicates=True, ) if result.is_ok(): diff --git a/packages/syft/src/syft/service/job/base_stash.py b/packages/syft/src/syft/service/job/base_stash.py index 5875be1261b..0ee01146b69 100644 --- a/packages/syft/src/syft/service/job/base_stash.py +++ b/packages/syft/src/syft/service/job/base_stash.py @@ -2,6 +2,7 @@ # stdlib import base64 +from enum import Enum import json import threading from typing import Any @@ -20,6 +21,7 @@ from sqlalchemy import Table from sqlalchemy import TypeDecorator from sqlalchemy import create_engine +from sqlalchemy import func from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import Session @@ -40,7 +42,6 @@ from ...types.datetime import DateTime from ...types.syft_object import SyftObject from ...types.uid import UID -from ..action.action_object import Action from ..action.action_permissions import ActionObjectEXECUTE from ..action.action_permissions import ActionObjectOWNER from ..action.action_permissions import ActionObjectPermission @@ -49,6 +50,7 @@ from ..action.action_permissions import ActionPermission from ..action.action_permissions import StoragePermission from ..response import SyftSuccess +from ..user.user_roles import ServiceRole class Base(DeclarativeBase): @@ -85,21 +87,33 @@ class CommonMixin: def model_dump(obj: pydantic.BaseModel) -> dict: - from syft.service.request.request import Change + # relative + from ...util.misc_objs import HTMLObject + from ...util.misc_objs import MarkdownDescription + from ..action.action_object import Action + from ..request.request import Change + from ..settings.settings import PwdTokenResetConfig obj_dict = obj.model_dump() for key, type_ in obj.model_fields.items(): if type_.annotation is UID: obj_dict[key] = obj_dict[key].no_dash - elif type_.annotation is SyftVerifyKey: - obj_dict[key] = str(getattr(obj, key)) - elif type_.annotation is SyftSigningKey: - obj_dict[key] = str(getattr(obj, key)) + elif ( + type_.annotation is SyftVerifyKey + or type_.annotation == SyftVerifyKey | None + or type_.annotation is SyftSigningKey + or type_.annotation == SyftSigningKey | None + ): + attr = getattr(obj, key) + obj_dict[key] = str(attr) if attr is not None else None elif ( type_.annotation is LinkedObject or type_.annotation == list[Change] or type_.annotation == Any | None # type: ignore or type_.annotation == Action | None # type: ignore + or getattr(type_.annotation, "__origin__", None) is dict + or type_.annotation == HTMLObject | MarkdownDescription + or type_.annotation == PwdTokenResetConfig ): # not very efficient as it serializes the object twice data = sy.serialize(getattr(obj, key), to_bytes=True) @@ -113,7 +127,12 @@ def model_dump(obj: pydantic.BaseModel) -> dict: def model_validate(obj_type: type[T], obj_dict: dict) -> T: - from syft.service.request.request import Change + # relative + from ...util.misc_objs import HTMLObject + from ...util.misc_objs import MarkdownDescription + from ..action.action_object import Action + from ..request.request import Change + from ..settings.settings import PwdTokenResetConfig for key, type_ in obj_type.model_fields.items(): if key not in obj_dict: @@ -121,15 +140,28 @@ def model_validate(obj_type: type[T], obj_dict: dict) -> T: # FIXME if type_.annotation is UID or type_.annotation == UID | None: obj_dict[key] = UID(obj_dict[key]) - elif type_.annotation is SyftVerifyKey: - obj_dict[key] = SyftVerifyKey.from_string(obj_dict[key]) - elif type_.annotation is SyftSigningKey: - obj_dict[key] = SyftSigningKey.from_string(obj_dict[key]) + elif ( + type_.annotation is SyftVerifyKey + or type_.annotation == SyftVerifyKey | None + ): + obj_dict[key] = ( + SyftVerifyKey.from_string(obj_dict[key]) if obj_dict[key] else None + ) + elif ( + type_.annotation is SyftSigningKey + or type_.annotation == SyftSigningKey | None + ): + obj_dict[key] = ( + SyftSigningKey.from_string(obj_dict[key]) if obj_dict[key] else None + ) elif ( type_.annotation is LinkedObject or type_.annotation == list[Change] or type_.annotation == Any | None or type_.annotation == Action | None + or getattr(type_.annotation, "__origin__", None) is dict + or type_.annotation == HTMLObject | MarkdownDescription + or type_.annotation == PwdTokenResetConfig ): data = base64.b64decode(obj_dict[key]) obj_dict[key] = sy.deserialize(data, from_bytes=True) @@ -142,7 +174,13 @@ def _default_dumps(val): # type: ignore return str(val.no_dash) elif isinstance(val, UUID): return val.hex - # raise TypeError(f"Can't serialize {val}, type {type(val)}") + elif issubclass(type(val), Enum): + return val.name + elif val is None: + return None + return str(val) + # elif isinstance + raise TypeError(f"Can't serialize {val}, type {type(val)}") def _default_loads(val): # type: ignore @@ -190,7 +228,7 @@ class ObjectStash(Generic[SyftT]): def __init__(self, store: DocumentStore) -> None: self.server_uid = store.server_uid - self.verify_key = store.root_verify_key + self.root_verify_key = store.root_verify_key # is there a better way to init the table _ = self.table self.db = SQLiteDBManager(self.server_uid) @@ -236,16 +274,25 @@ def _get_field_filter( if field_name == "id": # use id column directly return self.table.c.id == field_value - return self.table.c.fields[field_name] == field_value + return func.json_extract(self.table.c.fields, f"$.{field_name}") == field_value - def get_one_by_field( + def _get_by_field( self, credentials: SyftVerifyKey, field_name: str, field_value: str - ) -> Result[SyftT | None, str]: - result = self.session.execute( + ) -> Result[Row, str]: + stmt = self.table.select().where( sa.and_( self._get_field_filter(field_name, field_value), self._get_permission_filter(credentials), ) + ) + result = self.session.execute(stmt) + return result + + def get_one_by_field( + self, credentials: SyftVerifyKey, field_name: str, field_value: str + ) -> Result[SyftT | None, str]: + result = self._get_by_field( + credentials=credentials, field_name=field_name, field_value=field_value ).first() if result is None: return Ok(None) @@ -254,33 +301,44 @@ def get_one_by_field( def get_all_by_field( self, credentials: SyftVerifyKey, field_name: str, field_value: str ) -> Result[list[SyftT], str]: - stmt = self.table.select().where( - sa.and_( - self._get_field_filter(field_name, field_value), - self._get_permission_filter(credentials), - ) - ) - result = self.session.execute(stmt).all() + result = self._get_by_field( + credentials=credentials, field_name=field_name, field_value=field_value + ).all() objs = [self.row_as_obj(row) for row in result] return Ok(objs) def row_as_obj(self, row: Row) -> SyftT: return model_validate(self.object_type, row.fields) + def get_role(self, credentials: SyftVerifyKey) -> ServiceRole: + stmt = ( + Table("User", Base.metadata) + .select() + .where( + self._get_field_filter("verify_key", str(credentials)), + ) + ) + result = self.session.execute(stmt).first() + if result is None: + return ServiceRole.GUEST + return ServiceRole[result.fields["role"]] + def _get_permission_filter( self, credentials: SyftVerifyKey, permission: ActionPermission = ActionPermission.READ, ) -> sa.sql.elements.BinaryExpression: - # TODO: handle user.role in (ServiceRole.DATA_OWNER, ServiceRole.ADMIN) - # after user stash is implemented + if self.get_role(credentials) in (ServiceRole.ADMIN, ServiceRole.DATA_OWNER): + return sa.literal(True) - return self.table.c.permissions.contains( - ActionObjectREAD( + per_object_permission_filter = self.table.c.permissions.contains( + ActionObjectPermission( uid=UID(), # dummy uid, we just need the permission string credentials=credentials, + permission=permission, ).permission_string ) + return per_object_permission_filter def get_all( self, @@ -289,7 +347,9 @@ def get_all( has_permission: bool = False, ) -> Result[list[SyftT], str]: # filter by read permission - stmt = self.table.select().where(self._get_permission_filter(credentials)) + stmt = self.table.select() + if not has_permission: + stmt = stmt.where(self._get_permission_filter(credentials)) result = self.session.execute(stmt).all() objs = [self.row_as_obj(row) for row in result] return Ok(objs) @@ -363,7 +423,7 @@ def get_ownership_permissions( ] def delete_by_uid( - self, credentials: SyftVerifyKey, uid: UID + self, credentials: SyftVerifyKey, uid: UID, has_permission: bool = False ) -> Result[SyftSuccess, str]: stmt = self.table.delete().where( sa.and_( @@ -373,7 +433,9 @@ def delete_by_uid( ) self.session.execute(stmt) self.session.commit() - return Ok(SyftSuccess()) + return Ok( + SyftSuccess(message=f"{type(self.object_type).__name__}: {uid} deleted") + ) def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: # TODO: should do this in a single transaction diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 8d3ec259317..238e3549d12 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -871,21 +871,23 @@ def get_by_worker( self, credentials: SyftVerifyKey, worker_id: str ) -> Result[list[Job], str]: return self.get_all_by_field( - credentials=credentials, field_name="worker_id", field_value=worker_id + credentials=credentials, field_name="worker_id", field_value=str(worker_id) ) def get_by_user_code_id( self, credentials: SyftVerifyKey, user_code_id: UID ) -> Result[list[Job], str]: return self.get_all_by_field( - credentials=credentials, field_name="user_code_id", field_value=user_code_id + credentials=credentials, + field_name="user_code_id", + field_value=str(user_code_id), ) def get_by_parent_id( self, credentials: SyftVerifyKey, uid: UID ) -> Result[list[Job], str]: return self.get_all_by_field( - credentials=credentials, field_name="parent_job_id", field_value=uid + credentials=credentials, field_name="parent_job_id", field_value=str(uid) ) diff --git a/packages/syft/src/syft/service/request/request_stash.py b/packages/syft/src/syft/service/request/request_stash.py index 952fdb6f2d3..e0fef0a537f 100644 --- a/packages/syft/src/syft/service/request/request_stash.py +++ b/packages/syft/src/syft/service/request/request_stash.py @@ -1,20 +1,17 @@ # stdlib # third party -from result import Ok from result import Result -from syft.service.job.base_stash import ObjectStash # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys from ...types.datetime import DateTime from ...types.uid import UID from ...util.telemetry import instrument +from ..job.base_stash import ObjectStash from .request import Request RequestingUserVerifyKeyPartitionKey = PartitionKey( @@ -51,5 +48,5 @@ def get_by_usercode_id( return self.get_all_by_field( credentials=credentials, field_name="code_id", - field_value=user_code_id, + field_value=str(user_code_id), ) diff --git a/packages/syft/src/syft/service/settings/settings_stash.py b/packages/syft/src/syft/service/settings/settings_stash.py index 52c134274f7..e662040da0c 100644 --- a/packages/syft/src/syft/service/settings/settings_stash.py +++ b/packages/syft/src/syft/service/settings/settings_stash.py @@ -1,18 +1,15 @@ # stdlib # third party -from result import Result # relative from ...serde.serializable import serializable -from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from ...types.uid import UID from ...util.telemetry import instrument -from ..action.action_permissions import ActionObjectPermission +from ..job.base_stash import ObjectStash from .settings import ServerSettings NamePartitionKey = PartitionKey(key="name", type_=str) @@ -20,8 +17,8 @@ @instrument -@serializable(canonical_name="SettingsStash", version=1) -class SettingsStash(BaseUIDStoreStash): +@serializable(canonical_name="SettingsStashSQL", version=1) +class SettingsStash(ObjectStash[ServerSettings]): object_type = ServerSettings settings: PartitionSettings = PartitionSettings( name=ServerSettings.__canonical_name__, object_type=ServerSettings @@ -29,29 +26,3 @@ class SettingsStash(BaseUIDStoreStash): def __init__(self, store: DocumentStore) -> None: super().__init__(store=store) - - def set( - self, - credentials: SyftVerifyKey, - settings: ServerSettings, - add_permission: list[ActionObjectPermission] | None = None, - add_storage_permission: bool = True, - ignore_duplicates: bool = False, - ) -> Result[ServerSettings, str]: - res = self.check_type(settings, self.object_type) - # we dont use and_then logic here as it is hard because of the order of the arguments - if res.is_err(): - return res - return super().set(credentials=credentials, obj=res.ok()) - - def update( - self, - credentials: SyftVerifyKey, - settings: ServerSettings, - has_permission: bool = False, - ) -> Result[ServerSettings, str]: - res = self.check_type(settings, self.object_type) - # we dont use and_then logic here as it is hard because of the order of the arguments - if res.is_err(): - return res - return super().update(credentials=credentials, obj=res.ok()) diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index fe48010b4f4..bd2806046fc 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -79,7 +79,7 @@ def create( result = self.stash.set( credentials=context.credentials, - user=user, + obj=user, add_permissions=[ ActionObjectPermission( uid=user.id, permission=ActionPermission.ALL_READ @@ -213,7 +213,7 @@ def request_password_reset( user.reset_token_date = datetime.now() result = self.stash.update( - credentials=context.credentials, user=user, has_permission=True + credentials=context.credentials, obj=user, has_permission=True ) if result.is_err(): return SyftError( @@ -271,7 +271,7 @@ def reset_password( user.reset_token_date = None result = self.stash.update( - credentials=root_context.credentials, user=user, has_permission=True + credentials=root_context.credentials, obj=user, has_permission=True ) if result.is_err(): return SyftError( @@ -570,7 +570,7 @@ def update( setattr(user, name, value) result = self.stash.update( - credentials=context.credentials, user=user, has_permission=True + credentials=context.credentials, obj=user, has_permission=True ) if result.is_err(): @@ -587,7 +587,7 @@ def update( settings_data = settings.ok()[0] settings_data.admin_email = user.email settings_stash.update( - credentials=context.credentials, settings=settings_data + credentials=context.credentials, obj=settings_data ) return user.to(UserView) @@ -718,7 +718,7 @@ def register( result = self.stash.set( credentials=user.verify_key, - user=user, + obj=user, add_permissions=[ ActionObjectPermission( uid=user.id, permission=ActionPermission.ALL_READ @@ -798,7 +798,7 @@ def _set_notification_status( result = self.stash.update( credentials=user.verify_key, - user=user, + obj=user, ) if result.is_err(): return SyftError(message=str(result.err())) diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index 894d9a65115..7d96444c524 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -8,16 +8,11 @@ from ...serde.serializable import serializable from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from ...store.document_store import UIDPartitionKey -from ...types.uid import UID from ...util.telemetry import instrument -from ..action.action_permissions import ActionObjectPermission -from ..response import SyftSuccess +from ..job.base_stash import ObjectStash from .user import User from .user_roles import ServiceRole @@ -30,8 +25,8 @@ @instrument -@serializable(canonical_name="UserStash", version=1) -class UserStash(BaseStash): +@serializable(canonical_name="UserStashSQL", version=1) +class UserStash(ObjectStash[User]): object_type = User settings: PartitionSettings = PartitionSettings( name=User.__canonical_name__, @@ -41,51 +36,43 @@ class UserStash(BaseStash): def __init__(self, store: DocumentStore) -> None: super().__init__(store=store) - def set( - self, - credentials: SyftVerifyKey, - user: User, - add_permissions: list[ActionObjectPermission] | None = None, - add_storage_permission: bool = True, - ignore_duplicates: bool = False, - ) -> Result[User, str]: - res = self.check_type(user, self.object_type) - # we dont use and_then logic here as it is hard because of the order of the arguments - if res.is_err(): - return res - return super().set( - credentials=credentials, - obj=res.ok(), - add_permissions=add_permissions, - ignore_duplicates=ignore_duplicates, - add_storage_permission=add_storage_permission, - ) + self._init_root() + + def _init_root(self) -> None: + # start a transaction + users = self.get_all(self.root_verify_key, has_permission=True) + if not users: + # NOTE this is not thread safe, should use a session and transaction + super().set( + self.root_verify_key, + User( + email="_internal@root.com", + role=ServiceRole.ADMIN, + verify_key=self.root_verify_key, + ), + ) def admin_verify_key(self) -> Result[SyftVerifyKey | None, str]: - return Ok(self.partition.root_verify_key) + return Ok(self.root_verify_key) def admin_user(self) -> Result[User | None, str]: return self.get_by_role( credentials=self.admin_verify_key().ok(), role=ServiceRole.ADMIN ) - def get_by_uid( - self, credentials: SyftVerifyKey, uid: UID - ) -> Result[User | None, str]: - qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)]) - return self.query_one(credentials=credentials, qks=qks) - def get_by_reset_token( self, credentials: SyftVerifyKey, token: str ) -> Result[User | None, str]: - qks = QueryKeys(qks=[PasswordResetTokenPartitionKey.with_obj(token)]) - return self.query_one(credentials=credentials, qks=qks) + return self.get_one_by_field( + credentials=credentials, field_name="reset_token", field_value=token + ) def get_by_email( self, credentials: SyftVerifyKey, email: str ) -> Result[User | None, str]: - qks = QueryKeys(qks=[EmailPartitionKey.with_obj(email)]) - return self.query_one(credentials=credentials, qks=qks) + return self.get_one_by_field( + credentials=credentials, field_name="email", field_value=email + ) def email_exists(self, email: str) -> bool: res = self.get_by_email(credentials=self.admin_verify_key().ok(), email=email) @@ -97,43 +84,24 @@ def email_exists(self, email: str) -> bool: def get_by_role( self, credentials: SyftVerifyKey, role: ServiceRole ) -> Result[User | None, str]: - qks = QueryKeys(qks=[RolePartitionKey.with_obj(role)]) - return self.query_one(credentials=credentials, qks=qks) + return self.get_one_by_field( + credentials=credentials, field_name="role", field_value=role + ) def get_by_signing_key( self, credentials: SyftVerifyKey, signing_key: SyftSigningKey ) -> Result[User | None, str]: - if isinstance(signing_key, str): - signing_key = SyftSigningKey.from_string(signing_key) - qks = QueryKeys(qks=[SigningKeyPartitionKey.with_obj(signing_key)]) - return self.query_one(credentials=credentials, qks=qks) + return self.get_one_by_field( + credentials=credentials, + field_name="signing_key", + field_value=str(signing_key), + ) def get_by_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> Result[User | None, str]: - if isinstance(verify_key, str): - verify_key = SyftVerifyKey.from_string(verify_key) - qks = QueryKeys(qks=[VerifyKeyPartitionKey.with_obj(verify_key)]) - return self.query_one(credentials=credentials, qks=qks) - - def delete_by_uid( - self, credentials: SyftVerifyKey, uid: UID, has_permission: bool = False - ) -> Result[SyftSuccess, str]: - qk = UIDPartitionKey.with_obj(uid) - result = super().delete( - credentials=credentials, qk=qk, has_permission=has_permission - ) - if result.is_ok(): - return Ok(SyftSuccess(message=f"ID: {uid} deleted")) - return result - - def update( - self, credentials: SyftVerifyKey, user: User, has_permission: bool = False - ) -> Result[User, str]: - res = self.check_type(user, self.object_type) - # we dont use and_then logic here as it is hard because of the order of the arguments - if res.is_err(): - return res - return super().update( - credentials=credentials, obj=res.ok(), has_permission=has_permission + return self.get_one_by_field( + credentials=credentials, + field_name="verify_key", + field_value=str(verify_key), ) From 64ea446f6392f454bd0e1a1a317c03d86176df85 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 14 Aug 2024 16:22:50 +0200 Subject: [PATCH 006/197] add compound permissions --- .../syft/src/syft/service/job/base_stash.py | 34 +++++++++++++------ 1 file changed, 24 insertions(+), 10 deletions(-) diff --git a/packages/syft/src/syft/service/job/base_stash.py b/packages/syft/src/syft/service/job/base_stash.py index 0ee01146b69..48ba2d7423a 100644 --- a/packages/syft/src/syft/service/job/base_stash.py +++ b/packages/syft/src/syft/service/job/base_stash.py @@ -328,17 +328,31 @@ def _get_permission_filter( credentials: SyftVerifyKey, permission: ActionPermission = ActionPermission.READ, ) -> sa.sql.elements.BinaryExpression: - if self.get_role(credentials) in (ServiceRole.ADMIN, ServiceRole.DATA_OWNER): + role = self.get_role(credentials) + if role in (ServiceRole.ADMIN, ServiceRole.DATA_OWNER): return sa.literal(True) - per_object_permission_filter = self.table.c.permissions.contains( - ActionObjectPermission( - uid=UID(), # dummy uid, we just need the permission string - credentials=credentials, - permission=permission, - ).permission_string + permission_string = ActionObjectPermission( + uid=UID(), # dummy uid, we just need the permission string + credentials=credentials, + permission=permission, + ).permission_string + + compound_permission_map = { + ActionPermission.READ: ActionPermission.ALL_READ, + ActionPermission.WRITE: ActionPermission.ALL_WRITE, + ActionPermission.EXECUTE: ActionPermission.ALL_EXECUTE, + } + compound_permission_string = ActionObjectPermission( + uid=UID(), # dummy uid, we just need the permission string + credentials=None, # no credentials for compound permissions + permission=compound_permission_map[permission], + ).permission_string + + return sa.or_( + self.table.c.permissions.contains(permission_string), + self.table.c.permissions.contains(compound_permission_string), ) - return per_object_permission_filter def get_all( self, @@ -365,7 +379,7 @@ def update( .where( sa.and_( self._get_field_filter("id", obj.id), - self._get_permission_filter(credentials), + self._get_permission_filter(credentials, ActionObjectWRITE), ) ) .values(fields=model_dump(obj)) @@ -428,7 +442,7 @@ def delete_by_uid( stmt = self.table.delete().where( sa.and_( self._get_field_filter("id", uid), - self._get_permission_filter(credentials), + self._get_permission_filter(credentials, ActionPermission.OWNER), ) ) self.session.execute(stmt) From c6c5038c4f74ba992cb8c8695f567ff00f2b91ce Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 14 Aug 2024 17:40:02 +0200 Subject: [PATCH 007/197] fix permissions --- packages/syft/src/syft/server/server.py | 23 ++--- .../syft/src/syft/service/job/base_stash.py | 91 ++++++++++--------- 2 files changed, 60 insertions(+), 54 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 67b07a67525..556c93996a7 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -22,7 +22,6 @@ # third party from nacl.signing import SigningKey from result import Err -from result import Ok from result import Result # relative @@ -1156,16 +1155,18 @@ def handle_api_call_with_unsigned_result( api_call = api_call.message role = self.get_role_for_credentials(credentials=credentials) - settings = self.get_settings() - # TODO: This instance check should be removed once we can ensure that - # self.settings will always return a ServerSettings object. - if ( - settings is not None - and not isinstance(settings, Ok) - and not settings.allow_guest_sessions - and role == ServiceRole.GUEST - ): - return SyftError(message="Server doesn't allow guest sessions.") + # NOTE: can we cache this? + + # settings = self.get_settings() + # # TODO: This instance check should be removed once we can ensure that + # # self.settings will always return a ServerSettings object. + # if ( + # settings is not None + # and not isinstance(settings, Ok) + # and not settings.allow_guest_sessions + # and role == ServiceRole.GUEST + # ): + # return SyftError(message="Server doesn't allow guest sessions.") context = AuthedServiceContext( server=self, credentials=credentials, diff --git a/packages/syft/src/syft/service/job/base_stash.py b/packages/syft/src/syft/service/job/base_stash.py index 48ba2d7423a..8336e846bcb 100644 --- a/packages/syft/src/syft/service/job/base_stash.py +++ b/packages/syft/src/syft/service/job/base_stash.py @@ -86,14 +86,28 @@ class CommonMixin: json_document: Mapped[dict] = mapped_column(JSON, default={}) -def model_dump(obj: pydantic.BaseModel) -> dict: +def should_handle_as_bytes(type_) -> bool: # relative from ...util.misc_objs import HTMLObject from ...util.misc_objs import MarkdownDescription from ..action.action_object import Action from ..request.request import Change + from ..request.request import ChangeStatus from ..settings.settings import PwdTokenResetConfig + return ( + type_.annotation is LinkedObject + or type_.annotation == list[Change] + or type_.annotation == Any | None # type: ignore + or type_.annotation == Action | None # type: ignore + or getattr(type_.annotation, "__origin__", None) is dict + or type_.annotation == HTMLObject | MarkdownDescription + or type_.annotation == PwdTokenResetConfig + or type_.annotation == list[ChangeStatus] + ) + + +def model_dump(obj: pydantic.BaseModel) -> dict: obj_dict = obj.model_dump() for key, type_ in obj.model_fields.items(): if type_.annotation is UID: @@ -106,15 +120,7 @@ def model_dump(obj: pydantic.BaseModel) -> dict: ): attr = getattr(obj, key) obj_dict[key] = str(attr) if attr is not None else None - elif ( - type_.annotation is LinkedObject - or type_.annotation == list[Change] - or type_.annotation == Any | None # type: ignore - or type_.annotation == Action | None # type: ignore - or getattr(type_.annotation, "__origin__", None) is dict - or type_.annotation == HTMLObject | MarkdownDescription - or type_.annotation == PwdTokenResetConfig - ): + elif should_handle_as_bytes(type_): # not very efficient as it serializes the object twice data = sy.serialize(getattr(obj, key), to_bytes=True) base64_data = base64.b64encode(data).decode("utf-8") @@ -128,11 +134,6 @@ def model_dump(obj: pydantic.BaseModel) -> dict: def model_validate(obj_type: type[T], obj_dict: dict) -> T: # relative - from ...util.misc_objs import HTMLObject - from ...util.misc_objs import MarkdownDescription - from ..action.action_object import Action - from ..request.request import Change - from ..settings.settings import PwdTokenResetConfig for key, type_ in obj_type.model_fields.items(): if key not in obj_dict: @@ -144,25 +145,24 @@ def model_validate(obj_type: type[T], obj_dict: dict) -> T: type_.annotation is SyftVerifyKey or type_.annotation == SyftVerifyKey | None ): - obj_dict[key] = ( - SyftVerifyKey.from_string(obj_dict[key]) if obj_dict[key] else None - ) + if obj_dict[key] is None: + obj_dict[key] = None + elif isinstance(obj_dict[key], str): + obj_dict[key] = SyftVerifyKey.from_string(obj_dict[key]) + elif isinstance(obj_dict[key], SyftVerifyKey): + obj_dict[key] = obj_dict[key] + elif ( type_.annotation is SyftSigningKey or type_.annotation == SyftSigningKey | None ): - obj_dict[key] = ( - SyftSigningKey.from_string(obj_dict[key]) if obj_dict[key] else None - ) - elif ( - type_.annotation is LinkedObject - or type_.annotation == list[Change] - or type_.annotation == Any | None - or type_.annotation == Action | None - or getattr(type_.annotation, "__origin__", None) is dict - or type_.annotation == HTMLObject | MarkdownDescription - or type_.annotation == PwdTokenResetConfig - ): + if obj_dict[key] is None: + obj_dict[key] = None + elif isinstance(obj_dict[key], str): + obj_dict[key] = SyftSigningKey(signing_key=obj_dict[key]) + elif isinstance(obj_dict[key], SyftSigningKey): + obj_dict[key] = obj_dict[key] + elif should_handle_as_bytes(type_): data = base64.b64decode(obj_dict[key]) obj_dict[key] = sy.deserialize(data, from_bytes=True) @@ -269,19 +269,27 @@ def get_by_uid( return Ok(self.row_as_obj(result)) def _get_field_filter( - self, field_name: str, field_value: str + self, + field_name: str, + field_value: str, + table: Table | None = None, ) -> sa.sql.elements.BinaryExpression: + table = table if table is not None else self.table if field_name == "id": - # use id column directly - return self.table.c.id == field_value - return func.json_extract(self.table.c.fields, f"$.{field_name}") == field_value + return table.c.id == field_value + return func.json_extract(table.c.fields, f"$.{field_name}") == field_value def _get_by_field( - self, credentials: SyftVerifyKey, field_name: str, field_value: str + self, + credentials: SyftVerifyKey, + field_name: str, + field_value: str, + table: Table | None = None, ) -> Result[Row, str]: - stmt = self.table.select().where( + table = table if table is not None else self.table + stmt = table.select().where( sa.and_( - self._get_field_filter(field_name, field_value), + self._get_field_filter(field_name, field_value, table=table), self._get_permission_filter(credentials), ) ) @@ -311,12 +319,9 @@ def row_as_obj(self, row: Row) -> SyftT: return model_validate(self.object_type, row.fields) def get_role(self, credentials: SyftVerifyKey) -> ServiceRole: - stmt = ( - Table("User", Base.metadata) - .select() - .where( - self._get_field_filter("verify_key", str(credentials)), - ) + user_table = Table("User", Base.metadata) + stmt = user_table.select().where( + self._get_field_filter("verify_key", str(credentials), table=user_table), ) result = self.session.execute(stmt).first() if result is None: From c61d93478d8565f5f876411695195ba3bf422db7 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 14 Aug 2024 18:16:39 +0200 Subject: [PATCH 008/197] fix queries --- packages/syft/src/syft/server/server.py | 4 ++++ packages/syft/src/syft/service/job/base_stash.py | 12 +++++++++--- packages/syft/src/syft/service/job/job_service.py | 2 +- .../syft/src/syft/service/output/output_service.py | 8 ++++---- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 556c93996a7..1df25014734 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -864,6 +864,10 @@ def init_stores( def job_stash(self) -> JobStash: return self.get_service("jobservice").stash + @property + def output_stash(self) -> JobStash: + return self.get_service("outputservice").stash + @property def worker_stash(self) -> WorkerStash: return self.get_service("workerservice").stash diff --git a/packages/syft/src/syft/service/job/base_stash.py b/packages/syft/src/syft/service/job/base_stash.py index 8336e846bcb..6bd04ac3633 100644 --- a/packages/syft/src/syft/service/job/base_stash.py +++ b/packages/syft/src/syft/service/job/base_stash.py @@ -97,6 +97,9 @@ def should_handle_as_bytes(type_) -> bool: return ( type_.annotation is LinkedObject + or type_.annotation == LinkedObject | None + or type_.annotation == list[UID] | dict[str, UID] | None + or type_.annotation == dict[str, UID] | None or type_.annotation == list[Change] or type_.annotation == Any | None # type: ignore or type_.annotation == Action | None # type: ignore @@ -140,7 +143,10 @@ def model_validate(obj_type: type[T], obj_dict: dict) -> T: continue # FIXME if type_.annotation is UID or type_.annotation == UID | None: - obj_dict[key] = UID(obj_dict[key]) + if obj_dict[key] is None: + obj_dict[key] = None + else: + obj_dict[key] = UID(obj_dict[key]) elif ( type_.annotation is SyftVerifyKey or type_.annotation == SyftVerifyKey | None @@ -384,7 +390,7 @@ def update( .where( sa.and_( self._get_field_filter("id", obj.id), - self._get_permission_filter(credentials, ActionObjectWRITE), + self._get_permission_filter(credentials, ActionPermission.WRITE), ) ) .values(fields=model_dump(obj)) @@ -398,7 +404,7 @@ def set( credentials: SyftVerifyKey, obj: SyftT, add_permissions: list[ActionObjectPermission] | None = None, - add_storage_permission: bool = True, + add_storage_permission: bool = True, # TODO: check the default value ignore_duplicates: bool = False, # only used in one place, should use upsert instead ) -> Result[SyftT, str]: # uid is unique by database constraint diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index 3cfbe356f09..2ab1a3cc243 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -166,7 +166,7 @@ def restart( ) context.server.queue_stash.set_placeholder(context.credentials, queue_item) - context.server.job_stash.set(context.credentials, job) + self.stash.set(context.credentials, job) log_service = context.server.get_service("logservice") result = log_service.restart(context, job.log_id) diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index d8a2c75f7ad..ff6b1d75b42 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -204,16 +204,16 @@ def get_by_user_code_id( return self.get_all_by_field( credentials=credentials, field_name="user_code_id", - field_value=user_code_id, + field_value=str(user_code_id), ) def get_by_job_id( - self, credentials: SyftVerifyKey, user_code_id: UID + self, credentials: SyftVerifyKey, job_id: UID ) -> Result[ExecutionOutput | None, str]: return self.get_one_by_field( credentials=credentials, field_name="job_id", - field_value=user_code_id, + field_value=str(job_id), ) def get_by_output_policy_id( @@ -222,7 +222,7 @@ def get_by_output_policy_id( return self.get_all_by_field( credentials=credentials, field_name="output_policy_id", - field_value=output_policy_id, + field_value=str(output_policy_id), ) From 84ddf7ba94446f43425a918ed0adf1300a9e27e4 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 15 Aug 2024 10:53:52 +0200 Subject: [PATCH 009/197] implement dataset stash, fix datetime parsing --- .../src/syft/service/dataset/dataset_stash.py | 49 +++++++------------ .../syft/src/syft/service/job/base_stash.py | 18 +++++-- packages/syft/src/syft/types/datetime.py | 10 ++++ 3 files changed, 42 insertions(+), 35 deletions(-) diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index f03715985c0..2b8cccda208 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -8,23 +8,18 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys from ...types.uid import UID from ...util.telemetry import instrument +from ..job.base_stash import ObjectStash from .dataset import Dataset -from .dataset import DatasetUpdate - -NamePartitionKey = PartitionKey(key="name", type_=str) -ActionIDsPartitionKey = PartitionKey(key="action_ids", type_=list[UID]) @instrument -@serializable(canonical_name="DatasetStash", version=1) -class DatasetStash(BaseUIDStoreStash): +@serializable(canonical_name="DatasetStashSQL", version=1) +class DatasetStash(ObjectStash[Dataset]): object_type = Dataset settings: PartitionSettings = PartitionSettings( name=Dataset.__canonical_name__, object_type=Dataset @@ -36,26 +31,18 @@ def __init__(self, store: DocumentStore) -> None: def get_by_name( self, credentials: SyftVerifyKey, name: str ) -> Result[Dataset | None, str]: - qks = QueryKeys(qks=[NamePartitionKey.with_obj(name)]) - return self.query_one(credentials=credentials, qks=qks) - - def update( - self, - credentials: SyftVerifyKey, - dataset_update: DatasetUpdate, - has_permission: bool = False, - ) -> Result[Dataset, str]: - res = self.check_type(dataset_update, DatasetUpdate) - # we dont use and_then logic here as it is hard because of the order of the arguments - if res.is_err(): - return res - return super().update(credentials=credentials, obj=res.ok()) + return self.get_one_by_field( + credentials=credentials, field_name="name", field_value=name + ) def search_action_ids( self, credentials: SyftVerifyKey, uid: UID ) -> Result[list[Dataset], str]: - qks = QueryKeys(qks=[ActionIDsPartitionKey.with_obj(uid)]) - return self.query_all(credentials=credentials, qks=qks) + return self.get_all_by_field( + credentials=credentials, + field_name="action_ids", + field_value=str(uid), + ) def get_all( self, @@ -63,11 +50,9 @@ def get_all( order_by: PartitionKey | None = None, has_permission: bool = False, ) -> Ok[list] | Err[str]: - result = super().get_all(credentials, order_by, has_permission) - - if result.is_err(): - return result - filtered_datasets = [ - dataset for dataset in result.ok_value if not dataset.to_be_deleted - ] - return Ok(filtered_datasets) + result = self.get_all_by_field( + credentials=credentials, + field_name="to_be_deleted", + field_value=False, + ) + return result diff --git a/packages/syft/src/syft/service/job/base_stash.py b/packages/syft/src/syft/service/job/base_stash.py index 6bd04ac3633..ae0e54f8e27 100644 --- a/packages/syft/src/syft/service/job/base_stash.py +++ b/packages/syft/src/syft/service/job/base_stash.py @@ -91,6 +91,8 @@ def should_handle_as_bytes(type_) -> bool: from ...util.misc_objs import HTMLObject from ...util.misc_objs import MarkdownDescription from ..action.action_object import Action + from ..dataset.dataset import Asset + from ..dataset.dataset import Contributor from ..request.request import Change from ..request.request import ChangeStatus from ..settings.settings import PwdTokenResetConfig @@ -107,11 +109,15 @@ def should_handle_as_bytes(type_) -> bool: or type_.annotation == HTMLObject | MarkdownDescription or type_.annotation == PwdTokenResetConfig or type_.annotation == list[ChangeStatus] + or type_.annotation == list[Asset] + or type_.annotation == set[Contributor] + or type_.annotation == MarkdownDescription + or type_.annotation == Contributor ) def model_dump(obj: pydantic.BaseModel) -> dict: - obj_dict = obj.model_dump() + obj_dict = dict(obj) # obj.model_dump() does not work when for key, type_ in obj.model_fields.items(): if type_.annotation is UID: obj_dict[key] = obj_dict[key].no_dash @@ -123,6 +129,10 @@ def model_dump(obj: pydantic.BaseModel) -> dict: ): attr = getattr(obj, key) obj_dict[key] = str(attr) if attr is not None else None + elif type_.annotation is DateTime or type_.annotation == DateTime | None: + # FIXME: this is a hack, we should not be converting to string + if obj_dict[key] is not None: + obj_dict[key] = obj_dict[key].utc_timestamp elif should_handle_as_bytes(type_): # not very efficient as it serializes the object twice data = sy.serialize(getattr(obj, key), to_bytes=True) @@ -157,7 +167,9 @@ def model_validate(obj_type: type[T], obj_dict: dict) -> T: obj_dict[key] = SyftVerifyKey.from_string(obj_dict[key]) elif isinstance(obj_dict[key], SyftVerifyKey): obj_dict[key] = obj_dict[key] - + elif type_.annotation is DateTime or type_.annotation == DateTime | None: + if obj_dict[key] is not None: + obj_dict[key] = DateTime.from_timestamp(obj_dict[key]) elif ( type_.annotation is SyftSigningKey or type_.annotation == SyftSigningKey | None @@ -283,7 +295,7 @@ def _get_field_filter( table = table if table is not None else self.table if field_name == "id": return table.c.id == field_value - return func.json_extract(table.c.fields, f"$.{field_name}") == field_value + return table.c.fields[field_name] == func.json_quote(field_value) def _get_by_field( self, diff --git a/packages/syft/src/syft/types/datetime.py b/packages/syft/src/syft/types/datetime.py index afd4eb8c8dc..88b9aada9b3 100644 --- a/packages/syft/src/syft/types/datetime.py +++ b/packages/syft/src/syft/types/datetime.py @@ -1,6 +1,7 @@ # stdlib from datetime import datetime from datetime import timedelta +from datetime import timezone from functools import total_ordering import re from typing import Any @@ -63,6 +64,15 @@ def timedelta(self, other: "DateTime") -> timedelta: utc_timestamp_delta = self.utc_timestamp - other.utc_timestamp return timedelta(seconds=utc_timestamp_delta) + @classmethod + def from_timestamp(cls, ts: float) -> datetime: + return cls(utc_timestamp=ts) + + @classmethod + def from_datetime(cls, dt: datetime) -> "DateTime": + utc_datetime = dt.astimezone(timezone.utc) + return cls(utc_timestamp=utc_datetime.timestamp()) + def format_timedelta(local_timedelta: timedelta) -> str: total_seconds = int(local_timedelta.total_seconds()) From 50b0a952371b74a5ee5ee854216ca234e643be8b Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 15 Aug 2024 11:40:03 +0200 Subject: [PATCH 010/197] check unique constraints --- .../syft/src/syft/service/job/base_stash.py | 44 ++++++++++++++++++- 1 file changed, 42 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/job/base_stash.py b/packages/syft/src/syft/service/job/base_stash.py index ae0e54f8e27..aa033df5c1c 100644 --- a/packages/syft/src/syft/service/job/base_stash.py +++ b/packages/syft/src/syft/service/job/base_stash.py @@ -6,13 +6,13 @@ import json import threading from typing import Any -from typing import ClassVar from typing import Generic import uuid from uuid import UUID # third party import pydantic +from result import Err from result import Ok from result import Result import sqlalchemy as sa @@ -242,7 +242,7 @@ def session(self) -> Session: class ObjectStash(Generic[SyftT]): - object_type: ClassVar[type[SyftT]] + object_type: type[SyftT] def __init__(self, store: DocumentStore) -> None: self.server_uid = store.server_uid @@ -271,6 +271,39 @@ def table(self) -> Table: ) return Base.metadata.tables[table_name] + def _print_query(self, stmt: sa.sql.select) -> None: + print( + stmt.compile( + compile_kwargs={"literal_binds": True}, + dialect=self.session.bind.dialect, + ) + ) + + def is_unique(self, obj: SyftT) -> bool: + unique_fields = self.object_type.__attr_unique__ + if not unique_fields: + return True + filters = [] + for filter_name in unique_fields: + field_value = getattr(obj, filter_name, None) + if field_value is None: + continue + filt = self._get_field_filter( + field_name=filter_name, + # is the str cast correct? how to handle SyftVerifyKey? + field_value=str(field_value), + ) + filters.append(filt) + + stmt = self.table.select().where(sa.or_(*filters)) + results = self.session.execute(stmt).all() + if len(results) > 1: + return False + elif len(results) == 1: + result = results[0] + return result.id == obj.id + return True + def get_by_uid( self, credentials: SyftVerifyKey, uid: UID ) -> Result[SyftT | None, str]: @@ -384,6 +417,7 @@ def get_all( has_permission: bool = False, ) -> Result[list[SyftT], str]: # filter by read permission + # join on verify_key stmt = self.table.select() if not has_permission: stmt = stmt.where(self._get_permission_filter(credentials)) @@ -397,6 +431,9 @@ def update( obj: SyftT, has_permission: bool = False, ) -> Result[SyftT, str]: + if not self.is_unique(obj): + return Err(f"Some fields are not unique for {type(obj).__name__}") + stmt = ( self.table.update() .where( @@ -422,6 +459,9 @@ def set( # uid is unique by database constraint uid = obj.id + if not self.is_unique(obj): + return Err(f"Some fields are not unique for {type(obj).__name__}") + permissions = self.get_ownership_permissions(uid, credentials) if add_permissions is not None: add_permission_strings = [p.permission_string for p in add_permissions] From 95c1f0b20fb61baccb069e38f85da2407a9f3789 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 20 Aug 2024 12:01:59 +0200 Subject: [PATCH 011/197] add json serde --- packages/syft/src/syft/serde/json_serde.py | 358 +++++++++++++++++++++ 1 file changed, 358 insertions(+) create mode 100644 packages/syft/src/syft/serde/json_serde.py diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py new file mode 100644 index 00000000000..23f40b966c5 --- /dev/null +++ b/packages/syft/src/syft/serde/json_serde.py @@ -0,0 +1,358 @@ +# stdlib +import base64 +from collections.abc import Callable +from dataclasses import dataclass +from typing import Annotated +from typing import Any +from typing import Generic +from typing import TypeVar +from typing import Union +from typing import get_args +from typing import get_origin + +# third party +import pydantic +from pydantic import TypeAdapter +from pydantic import ValidationError +from pydantic import ValidationInfo +from pydantic import ValidatorFunctionWrapHandler +from pydantic import WrapValidator +from pydantic_core import PydanticCustomError +from typing_extensions import TypeAliasType + +# syft absolute +import syft as sy + +# relative +from ..server.credentials import SyftSigningKey +from ..server.credentials import SyftVerifyKey +from ..types.datetime import DateTime +from ..types.syft_object import BaseDateTime +from ..types.syft_object_registry import SyftObjectRegistry +from ..types.uid import UID + +T = TypeVar("T") + +JSON_CANONICAL_NAME_FIELD = "__canonical_name__" +JSON_VERSION_FIELD = "__version__" +JSON_DATA_FIELD = "data" + + +# JSON validator from Pydantic docs +# Source: https://docs.pydantic.dev/latest/concepts/types/#named-type-aliases +def json_custom_error_validator( + value: Any, handler: ValidatorFunctionWrapHandler, _info: ValidationInfo +) -> Any: + """ + Simplify the error message to avoid a gross error stemming + from exhaustive checking of all union options. + """ + try: + return handler(value) + except ValidationError: + raise PydanticCustomError( + "invalid_json", + "Input is not valid json", + ) + + +Json = TypeAliasType( # type: ignore + "Json", + Annotated[ + dict[str, "Json"] | list["Json"] | str | int | float | bool | None, # type: ignore + WrapValidator(json_custom_error_validator), + ], +) + +JSON_TYPE_ADAPTER = TypeAdapter(Json) + + +def _is_valid_json(value: Any) -> bool: + try: + JSON_TYPE_ADAPTER.validate_python(value) + return True + except ValidationError: + return False + + +@dataclass +class JSONSerde(Generic[T]): + # TODO add json schema + klass: type[T] + serialize_fn: Callable[[T], Json] | None = None + deserialize_fn: Callable[[Json], T] | None = None + + def serialize(self, obj: T) -> Json: + if self.serialize_fn is None: + return obj # type: ignore + else: + return self.serialize_fn(obj) + + def deserialize(self, obj: Json) -> T: + if self.deserialize_fn is None: + return obj # type: ignore + return self.deserialize_fn(obj) + + +JSON_SERDE_REGISTRY: dict[type[T], JSONSerde[T]] = {} + + +def register_json_serde( + type_: type[T], + serialize: Callable[[T], Json] | None = None, + deserialize: Callable[[Json], T] | None = None, +) -> None: + if type_ in JSON_SERDE_REGISTRY: + raise ValueError(f"Type {type_} is already registered") + + JSON_SERDE_REGISTRY[(type_)] = JSONSerde( + klass=type_, + serialize_fn=serialize, + deserialize_fn=deserialize, + ) + + +# Standard JSON primitives +register_json_serde(int) +register_json_serde(str) +register_json_serde(bool) +register_json_serde(float) +register_json_serde(type(None)) + +# Syft primitives +register_json_serde(UID, lambda uid: uid.no_dash, lambda s: UID(s)) +register_json_serde( + DateTime, lambda dt: dt.utc_timestamp, lambda f: DateTime(utc_timestamp=f) +) +register_json_serde( + BaseDateTime, lambda dt: dt.utc_timestamp, lambda f: BaseDateTime(utc_timestamp=f) +) +register_json_serde(SyftVerifyKey, lambda key: str(key), SyftVerifyKey.from_string) +register_json_serde(SyftSigningKey, lambda key: str(key), SyftSigningKey.from_string) + + +def _is_optional_annotation(annotation: Any) -> Any: + return annotation | None == annotation + + +def _get_nonoptional_annotation(annotation: Any) -> Any: + """Return the type anntation with None type removed, if it is present. + + Args: + annotation (Any): type annotation + + Returns: + Any: type annotation without None type + """ + if _is_optional_annotation(annotation): + args = get_args(annotation) + return Union[tuple(arg for arg in args if arg is not type(None))] # noqa + return annotation + + +def _annotation_is_subclass_of(annotation: Any, cls: type) -> bool: + try: + return issubclass(annotation, cls) + except TypeError: + return False + + +def _serialize_pydantic_to_json(obj: pydantic.BaseModel) -> Json: + canonical_name, version = SyftObjectRegistry.get_canonical_name_version(obj) + result = { + JSON_CANONICAL_NAME_FIELD: canonical_name, + JSON_VERSION_FIELD: version, + } + + for key, type_ in obj.model_fields.items(): + result[key] = serialize_json(getattr(obj, key), type_.annotation) + return result + + +def _deserialize_pydantic_from_json( + obj_dict: dict[str, Json], +) -> pydantic.BaseModel: + canonical_name = obj_dict[JSON_CANONICAL_NAME_FIELD] + version = obj_dict[JSON_VERSION_FIELD] + obj_type = SyftObjectRegistry.get_serde_class(canonical_name, version) + + result = {} + for key, type_ in obj_type.model_fields.items(): + result[key] = deserialize_json(obj_dict[key], type_.annotation) + + return obj_type.model_validate(result) + + +def _is_serializable_iterable(annotation: Any) -> bool: + # we can only serialize typed iterables without Union/Any + # NOTE optional is allowed + + # 1. check if it is an iterable + if get_origin(annotation) not in {list, tuple, set, frozenset}: + return False + + # 2. check if iterable annotation is serializable + args = get_args(annotation) + if len(args) != 1: + return False + + inner_type = _get_nonoptional_annotation(args[0]) + return inner_type in JSON_SERDE_REGISTRY or _annotation_is_subclass_of( + inner_type, pydantic.BaseModel + ) + + +def _serialize_iterable_to_json(value: Any, annotation: Any) -> Json: + return [serialize_json(v) for v in value] + + +def _deserialize_iterable_from_json(value: Json, annotation: Any) -> Any: + if not isinstance(value, list): + raise ValueError(f"Cannot deserialize {type(value)} to {annotation}") + + annotation = _get_nonoptional_annotation(annotation) + + if not _is_serializable_iterable(annotation): + raise ValueError(f"Cannot deserialize {annotation} from JSON") + + inner_type = _get_nonoptional_annotation(get_args(annotation)[0]) + return [deserialize_json(v, inner_type) for v in value] + + +def _is_serializable_mapping(annotation: Any) -> bool: + """ + Mapping is serializable if: + - it is a dict + - the key type is str + - the value type is serializable and not a Union + """ + if get_origin(annotation) != dict: + return False + + args = get_args(annotation) + if len(args) != 2: + return False + + key_type, value_type = args + # JSON only allows string keys + if not isinstance(key_type, str): + return False + + # check if value type is serializable + value_type = _get_nonoptional_annotation(value_type) + return value_type in JSON_SERDE_REGISTRY or _annotation_is_subclass_of( + value_type, pydantic.BaseModel + ) + + +def _serialize_mapping_to_json(value: Any, annotation: Any) -> Json: + _, value_type = get_args(annotation) + return {k: serialize_json(v, value_type) for k, v in value.items()} + + +def _deserialize_mapping_from_json(value: Json, annotation: Any) -> Any: + if not isinstance(value, dict): + raise ValueError(f"Cannot deserialize {type(value)} to {annotation}") + + annotation = _get_nonoptional_annotation(annotation) + + if not _is_serializable_mapping(annotation): + raise ValueError(f"Cannot deserialize {annotation} from JSON") + + _, value_type = get_args(annotation) + return {k: deserialize_json(v, value_type) for k, v in value.items()} + + +def _serialize_to_json_bytes(obj: Any) -> str: + obj_bytes = sy.serialize(obj, to_bytes=True) + return base64.b64encode(obj_bytes).decode("utf-8") + + +def _deserialize_from_json_bytes(obj: str) -> Any: + obj_bytes = base64.b64decode(obj) + return sy.deserialize(obj_bytes, from_bytes=True) + + +def serialize_json(value: Any, annotation: Any = None) -> Json: + """ + Serialize a value to a JSON-serializable object, using the schema defined by the + provided annotation. + + Serialization is always done according to the annotation, as the same annotation + is used for deserialization. If the annotation is not provided or is ambiguous, + the JSON serialization will fall back to serializing bytes. + + 'Strictly typed' means the annotation is unambiguous during deserialization: + - `int | None` is strictly typed and serialized to int (nullable) + - `str | int` is ambiguous and serialized to bytes + - `list[int]` is strictly typed + - `list`, `list[str | int]`, `list[Any]` are ambiguous and serialized to bytes + - Optional types are treated as strictly typed if the inner type is strictly typed + + The function chooses the appropriate serialization method in the following order: + 1. Method registered in `JSON_SERDE_REGISTRY` for the annotation type. + 2. Pydantic model serialization, including all `SyftObjects`. + 3. Iterable serialization, if the annotation is a strict iterable (e.g., `list[int]`). + 4. Mapping serialization, if the annotation is a strictly typed mapping with string keys. + 5. Serialize the object to bytes and encode it as base64. + + Args: + value (Any): Value to serialize. + annotation (Any, optional): Type annotation for the value. Defaults to None. + + Returns: + Json: JSON-serializable object. + """ + if annotation is None: + annotation = type(value) + + if value is None: + return None + + # Remove None type from annotation if it is present. + annotation = _get_nonoptional_annotation(annotation) + + if annotation in JSON_SERDE_REGISTRY: + return JSON_SERDE_REGISTRY[annotation].serialize(value) + # SyftObject, or any other Pydantic model + elif _annotation_is_subclass_of(annotation, pydantic.BaseModel): + return _serialize_pydantic_to_json(value) + + # Recursive types + # NOTE only strictly annotated iterables and mappings are supported + # example: list[int] is supported, but not list[Union[int, str]] + elif _is_serializable_iterable(annotation): + return _serialize_iterable_to_json(value, annotation) + elif _is_serializable_mapping(annotation): + return _serialize_mapping_to_json(value, annotation) + else: + return _serialize_to_json_bytes(value) + + +def deserialize_json(value: Json, annotation: Any) -> Any: + """Deserialize a JSON-serializable object to a value, using the schema defined by the + provided annotation. Inverse of `serialize_json`. + + Args: + value (Json): JSON-serializable object. + annotation (Any): Type annotation for the value. + + Returns: + Any: Deserialized value. + """ + if value is None: + return None + + # Remove None type from annotation if it is present. + annotation = _get_nonoptional_annotation(annotation) + + if annotation in JSON_SERDE_REGISTRY: + return JSON_SERDE_REGISTRY[annotation].deserialize(value) + elif _annotation_is_subclass_of(annotation, pydantic.BaseModel): + return _deserialize_pydantic_from_json(value) + elif isinstance(value, list): + return _deserialize_iterable_from_json(value, annotation) + elif isinstance(value, dict): + return _deserialize_mapping_from_json(value, annotation) + else: + return _deserialize_from_json_bytes(value) From 2a77c88af095300227454b25f118624e1d521886 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 20 Aug 2024 12:16:33 +0200 Subject: [PATCH 012/197] better validation --- packages/syft/src/syft/serde/json_serde.py | 66 ++++++---------------- 1 file changed, 18 insertions(+), 48 deletions(-) diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index 23f40b966c5..5a61a6e4115 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -2,7 +2,7 @@ import base64 from collections.abc import Callable from dataclasses import dataclass -from typing import Annotated +import json from typing import Any from typing import Generic from typing import TypeVar @@ -12,12 +12,6 @@ # third party import pydantic -from pydantic import TypeAdapter -from pydantic import ValidationError -from pydantic import ValidationInfo -from pydantic import ValidatorFunctionWrapHandler -from pydantic import WrapValidator -from pydantic_core import PydanticCustomError from typing_extensions import TypeAliasType # syft absolute @@ -38,42 +32,11 @@ JSON_DATA_FIELD = "data" -# JSON validator from Pydantic docs -# Source: https://docs.pydantic.dev/latest/concepts/types/#named-type-aliases -def json_custom_error_validator( - value: Any, handler: ValidatorFunctionWrapHandler, _info: ValidationInfo -) -> Any: - """ - Simplify the error message to avoid a gross error stemming - from exhaustive checking of all union options. - """ - try: - return handler(value) - except ValidationError: - raise PydanticCustomError( - "invalid_json", - "Input is not valid json", - ) - - -Json = TypeAliasType( # type: ignore +Json = TypeAliasType( # type: ignore[misc] "Json", - Annotated[ - dict[str, "Json"] | list["Json"] | str | int | float | bool | None, # type: ignore - WrapValidator(json_custom_error_validator), - ], + dict[str, "Json"] | list["Json"] | str | int | float | bool | None, # type: ignore[misc] ) -JSON_TYPE_ADAPTER = TypeAdapter(Json) - - -def _is_valid_json(value: Any) -> bool: - try: - JSON_TYPE_ADAPTER.validate_python(value) - return True - except ValidationError: - return False - @dataclass class JSONSerde(Generic[T]): @@ -203,7 +166,8 @@ def _is_serializable_iterable(annotation: Any) -> bool: def _serialize_iterable_to_json(value: Any, annotation: Any) -> Json: - return [serialize_json(v) for v in value] + # No need to validate in recursive calls + return [serialize_json(v, validate=False) for v in value] def _deserialize_iterable_from_json(value: Json, annotation: Any) -> Any: @@ -247,7 +211,8 @@ def _is_serializable_mapping(annotation: Any) -> bool: def _serialize_mapping_to_json(value: Any, annotation: Any) -> Json: _, value_type = get_args(annotation) - return {k: serialize_json(v, value_type) for k, v in value.items()} + # No need to validate in recursive calls + return {k: serialize_json(v, value_type, validate=False) for k, v in value.items()} def _deserialize_mapping_from_json(value: Json, annotation: Any) -> Any: @@ -273,7 +238,7 @@ def _deserialize_from_json_bytes(obj: str) -> Any: return sy.deserialize(obj_bytes, from_bytes=True) -def serialize_json(value: Any, annotation: Any = None) -> Json: +def serialize_json(value: Any, annotation: Any = None, validate: bool = True) -> Json: """ Serialize a value to a JSON-serializable object, using the schema defined by the provided annotation. @@ -313,20 +278,25 @@ def serialize_json(value: Any, annotation: Any = None) -> Json: annotation = _get_nonoptional_annotation(annotation) if annotation in JSON_SERDE_REGISTRY: - return JSON_SERDE_REGISTRY[annotation].serialize(value) + result = JSON_SERDE_REGISTRY[annotation].serialize(value) # SyftObject, or any other Pydantic model elif _annotation_is_subclass_of(annotation, pydantic.BaseModel): - return _serialize_pydantic_to_json(value) + result = _serialize_pydantic_to_json(value) # Recursive types # NOTE only strictly annotated iterables and mappings are supported # example: list[int] is supported, but not list[Union[int, str]] elif _is_serializable_iterable(annotation): - return _serialize_iterable_to_json(value, annotation) + result = _serialize_iterable_to_json(value, annotation) elif _is_serializable_mapping(annotation): - return _serialize_mapping_to_json(value, annotation) + result = _serialize_mapping_to_json(value, annotation) else: - return _serialize_to_json_bytes(value) + result = _serialize_to_json_bytes(value) + + if validate: + _ = json.dumps(result) + + return result def deserialize_json(value: Json, annotation: Any) -> Any: From 96d12d30e7efc81a8f08b8105494ff1834fe57bc Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 20 Aug 2024 12:17:46 +0200 Subject: [PATCH 013/197] better validation --- packages/syft/src/syft/serde/json_serde.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index 5a61a6e4115..22e40cf1881 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -12,7 +12,6 @@ # third party import pydantic -from typing_extensions import TypeAliasType # syft absolute import syft as sy @@ -32,10 +31,7 @@ JSON_DATA_FIELD = "data" -Json = TypeAliasType( # type: ignore[misc] - "Json", - dict[str, "Json"] | list["Json"] | str | int | float | bool | None, # type: ignore[misc] -) +Json = str | int | float | bool | None | list["Json"] | dict[str, "Json"] @dataclass From da94a2bb2f2c8b5e19bf9b94a1e19e3992100c32 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 20 Aug 2024 13:06:28 +0200 Subject: [PATCH 014/197] add LineageID --- packages/syft/src/syft/serde/json_serde.py | 27 ++++++++++++++-------- 1 file changed, 18 insertions(+), 9 deletions(-) diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index 22e40cf1881..3c3333e68b1 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -22,6 +22,7 @@ from ..types.datetime import DateTime from ..types.syft_object import BaseDateTime from ..types.syft_object_registry import SyftObjectRegistry +from ..types.uid import LineageID from ..types.uid import UID T = TypeVar("T") @@ -80,6 +81,7 @@ def register_json_serde( # Syft primitives register_json_serde(UID, lambda uid: uid.no_dash, lambda s: UID(s)) +register_json_serde(LineageID, lambda uid: uid.no_dash, lambda s: LineageID(s)) register_json_serde( DateTime, lambda dt: dt.utc_timestamp, lambda f: DateTime(utc_timestamp=f) ) @@ -90,6 +92,12 @@ def register_json_serde( register_json_serde(SyftSigningKey, lambda key: str(key), SyftSigningKey.from_string) +def _validate_json(value: T) -> T: + # Throws TypeError if value is not JSON-serializable + json.dumps(value) + return value + + def _is_optional_annotation(annotation: Any) -> Any: return annotation | None == annotation @@ -109,7 +117,8 @@ def _get_nonoptional_annotation(annotation: Any) -> Any: return annotation -def _annotation_is_subclass_of(annotation: Any, cls: type) -> bool: +def _annotation_issubclass(annotation: Any, cls: type) -> bool: + # issubclass throws TypeError if annotation is not a valid type (eg Union) try: return issubclass(annotation, cls) except TypeError: @@ -156,7 +165,7 @@ def _is_serializable_iterable(annotation: Any) -> bool: return False inner_type = _get_nonoptional_annotation(args[0]) - return inner_type in JSON_SERDE_REGISTRY or _annotation_is_subclass_of( + return inner_type in JSON_SERDE_REGISTRY or _annotation_issubclass( inner_type, pydantic.BaseModel ) @@ -200,7 +209,7 @@ def _is_serializable_mapping(annotation: Any) -> bool: # check if value type is serializable value_type = _get_nonoptional_annotation(value_type) - return value_type in JSON_SERDE_REGISTRY or _annotation_is_subclass_of( + return value_type in JSON_SERDE_REGISTRY or _annotation_issubclass( value_type, pydantic.BaseModel ) @@ -276,12 +285,12 @@ def serialize_json(value: Any, annotation: Any = None, validate: bool = True) -> if annotation in JSON_SERDE_REGISTRY: result = JSON_SERDE_REGISTRY[annotation].serialize(value) # SyftObject, or any other Pydantic model - elif _annotation_is_subclass_of(annotation, pydantic.BaseModel): + elif _annotation_issubclass(annotation, pydantic.BaseModel): result = _serialize_pydantic_to_json(value) - # Recursive types - # NOTE only strictly annotated iterables and mappings are supported - # example: list[int] is supported, but not list[Union[int, str]] + # JSON recursive types + # only strictly annotated iterables and mappings are supported + # example: list[int] is supported, but not list[int | str] elif _is_serializable_iterable(annotation): result = _serialize_iterable_to_json(value, annotation) elif _is_serializable_mapping(annotation): @@ -290,7 +299,7 @@ def serialize_json(value: Any, annotation: Any = None, validate: bool = True) -> result = _serialize_to_json_bytes(value) if validate: - _ = json.dumps(result) + _validate_json(result) return result @@ -314,7 +323,7 @@ def deserialize_json(value: Json, annotation: Any) -> Any: if annotation in JSON_SERDE_REGISTRY: return JSON_SERDE_REGISTRY[annotation].deserialize(value) - elif _annotation_is_subclass_of(annotation, pydantic.BaseModel): + elif _annotation_issubclass(annotation, pydantic.BaseModel): return _deserialize_pydantic_from_json(value) elif isinstance(value, list): return _deserialize_iterable_from_json(value, annotation) From e857cbdbfd45fc53aa864f308c569ff967ba745a Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 21 Aug 2024 10:12:15 +0200 Subject: [PATCH 015/197] remove pydantic validation --- packages/syft/src/syft/serde/json_serde.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index 3c3333e68b1..1a91fc7d222 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -253,11 +253,10 @@ def serialize_json(value: Any, annotation: Any = None, validate: bool = True) -> the JSON serialization will fall back to serializing bytes. 'Strictly typed' means the annotation is unambiguous during deserialization: - - `int | None` is strictly typed and serialized to int (nullable) - `str | int` is ambiguous and serialized to bytes - `list[int]` is strictly typed - `list`, `list[str | int]`, `list[Any]` are ambiguous and serialized to bytes - - Optional types are treated as strictly typed if the inner type is strictly typed + - Optional types are serializable The function chooses the appropriate serialization method in the following order: 1. Method registered in `JSON_SERDE_REGISTRY` for the annotation type. @@ -329,5 +328,7 @@ def deserialize_json(value: Json, annotation: Any) -> Any: return _deserialize_iterable_from_json(value, annotation) elif isinstance(value, dict): return _deserialize_mapping_from_json(value, annotation) - else: + elif isinstance(value, str): return _deserialize_from_json_bytes(value) + else: + raise ValueError(f"Cannot deserialize {value} to {annotation}") From 5a30ceff43a873f4ae24d2d42d604154be8fe0b7 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 21 Aug 2024 10:39:53 +0200 Subject: [PATCH 016/197] update --- packages/syft/src/syft/serde/json_serde.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index 1a91fc7d222..fb2e48b9514 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -250,13 +250,10 @@ def serialize_json(value: Any, annotation: Any = None, validate: bool = True) -> Serialization is always done according to the annotation, as the same annotation is used for deserialization. If the annotation is not provided or is ambiguous, - the JSON serialization will fall back to serializing bytes. - - 'Strictly typed' means the annotation is unambiguous during deserialization: - - `str | int` is ambiguous and serialized to bytes - - `list[int]` is strictly typed - - `list`, `list[str | int]`, `list[Any]` are ambiguous and serialized to bytes - - Optional types are serializable + the JSON serialization will fall back to serializing bytes. Examples: + - int, `list[int]` are strictly typed + - `str | int`, `list`, `list[str | int]`, `list[Any]` are ambiguous and serialized to bytes + - Optional types (like int | None) are serialized to the not-None type The function chooses the appropriate serialization method in the following order: 1. Method registered in `JSON_SERDE_REGISTRY` for the annotation type. From 0af6fb6ea6c45d64956ed8214d51da46519908c2 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 21 Aug 2024 12:10:21 +0200 Subject: [PATCH 017/197] fix typing --- packages/syft/src/syft/serde/json_serde.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index fb2e48b9514..e774ecb58d6 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -125,9 +125,9 @@ def _annotation_issubclass(annotation: Any, cls: type) -> bool: return False -def _serialize_pydantic_to_json(obj: pydantic.BaseModel) -> Json: +def _serialize_pydantic_to_json(obj: pydantic.BaseModel) -> dict[str, Json]: canonical_name, version = SyftObjectRegistry.get_canonical_name_version(obj) - result = { + result: dict[str, Json] = { JSON_CANONICAL_NAME_FIELD: canonical_name, JSON_VERSION_FIELD: version, } From f83a959003acd260cad7ed791cf7e30797b50f8a Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 21 Aug 2024 13:43:06 +0200 Subject: [PATCH 018/197] use new json serde, add enum support --- packages/syft/src/syft/serde/json_serde.py | 25 ++-- .../syft/src/syft/service/job/base_stash.py | 129 ++---------------- 2 files changed, 28 insertions(+), 126 deletions(-) diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index e774ecb58d6..e5e67d818d9 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -2,6 +2,7 @@ import base64 from collections.abc import Callable from dataclasses import dataclass +from enum import Enum import json from typing import Any from typing import Generic @@ -78,6 +79,7 @@ def register_json_serde( register_json_serde(bool) register_json_serde(float) register_json_serde(type(None)) +register_json_serde(pydantic.EmailStr) # Syft primitives register_json_serde(UID, lambda uid: uid.no_dash, lambda s: UID(s)) @@ -140,15 +142,19 @@ def _serialize_pydantic_to_json(obj: pydantic.BaseModel) -> dict[str, Json]: def _deserialize_pydantic_from_json( obj_dict: dict[str, Json], ) -> pydantic.BaseModel: - canonical_name = obj_dict[JSON_CANONICAL_NAME_FIELD] - version = obj_dict[JSON_VERSION_FIELD] - obj_type = SyftObjectRegistry.get_serde_class(canonical_name, version) + try: + canonical_name = obj_dict[JSON_CANONICAL_NAME_FIELD] + version = obj_dict[JSON_VERSION_FIELD] + obj_type = SyftObjectRegistry.get_serde_class(canonical_name, version) - result = {} - for key, type_ in obj_type.model_fields.items(): - result[key] = deserialize_json(obj_dict[key], type_.annotation) + result = {} + for key, type_ in obj_type.model_fields.items(): + result[key] = deserialize_json(obj_dict[key], type_.annotation) - return obj_type.model_validate(result) + return obj_type.model_validate(result) + except Exception as e: + print(json.dumps(obj_dict, indent=2)) + raise ValueError(f"Failed to deserialize Pydantic model: {e}") def _is_serializable_iterable(annotation: Any) -> bool: @@ -280,9 +286,10 @@ def serialize_json(value: Any, annotation: Any = None, validate: bool = True) -> if annotation in JSON_SERDE_REGISTRY: result = JSON_SERDE_REGISTRY[annotation].serialize(value) - # SyftObject, or any other Pydantic model elif _annotation_issubclass(annotation, pydantic.BaseModel): result = _serialize_pydantic_to_json(value) + elif _annotation_issubclass(annotation, Enum): + result = value.name # JSON recursive types # only strictly annotated iterables and mappings are supported @@ -321,6 +328,8 @@ def deserialize_json(value: Json, annotation: Any) -> Any: return JSON_SERDE_REGISTRY[annotation].deserialize(value) elif _annotation_issubclass(annotation, pydantic.BaseModel): return _deserialize_pydantic_from_json(value) + elif _annotation_issubclass(annotation, Enum): + return annotation[value] elif isinstance(value, list): return _deserialize_iterable_from_json(value, annotation) elif isinstance(value, dict): diff --git a/packages/syft/src/syft/service/job/base_stash.py b/packages/syft/src/syft/service/job/base_stash.py index aa033df5c1c..334f2cfc1a3 100644 --- a/packages/syft/src/syft/service/job/base_stash.py +++ b/packages/syft/src/syft/service/job/base_stash.py @@ -1,7 +1,6 @@ # stdlib # stdlib -import base64 from enum import Enum import json import threading @@ -11,7 +10,6 @@ from uuid import UUID # third party -import pydantic from result import Err from result import Ok from result import Result @@ -30,15 +28,13 @@ from sqlalchemy.types import JSON from typing_extensions import TypeVar -# syft absolute -import syft as sy - # relative -from ...server.credentials import SyftSigningKey +from ...serde.json_serde import Json +from ...serde.json_serde import deserialize_json +from ...serde.json_serde import serialize_json from ...server.credentials import SyftVerifyKey from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey -from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime from ...types.syft_object import SyftObject from ...types.uid import UID @@ -65,7 +61,7 @@ class UIDTypeDecorator(TypeDecorator): def process_bind_param(self, value, dialect): # type: ignore if value is not None: - return value + return value.value def process_result_value(self, value, dialect): # type: ignore if value is not None: @@ -86,108 +82,7 @@ class CommonMixin: json_document: Mapped[dict] = mapped_column(JSON, default={}) -def should_handle_as_bytes(type_) -> bool: - # relative - from ...util.misc_objs import HTMLObject - from ...util.misc_objs import MarkdownDescription - from ..action.action_object import Action - from ..dataset.dataset import Asset - from ..dataset.dataset import Contributor - from ..request.request import Change - from ..request.request import ChangeStatus - from ..settings.settings import PwdTokenResetConfig - - return ( - type_.annotation is LinkedObject - or type_.annotation == LinkedObject | None - or type_.annotation == list[UID] | dict[str, UID] | None - or type_.annotation == dict[str, UID] | None - or type_.annotation == list[Change] - or type_.annotation == Any | None # type: ignore - or type_.annotation == Action | None # type: ignore - or getattr(type_.annotation, "__origin__", None) is dict - or type_.annotation == HTMLObject | MarkdownDescription - or type_.annotation == PwdTokenResetConfig - or type_.annotation == list[ChangeStatus] - or type_.annotation == list[Asset] - or type_.annotation == set[Contributor] - or type_.annotation == MarkdownDescription - or type_.annotation == Contributor - ) - - -def model_dump(obj: pydantic.BaseModel) -> dict: - obj_dict = dict(obj) # obj.model_dump() does not work when - for key, type_ in obj.model_fields.items(): - if type_.annotation is UID: - obj_dict[key] = obj_dict[key].no_dash - elif ( - type_.annotation is SyftVerifyKey - or type_.annotation == SyftVerifyKey | None - or type_.annotation is SyftSigningKey - or type_.annotation == SyftSigningKey | None - ): - attr = getattr(obj, key) - obj_dict[key] = str(attr) if attr is not None else None - elif type_.annotation is DateTime or type_.annotation == DateTime | None: - # FIXME: this is a hack, we should not be converting to string - if obj_dict[key] is not None: - obj_dict[key] = obj_dict[key].utc_timestamp - elif should_handle_as_bytes(type_): - # not very efficient as it serializes the object twice - data = sy.serialize(getattr(obj, key), to_bytes=True) - base64_data = base64.b64encode(data).decode("utf-8") - obj_dict[key] = base64_data - - return obj_dict - - -T = TypeVar("T", bound=pydantic.BaseModel) - - -def model_validate(obj_type: type[T], obj_dict: dict) -> T: - # relative - - for key, type_ in obj_type.model_fields.items(): - if key not in obj_dict: - continue - # FIXME - if type_.annotation is UID or type_.annotation == UID | None: - if obj_dict[key] is None: - obj_dict[key] = None - else: - obj_dict[key] = UID(obj_dict[key]) - elif ( - type_.annotation is SyftVerifyKey - or type_.annotation == SyftVerifyKey | None - ): - if obj_dict[key] is None: - obj_dict[key] = None - elif isinstance(obj_dict[key], str): - obj_dict[key] = SyftVerifyKey.from_string(obj_dict[key]) - elif isinstance(obj_dict[key], SyftVerifyKey): - obj_dict[key] = obj_dict[key] - elif type_.annotation is DateTime or type_.annotation == DateTime | None: - if obj_dict[key] is not None: - obj_dict[key] = DateTime.from_timestamp(obj_dict[key]) - elif ( - type_.annotation is SyftSigningKey - or type_.annotation == SyftSigningKey | None - ): - if obj_dict[key] is None: - obj_dict[key] = None - elif isinstance(obj_dict[key], str): - obj_dict[key] = SyftSigningKey(signing_key=obj_dict[key]) - elif isinstance(obj_dict[key], SyftSigningKey): - obj_dict[key] = obj_dict[key] - elif should_handle_as_bytes(type_): - data = base64.b64decode(obj_dict[key]) - obj_dict[key] = sy.deserialize(data, from_bytes=True) - - return obj_type.model_validate(obj_dict) - - -def _default_dumps(val): # type: ignore +def _default_dumps(val: Any) -> Json: # type: ignore if isinstance(val, UID): return str(val.no_dash) elif isinstance(val, UUID): @@ -197,21 +92,19 @@ def _default_dumps(val): # type: ignore elif val is None: return None return str(val) - # elif isinstance - raise TypeError(f"Can't serialize {val}, type {type(val)}") -def _default_loads(val): # type: ignore +def _default_loads(val: Any) -> Any: # type: ignore if "UID" in val: return UID(val) return val -def dumps(d: dict) -> str: +def dumps(d: Any) -> str: return json.dumps(d, default=_default_dumps) -def loads(d: str) -> dict: +def loads(d: str) -> Any: return json.loads(d, object_hook=_default_loads) @@ -367,7 +260,7 @@ def get_all_by_field( return Ok(objs) def row_as_obj(self, row: Row) -> SyftT: - return model_validate(self.object_type, row.fields) + return deserialize_json(row.fields, self.object_type) def get_role(self, credentials: SyftVerifyKey) -> ServiceRole: user_table = Table("User", Base.metadata) @@ -442,7 +335,7 @@ def update( self._get_permission_filter(credentials, ActionPermission.WRITE), ) ) - .values(fields=model_dump(obj)) + .values(fields=serialize_json(obj)) ) self.session.execute(stmt) self.session.commit() @@ -481,7 +374,7 @@ def set( # create the object with the permissions stmt = self.table.insert().values( id=uid, - fields=model_dump(obj), + fields=serialize_json(obj), permissions=permissions, # storage_permissions=storage_permissions, ) From 292e5b591a3649be85c074fa22dee85e2313bff6 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 21 Aug 2024 16:26:46 +0200 Subject: [PATCH 019/197] usercode, status, log --- .../syft/src/syft/server/service_registry.py | 3 ++ .../src/syft/service/code/status_service.py | 21 +++---------- .../syft/src/syft/service/code/user_code.py | 6 ---- .../src/syft/service/code/user_code_stash.py | 31 +++++++------------ .../syft/src/syft/service/log/log_stash.py | 4 +-- packages/syft/src/syft/types/syft_object.py | 2 +- 6 files changed, 22 insertions(+), 45 deletions(-) diff --git a/packages/syft/src/syft/server/service_registry.py b/packages/syft/src/syft/server/service_registry.py index d7c3555f10c..1504e0eb817 100644 --- a/packages/syft/src/syft/server/service_registry.py +++ b/packages/syft/src/syft/server/service_registry.py @@ -133,3 +133,6 @@ def _get_service_from_path(self, path: str) -> AbstractService: return self.service_path_map[service_name.lower()] except KeyError: raise ValueError(f"Service {path} not found.") + + def __iter__(self) -> typing.Iterator[AbstractService]: + return iter(self.services) diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index 2baafb2ea57..3ec191daf28 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -1,19 +1,15 @@ # stdlib # third party -from result import Result # relative from ...serde.serializable import serializable -from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from ...store.document_store import UIDPartitionKey from ...types.uid import UID from ...util.telemetry import instrument from ..context import AuthedServiceContext +from ..job.base_stash import ObjectStash from ..response import SyftError from ..response import SyftSuccess from ..service import AbstractService @@ -25,8 +21,8 @@ @instrument -@serializable(canonical_name="StatusStash", version=1) -class StatusStash(BaseUIDStoreStash): +@serializable(canonical_name="StatusSQLStash", version=1) +class StatusStash(ObjectStash[UserCodeStatusCollection]): object_type = UserCodeStatusCollection settings: PartitionSettings = PartitionSettings( name=UserCodeStatusCollection.__canonical_name__, @@ -34,16 +30,7 @@ class StatusStash(BaseUIDStoreStash): ) def __init__(self, store: DocumentStore) -> None: - super().__init__(store) - self.store = store - self.settings = self.settings - self._object_type = self.object_type - - def get_by_uid( - self, credentials: SyftVerifyKey, uid: UID - ) -> Result[UserCodeStatusCollection, str]: - qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)]) - return self.query_one(credentials=credentials, qks=qks) + super().__init__(store=store) @instrument diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index 955005a5895..fda6c4d754c 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -48,7 +48,6 @@ from ...serde.signature import signature_remove_context from ...serde.signature import signature_remove_self from ...server.credentials import SyftVerifyKey -from ...store.document_store import PartitionKey from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime from ...types.dicttuple import DictTuple @@ -108,11 +107,6 @@ # relative from ...service.sync.diff_state import AttrDiff -UserVerifyKeyPartitionKey = PartitionKey(key="user_verify_key", type_=SyftVerifyKey) -CodeHashPartitionKey = PartitionKey(key="code_hash", type_=str) -ServiceFuncNamePartitionKey = PartitionKey(key="service_func_name", type_=str) -SubmitTimePartitionKey = PartitionKey(key="submit_time", type_=DateTime) - PyCodeObject = Any diff --git a/packages/syft/src/syft/service/code/user_code_stash.py b/packages/syft/src/syft/service/code/user_code_stash.py index 0fcb41b2087..123d8969dce 100644 --- a/packages/syft/src/syft/service/code/user_code_stash.py +++ b/packages/syft/src/syft/service/code/user_code_stash.py @@ -6,21 +6,16 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys from ...util.telemetry import instrument -from .user_code import CodeHashPartitionKey -from .user_code import ServiceFuncNamePartitionKey -from .user_code import SubmitTimePartitionKey +from ..job.base_stash import ObjectStash from .user_code import UserCode -from .user_code import UserVerifyKeyPartitionKey @instrument -@serializable(canonical_name="UserCodeStash", version=1) -class UserCodeStash(BaseUIDStoreStash): +@serializable(canonical_name="UserCodeSQLStash", version=1) +class UserCodeStash(ObjectStash[UserCode]): object_type = UserCode settings: PartitionSettings = PartitionSettings( name=UserCode.__canonical_name__, object_type=UserCode @@ -29,22 +24,20 @@ class UserCodeStash(BaseUIDStoreStash): def __init__(self, store: DocumentStore) -> None: super().__init__(store=store) - def get_all_by_user_verify_key( - self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey - ) -> Result[list[UserCode], str]: - qks = QueryKeys(qks=[UserVerifyKeyPartitionKey.with_obj(user_verify_key)]) - return self.query_one(credentials=credentials, qks=qks) - def get_by_code_hash( self, credentials: SyftVerifyKey, code_hash: str ) -> Result[UserCode | None, str]: - qks = QueryKeys(qks=[CodeHashPartitionKey.with_obj(code_hash)]) - return self.query_one(credentials=credentials, qks=qks) + return self.get_one_by_field( + credentials=credentials, + field_name="code_hash", + field_value=code_hash, + ) def get_by_service_func_name( self, credentials: SyftVerifyKey, service_func_name: str ) -> Result[list[UserCode], str]: - qks = QueryKeys(qks=[ServiceFuncNamePartitionKey.with_obj(service_func_name)]) - return self.query_all( - credentials=credentials, qks=qks, order_by=SubmitTimePartitionKey + return self.get_all_by_field( + credentials=credentials, + field_name="service_func_name", + field_value=service_func_name, ) diff --git a/packages/syft/src/syft/service/log/log_stash.py b/packages/syft/src/syft/service/log/log_stash.py index 54657982633..f952a70604c 100644 --- a/packages/syft/src/syft/service/log/log_stash.py +++ b/packages/syft/src/syft/service/log/log_stash.py @@ -1,15 +1,15 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import BaseUIDStoreStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...util.telemetry import instrument +from ..job.base_stash import ObjectStash from .log import SyftLog @instrument @serializable(canonical_name="LogStash", version=1) -class LogStash(BaseUIDStoreStash): +class LogStash(ObjectStash[SyftLog]): object_type = SyftLog settings: PartitionSettings = PartitionSettings( name=SyftLog.__canonical_name__, object_type=SyftLog diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index 97b70163ee3..ac27f346ea8 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -686,7 +686,7 @@ def _get_api(self) -> "SyftAPI | SyftError": ) if api is None: return SyftError( - f"Can't access the api. You must login to {self.server_uid}" + message=f"Can't access the api. You must login to {self.server_uid}" ) return api From c9d647c57147a102d2d930c87cd45500b364e375 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 21 Aug 2024 23:10:11 +0200 Subject: [PATCH 020/197] organize files --- .../src/syft/service/code/status_service.py | 2 +- .../src/syft/service/code/user_code_stash.py | 2 +- .../src/syft/service/dataset/dataset_stash.py | 2 +- .../syft/src/syft/service/job/job_stash.py | 2 +- .../syft/src/syft/service/log/log_stash.py | 2 +- .../src/syft/service/output/output_service.py | 2 +- .../src/syft/service/request/request_stash.py | 2 +- .../syft/service/settings/settings_stash.py | 2 +- .../syft/src/syft/service/user/user_stash.py | 2 +- packages/syft/src/syft/store/db/__init__.py | 0 .../{service/job => store/db}/base_stash.py | 132 ++++-------------- packages/syft/src/syft/store/db/models.py | 47 +++++++ packages/syft/src/syft/store/db/sqlite_db.py | 38 +++++ packages/syft/src/syft/store/db/utils.py | 35 +++++ 14 files changed, 154 insertions(+), 116 deletions(-) create mode 100644 packages/syft/src/syft/store/db/__init__.py rename packages/syft/src/syft/{service/job => store/db}/base_stash.py (80%) create mode 100644 packages/syft/src/syft/store/db/models.py create mode 100644 packages/syft/src/syft/store/db/sqlite_db.py create mode 100644 packages/syft/src/syft/store/db/utils.py diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index 3ec191daf28..eceb3713123 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -4,12 +4,12 @@ # relative from ...serde.serializable import serializable +from ...store.db.base_stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...types.uid import UID from ...util.telemetry import instrument from ..context import AuthedServiceContext -from ..job.base_stash import ObjectStash from ..response import SyftError from ..response import SyftSuccess from ..service import AbstractService diff --git a/packages/syft/src/syft/service/code/user_code_stash.py b/packages/syft/src/syft/service/code/user_code_stash.py index 123d8969dce..943483d4db5 100644 --- a/packages/syft/src/syft/service/code/user_code_stash.py +++ b/packages/syft/src/syft/service/code/user_code_stash.py @@ -6,10 +6,10 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey +from ...store.db.base_stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...util.telemetry import instrument -from ..job.base_stash import ObjectStash from .user_code import UserCode diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index 2b8cccda208..17b255a2f64 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -8,12 +8,12 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey +from ...store.db.base_stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from ...types.uid import UID from ...util.telemetry import instrument -from ..job.base_stash import ObjectStash from .dataset import Dataset diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 5c2586e8a4f..d520a18555b 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -23,6 +23,7 @@ from ...server.credentials import SyftVerifyKey from ...service.context import AuthedServiceContext from ...service.worker.worker_pool import SyftWorker +from ...store.db.base_stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...types.datetime import DateTime @@ -45,7 +46,6 @@ from ..response import SyftNotReady from ..response import SyftSuccess from ..user.user import UserView -from .base_stash import ObjectStash from .html_template import job_repr_template diff --git a/packages/syft/src/syft/service/log/log_stash.py b/packages/syft/src/syft/service/log/log_stash.py index f952a70604c..01fec3487f6 100644 --- a/packages/syft/src/syft/service/log/log_stash.py +++ b/packages/syft/src/syft/service/log/log_stash.py @@ -1,9 +1,9 @@ # relative from ...serde.serializable import serializable +from ...store.db.base_stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...util.telemetry import instrument -from ..job.base_stash import ObjectStash from .log import SyftLog diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index c687bc40444..e57ff658e62 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -9,6 +9,7 @@ from ...client.api import APIRegistry from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey +from ...store.db.base_stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings @@ -21,7 +22,6 @@ from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectREAD from ..context import AuthedServiceContext -from ..job.base_stash import ObjectStash from ..response import SyftError from ..service import AbstractService from ..service import TYPE_TO_SERVICE diff --git a/packages/syft/src/syft/service/request/request_stash.py b/packages/syft/src/syft/service/request/request_stash.py index e0fef0a537f..cee6464a419 100644 --- a/packages/syft/src/syft/service/request/request_stash.py +++ b/packages/syft/src/syft/service/request/request_stash.py @@ -6,12 +6,12 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey +from ...store.db.base_stash import ObjectStash from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from ...types.datetime import DateTime from ...types.uid import UID from ...util.telemetry import instrument -from ..job.base_stash import ObjectStash from .request import Request RequestingUserVerifyKeyPartitionKey = PartitionKey( diff --git a/packages/syft/src/syft/service/settings/settings_stash.py b/packages/syft/src/syft/service/settings/settings_stash.py index e662040da0c..c4c4179e0d7 100644 --- a/packages/syft/src/syft/service/settings/settings_stash.py +++ b/packages/syft/src/syft/service/settings/settings_stash.py @@ -4,12 +4,12 @@ # relative from ...serde.serializable import serializable +from ...store.db.base_stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from ...types.uid import UID from ...util.telemetry import instrument -from ..job.base_stash import ObjectStash from .settings import ServerSettings NamePartitionKey = PartitionKey(key="name", type_=str) diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index 7d96444c524..5b8ae398ea3 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -8,11 +8,11 @@ from ...serde.serializable import serializable from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey +from ...store.db.base_stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from ...util.telemetry import instrument -from ..job.base_stash import ObjectStash from .user import User from .user_roles import ServiceRole diff --git a/packages/syft/src/syft/store/db/__init__.py b/packages/syft/src/syft/store/db/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/packages/syft/src/syft/service/job/base_stash.py b/packages/syft/src/syft/store/db/base_stash.py similarity index 80% rename from packages/syft/src/syft/service/job/base_stash.py rename to packages/syft/src/syft/store/db/base_stash.py index 334f2cfc1a3..06da3d82896 100644 --- a/packages/syft/src/syft/service/job/base_stash.py +++ b/packages/syft/src/syft/store/db/base_stash.py @@ -1,13 +1,8 @@ # stdlib # stdlib -from enum import Enum -import json -import threading -from typing import Any from typing import Generic import uuid -from uuid import UUID # third party from result import Err @@ -17,121 +12,34 @@ from sqlalchemy import Column from sqlalchemy import Row from sqlalchemy import Table -from sqlalchemy import TypeDecorator -from sqlalchemy import create_engine from sqlalchemy import func -from sqlalchemy.orm import DeclarativeBase -from sqlalchemy.orm import Mapped from sqlalchemy.orm import Session -from sqlalchemy.orm import mapped_column -from sqlalchemy.orm import sessionmaker from sqlalchemy.types import JSON from typing_extensions import TypeVar # relative -from ...serde.json_serde import Json from ...serde.json_serde import deserialize_json from ...serde.json_serde import serialize_json from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import PartitionKey -from ...types.datetime import DateTime +from ...service.action.action_permissions import ActionObjectEXECUTE +from ...service.action.action_permissions import ActionObjectOWNER +from ...service.action.action_permissions import ActionObjectPermission +from ...service.action.action_permissions import ActionObjectREAD +from ...service.action.action_permissions import ActionObjectWRITE +from ...service.action.action_permissions import ActionPermission +from ...service.action.action_permissions import StoragePermission +from ...service.response import SyftSuccess +from ...service.user.user_roles import ServiceRole from ...types.syft_object import SyftObject from ...types.uid import UID -from ..action.action_permissions import ActionObjectEXECUTE -from ..action.action_permissions import ActionObjectOWNER -from ..action.action_permissions import ActionObjectPermission -from ..action.action_permissions import ActionObjectREAD -from ..action.action_permissions import ActionObjectWRITE -from ..action.action_permissions import ActionPermission -from ..action.action_permissions import StoragePermission -from ..response import SyftSuccess -from ..user.user_roles import ServiceRole - - -class Base(DeclarativeBase): - pass - - -class UIDTypeDecorator(TypeDecorator): - """Converts between Syft UID and UUID.""" - - impl = sa.UUID - cache_ok = True - - def process_bind_param(self, value, dialect): # type: ignore - if value is not None: - return value.value - - def process_result_value(self, value, dialect): # type: ignore - if value is not None: - return UID(value) - - -class CommonMixin: - id: Mapped[UID] = mapped_column( - default=uuid.uuid4, - primary_key=True, - ) - created_at: Mapped[DateTime] = mapped_column(server_default=sa.func.now()) - - updated_at: Mapped[DateTime] = mapped_column( - server_default=sa.func.now(), - server_onupdate=sa.func.now(), - ) - json_document: Mapped[dict] = mapped_column(JSON, default={}) - - -def _default_dumps(val: Any) -> Json: # type: ignore - if isinstance(val, UID): - return str(val.no_dash) - elif isinstance(val, UUID): - return val.hex - elif issubclass(type(val), Enum): - return val.name - elif val is None: - return None - return str(val) - - -def _default_loads(val: Any) -> Any: # type: ignore - if "UID" in val: - return UID(val) - return val - - -def dumps(d: Any) -> str: - return json.dumps(d, default=_default_dumps) - - -def loads(d: str) -> Any: - return json.loads(d, object_hook=_default_loads) - - -class SQLiteDBManager: - def __init__(self, server_uid: str) -> None: - self.server_uid = server_uid - self.path = f"sqlite:////tmp/{server_uid}.db" - self.engine = create_engine( - self.path, json_serializer=dumps, json_deserializer=loads - ) - print(f"Connecting to {self.path}") - self.SessionFactory = sessionmaker(bind=self.engine) - self.thread_local = threading.local() - - Base.metadata.create_all(self.engine) - - def get_session(self) -> Session: - if not hasattr(self.thread_local, "session"): - self.thread_local.session = self.SessionFactory() - return self.thread_local.session - - @property - def session(self) -> Session: - return self.get_session() - +from ..document_store import DocumentStore +from ..document_store import PartitionKey +from .models import Base +from .models import UIDTypeDecorator +from .sqlite_db import SQLiteDBManager SyftT = TypeVar("SyftT", bound=SyftObject) +T = TypeVar("T") class ObjectStash(Generic[SyftT]): @@ -144,6 +52,13 @@ def __init__(self, store: DocumentStore) -> None: _ = self.table self.db = SQLiteDBManager(self.server_uid) + def check_type(self, obj: T, type_: type) -> Result[T, str]: + return ( + Ok(obj) + if isinstance(obj, type_) + else Err(f"{type(obj)} does not match required type: {type_}") + ) + @property def session(self) -> Session: return self.db.session @@ -324,6 +239,7 @@ def update( obj: SyftT, has_permission: bool = False, ) -> Result[SyftT, str]: + # TODO has_permission is not used if not self.is_unique(obj): return Err(f"Some fields are not unique for {type(obj).__name__}") @@ -395,6 +311,7 @@ def get_ownership_permissions( def delete_by_uid( self, credentials: SyftVerifyKey, uid: UID, has_permission: bool = False ) -> Result[SyftSuccess, str]: + # TODO check delete permissions stmt = self.table.delete().where( sa.and_( self._get_field_filter("id", uid), @@ -467,4 +384,5 @@ def has_permission(self, permission: ActionObjectPermission) -> bool: return result is not None def has_storage_permission(self, permission: StoragePermission) -> bool: + # TODO return True diff --git a/packages/syft/src/syft/store/db/models.py b/packages/syft/src/syft/store/db/models.py new file mode 100644 index 00000000000..0288f46d34d --- /dev/null +++ b/packages/syft/src/syft/store/db/models.py @@ -0,0 +1,47 @@ +# stdlib +import uuid + +# third party +import sqlalchemy as sa +from sqlalchemy import TypeDecorator +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.types import JSON + +# relative +from ...types.datetime import DateTime +from ...types.uid import UID + + +class Base(DeclarativeBase): + pass + + +class UIDTypeDecorator(TypeDecorator): + """Converts between Syft UID and UUID.""" + + impl = sa.UUID + cache_ok = True + + def process_bind_param(self, value, dialect): # type: ignore + if value is not None: + return value.value + + def process_result_value(self, value, dialect): # type: ignore + if value is not None: + return UID(value) + + +class CommonMixin: + id: Mapped[UID] = mapped_column( + default=uuid.uuid4, + primary_key=True, + ) + created_at: Mapped[DateTime] = mapped_column(server_default=sa.func.now()) + + updated_at: Mapped[DateTime] = mapped_column( + server_default=sa.func.now(), + server_onupdate=sa.func.now(), + ) + json_document: Mapped[dict] = mapped_column(JSON, default={}) diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py new file mode 100644 index 00000000000..578ca910c48 --- /dev/null +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -0,0 +1,38 @@ +# stdlib +import threading + +# third party +from sqlalchemy import create_engine +from sqlalchemy.orm import Session +from sqlalchemy.orm import sessionmaker + +# relative +from ...types.uid import UID +from .models import Base +from .utils import dumps +from .utils import loads + + +class SQLiteDBManager: + def __init__(self, server_uid: UID) -> None: + self.server_uid = server_uid + self.path = f"sqlite:////tmp/{str(server_uid)}.db" + self.engine = create_engine( + self.path, json_serializer=dumps, json_deserializer=loads + ) + print(f"Connecting to {self.path}") + self.Session = sessionmaker(bind=self.engine) + self.thread_local = threading.local() + + Base.metadata.create_all(self.engine) + + # TODO remove + def get_session_threading_local(self) -> Session: + if not hasattr(self.thread_local, "session"): + self.thread_local.session = self.Session() + return self.thread_local.session + + # TODO remove + @property + def session(self) -> Session: + return self.get_session_threading_local() diff --git a/packages/syft/src/syft/store/db/utils.py b/packages/syft/src/syft/store/db/utils.py new file mode 100644 index 00000000000..870186fa66c --- /dev/null +++ b/packages/syft/src/syft/store/db/utils.py @@ -0,0 +1,35 @@ +# stdlib +from enum import Enum +import json +from typing import Any +from uuid import UUID + +# relative +from ...serde.json_serde import Json +from ...types.uid import UID + + +def _default_dumps(val: Any) -> Json: # type: ignore + if isinstance(val, UID): + return str(val.no_dash) + elif isinstance(val, UUID): + return val.hex + elif issubclass(type(val), Enum): + return val.name + elif val is None: + return None + return str(val) + + +def _default_loads(val: Any) -> Any: # type: ignore + if "UID" in val: + return UID(val) + return val + + +def dumps(d: Any) -> str: + return json.dumps(d, default=_default_dumps) + + +def loads(d: str) -> Any: + return json.loads(d, object_hook=_default_loads) From 3529f4429eb652ac1802d9776efd3f22f9a594af Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 21 Aug 2024 23:25:17 +0200 Subject: [PATCH 021/197] cleaning --- .../src/syft/service/code/status_service.py | 2 +- .../src/syft/service/code/user_code_stash.py | 2 +- .../src/syft/service/dataset/dataset_stash.py | 2 +- .../syft/src/syft/service/job/job_stash.py | 2 +- .../syft/src/syft/service/log/log_stash.py | 2 +- .../src/syft/service/output/output_service.py | 2 +- .../src/syft/service/request/request_stash.py | 2 +- .../syft/service/settings/settings_stash.py | 2 +- .../syft/src/syft/service/user/user_stash.py | 2 +- packages/syft/src/syft/store/db/models.py | 19 ------------------- .../syft/store/db/{base_stash.py => stash.py} | 0 11 files changed, 9 insertions(+), 28 deletions(-) rename packages/syft/src/syft/store/db/{base_stash.py => stash.py} (100%) diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index eceb3713123..62de0a020fe 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -4,7 +4,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.base_stash import ObjectStash +from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...types.uid import UID diff --git a/packages/syft/src/syft/service/code/user_code_stash.py b/packages/syft/src/syft/service/code/user_code_stash.py index 943483d4db5..6d514f0a448 100644 --- a/packages/syft/src/syft/service/code/user_code_stash.py +++ b/packages/syft/src/syft/service/code/user_code_stash.py @@ -6,7 +6,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.base_stash import ObjectStash +from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...util.telemetry import instrument diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index 17b255a2f64..b42719ee579 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -8,7 +8,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.base_stash import ObjectStash +from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index d520a18555b..c2d612dfcbf 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -23,7 +23,7 @@ from ...server.credentials import SyftVerifyKey from ...service.context import AuthedServiceContext from ...service.worker.worker_pool import SyftWorker -from ...store.db.base_stash import ObjectStash +from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...types.datetime import DateTime diff --git a/packages/syft/src/syft/service/log/log_stash.py b/packages/syft/src/syft/service/log/log_stash.py index 01fec3487f6..54d7c9ba04b 100644 --- a/packages/syft/src/syft/service/log/log_stash.py +++ b/packages/syft/src/syft/service/log/log_stash.py @@ -1,6 +1,6 @@ # relative from ...serde.serializable import serializable -from ...store.db.base_stash import ObjectStash +from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...util.telemetry import instrument diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index e57ff658e62..6c6ce1ab400 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -9,7 +9,7 @@ from ...client.api import APIRegistry from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.base_stash import ObjectStash +from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings diff --git a/packages/syft/src/syft/service/request/request_stash.py b/packages/syft/src/syft/service/request/request_stash.py index cee6464a419..42e818e554b 100644 --- a/packages/syft/src/syft/service/request/request_stash.py +++ b/packages/syft/src/syft/service/request/request_stash.py @@ -6,7 +6,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.base_stash import ObjectStash +from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from ...types.datetime import DateTime diff --git a/packages/syft/src/syft/service/settings/settings_stash.py b/packages/syft/src/syft/service/settings/settings_stash.py index c4c4179e0d7..49adfe8979f 100644 --- a/packages/syft/src/syft/service/settings/settings_stash.py +++ b/packages/syft/src/syft/service/settings/settings_stash.py @@ -4,7 +4,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.base_stash import ObjectStash +from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index 5b8ae398ea3..faf4fff404b 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -8,7 +8,7 @@ from ...serde.serializable import serializable from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey -from ...store.db.base_stash import ObjectStash +from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings diff --git a/packages/syft/src/syft/store/db/models.py b/packages/syft/src/syft/store/db/models.py index 0288f46d34d..92dd2c2a6fa 100644 --- a/packages/syft/src/syft/store/db/models.py +++ b/packages/syft/src/syft/store/db/models.py @@ -1,16 +1,11 @@ # stdlib -import uuid # third party import sqlalchemy as sa from sqlalchemy import TypeDecorator from sqlalchemy.orm import DeclarativeBase -from sqlalchemy.orm import Mapped -from sqlalchemy.orm import mapped_column -from sqlalchemy.types import JSON # relative -from ...types.datetime import DateTime from ...types.uid import UID @@ -31,17 +26,3 @@ def process_bind_param(self, value, dialect): # type: ignore def process_result_value(self, value, dialect): # type: ignore if value is not None: return UID(value) - - -class CommonMixin: - id: Mapped[UID] = mapped_column( - default=uuid.uuid4, - primary_key=True, - ) - created_at: Mapped[DateTime] = mapped_column(server_default=sa.func.now()) - - updated_at: Mapped[DateTime] = mapped_column( - server_default=sa.func.now(), - server_onupdate=sa.func.now(), - ) - json_document: Mapped[dict] = mapped_column(JSON, default={}) diff --git a/packages/syft/src/syft/store/db/base_stash.py b/packages/syft/src/syft/store/db/stash.py similarity index 100% rename from packages/syft/src/syft/store/db/base_stash.py rename to packages/syft/src/syft/store/db/stash.py From 343ed69912736f2e3a48de66f1285f24e8fa8387 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 22 Aug 2024 12:05:53 +0200 Subject: [PATCH 022/197] add attr_searchable --- packages/syft/src/syft/serde/json_serde.py | 53 +++++++++++++++++++ .../syft/src/syft/service/dataset/dataset.py | 1 + .../syft/src/syft/service/job/job_stash.py | 2 + packages/syft/src/syft/store/db/stash.py | 8 +++ 4 files changed, 64 insertions(+) diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index e5e67d818d9..8f7d32caf24 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -136,9 +136,62 @@ def _serialize_pydantic_to_json(obj: pydantic.BaseModel) -> dict[str, Json]: for key, type_ in obj.model_fields.items(): result[key] = serialize_json(getattr(obj, key), type_.annotation) + + result = _serialize_searchable_attrs(obj, result) + return result +def get_property_return_type(obj: Any, attr_name: str) -> Any: + """ + Get the return type annotation of a @property. + """ + cls = type(obj) + attr = getattr(cls, attr_name, None) + + if isinstance(attr, property): + return attr.fget.__annotations__.get("return", None) + + return None + + +def _serialize_searchable_attrs( + obj: pydantic.BaseModel, obj_dict: dict[str, Json], raise_errors: bool = True +) -> dict[str, Json]: + """ + Add searchable attrs to the serialized object dict, if they are not already present. + Needed for adding non-field attributes (like @property) + + Args: + obj (pydantic.BaseModel): Object to serialize. + obj_dict (dict[str, Json]): Serialized object dict. Should contain the object's fields. + raise_errors (bool, optional): Raise errors if an attribute cannot be accessed. + If False, the attribute will be skipped. Defaults to True. + + Raises: + Exception: Any exception raised when accessing an attribute. + + Returns: + dict[str, Json]: Serialized object dict including searchable attributes. + """ + searchable_attrs: list[str] = getattr(obj, "__attr_searchable__", []) + for attr in searchable_attrs: + if attr not in obj_dict: + try: + value = getattr(obj, attr) + except Exception as e: + if raise_errors: + raise e + else: + continue + property_annotation = get_property_return_type(obj, attr) + obj_dict[attr] = serialize_json( + value, validate=False, annotation=property_annotation + ) + + return obj_dict + + def _deserialize_pydantic_from_json( obj_dict: dict[str, Json], ) -> pydantic.BaseModel: diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 9f8909d989a..8b6431375cd 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -538,6 +538,7 @@ def _repr_html_(self) -> Any: {self.assets._repr_html_()} """ + @property def action_ids(self) -> list[UID]: return [asset.action_id for asset in self.asset_list if asset.action_id] diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index c2d612dfcbf..dccfa54abe9 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -117,6 +117,7 @@ class Job(SyncableSyftObject): "user_code_id", "result_id", ] + __repr_attrs__ = [ "id", "result", @@ -125,6 +126,7 @@ class Job(SyncableSyftObject): "creation_time", "user_code_name", ] + __exclude_sync_diff_attrs__ = ["action", "server_uid"] __table_coll_widths__ = [ "min-content", diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 06da3d82896..e9a257ca83c 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -239,6 +239,14 @@ def update( obj: SyftT, has_permission: bool = False, ) -> Result[SyftT, str]: + """ + NOTE: We cannot do partial updates on the database, + because we are using computed fields that are not known to the DB or ORM: + - serialize_json will add computed fields to the JSON stored in the database + - If we update a single field in the JSON, the computed fields can get out of sync. + - To fix, we either need db-supported computed fields, or know in our ORM which fields should be re-computed. + """ + # TODO has_permission is not used if not self.is_unique(obj): return Err(f"Some fields are not unique for {type(obj).__name__}") From 05e83060532af9b11d1fcfd654f738adac135896 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 22 Aug 2024 17:37:20 +0200 Subject: [PATCH 023/197] add order_by for sync stash --- .../syft/src/syft/service/api/api_stash.py | 30 +++++---- packages/syft/src/syft/store/db/stash.py | 64 +++++++++++++++++-- 2 files changed, 76 insertions(+), 18 deletions(-) diff --git a/packages/syft/src/syft/service/api/api_stash.py b/packages/syft/src/syft/service/api/api_stash.py index 3b6daac1422..94567d3ce05 100644 --- a/packages/syft/src/syft/service/api/api_stash.py +++ b/packages/syft/src/syft/service/api/api_stash.py @@ -8,7 +8,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash +from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from .api import TwinAPIEndpoint @@ -16,8 +16,8 @@ MISSING_PATH_STRING = "Endpoint path: {path} does not exist." -@serializable(canonical_name="TwinAPIEndpointStash", version=1) -class TwinAPIEndpointStash(BaseUIDStoreStash): +@serializable(canonical_name="TwinAPIEndpointSQLStash", version=1) +class TwinAPIEndpointStash(ObjectStash[TwinAPIEndpoint]): object_type = TwinAPIEndpoint settings: PartitionSettings = PartitionSettings( name=TwinAPIEndpoint.__canonical_name__, object_type=TwinAPIEndpoint @@ -29,19 +29,20 @@ def __init__(self, store: DocumentStore) -> None: def get_by_path( self, credentials: SyftVerifyKey, path: str ) -> Result[TwinAPIEndpoint, str]: - endpoint_results = self.get_all(credentials=credentials) - if endpoint_results.is_err(): - return endpoint_results - - endpoints = [] - if endpoint_results.is_ok(): - endpoints = endpoint_results.ok() + # TODO standardize by returning None if endpoint doesnt exist. + res_or_err = self.get_one_by_field( + credentials=credentials, + field_name="path", + field_value=path, + ) - for endpoint in endpoints: - if endpoint.path == path: - return Ok(endpoint) + if res_or_err.is_err(): + return res_or_err - return Err(MISSING_PATH_STRING.format(path=path)) + res = res_or_err.ok() + if res is None: + return Err(MISSING_PATH_STRING.format(path=path)) + return Ok(res) def path_exists(self, credentials: SyftVerifyKey, path: str) -> Result[bool, str]: result = self.get_by_path(credentials=credentials, path=path) @@ -60,6 +61,7 @@ def upsert( has_permission: bool = False, ) -> Result[TwinAPIEndpoint, str]: """Upsert an endpoint.""" + # TODO has_permission is not used. result = self.path_exists(credentials=credentials, path=endpoint.path) if result.is_err(): diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index e9a257ca83c..dfe3bddec19 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -1,6 +1,7 @@ # stdlib # stdlib +from typing import Any from typing import Generic import uuid @@ -33,7 +34,6 @@ from ...types.syft_object import SyftObject from ...types.uid import UID from ..document_store import DocumentStore -from ..document_store import PartitionKey from .models import Base from .models import UIDTypeDecorator from .sqlite_db import SQLiteDBManager @@ -144,6 +144,10 @@ def _get_by_field( field_name: str, field_value: str, table: Table | None = None, + order_by: str | None = None, + sort_order: str = "asc", + limit: int | None = None, + offset: int | None = None, ) -> Result[Row, str]: table = table if table is not None else self.table stmt = table.select().where( @@ -152,6 +156,9 @@ def _get_by_field( self._get_permission_filter(credentials), ) ) + stmt = self._apply_order_by(stmt, order_by, sort_order) + stmt = self._apply_limit_offset(stmt, limit, offset) + result = self.session.execute(stmt) return result @@ -166,10 +173,23 @@ def get_one_by_field( return Ok(self.row_as_obj(result)) def get_all_by_field( - self, credentials: SyftVerifyKey, field_name: str, field_value: str + self, + credentials: SyftVerifyKey, + field_name: str, + field_value: str, + order_by: str | None = None, + sort_order: str = "asc", + limit: int | None = None, + offset: int | None = None, ) -> Result[list[SyftT], str]: result = self._get_by_field( - credentials=credentials, field_name=field_name, field_value=field_value + credentials=credentials, + field_name=field_name, + field_value=field_value, + order_by=order_by, + sort_order=sort_order, + limit=limit, + offset=offset, ).all() objs = [self.row_as_obj(row) for row in result] return Ok(objs) @@ -218,17 +238,53 @@ def _get_permission_filter( self.table.c.permissions.contains(compound_permission_string), ) + def _apply_limit_offset( + self, + stmt: Any, + limit: int | None = None, + offset: int | None = None, + ) -> Any: + if offset is not None: + stmt = stmt.offset(offset) + if limit is not None: + stmt = stmt.limit(limit) + return stmt + + def _apply_order_by( + self, + stmt: Any, + order_by: str | None = None, + sort_order: str = "asc", + ) -> Any: + default_order_by = self.table.c.created_at + default_order_by = ( + default_order_by.desc() if sort_order == "desc" else default_order_by + ) + if order_by is None: + return stmt.order_by(default_order_by) + else: + order_by_col = self.table.c.fields[order_by] + order_by = order_by_col.desc() if sort_order == "desc" else order_by_col + return stmt.order_by(order_by, default_order_by) + def get_all( self, credentials: SyftVerifyKey, - order_by: PartitionKey | None = None, has_permission: bool = False, + order_by: str | None = None, + sort_order: str = "asc", + limit: int | None = None, + offset: int | None = None, ) -> Result[list[SyftT], str]: # filter by read permission # join on verify_key stmt = self.table.select() if not has_permission: stmt = stmt.where(self._get_permission_filter(credentials)) + + stmt = self._apply_order_by(stmt, order_by, sort_order) + stmt = self._apply_limit_offset(stmt, limit, offset) + result = self.session.execute(stmt).all() objs = [self.row_as_obj(row) for row in result] return Ok(objs) From e23af54b808f297e91b38194a0140225081f72a6 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 22 Aug 2024 18:44:55 +0200 Subject: [PATCH 024/197] sync stash --- .../src/syft/service/sync/sync_service.py | 2 +- .../syft/src/syft/service/sync/sync_stash.py | 53 ++++++------------- packages/syft/src/syft/store/db/stash.py | 4 +- 3 files changed, 19 insertions(+), 40 deletions(-) diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index a7a0dfbe2df..32686cb75dc 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -395,7 +395,7 @@ def build_current_state( permissions = {} storage_permissions = {} - previous_state = self.stash.get_latest(context=context) + previous_state = self.stash.get_latest(context.credentials) if previous_state.is_err(): return previous_state previous_state = previous_state.ok() diff --git a/packages/syft/src/syft/service/sync/sync_stash.py b/packages/syft/src/syft/service/sync/sync_stash.py index 6c989c1b8ef..b3cf80a362f 100644 --- a/packages/syft/src/syft/service/sync/sync_stash.py +++ b/packages/syft/src/syft/service/sync/sync_stash.py @@ -3,7 +3,6 @@ # stdlib # stdlib -import threading # third party from result import Ok @@ -11,13 +10,13 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import BaseUIDStoreStash +from ...server.credentials import SyftVerifyKey +from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from ...types.datetime import DateTime from ...util.telemetry import instrument -from ..context import AuthedServiceContext from .sync_state import SyncState OrderByDatePartitionKey = PartitionKey(key="created_at", type_=DateTime) @@ -25,7 +24,7 @@ @instrument @serializable(canonical_name="SyncStash", version=1) -class SyncStash(BaseUIDStoreStash): +class SyncStash(ObjectStash[SyncState]): object_type = SyncState settings: PartitionSettings = PartitionSettings( name=SyncState.__canonical_name__, @@ -35,45 +34,23 @@ class SyncStash(BaseUIDStoreStash): def __init__(self, store: DocumentStore): super().__init__(store) self.store = store - self.settings = self.settings - self._object_type = self.object_type self.last_state: SyncState | None = None - def get_latest( - self, context: AuthedServiceContext - ) -> Result[SyncState | None, str]: - # print("SyncStash.get_latest called") + def get_latest(self, credentials: SyftVerifyKey) -> Result[SyncState | None, str]: if self.last_state is not None: return Ok(self.last_state) - all_states = self.get_all( - credentials=context.server.verify_key, # type: ignore - order_by=OrderByDatePartitionKey, + + states_or_err = self.get_all( + credentials=credentials, + sort_order="desc", + limit=1, ) - if all_states.is_err(): - return all_states + if states_or_err.is_err(): + return states_or_err - all_states = all_states.ok() - if len(all_states) > 0: - self.last_state = all_states[-1] - return Ok(all_states[-1]) + last_state = states_or_err.ok() + if len(last_state) > 0: + self.last_state = last_state[0] + return Ok(last_state[0]) return Ok(None) - - def set( # type: ignore - self, - context: AuthedServiceContext, - item: SyncState, - **kwargs, - ) -> Result[SyncState, str]: - self.last_state = item - - # use threading - threading.Thread( - target=super().set, - args=( - context, - item, - ), - kwargs=kwargs, - ).start() - return Ok(item) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index dfe3bddec19..9500cff952c 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -74,7 +74,9 @@ def table(self) -> Table: Column("id", UIDTypeDecorator, primary_key=True, default=uuid.uuid4), Column("fields", JSON, default={}), Column("permissions", JSON, default=[]), - Column("created_at", sa.DateTime, server_default=sa.func.now()), + Column( + "created_at", sa.DateTime, server_default=sa.func.now(), index=True + ), Column("updated_at", sa.DateTime, server_onupdate=sa.func.now()), ) return Base.metadata.tables[table_name] From 85f39917aea69fcfa4a933747aaebaeb007e70ee Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 23 Aug 2024 11:14:20 +0200 Subject: [PATCH 025/197] implement code history --- .../code_history/code_history_stash.py | 43 +++++----- packages/syft/src/syft/store/db/stash.py | 82 +++++++++++++++---- 2 files changed, 88 insertions(+), 37 deletions(-) diff --git a/packages/syft/src/syft/service/code_history/code_history_stash.py b/packages/syft/src/syft/service/code_history/code_history_stash.py index c419416664a..ee5ff63d6ba 100644 --- a/packages/syft/src/syft/service/code_history/code_history_stash.py +++ b/packages/syft/src/syft/service/code_history/code_history_stash.py @@ -6,19 +6,18 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash +from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys from .code_history import CodeHistory NamePartitionKey = PartitionKey(key="service_func_name", type_=str) VerifyKeyPartitionKey = PartitionKey(key="user_verify_key", type_=SyftVerifyKey) -@serializable(canonical_name="CodeHistoryStash", version=1) -class CodeHistoryStash(BaseUIDStoreStash): +@serializable(canonical_name="CodeHistoryStashSQL", version=1) +class CodeHistoryStash(ObjectStash[CodeHistory]): object_type = CodeHistory settings: PartitionSettings = PartitionSettings( name=CodeHistory.__canonical_name__, object_type=CodeHistory @@ -33,30 +32,30 @@ def get_by_service_func_name_and_verify_key( service_func_name: str, user_verify_key: SyftVerifyKey, ) -> Result[list[CodeHistory], str]: - qks = QueryKeys( - qks=[ - NamePartitionKey.with_obj(service_func_name), - VerifyKeyPartitionKey.with_obj(user_verify_key), - ] + return self.get_one_by_fields( + credentials=credentials, + fields={ + "user_verify_key": str(user_verify_key), + "service_func_name": service_func_name, + }, ) - return self.query_one(credentials=credentials, qks=qks) def get_by_service_func_name( self, credentials: SyftVerifyKey, service_func_name: str ) -> Result[list[CodeHistory], str]: - qks = QueryKeys(qks=[NamePartitionKey.with_obj(service_func_name)]) - return self.query_all(credentials=credentials, qks=qks) + return self.get_all_by_field( + credentials=credentials, + field_name="service_func_name", + field_value=service_func_name, + ) def get_by_verify_key( self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey ) -> Result[CodeHistory | None, str]: - if isinstance(user_verify_key, str): - user_verify_key = SyftVerifyKey.from_string(user_verify_key) - qks = QueryKeys(qks=[VerifyKeyPartitionKey.with_obj(user_verify_key)]) - return self.query_all(credentials=credentials, qks=qks) - - # def get_version(self, name:str, version:int) -> Optional[UserCode]: - # for obj in self.objs.values(): - # if obj.name == name and obj.version == version: - # return obj - # return None + if not isinstance(user_verify_key, str): + user_verify_key = str(user_verify_key) + return self.get_all_by_field( + credentials=credentials, + field_name="user_verify_key", + field_value=user_verify_key, + ) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 9500cff952c..631766e8dd2 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -12,6 +12,7 @@ import sqlalchemy as sa from sqlalchemy import Column from sqlalchemy import Row +from sqlalchemy import Select from sqlalchemy import Table from sqlalchemy import func from sqlalchemy.orm import Session @@ -140,11 +141,10 @@ def _get_field_filter( return table.c.id == field_value return table.c.fields[field_name] == func.json_quote(field_value) - def _get_by_field( + def _get_by_fields( self, credentials: SyftVerifyKey, - field_name: str, - field_value: str, + fields: dict[str, str], table: Table | None = None, order_by: str | None = None, sort_order: str = "asc", @@ -152,9 +152,14 @@ def _get_by_field( offset: int | None = None, ) -> Result[Row, str]: table = table if table is not None else self.table + filters = [] + for field_name, field_value in fields.items(): + filt = self._get_field_filter(field_name, field_value, table=table) + filters.append(filt) + stmt = table.select().where( sa.and_( - self._get_field_filter(field_name, field_value, table=table), + sa.and_(*filters), self._get_permission_filter(credentials), ) ) @@ -167,13 +172,54 @@ def _get_by_field( def get_one_by_field( self, credentials: SyftVerifyKey, field_name: str, field_value: str ) -> Result[SyftT | None, str]: - result = self._get_by_field( - credentials=credentials, field_name=field_name, field_value=field_value + result = self._get_by_fields( + credentials=credentials, + fields={field_name: field_value}, + ).first() + if result is None: + return Ok(None) + return Ok(self.row_as_obj(result)) + + def get_one_by_fields( + self, + credentials: SyftVerifyKey, + fields: dict[str, str], + ) -> Result[SyftT | None, str]: + result = self._get_by_fields( + credentials=credentials, + fields=fields, ).first() if result is None: return Ok(None) return Ok(self.row_as_obj(result)) + def get_all_by_fields( + self, + credentials: SyftVerifyKey, + fields: dict[str, str], + order_by: str | None = None, + sort_order: str = "asc", + limit: int | None = None, + offset: int | None = None, + ) -> Result[list[SyftT], str]: + # sanity check if the field is not a list, set etc. + for field_name in fields: + if field_name not in self.object_type.__annotations__: + return Err( + f"Field {field_name} not found in {self.object_type.__name__}" + ) + + result = self._get_by_fields( + credentials=credentials, + fields=fields, + order_by=order_by, + sort_order=sort_order, + limit=limit, + offset=offset, + ).all() + objs = [self.row_as_obj(row) for row in result] + return Ok(objs) + def get_all_by_field( self, credentials: SyftVerifyKey, @@ -184,10 +230,9 @@ def get_all_by_field( limit: int | None = None, offset: int | None = None, ) -> Result[list[SyftT], str]: - result = self._get_by_field( + result = self._get_by_fields( credentials=credentials, - field_name=field_name, - field_value=field_value, + fields={field_name: field_value}, order_by=order_by, sort_order=sort_order, limit=limit, @@ -242,7 +287,7 @@ def _get_permission_filter( def _apply_limit_offset( self, - stmt: Any, + stmt: Select, limit: int | None = None, offset: int | None = None, ) -> Any: @@ -254,7 +299,7 @@ def _apply_limit_offset( def _apply_order_by( self, - stmt: Any, + stmt: Select, order_by: str | None = None, sort_order: str = "asc", ) -> Any: @@ -269,6 +314,16 @@ def _apply_order_by( order_by = order_by_col.desc() if sort_order == "desc" else order_by_col return stmt.order_by(order_by, default_order_by) + def _apply_permission_filter( + self, + stmt: Select, + credentials: SyftVerifyKey, + has_permission: bool = False, + ) -> Any: + if not has_permission: + stmt = stmt.where(self._get_permission_filter(credentials)) + return stmt + def get_all( self, credentials: SyftVerifyKey, @@ -278,12 +333,9 @@ def get_all( limit: int | None = None, offset: int | None = None, ) -> Result[list[SyftT], str]: - # filter by read permission - # join on verify_key stmt = self.table.select() - if not has_permission: - stmt = stmt.where(self._get_permission_filter(credentials)) + stmt = self._apply_permission_filter(stmt, credentials, has_permission) stmt = self._apply_order_by(stmt, order_by, sort_order) stmt = self._apply_limit_offset(stmt, limit, offset) From 94a75af8bdb8184ba0a6fa76bc8af96c94a66cf1 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 23 Aug 2024 11:20:19 +0200 Subject: [PATCH 026/197] blob store, contains --- packages/syft/src/syft/serde/json_serde.py | 9 +++++- .../src/syft/service/blob_storage/stash.py | 6 ++-- .../src/syft/service/dataset/dataset_stash.py | 20 ++++++++----- packages/syft/src/syft/store/db/stash.py | 28 ++++++++++++++++++- 4 files changed, 51 insertions(+), 12 deletions(-) diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index 8f7d32caf24..39177ffe620 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -360,7 +360,7 @@ def serialize_json(value: Any, annotation: Any = None, validate: bool = True) -> return result -def deserialize_json(value: Json, annotation: Any) -> Any: +def deserialize_json(value: Json, annotation: Any = None) -> Any: """Deserialize a JSON-serializable object to a value, using the schema defined by the provided annotation. Inverse of `serialize_json`. @@ -371,6 +371,13 @@ def deserialize_json(value: Json, annotation: Any) -> Any: Returns: Any: Deserialized value. """ + if ( + isinstance(value, dict) + and JSON_CANONICAL_NAME_FIELD in value + and JSON_VERSION_FIELD in value + ): + return _deserialize_pydantic_from_json(value) + if value is None: return None diff --git a/packages/syft/src/syft/service/blob_storage/stash.py b/packages/syft/src/syft/service/blob_storage/stash.py index 8fc93a4f034..67ddd5f8ebb 100644 --- a/packages/syft/src/syft/service/blob_storage/stash.py +++ b/packages/syft/src/syft/service/blob_storage/stash.py @@ -1,13 +1,13 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import BaseUIDStoreStash +from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...types.blob_storage import BlobStorageEntry -@serializable(canonical_name="BlobStorageStash", version=1) -class BlobStorageStash(BaseUIDStoreStash): +@serializable(canonical_name="BlobStorageSQLStash", version=1) +class BlobStorageStash(ObjectStash[BlobStorageEntry]): object_type = BlobStorageEntry settings: PartitionSettings = PartitionSettings( name=BlobStorageEntry.__canonical_name__, object_type=BlobStorageEntry diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index b42719ee579..81f06098c06 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -1,8 +1,8 @@ # stdlib +# stdlib + # third party -from result import Err -from result import Ok from result import Result # relative @@ -10,7 +10,6 @@ from ...server.credentials import SyftVerifyKey from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore -from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from ...types.uid import UID from ...util.telemetry import instrument @@ -38,21 +37,28 @@ def get_by_name( def search_action_ids( self, credentials: SyftVerifyKey, uid: UID ) -> Result[list[Dataset], str]: - return self.get_all_by_field( + return self.get_all_contains( credentials=credentials, field_name="action_ids", - field_value=str(uid), + field_value=uid.no_dash, ) def get_all( self, credentials: SyftVerifyKey, - order_by: PartitionKey | None = None, has_permission: bool = False, - ) -> Ok[list] | Err[str]: + order_by: str | None = None, + sort_order: str = "asc", + limit: int | None = None, + offset: int | None = None, + ) -> Result[list[Dataset], str]: result = self.get_all_by_field( credentials=credentials, field_name="to_be_deleted", field_value=False, + order_by=order_by, + sort_order=sort_order, + limit=limit, + offset=offset, ) return result diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 9500cff952c..f192d6066a9 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -196,8 +196,34 @@ def get_all_by_field( objs = [self.row_as_obj(row) for row in result] return Ok(objs) + def get_all_contains( + self, + credentials: SyftVerifyKey, + field_name: str, + field_value: str, + order_by: str | None = None, + sort_order: str = "asc", + limit: int | None = None, + offset: int | None = None, + ) -> Result[list[SyftT], str]: + # TODO write filter logic, merge with get_all + + stmt = self.table.select().where( + sa.and_( + self.table.c.fields[field_name].contains(func.json_quote(field_value)), + self._get_permission_filter(credentials), + ) + ) + + stmt = self._apply_order_by(stmt, order_by, sort_order) + stmt = self._apply_limit_offset(stmt, limit, offset) + + result = self.session.execute(stmt).all() + objs = [self.row_as_obj(row) for row in result] + return Ok(objs) + def row_as_obj(self, row: Row) -> SyftT: - return deserialize_json(row.fields, self.object_type) + return deserialize_json(row.fields) def get_role(self, credentials: SyftVerifyKey) -> ServiceRole: user_table = Table("User", Base.metadata) From 62431565642fb3eb0a32e2418a124728b78429cf Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 23 Aug 2024 11:21:31 +0200 Subject: [PATCH 027/197] key to str --- .../syft/src/syft/service/code_history/code_history_stash.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/packages/syft/src/syft/service/code_history/code_history_stash.py b/packages/syft/src/syft/service/code_history/code_history_stash.py index ee5ff63d6ba..01acd28b41f 100644 --- a/packages/syft/src/syft/service/code_history/code_history_stash.py +++ b/packages/syft/src/syft/service/code_history/code_history_stash.py @@ -52,10 +52,8 @@ def get_by_service_func_name( def get_by_verify_key( self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey ) -> Result[CodeHistory | None, str]: - if not isinstance(user_verify_key, str): - user_verify_key = str(user_verify_key) return self.get_all_by_field( credentials=credentials, field_name="user_verify_key", - field_value=user_verify_key, + field_value=str(user_verify_key), ) From e47682ea560087c725d96a7eaa8d86b065141acb Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 23 Aug 2024 12:29:59 +0200 Subject: [PATCH 028/197] update permissions --- packages/syft/src/syft/store/db/stash.py | 68 +++++++++++++++--------- 1 file changed, 44 insertions(+), 24 deletions(-) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index be70bd3e11f..5518047befb 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -45,12 +45,13 @@ class ObjectStash(Generic[SyftT]): object_type: type[SyftT] + table: Table def __init__(self, store: DocumentStore) -> None: self.server_uid = store.server_uid self.root_verify_key = store.root_verify_key # is there a better way to init the table - _ = self.table + self.table = self._create_table() self.db = SQLiteDBManager(self.server_uid) def check_type(self, obj: T, type_: type) -> Result[T, str]: @@ -64,8 +65,7 @@ def check_type(self, obj: T, type_: type) -> Result[T, str]: def session(self) -> Session: return self.db.session - @property - def table(self) -> Table: + def _create_table(self) -> Table: # need to call Base.metadata.create_all(engine) to create the table table_name = self.object_type.__canonical_name__ if table_name not in Base.metadata.tables: @@ -115,17 +115,22 @@ def is_unique(self, obj: SyftT) -> bool: return result.id == obj.id return True + def exists(self, credentials: SyftVerifyKey, uid: UID) -> bool: + # TODO needs credentials check? + # TODO use COUNT(*) instead of SELECT + stmt = self.table.select().where(self._get_field_filter("id", uid)) + result = self.session.execute(stmt).first() + return result is not None + def get_by_uid( - self, credentials: SyftVerifyKey, uid: UID + self, credentials: SyftVerifyKey, uid: UID, has_permission: bool = False ) -> Result[SyftT | None, str]: - result = self.session.execute( - self.table.select().where( - sa.and_( - self._get_field_filter("id", uid), - self._get_permission_filter(credentials), - ) - ) - ).first() + # TODO implement has_permission + stmt = self.table.select() + stmt = stmt.where(self._get_field_filter("id", uid)) + stmt = self._apply_permission_filter(stmt, credentials, has_permission) + result = self.session.execute(stmt).first() + if result is None: return Ok(None) return Ok(self.row_as_obj(result)) @@ -150,6 +155,7 @@ def _get_by_fields( sort_order: str = "asc", limit: int | None = None, offset: int | None = None, + has_permission: bool = False, ) -> Result[Row, str]: table = table if table is not None else self.table filters = [] @@ -157,11 +163,10 @@ def _get_by_fields( filt = self._get_field_filter(field_name, field_value, table=table) filters.append(filt) - stmt = table.select().where( - sa.and_( - sa.and_(*filters), - self._get_permission_filter(credentials), - ) + stmt = table.select() + stmt = stmt.where(sa.and_(*filters)) + stmt = self._apply_permission_filter( + stmt, credentials, has_permission=has_permission ) stmt = self._apply_order_by(stmt, order_by, sort_order) stmt = self._apply_limit_offset(stmt, limit, offset) @@ -170,11 +175,16 @@ def _get_by_fields( return result def get_one_by_field( - self, credentials: SyftVerifyKey, field_name: str, field_value: str + self, + credentials: SyftVerifyKey, + field_name: str, + field_value: str, + has_permission: bool = False, ) -> Result[SyftT | None, str]: result = self._get_by_fields( credentials=credentials, fields={field_name: field_value}, + has_permission=has_permission, ).first() if result is None: return Ok(None) @@ -184,10 +194,12 @@ def get_one_by_fields( self, credentials: SyftVerifyKey, fields: dict[str, str], + has_permission: bool = False, ) -> Result[SyftT | None, str]: result = self._get_by_fields( credentials=credentials, fields=fields, + has_permission=has_permission, ).first() if result is None: return Ok(None) @@ -201,6 +213,7 @@ def get_all_by_fields( sort_order: str = "asc", limit: int | None = None, offset: int | None = None, + has_permission: bool = False, ) -> Result[list[SyftT], str]: # sanity check if the field is not a list, set etc. for field_name in fields: @@ -216,6 +229,7 @@ def get_all_by_fields( sort_order=sort_order, limit=limit, offset=offset, + has_permission=has_permission, ).all() objs = [self.row_as_obj(row) for row in result] return Ok(objs) @@ -229,6 +243,7 @@ def get_all_by_field( sort_order: str = "asc", limit: int | None = None, offset: int | None = None, + has_permission: bool = False, ) -> Result[list[SyftT], str]: result = self._get_by_fields( credentials=credentials, @@ -237,6 +252,7 @@ def get_all_by_field( sort_order=sort_order, limit=limit, offset=offset, + has_permission=has_permission, ).all() objs = [self.row_as_obj(row) for row in result] return Ok(objs) @@ -250,16 +266,14 @@ def get_all_contains( sort_order: str = "asc", limit: int | None = None, offset: int | None = None, + has_permission: bool = False, ) -> Result[list[SyftT], str]: # TODO write filter logic, merge with get_all stmt = self.table.select().where( - sa.and_( - self.table.c.fields[field_name].contains(func.json_quote(field_value)), - self._get_permission_filter(credentials), - ) + self.table.c.fields[field_name].contains(func.json_quote(field_value)), ) - + stmt = self._apply_permission_filter(stmt, credentials, has_permission) stmt = self._apply_order_by(stmt, order_by, sort_order) stmt = self._apply_limit_offset(stmt, limit, offset) @@ -387,12 +401,17 @@ def update( if not self.is_unique(obj): return Err(f"Some fields are not unique for {type(obj).__name__}") + has_permission_stmt = ( + self._get_permission_filter(credentials, ActionPermission.WRITE) + if has_permission + else sa.literal(True) + ) stmt = ( self.table.update() .where( sa.and_( self._get_field_filter("id", obj.id), - self._get_permission_filter(credentials, ActionPermission.WRITE), + has_permission_stmt, ) ) .values(fields=serialize_json(obj)) @@ -518,6 +537,7 @@ def remove_permission(self, permission: ActionObjectPermission) -> None: return None def has_permission(self, permission: ActionObjectPermission) -> bool: + # TODO: should check for compound permissions stmt = self.table.select().where( sa.and_( self._get_field_filter("id", permission.uid), From 812b4a547e0ed862e96534b585f6be7b8b63aa93 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 23 Aug 2024 12:30:32 +0200 Subject: [PATCH 029/197] action store WIP --- .../src/syft/service/action/action_service.py | 10 +- .../src/syft/service/action/action_store.py | 189 +++++------------- 2 files changed, 52 insertions(+), 147 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index ab82c80a2b8..d5049c3e2be 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -44,7 +44,7 @@ from .action_permissions import ActionObjectPermission from .action_permissions import ActionObjectREAD from .action_permissions import ActionPermission -from .action_store import ActionStore +from .action_store import ActionObjectStash from .action_types import action_type_for_type from .numpy import NumpyArrayObject from .pandas import PandasDataFrameObject # noqa: F401 @@ -55,9 +55,7 @@ @serializable(canonical_name="ActionService", version=1) class ActionService(AbstractService): - store_type = ActionStore - - def __init__(self, store: ActionStore) -> None: + def __init__(self, store: ActionObjectStash) -> None: self.store = store @service_method(path="action.np_array", name="np_array") @@ -352,7 +350,7 @@ def get_mock( self, context: AuthedServiceContext, uid: UID ) -> Result[SyftError, SyftObject]: """Get a pointer from the action store""" - result = self.store.get_mock(uid=uid) + result = self.store.get_mock(credentials=context.credentials, uid=uid) if result.is_ok(): return result.ok() return SyftError(message=result.err()) @@ -935,7 +933,7 @@ def exists( self, context: AuthedServiceContext, obj_id: UID ) -> Result[SyftSuccess, SyftError]: """Checks if the given object id exists in the Action Store""" - if self.store.exists(obj_id): + if self.store.exists(context.credentials, obj_id): return SyftSuccess(message=f"Object: {obj_id} exists") else: return SyftError(message=f"Object: {obj_id} does not exist") diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index 250b3c5e9b5..c487ff5ced0 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -13,7 +13,6 @@ from ...serde.serializable import serializable from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey -from ...store.dict_document_store import DictStoreConfig from ...store.document_store import BasePartitionSettings from ...store.document_store import DocumentStore from ...store.document_store import StoreConfig @@ -30,81 +29,38 @@ from .action_permissions import ActionObjectWRITE from .action_permissions import ActionPermission from .action_permissions import StoragePermission +from ...store.db.stash import ObjectStash +from .action_object import ActionObject -lock = threading.RLock() - - -class ActionStore: - pass - - -@serializable(canonical_name="KeyValueActionStore", version=1) -class KeyValueActionStore(ActionStore): - """Generic Key-Value Action store. - - Parameters: - store_config: StoreConfig - Backend specific configuration, including connection configuration, database name, or client class type. - root_verify_key: Optional[SyftVerifyKey] - Signature verification key, used for checking access permissions. - """ - - def __init__( - self, - server_uid: UID, - store_config: StoreConfig, - root_verify_key: SyftVerifyKey | None = None, - document_store: DocumentStore | None = None, - ) -> None: - self.server_uid = server_uid - self.store_config = store_config - self.settings = BasePartitionSettings(name="Action") - self.data = self.store_config.backing_store( - "data", self.settings, self.store_config - ) - self.permissions = self.store_config.backing_store( - "permissions", self.settings, self.store_config, ddtype=set - ) - self.storage_permissions = self.store_config.backing_store( - "storage_permissions", self.settings, self.store_config, ddtype=set - ) - - if root_verify_key is None: - root_verify_key = SyftSigningKey.generate().verify_key - self.root_verify_key = root_verify_key - - self.__user_stash = None - if document_store is not None: - # relative - from ...service.user.user_stash import UserStash - - self.__user_stash = UserStash(store=document_store) +@serializable(canonical_name="ActionObjectSQLStore", version=1) +class ActionObjectStash(ObjectStash[ActionObject]): def get( self, uid: UID, credentials: SyftVerifyKey, has_permission: bool = False - ) -> Result[SyftObject, str]: + ) -> Result[ActionObject, str]: uid = uid.id # We only need the UID from LineageID or UID - # if you get something you need READ permission - read_permission = ActionObjectREAD(uid=uid, credentials=credentials) - if has_permission or self.has_permission(read_permission): - try: - if isinstance(uid, LineageID): - syft_object = self.data[uid.id] - elif isinstance(uid, UID): - syft_object = self.data[uid] - else: - raise Exception(f"Unrecognized UID type: {type(uid)}") - return Ok(syft_object) - except Exception as e: - return Err(f"Could not find item with uid {uid}, {e}") - return Err(f"Permission: {read_permission} denied") - - def get_mock(self, uid: UID) -> Result[SyftObject, str]: + # TODO remove and use get_by_uid instead + result_or_err = self.get_by_uid( + credentials=credentials, + uid=uid, + has_permission=has_permission, + ) + if result_or_err.is_err(): + return Err(result_or_err.err()) + + result = result_or_err.ok() + if result is None: + return Err(f"Could not find item with uid {uid}") + return Ok(result) + + def get_mock(self, credentials: SyftVerifyKey, uid: UID) -> Result[SyftObject, str]: uid = uid.id # We only need the UID from LineageID or UID try: - syft_object = self.data[uid] + syft_object = self.get_by_uid( + credentials=credentials, uid=uid, has_permission=True + ) # type: ignore if isinstance(syft_object, TwinObject) and not is_action_data_empty( syft_object.mock ): @@ -122,31 +78,37 @@ def get_pointer( uid = uid.id # We only need the UID from LineageID or UID try: - if uid in self.data: - obj = self.data[uid] - read_permission = ActionObjectREAD(uid=uid, credentials=credentials) + result_or_err = self.get_by_uid( + credentials=credentials, uid=uid, has_permission=True + ) + has_permissions = self.has_permission( + ActionObjectREAD(uid=uid, credentials=credentials) + ) + if result_or_err.is_err(): + return Err(result_or_err.err()) - # if you have permission you can have private data - if self.has_permission(read_permission): - if isinstance(obj, TwinObject): - return Ok(obj.private.syft_point_to(server_uid)) - return Ok(obj.syft_point_to(server_uid)) + obj = result_or_err.ok() + if obj is None: + return Err("Permission denied") - # if its a twin with a mock anyone can have this + if has_permissions: if isinstance(obj, TwinObject): - return Ok(obj.mock.syft_point_to(server_uid)) + return Ok(obj.private.syft_point_to(server_uid)) + return Ok(obj.syft_point_to(server_uid)) + + # if its a twin with a mock anyone can have this + if isinstance(obj, TwinObject): + return Ok(obj.mock.syft_point_to(server_uid)) - # finally worst case you get ActionDataEmpty so you can still trace - return Ok(obj.as_empty().syft_point_to(server_uid)) + # finally worst case you get ActionDataEmpty so you can still trace + return Ok(obj.as_empty().syft_point_to(server_uid)) - return Err("Permission denied") except Exception as e: return Err(str(e)) - def exists(self, uid: UID) -> bool: - uid = uid.id # We only need the UID from LineageID or UID - - return uid in self.data + def exists(self, credentials: SyftVerifyKey, uid: UID) -> bool: + uid = uid.id + return super().exists(credentials=credentials, uid=uid) def set( self, @@ -162,7 +124,7 @@ def set( write_permission = ActionObjectWRITE(uid=uid, credentials=credentials) can_write = self.has_permission(write_permission) - if not self.exists(uid=uid): + if not self.exists(credentials=credentials, uid=uid): # attempt to claim it for writing if has_result_read_permission: ownership_result = self.take_ownership(uid=uid, credentials=credentials) @@ -233,7 +195,7 @@ def delete(self, uid: UID, credentials: SyftVerifyKey) -> Result[SyftSuccess, st return Ok(SyftSuccess(message=f"ID: {uid} deleted")) return Err(f"Permission: {owner_permission} denied") - def has_permission(self, permission: ActionObjectPermission) -> bool: + def _has_permission(self, permission: ActionObjectPermission) -> bool: if not isinstance(permission.permission, ActionPermission): raise Exception(f"ObjectPermission type: {permission.permission} not valid") @@ -369,58 +331,3 @@ def migrate_data( return Ok(True) return Err("You don't have permissions to migrate data.") - - -@serializable(canonical_name="DictActionStore", version=1) -class DictActionStore(KeyValueActionStore): - """Dictionary-Based Key-Value Action store. - - Parameters: - store_config: StoreConfig - Backend specific configuration, including client class type. - root_verify_key: Optional[SyftVerifyKey] - Signature verification key, used for checking access permissions. - """ - - def __init__( - self, - server_uid: UID, - store_config: StoreConfig | None = None, - root_verify_key: SyftVerifyKey | None = None, - document_store: DocumentStore | None = None, - ) -> None: - store_config = store_config if store_config is not None else DictStoreConfig() - super().__init__( - server_uid=server_uid, - store_config=store_config, - root_verify_key=root_verify_key, - document_store=document_store, - ) - - -@serializable(canonical_name="SQLiteActionStore", version=1) -class SQLiteActionStore(KeyValueActionStore): - """SQLite-Based Key-Value Action store. - - Parameters: - store_config: StoreConfig - SQLite specific configuration, including connection settings or client class type. - root_verify_key: Optional[SyftVerifyKey] - Signature verification key, used for checking access permissions. - """ - - pass - - -@serializable(canonical_name="MongoActionStore", version=1) -class MongoActionStore(KeyValueActionStore): - """Mongo-Based Action store. - - Parameters: - store_config: StoreConfig - Mongo specific configuration. - root_verify_key: Optional[SyftVerifyKey] - Signature verification key, used for checking access permissions. - """ - - pass From f771c5979ff7ffb58588031d8b43db4f42ca4905 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 23 Aug 2024 14:31:06 +0200 Subject: [PATCH 030/197] implement action stash --- .../src/syft/service/action/action_service.py | 6 +- .../src/syft/service/action/action_store.py | 246 +++--------------- .../service/migration/migration_service.py | 7 +- packages/syft/src/syft/store/db/stash.py | 90 ++++--- 4 files changed, 92 insertions(+), 257 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index d5049c3e2be..8a21d415b47 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -179,7 +179,7 @@ def _set( or has_result_read_permission ) - result = self.store.set( + result = self.store.set_or_update( uid=action_object.id, credentials=context.credentials, syft_object=action_object, @@ -933,7 +933,7 @@ def exists( self, context: AuthedServiceContext, obj_id: UID ) -> Result[SyftSuccess, SyftError]: """Checks if the given object id exists in the Action Store""" - if self.store.exists(context.credentials, obj_id): + if self.store.exists(context.credentials, obj_id.id): return SyftSuccess(message=f"Object: {obj_id} exists") else: return SyftError(message=f"Object: {obj_id} does not exist") @@ -1030,7 +1030,7 @@ def _delete_from_action_store( if res.is_err(): return SyftError(message=res.err()) else: - res = self.store.delete(credentials=context.credentials, uid=uid) + res = self.store.delete_by_uid(credentials=context.credentials, uid=uid.id) if res.is_err(): return SyftError(message=res.err()) diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index c487ff5ced0..4f011ee7e58 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -1,9 +1,6 @@ # future from __future__ import annotations -# stdlib -import threading - # third party from result import Err from result import Ok @@ -11,26 +8,18 @@ # relative from ...serde.serializable import serializable -from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey -from ...store.document_store import BasePartitionSettings -from ...store.document_store import DocumentStore -from ...store.document_store import StoreConfig +from ...store.db.stash import ObjectStash from ...types.syft_object import SyftObject from ...types.twin_object import TwinObject -from ...types.uid import LineageID from ...types.uid import UID from ..response import SyftSuccess +from .action_object import ActionObject from .action_object import is_action_data_empty from .action_permissions import ActionObjectEXECUTE -from .action_permissions import ActionObjectOWNER -from .action_permissions import ActionObjectPermission from .action_permissions import ActionObjectREAD from .action_permissions import ActionObjectWRITE -from .action_permissions import ActionPermission from .action_permissions import StoragePermission -from ...store.db.stash import ObjectStash -from .action_object import ActionObject @serializable(canonical_name="ActionObjectSQLStore", version=1) @@ -106,11 +95,7 @@ def get_pointer( except Exception as e: return Err(str(e)) - def exists(self, credentials: SyftVerifyKey, uid: UID) -> bool: - uid = uid.id - return super().exists(credentials=credentials, uid=uid) - - def set( + def set_or_update( # type: ignore self, uid: UID, credentials: SyftVerifyKey, @@ -120,214 +105,47 @@ def set( ) -> Result[SyftSuccess, Err]: uid = uid.id # We only need the UID from LineageID or UID - # if you set something you need WRITE permission - write_permission = ActionObjectWRITE(uid=uid, credentials=credentials) - can_write = self.has_permission(write_permission) - - if not self.exists(credentials=credentials, uid=uid): - # attempt to claim it for writing + if self.exists(credentials=credentials, uid=uid): + permissions = [] if has_result_read_permission: - ownership_result = self.take_ownership(uid=uid, credentials=credentials) - can_write = True if ownership_result.is_ok() else False + permissions.append(ActionObjectREAD(uid=uid, credentials=credentials)) else: - # root takes owneship, but you can still write - ownership_result = self.take_ownership( - uid=uid, credentials=self.root_verify_key - ) - can_write = True if ownership_result.is_ok() else False - - if can_write: - self.data[uid] = syft_object - if uid not in self.permissions: - # create default permissions - self.permissions[uid] = set() - if has_result_read_permission: - self.add_permission(ActionObjectREAD(uid=uid, credentials=credentials)) - else: - self.add_permissions( + permissions.extend( [ ActionObjectWRITE(uid=uid, credentials=credentials), ActionObjectEXECUTE(uid=uid, credentials=credentials), ] ) - - if uid not in self.storage_permissions: - # create default storage permissions - self.storage_permissions[uid] = set() + storage_permission = [] if add_storage_permission: - self.add_storage_permission( + storage_permission.append( StoragePermission(uid=uid, server_uid=self.server_uid) ) - return Ok(SyftSuccess(message=f"Set for ID: {uid}")) - return Err(f"Permission: {write_permission} denied") - - def take_ownership( - self, uid: UID, credentials: SyftVerifyKey - ) -> Result[SyftSuccess, str]: - uid = uid.id # We only need the UID from LineageID or UID - - # first person using this UID can claim ownership - if uid not in self.permissions and uid not in self.data: - self.add_permissions( - [ - ActionObjectOWNER(uid=uid, credentials=credentials), - ActionObjectWRITE(uid=uid, credentials=credentials), - ActionObjectREAD(uid=uid, credentials=credentials), - ActionObjectEXECUTE(uid=uid, credentials=credentials), - ] + self.update( + credentials=credentials, + obj=syft_object, ) - return Ok(SyftSuccess(message=f"Ownership of ID: {uid} taken.")) - return Err(f"UID: {uid} already owned.") - - def delete(self, uid: UID, credentials: SyftVerifyKey) -> Result[SyftSuccess, str]: - uid = uid.id # We only need the UID from LineageID or UID - - # if you delete something you need OWNER permission - # is it bad to evict a key and have someone else reuse it? - # perhaps we should keep permissions but no data? - owner_permission = ActionObjectOWNER(uid=uid, credentials=credentials) - if self.has_permission(owner_permission): - if uid in self.data: - del self.data[uid] - if uid in self.permissions: - del self.permissions[uid] - return Ok(SyftSuccess(message=f"ID: {uid} deleted")) - return Err(f"Permission: {owner_permission} denied") - - def _has_permission(self, permission: ActionObjectPermission) -> bool: - if not isinstance(permission.permission, ActionPermission): - raise Exception(f"ObjectPermission type: {permission.permission} not valid") - - if ( - permission.credentials is not None - and self.root_verify_key.verify == permission.credentials.verify - ): - return True - - if self.__user_stash is not None: - # relative - from ...service.user.user_roles import ServiceRole - - res = self.__user_stash.get_by_verify_key( - credentials=permission.credentials, - verify_key=permission.credentials, - ) - - if ( - res.is_ok() - and (user := res.ok()) is not None - and user.role in (ServiceRole.DATA_OWNER, ServiceRole.ADMIN) - ): - return True - - if ( - permission.uid in self.permissions - and permission.permission_string in self.permissions[permission.uid] - ): - return True - - # 🟡 TODO 14: add ALL_READ, ALL_EXECUTE etc - if permission.permission == ActionPermission.OWNER: - pass - elif permission.permission == ActionPermission.READ: - pass - elif permission.permission == ActionPermission.WRITE: - pass - elif permission.permission == ActionPermission.EXECUTE: - pass - - return False - - def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: - return all(self.has_permission(p) for p in permissions) - - def add_permission(self, permission: ActionObjectPermission) -> None: - permissions = self.permissions[permission.uid] - permissions.add(permission.permission_string) - self.permissions[permission.uid] = permissions - - def remove_permission(self, permission: ActionObjectPermission) -> None: - permissions = self.permissions[permission.uid] - permissions.remove(permission.permission_string) - self.permissions[permission.uid] = permissions - - def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: - for permission in permissions: - self.add_permission(permission) - - def _get_permissions_for_uid(self, uid: UID) -> Result[set[str], str]: - if uid in self.permissions: - return Ok(self.permissions[uid]) - return Err(f"No permissions found for uid: {uid}") - - def get_all_permissions(self) -> Result[dict[UID, set[str]], str]: - return Ok(dict(self.permissions.items())) - - def add_storage_permission(self, permission: StoragePermission) -> None: - permissions = self.storage_permissions[permission.uid] - permissions.add(permission.server_uid) - self.storage_permissions[permission.uid] = permissions - - def add_storage_permissions(self, permissions: list[StoragePermission]) -> None: - for permission in permissions: - self.add_storage_permission(permission) - - def remove_storage_permission(self, permission: StoragePermission) -> None: - permissions = self.storage_permissions[permission.uid] - permissions.remove(permission.server_uid) - self.storage_permissions[permission.uid] = permissions - - def has_storage_permission(self, permission: StoragePermission | UID) -> bool: - if isinstance(permission, UID): - permission = StoragePermission(uid=permission, server_uid=self.server_uid) - - if permission.uid in self.storage_permissions: - return permission.server_uid in self.storage_permissions[permission.uid] - return False - - def _get_storage_permissions_for_uid(self, uid: UID) -> Result[set[UID], str]: - if uid in self.storage_permissions: - return Ok(self.storage_permissions[uid]) - return Err(f"No storage permissions found for uid: {uid}") - - def get_all_storage_permissions(self) -> Result[dict[UID, set[UID]], str]: - return Ok(dict(self.storage_permissions.items())) - - def _all( - self, - credentials: SyftVerifyKey, - has_permission: bool | None = False, - ) -> Result[list[SyftObject], str]: - # this checks permissions - res = [self.get(uid, credentials, has_permission) for uid in self.data.keys()] - result = [x.ok() for x in res if x.is_ok()] - return Ok(result) - - def migrate_data( - self, to_klass: SyftObject, credentials: SyftVerifyKey - ) -> Result[bool, str]: - has_root_permission = credentials == self.root_verify_key - - if has_root_permission: - for key, value in self.data.items(): - try: - if value.__canonical_name__ != to_klass.__canonical_name__: - continue - migrated_value = value.migrate_to(to_klass.__version__) - except Exception as e: - return Err( - f"Failed to migrate data to {to_klass} {to_klass.__version__} for qk: {key}. Exception: {e}" - ) - result = self.set( - uid=key, - credentials=credentials, - syft_object=migrated_value, - ) + self.add_permissions(permissions) + self.add_storage_permissions(storage_permission) + return Ok(SyftSuccess(message=f"Set for ID: {uid}")) - if result.is_err(): - return result.err() + owner_credentials = ( + credentials if has_result_read_permission else self.root_verify_key + ) + # if not has_result_read_permission + # root takes owneship, but you can still write and execute + super().set( + credentials=owner_credentials, + obj=syft_object, + add_permissions=[ + ActionObjectWRITE(uid=uid, credentials=credentials), + ActionObjectEXECUTE(uid=uid, credentials=credentials), + ], + add_storage_permission=add_storage_permission, + ) - return Ok(True) + return Ok(SyftSuccess(message=f"Set for ID: {uid}")) - return Err("You don't have permissions to migrate data.") + def set(self, *args, **kwargs): # type: ignore + raise Exception("Use `ActionObjectStash.set_or_update` instead.") diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 56f29ccb17a..3e468191777 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -19,6 +19,7 @@ from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import StoragePermission +from ..action.action_store import ActionObjectStash from ..action.action_store import KeyValueActionStore from ..context import AuthedServiceContext from ..response import SyftError @@ -549,7 +550,7 @@ def _get_migration_actionobjects( ) result_dict: dict[type[SyftObject], list[SyftObject]] = defaultdict(list) action_store = context.server.action_store - action_store_objects_result = action_store._all( + action_store_objects_result = action_store.get_all( context.credentials, has_permission=True ) if action_store_objects_result.is_err(): @@ -580,9 +581,9 @@ def _update_migrated_actionobjects( self, context: AuthedServiceContext, objects: list[SyftObject] ) -> Result[str, str]: # Track all object types from action store - action_store = context.server.action_store + action_store: ActionObjectStash = context.server.action_store for obj in objects: - res = action_store.set( + res = action_store.set_or_update( uid=obj.id, credentials=context.credentials, syft_object=obj ) if res.is_err(): diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 5518047befb..2bfc4b07f0b 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -426,12 +426,14 @@ def set( obj: SyftT, add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, # TODO: check the default value - ignore_duplicates: bool = False, # only used in one place, should use upsert instead + ignore_duplicates: bool = False, ) -> Result[SyftT, str]: # uid is unique by database constraint uid = obj.id - if not self.is_unique(obj): + if self.exists(credentials, uid) or not self.is_unique(obj): + if ignore_duplicates: + return Ok(obj) return Err(f"Some fields are not unique for {type(obj).__name__}") permissions = self.get_ownership_permissions(uid, credentials) @@ -494,20 +496,10 @@ def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: return None def add_permission(self, permission: ActionObjectPermission) -> None: - stmt = ( - self.table.update() - .values( - permissions=sa.func.array_append( - self.table.c.permissions, permission.permission_string - ) - ) - .where( - sa.and_( - self._get_field_filter("id", permission.uid), - self._get_permission_filter( - permission.credentials, ActionPermission.WRITE - ), - ) + # TODO: handle duplicates + stmt = self.table.update(self.table.c.id == permission.uid).values( + permissions=sa.func.array_append( + self.table.c.permissions, permission.permission_string ) ) self.session.execute(stmt) @@ -515,38 +507,62 @@ def add_permission(self, permission: ActionObjectPermission) -> None: return None def remove_permission(self, permission: ActionObjectPermission) -> None: - stmt = ( - self.table.update() - .values( - permissions=sa.func.array_remove(self.table.c.permissions, permission) - ) - .where( - sa.and_( - self._get_field_filter("id", permission.uid), - self._get_permission_filter( - permission.credentials, - # since anyone with write permission can add permissions, - # owner check doesn't make sense, it should be write - ActionPermission.OWNER, - ), - ) - ) + # TODO: handle duplicates + stmt = self.table.update(self.table.c.id == permission.uid).values( + permissions=sa.func.array_remove(self.table.c.permissions, permission) ) self.session.execute(stmt) self.session.commit() return None + def remove_storage_permission(self, permission: StoragePermission) -> None: + # TODO + return None + + def _get_storage_permissions_for_uid(self, uid: UID) -> Result[set[UID], str]: + # TODO + return Ok(set()) + + def get_all_storage_permissions(self) -> Result[dict[UID, set[UID]], str]: + # TODO + return Ok({}) + def has_permission(self, permission: ActionObjectPermission) -> bool: + return self.has_permissions([permission]) + + def has_storage_permission(self, permission: StoragePermission) -> bool: + # TODO + return True + + def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: + # NOTE: maybe we should use a permissions table to check all permissions at once # TODO: should check for compound permissions + permission_filters = [ + sa.and_( + self._get_field_filter("id", p.uid), + self.table.c.permissions.contains(p.permission_string), + ) + for p in permissions + ] + stmt = self.table.select().where( sa.and_( - self._get_field_filter("id", permission.uid), - self.table.c.permissions.contains(permission.permission_string), + *permission_filters, ) ) result = self.session.execute(stmt).first() return result is not None - def has_storage_permission(self, permission: StoragePermission) -> bool: - # TODO - return True + def _get_permissions_for_uid(self, uid: UID) -> Result[set[str], str]: + stmt = self.table.select(self.table.c.id, self.table.c.permissions).where( + self._get_field_filter("id", uid) + ) + result = self.session.execute(stmt).first() + if result is None: + return Err(f"No permissions found for uid: {uid}") + return Ok(set(result.permissions)) + + def get_all_permissions(self) -> Result[dict[UID, set[str]], str]: + stmt = self.table.select(self.table.c.id, self.table.c.permissions) + results = self.session.execute(stmt).all() + return Ok({row.id: set(row.permissions) for row in results}) From db703f4248037b2f5928f450409e1985b5330c63 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 23 Aug 2024 14:47:54 +0200 Subject: [PATCH 031/197] implement storage permissions --- .../syft/service/action/action_permissions.py | 4 + packages/syft/src/syft/store/db/stash.py | 98 +++++++++++++++---- 2 files changed, 82 insertions(+), 20 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_permissions.py b/packages/syft/src/syft/service/action/action_permissions.py index 03992eeab07..8177222bd1a 100644 --- a/packages/syft/src/syft/service/action/action_permissions.py +++ b/packages/syft/src/syft/service/action/action_permissions.py @@ -122,3 +122,7 @@ def _coll_repr_(self) -> dict[str, Any]: "uid": str(self.uid), "server_uid": str(self.server_uid), } + + @property + def permission_string(self) -> str: + return str(self.server_uid) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 2bfc4b07f0b..2280cd718cd 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -15,6 +15,7 @@ from sqlalchemy import Select from sqlalchemy import Table from sqlalchemy import func +from sqlalchemy import select from sqlalchemy.orm import Session from sqlalchemy.types import JSON from typing_extensions import TypeVar @@ -75,6 +76,7 @@ def _create_table(self) -> Table: Column("id", UIDTypeDecorator, primary_key=True, default=uuid.uuid4), Column("fields", JSON, default={}), Column("permissions", JSON, default=[]), + Column("storage_permissions", JSON, default=[]), Column( "created_at", sa.DateTime, server_default=sa.func.now(), index=True ), @@ -444,20 +446,15 @@ def set( storage_permissions = [] if add_storage_permission: storage_permissions.append( - StoragePermission( - uid=uid, - server_uid=self.server_uid, - ) + self.server_uid, ) - # TODO: write the storage permissions to the database - # create the object with the permissions stmt = self.table.insert().values( id=uid, fields=serialize_json(obj), permissions=permissions, - # storage_permissions=storage_permissions, + storage_permissions=storage_permissions, ) self.session.execute(stmt) self.session.commit() @@ -496,18 +493,27 @@ def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: return None def add_permission(self, permission: ActionObjectPermission) -> None: - # TODO: handle duplicates - stmt = self.table.update(self.table.c.id == permission.uid).values( - permissions=sa.func.array_append( - self.table.c.permissions, permission.permission_string + # handle duplicates by removing the permission first + stmt = ( + self.table.update() + .where(self.table.c.id == permission.uid) + .values( + permissions=func.array_append( + func.array_remove( + select(self.table.c.permissions) + .where(self.table.c.id == permission.uid) + .scalar_subquery(), + permission.permission_string, + ), + permission.permission_string, + ) ) ) + self.session.execute(stmt) self.session.commit() - return None def remove_permission(self, permission: ActionObjectPermission) -> None: - # TODO: handle duplicates stmt = self.table.update(self.table.c.id == permission.uid).values( permissions=sa.func.array_remove(self.table.c.permissions, permission) ) @@ -516,23 +522,51 @@ def remove_permission(self, permission: ActionObjectPermission) -> None: return None def remove_storage_permission(self, permission: StoragePermission) -> None: - # TODO + stmt = self.table.update(self.table.c.id == permission.uid).values( + storage_permissions=sa.func.array_remove( + self.table.c.storage_permissions, permission.permission_string + ) + ) + self.session.execute(stmt) + self.session.commit() return None def _get_storage_permissions_for_uid(self, uid: UID) -> Result[set[UID], str]: - # TODO - return Ok(set()) + stmt = self.table.select( + self.table.c.id, self.table.c.storage_permissions + ).where(self.table.c.id == uid) + result = self.session.execute(stmt).first() + if result is None: + return Err(f"No storage permissions found for uid: {uid}") + return Ok(set(result.storage_permissions)) def get_all_storage_permissions(self) -> Result[dict[UID, set[UID]], str]: - # TODO - return Ok({}) + stmt = self.table.select(self.table.c.id, self.table.c.storage_permissions) + results = self.session.execute(stmt).all() + return Ok({row.id: set(row.storage_permissions) for row in results}) def has_permission(self, permission: ActionObjectPermission) -> bool: return self.has_permissions([permission]) def has_storage_permission(self, permission: StoragePermission) -> bool: - # TODO - return True + return self.has_storage_permissions([permission]) + + def has_storage_permissions(self, permissions: list[StoragePermission]) -> bool: + permission_filters = [ + sa.and_( + self._get_field_filter("id", p.uid), + self.table.c.storage_permissions.contains(p.server_uid), + ) + for p in permissions + ] + + stmt = self.table.select().where( + sa.and_( + *permission_filters, + ) + ) + result = self.session.execute(stmt).first() + return result def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: # NOTE: maybe we should use a permissions table to check all permissions at once @@ -553,6 +587,30 @@ def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: result = self.session.execute(stmt).first() return result is not None + def add_storage_permission(self, permission: StoragePermission) -> None: + stmt = ( + self.table.update() + .where(self.table.c.id == permission.uid) + .values( + storage_permissions=func.array_append( + func.array_remove( + select(self.table.c.storage_permissions) + .where(self.table.c.id == permission.uid) + .scalar_subquery(), + permission.server_uid, + ), + permission.server_uid, + ) + ) + ) + self.session.execute(stmt) + self.session.commit() + return None + + def add_storage_permissions(self, permissions: list[StoragePermission]) -> None: + for permission in permissions: + self.add_storage_permission(permission) + def _get_permissions_for_uid(self, uid: UID) -> Result[set[str], str]: stmt = self.table.select(self.table.c.id, self.table.c.permissions).where( self._get_field_filter("id", uid) From b7aa781dc159ff3c3c8470e4b6f9d18066a8578d Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 23 Aug 2024 17:11:49 +0200 Subject: [PATCH 032/197] action object fixes --- packages/syft/src/syft/server/server.py | 16 ++-- .../syft/src/syft/server/service_registry.py | 4 +- .../src/syft/service/action/action_service.py | 6 +- .../src/syft/service/action/action_store.py | 2 + .../service/migration/migration_service.py | 34 +++----- .../syft/src/syft/service/request/request.py | 7 +- packages/syft/src/syft/store/db/stash.py | 80 +++++++++---------- 7 files changed, 70 insertions(+), 79 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 1df25014734..a0096272044 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -40,10 +40,7 @@ from ..protocol.data_protocol import get_data_protocol from ..service.action.action_object import Action from ..service.action.action_object import ActionObject -from ..service.action.action_store import ActionStore -from ..service.action.action_store import DictActionStore -from ..service.action.action_store import MongoActionStore -from ..service.action.action_store import SQLiteActionStore +from ..service.action.action_store import ActionObjectStash from ..service.blob_storage.service import BlobStorageService from ..service.code.user_code_service import UserCodeService from ..service.code.user_code_stash import UserCodeStash @@ -831,11 +828,8 @@ def init_stores( ) if isinstance(action_store_config, SQLiteStoreConfig): - self.action_store: ActionStore = SQLiteActionStore( - server_uid=self.id, - store_config=action_store_config, - root_verify_key=self.verify_key, - document_store=self.document_store, + self.action_store: ActionObjectStash = ActionObjectStash( + store=self.document_store, ) elif isinstance(action_store_config, MongoStoreConfig): # We add the python id of the current server in order @@ -844,14 +838,14 @@ def init_stores( # different thread through the garbage collection action_store_config.client_config.server_obj_python_id = id(self) - self.action_store = MongoActionStore( + self.action_store = ActionObjectStash( server_uid=self.id, root_verify_key=self.verify_key, store_config=action_store_config, document_store=self.document_store, ) else: - self.action_store = DictActionStore( + self.action_store = ActionObjectStash( server_uid=self.id, root_verify_key=self.verify_key, document_store=self.document_store, diff --git a/packages/syft/src/syft/server/service_registry.py b/packages/syft/src/syft/server/service_registry.py index 1504e0eb817..9bd3816162e 100644 --- a/packages/syft/src/syft/server/service_registry.py +++ b/packages/syft/src/syft/server/service_registry.py @@ -9,7 +9,7 @@ # relative from ..serde.serializable import serializable from ..service.action.action_service import ActionService -from ..service.action.action_store import ActionStore +from ..service.action.action_store import ActionObjectStash from ..service.api.api_service import APIService from ..service.attestation.attestation_service import AttestationService from ..service.blob_storage.service import BlobStorageService @@ -110,7 +110,7 @@ def _construct_services(cls, server: "Server") -> dict[str, AbstractService]: service_dict = {} for field_name, service_cls in cls.get_service_classes().items(): svc_kwargs: dict[str, Any] = {} - if issubclass(service_cls.store_type, ActionStore): + if issubclass(service_cls.store_type, ActionObjectStash): svc_kwargs["store"] = server.action_store else: svc_kwargs["store"] = server.document_store diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 8a21d415b47..7694a72b3f0 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -10,6 +10,8 @@ from result import Ok from result import Result +from syft.store.document_store import DocumentStore + # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey @@ -55,8 +57,8 @@ @serializable(canonical_name="ActionService", version=1) class ActionService(AbstractService): - def __init__(self, store: ActionObjectStash) -> None: - self.store = store + def __init__(self, store: DocumentStore) -> None: + self.store = ActionObjectStash(store) @service_method(path="action.np_array", name="np_array") def np_array(self, context: AuthedServiceContext, data: Any) -> Any: diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index 4f011ee7e58..4acbbbdc81c 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -24,6 +24,8 @@ @serializable(canonical_name="ActionObjectSQLStore", version=1) class ActionObjectStash(ObjectStash[ActionObject]): + object_type = ActionObject + def get( self, uid: UID, credentials: SyftVerifyKey, has_permission: bool = False ) -> Result[ActionObject, str]: diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 3e468191777..4ffc636c3e8 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -20,7 +20,6 @@ from ..action.action_permissions import ActionObjectPermission from ..action.action_permissions import StoragePermission from ..action.action_store import ActionObjectStash -from ..action.action_store import KeyValueActionStore from ..context import AuthedServiceContext from ..response import SyftError from ..response import SyftSuccess @@ -133,32 +132,23 @@ def get_all_store_metadata( document_store_object_types: list[type[SyftObject]] | None = None, include_action_store: bool = True, ) -> dict[str, StoreMetadata] | SyftError: - res = self._get_all_store_metadata( - context, - document_store_object_types=document_store_object_types, - include_action_store=include_action_store, - ) - if res.is_err(): - return SyftError(message=res.value) - else: - return res.ok() + # res = self._get_all_store_metadata( + # context, + # document_store_object_types=document_store_object_types, + # include_action_store=include_action_store, + # ) + # if res.is_err(): + # return SyftError(message=res.value) + # else: + # return res.ok() + raise Exception("Not implemented") def _get_partition_from_type( self, context: AuthedServiceContext, object_type: type[SyftObject], - ) -> Result[KeyValueActionStore | StorePartition, str]: - object_partition: KeyValueActionStore | StorePartition | None = None - if issubclass(object_type, ActionObject): - object_partition = cast(KeyValueActionStore, context.server.action_store) - else: - canonical_name = object_type.__canonical_name__ # type: ignore[unreachable] - object_partition = self.store.partitions.get(canonical_name) - - if object_partition is None: - return Err(f"Object partition not found for {object_type}") # type: ignore - - return Ok(object_partition) + ): + return None def _get_store_metadata( self, diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 9b5bb00ca22..9bbc51450ce 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -13,6 +13,11 @@ from result import Result from typing_extensions import Self +from syft.service.action.action_permissions import ( + ActionObjectPermission, + ActionPermission, +) + # relative from ...abstract_server import ServerSideType from ...client.api import APIRegistry @@ -41,8 +46,6 @@ from ...util.util import prompt_warning_message from ..action.action_object import ActionObject from ..action.action_service import ActionService -from ..action.action_store import ActionObjectPermission -from ..action.action_store import ActionPermission from ..blob_storage.service import BlobStorageService from ..code.user_code import UserCode from ..code.user_code import UserCodeStatus diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 2280cd718cd..fce1a640daf 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -1,7 +1,7 @@ # stdlib # stdlib -from typing import Any +from typing import Any, Set from typing import Generic import uuid @@ -422,44 +422,6 @@ def update( self.session.commit() return Ok(obj) - def set( - self, - credentials: SyftVerifyKey, - obj: SyftT, - add_permissions: list[ActionObjectPermission] | None = None, - add_storage_permission: bool = True, # TODO: check the default value - ignore_duplicates: bool = False, - ) -> Result[SyftT, str]: - # uid is unique by database constraint - uid = obj.id - - if self.exists(credentials, uid) or not self.is_unique(obj): - if ignore_duplicates: - return Ok(obj) - return Err(f"Some fields are not unique for {type(obj).__name__}") - - permissions = self.get_ownership_permissions(uid, credentials) - if add_permissions is not None: - add_permission_strings = [p.permission_string for p in add_permissions] - permissions.extend(add_permission_strings) - - storage_permissions = [] - if add_storage_permission: - storage_permissions.append( - self.server_uid, - ) - - # create the object with the permissions - stmt = self.table.insert().values( - id=uid, - fields=serialize_json(obj), - permissions=permissions, - storage_permissions=storage_permissions, - ) - self.session.execute(stmt) - self.session.commit() - return Ok(obj) - def get_ownership_permissions( self, uid: UID, credentials: SyftVerifyKey ) -> list[str]: @@ -531,7 +493,7 @@ def remove_storage_permission(self, permission: StoragePermission) -> None: self.session.commit() return None - def _get_storage_permissions_for_uid(self, uid: UID) -> Result[set[UID], str]: + def _get_storage_permissions_for_uid(self, uid: UID) -> Result[Set[UID], str]: stmt = self.table.select( self.table.c.id, self.table.c.storage_permissions ).where(self.table.c.id == uid) @@ -624,3 +586,41 @@ def get_all_permissions(self) -> Result[dict[UID, set[str]], str]: stmt = self.table.select(self.table.c.id, self.table.c.permissions) results = self.session.execute(stmt).all() return Ok({row.id: set(row.permissions) for row in results}) + + def set( + self, + credentials: SyftVerifyKey, + obj: SyftT, + add_permissions: list[ActionObjectPermission] | None = None, + add_storage_permission: bool = True, # TODO: check the default value + ignore_duplicates: bool = False, + ) -> Result[SyftT, str]: + # uid is unique by database constraint + uid = obj.id + + if self.exists(credentials, uid) or not self.is_unique(obj): + if ignore_duplicates: + return Ok(obj) + return Err(f"Some fields are not unique for {type(obj).__name__}") + + permissions = self.get_ownership_permissions(uid, credentials) + if add_permissions is not None: + add_permission_strings = [p.permission_string for p in add_permissions] + permissions.extend(add_permission_strings) + + storage_permissions = [] + if add_storage_permission: + storage_permissions.append( + self.server_uid, + ) + + # create the object with the permissions + stmt = self.table.insert().values( + id=uid, + fields=serialize_json(obj), + permissions=permissions, + storage_permissions=storage_permissions, + ) + self.session.execute(stmt) + self.session.commit() + return Ok(obj) From 43ce9e575ef7ae9165e3e9347d83431f8f146693 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 23 Aug 2024 17:24:42 +0200 Subject: [PATCH 033/197] fix get_mock --- packages/syft/src/syft/service/action/action_store.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index 4acbbbdc81c..b3a877ec1f9 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -49,13 +49,14 @@ def get_mock(self, credentials: SyftVerifyKey, uid: UID) -> Result[SyftObject, s uid = uid.id # We only need the UID from LineageID or UID try: - syft_object = self.get_by_uid( + obj_or_err = self.get_by_uid( credentials=credentials, uid=uid, has_permission=True ) # type: ignore - if isinstance(syft_object, TwinObject) and not is_action_data_empty( - syft_object.mock - ): - return Ok(syft_object.mock) + if obj_or_err.is_err(): + return Err(obj_or_err.err()) + obj = obj_or_err.ok() + if isinstance(obj, TwinObject) and not is_action_data_empty(obj.mock): + return Ok(obj.mock) return Err("No mock") except Exception as e: return Err(f"Could not find item with uid {uid}, {e}") From 72df3b252dde1b732b689fb17403f9c097caba1d Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 23 Aug 2024 17:29:55 +0200 Subject: [PATCH 034/197] add exclude attrs --- packages/syft/src/syft/serde/json_serde.py | 11 ++++++++++- packages/syft/src/syft/serde/recursive.py | 5 ++--- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index 39177ffe620..900e0fa30ed 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -25,6 +25,7 @@ from ..types.syft_object_registry import SyftObjectRegistry from ..types.uid import LineageID from ..types.uid import UID +from .recursive import DEFAULT_EXCLUDE_ATTRS T = TypeVar("T") @@ -129,13 +130,21 @@ def _annotation_issubclass(annotation: Any, cls: type) -> bool: def _serialize_pydantic_to_json(obj: pydantic.BaseModel) -> dict[str, Json]: canonical_name, version = SyftObjectRegistry.get_canonical_name_version(obj) + serde_attributes = SyftObjectRegistry.get_serde_properties(canonical_name, version) + exclude_attrs = serde_attributes[4] + result: dict[str, Json] = { JSON_CANONICAL_NAME_FIELD: canonical_name, JSON_VERSION_FIELD: version, } for key, type_ in obj.model_fields.items(): - result[key] = serialize_json(getattr(obj, key), type_.annotation) + if key in exclude_attrs or key in DEFAULT_EXCLUDE_ATTRS: + continue + try: + result[key] = serialize_json(getattr(obj, key), type_.annotation) + except Exception as e: + raise ValueError(f"Failed to serialize attribute {key}: {e}") result = _serialize_searchable_attrs(obj, result) diff --git a/packages/syft/src/syft/serde/recursive.py b/packages/syft/src/syft/serde/recursive.py index 33bf94c8d4f..da3b37ceb62 100644 --- a/packages/syft/src/syft/serde/recursive.py +++ b/packages/syft/src/syft/serde/recursive.py @@ -25,6 +25,7 @@ recursive_scheme = get_capnp_schema("recursive_serde.capnp").RecursiveSerde SPOOLED_FILE_MAX_SIZE_SERDE = 50 * (1024**2) # 50MB +DEFAULT_EXCLUDE_ATTRS: set[str] = {"syft_pre_hooks__", "syft_post_hooks__"} def get_types(cls: type, keys: list[str] | None = None) -> list[type] | None: @@ -192,9 +193,7 @@ def recursive_serde_register( attribute_list.update(["value"]) exclude_attrs = [] if exclude_attrs is None else exclude_attrs - attribute_list = ( - attribute_list - set(exclude_attrs) - {"syft_pre_hooks__", "syft_post_hooks__"} - ) + attribute_list = attribute_list - set(exclude_attrs) - DEFAULT_EXCLUDE_ATTRS if inheritable_attrs and attribute_list and not is_pydantic: # only set __syft_serializable__ for non-pydantic classes because From 948ef0d4f9efae50a49cc353d293932898023e6d Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 23 Aug 2024 17:43:53 +0200 Subject: [PATCH 035/197] fix deserialize --- packages/syft/src/syft/serde/json_serde.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index 900e0fa30ed..f352a884c36 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -211,6 +211,8 @@ def _deserialize_pydantic_from_json( result = {} for key, type_ in obj_type.model_fields.items(): + if key not in obj_dict: + continue result[key] = deserialize_json(obj_dict[key], type_.annotation) return obj_type.model_validate(result) From 4d8c159564af9a8e5c78f0e0b9059475c002cae3 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 23 Aug 2024 18:06:27 +0200 Subject: [PATCH 036/197] fix error handling --- packages/syft/src/syft/client/api.py | 2 +- packages/syft/src/syft/serde/json_serde.py | 1 + packages/syft/src/syft/service/action/action_service.py | 3 +-- packages/syft/src/syft/service/dataset/dataset.py | 6 +++--- packages/syft/src/syft/service/project/project.py | 6 +++--- packages/syft/src/syft/service/request/request.py | 7 ++----- packages/syft/src/syft/store/db/stash.py | 4 ++-- 7 files changed, 13 insertions(+), 16 deletions(-) diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index 3577bc81112..6e080b25178 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -1069,7 +1069,7 @@ def make_call(self, api_call: SyftAPICall, cache_result: bool = True) -> Result: if result.is_ok(): result = result.ok() else: - result = result.err() + return result # we update the api when we create objects that change it self.update_api(result) return result diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index f352a884c36..65793c52e09 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -217,6 +217,7 @@ def _deserialize_pydantic_from_json( return obj_type.model_validate(result) except Exception as e: + print(f"Failed to deserialize Pydantic model: {e}") print(json.dumps(obj_dict, indent=2)) raise ValueError(f"Failed to deserialize Pydantic model: {e}") diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 7694a72b3f0..704de7890b4 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -10,11 +10,10 @@ from result import Ok from result import Result -from syft.store.document_store import DocumentStore - # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey +from ...store.document_store import DocumentStore from ...types.datetime import DateTime from ...types.syft_object import SyftObject from ...types.twin_object import TwinObject diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index 8b6431375cd..376f0f252ef 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -279,10 +279,10 @@ def _private_data(self) -> Result[Any, str]: if api is None or api.services is None: return Ok(None) res = api.services.action.get(self.action_id) - if self.has_permission(res): - return Ok(res.syft_action_data) - else: + if isinstance(res, Err): return Err("You do not have permission to access private data.") + else: + return Ok(res.syft_action_data) @property def data(self) -> Any: diff --git a/packages/syft/src/syft/service/project/project.py b/packages/syft/src/syft/service/project/project.py index e1c1dfe1b47..3af3fd9bfbd 100644 --- a/packages/syft/src/syft/service/project/project.py +++ b/packages/syft/src/syft/service/project/project.py @@ -1274,12 +1274,12 @@ def send(self, return_all_projects: bool = False) -> Project | list[Project]: projects_map = self._create_projects(self.clients) # bootstrap project with pending events on leader server's project - self._bootstrap_events(projects_map[leader]) + self._bootstrap_events(projects_map[leader.id]) if return_all_projects: return list(projects_map.values()) - return projects_map[leader] + return projects_map[leader.id] except SyftException as exp: return SyftError(message=str(exp)) @@ -1316,7 +1316,7 @@ def _create_projects(self, clients: list[SyftClient]) -> dict[SyftClient, Projec result = client.api.services.project.create_project(project=self) if isinstance(result, SyftError): raise SyftException(result.message) - projects[client] = result + projects[client.id] = result return projects diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 9bbc51450ce..c299df2d9b9 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -13,11 +13,6 @@ from result import Result from typing_extensions import Self -from syft.service.action.action_permissions import ( - ActionObjectPermission, - ActionPermission, -) - # relative from ...abstract_server import ServerSideType from ...client.api import APIRegistry @@ -45,6 +40,8 @@ from ...util.notebook_ui.icons import Icon from ...util.util import prompt_warning_message from ..action.action_object import ActionObject +from ..action.action_permissions import ActionObjectPermission +from ..action.action_permissions import ActionPermission from ..action.action_service import ActionService from ..blob_storage.service import BlobStorageService from ..code.user_code import UserCode diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index fce1a640daf..74ccdfde30a 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -1,7 +1,7 @@ # stdlib # stdlib -from typing import Any, Set +from typing import Any from typing import Generic import uuid @@ -493,7 +493,7 @@ def remove_storage_permission(self, permission: StoragePermission) -> None: self.session.commit() return None - def _get_storage_permissions_for_uid(self, uid: UID) -> Result[Set[UID], str]: + def _get_storage_permissions_for_uid(self, uid: UID) -> Result[set[UID], str]: stmt = self.table.select( self.table.c.id, self.table.c.storage_permissions ).where(self.table.c.id == uid) From b7bd8cb112d574564db7c17af8f23ba81da07daf Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 23 Aug 2024 18:10:28 +0200 Subject: [PATCH 037/197] type hint fixes --- packages/syft/src/syft/server/server.py | 9 ++------- .../syft/src/syft/service/action/action_store.py | 3 ++- .../src/syft/service/migration/migration_service.py | 12 +++--------- packages/syft/src/syft/service/project/project.py | 10 +++++----- 4 files changed, 12 insertions(+), 22 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index a0096272044..1a87d07d47e 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -839,16 +839,11 @@ def init_stores( action_store_config.client_config.server_obj_python_id = id(self) self.action_store = ActionObjectStash( - server_uid=self.id, - root_verify_key=self.verify_key, - store_config=action_store_config, - document_store=self.document_store, + store=self.document_store, ) else: self.action_store = ActionObjectStash( - server_uid=self.id, - root_verify_key=self.verify_key, - document_store=self.document_store, + store=self.document_store, ) self.action_store_config = action_store_config diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index b3a877ec1f9..4ec3c7bdedd 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -17,6 +17,7 @@ from .action_object import ActionObject from .action_object import is_action_data_empty from .action_permissions import ActionObjectEXECUTE +from .action_permissions import ActionObjectPermission from .action_permissions import ActionObjectREAD from .action_permissions import ActionObjectWRITE from .action_permissions import StoragePermission @@ -109,7 +110,7 @@ def set_or_update( # type: ignore uid = uid.id # We only need the UID from LineageID or UID if self.exists(credentials=credentials, uid=uid): - permissions = [] + permissions: list[ActionObjectPermission] = [] if has_result_read_permission: permissions.append(ActionObjectREAD(uid=uid, credentials=credentials)) else: diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 4ffc636c3e8..f92c3c2bec4 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -1,7 +1,6 @@ # stdlib from collections import defaultdict import sys -from typing import cast # third party from result import Err @@ -132,6 +131,7 @@ def get_all_store_metadata( document_store_object_types: list[type[SyftObject]] | None = None, include_action_store: bool = True, ) -> dict[str, StoreMetadata] | SyftError: + # FIXME # res = self._get_all_store_metadata( # context, # document_store_object_types=document_store_object_types, @@ -143,19 +143,13 @@ def get_all_store_metadata( # return res.ok() raise Exception("Not implemented") - def _get_partition_from_type( - self, - context: AuthedServiceContext, - object_type: type[SyftObject], - ): - return None - def _get_store_metadata( self, context: AuthedServiceContext, object_type: type[SyftObject], ) -> Result[StoreMetadata, str]: - object_partition = self._get_partition_from_type(context, object_type) + # FIXME + object_partition = ... if object_partition.is_err(): return object_partition object_partition = object_partition.ok() diff --git a/packages/syft/src/syft/service/project/project.py b/packages/syft/src/syft/service/project/project.py index 3af3fd9bfbd..b51830bfc57 100644 --- a/packages/syft/src/syft/service/project/project.py +++ b/packages/syft/src/syft/service/project/project.py @@ -1274,12 +1274,12 @@ def send(self, return_all_projects: bool = False) -> Project | list[Project]: projects_map = self._create_projects(self.clients) # bootstrap project with pending events on leader server's project - self._bootstrap_events(projects_map[leader.id]) + self._bootstrap_events(projects_map[leader.id]) # type: ignore if return_all_projects: return list(projects_map.values()) - return projects_map[leader.id] + return projects_map[leader.id] # type: ignore except SyftException as exp: return SyftError(message=str(exp)) @@ -1309,14 +1309,14 @@ def _exchange_routes(self, leader: SyftClient, followers: list[SyftClient]) -> N self.leader_server_route = connection_to_route(leader.connection) - def _create_projects(self, clients: list[SyftClient]) -> dict[SyftClient, Project]: - projects: dict[SyftClient, Project] = {} + def _create_projects(self, clients: list[SyftClient]) -> dict[UID, Project]: + projects: dict[UID, Project] = {} for client in clients: result = client.api.services.project.create_project(project=self) if isinstance(result, SyftError): raise SyftException(result.message) - projects[client.id] = result + projects[client.id] = result # type: ignore return projects From bb84ae6c056a9cecb6d9cf6115d15e9e5fa9ebc1 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 23 Aug 2024 18:55:51 +0200 Subject: [PATCH 038/197] fix perms --- .../syft/service/dataset/dataset_service.py | 8 ++- packages/syft/src/syft/store/db/stash.py | 50 ++++++++++--------- 2 files changed, 30 insertions(+), 28 deletions(-) diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py index 3ef76414593..6b46b987dc2 100644 --- a/packages/syft/src/syft/service/dataset/dataset_service.py +++ b/packages/syft/src/syft/service/dataset/dataset_service.py @@ -29,7 +29,6 @@ from .dataset import CreateDataset from .dataset import Dataset from .dataset import DatasetPageView -from .dataset import DatasetUpdate from .dataset_stash import DatasetStash logger = logging.getLogger(__name__) @@ -258,10 +257,9 @@ def delete( return_msg.append(f"Asset with id '{asset.id}' successfully deleted.") # soft delete the dataset object from the store - dataset_update = DatasetUpdate( - id=uid, name=f"_deleted_{dataset.name}_{uid}", to_be_deleted=True - ) - result = self.stash.update(context.credentials, dataset_update) + dataset.name = f"_deleted_{dataset.name}_{uid}" + dataset.to_be_deleted = True + result = self.stash.update(context.credentials, dataset) if result.is_err(): return SyftError(message=result.err()) return_msg.append(f"Dataset with id '{uid}' successfully deleted.") diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 74ccdfde30a..8bb6b9fbe23 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -15,7 +15,6 @@ from sqlalchemy import Select from sqlalchemy import Table from sqlalchemy import func -from sqlalchemy import select from sqlalchemy.orm import Session from sqlalchemy.types import JSON from typing_extensions import TypeVar @@ -455,18 +454,13 @@ def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: return None def add_permission(self, permission: ActionObjectPermission) -> None: - # handle duplicates by removing the permission first stmt = ( self.table.update() .where(self.table.c.id == permission.uid) .values( - permissions=func.array_append( - func.array_remove( - select(self.table.c.permissions) - .where(self.table.c.id == permission.uid) - .scalar_subquery(), - permission.permission_string, - ), + permissions=func.json_insert( + self.table.c.permissions, + "$[#]", permission.permission_string, ) ) @@ -475,19 +469,33 @@ def add_permission(self, permission: ActionObjectPermission) -> None: self.session.execute(stmt) self.session.commit() - def remove_permission(self, permission: ActionObjectPermission) -> None: + def remove_permission( + self, permission: ActionObjectPermission + ) -> Result[None, str]: + permissions_or_err = self._get_permissions_for_uid(permission.uid) + if permissions_or_err.is_err(): + return permissions_or_err + permissions = permissions_or_err.ok() + permissions.remove(permission.permission_string) + stmt = self.table.update(self.table.c.id == permission.uid).values( - permissions=sa.func.array_remove(self.table.c.permissions, permission) + permissions=list(permissions) ) self.session.execute(stmt) self.session.commit() return None - def remove_storage_permission(self, permission: StoragePermission) -> None: + def remove_storage_permission( + self, permission: StoragePermission + ) -> Result[None, str]: + permissions_or_err = self._get_storage_permissions_for_uid(permission.uid) + if permissions_or_err.is_err(): + return permissions_or_err + permissions = permissions_or_err.ok() + permissions.pop(permission.permission_string) + stmt = self.table.update(self.table.c.id == permission.uid).values( - storage_permissions=sa.func.array_remove( - self.table.c.storage_permissions, permission.permission_string - ) + storage_permissions=list(permissions) ) self.session.execute(stmt) self.session.commit() @@ -554,14 +562,10 @@ def add_storage_permission(self, permission: StoragePermission) -> None: self.table.update() .where(self.table.c.id == permission.uid) .values( - storage_permissions=func.array_append( - func.array_remove( - select(self.table.c.storage_permissions) - .where(self.table.c.id == permission.uid) - .scalar_subquery(), - permission.server_uid, - ), - permission.server_uid, + storage_permissions=func.json_insert( + self.table.c.storage_permissions, + "$[#]", + permission.permission_string, ) ) ) From 5e3666d241d733865966a865d8fee9a7e370b03c Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 23 Aug 2024 18:59:32 +0200 Subject: [PATCH 039/197] fix test --- packages/syft/src/syft/server/server.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 1a87d07d47e..ba44934f08c 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -1345,7 +1345,7 @@ def add_queueitem_to_queue( action_service = self.get_service("actionservice") - if not action_service.store.exists(uid=action.result_id): + if not action_service.store.exists(credentials, uid=action.result_id): result = action_service.set_result_to_store( result_action_object=result_obj, context=context, From f25a86a5f621ff42c6f8ddf7123d0bfdfd90e604 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 23 Aug 2024 19:12:56 +0200 Subject: [PATCH 040/197] fix get_perms --- .../src/syft/service/request/request_stash.py | 4 +--- packages/syft/src/syft/store/db/stash.py | 23 +++++++++++-------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/packages/syft/src/syft/service/request/request_stash.py b/packages/syft/src/syft/service/request/request_stash.py index 42e818e554b..7af8db6f119 100644 --- a/packages/syft/src/syft/service/request/request_stash.py +++ b/packages/syft/src/syft/service/request/request_stash.py @@ -34,12 +34,10 @@ def get_all_for_verify_key( credentials: SyftVerifyKey, verify_key: SyftVerifyKey, ) -> Result[list[Request], str]: - if isinstance(verify_key, str): - verify_key = SyftVerifyKey.from_string(verify_key) return self.get_all_by_field( credentials=credentials, field_name="requesting_user_verify_key", - field_value=verify_key, + field_value=str(verify_key), ) def get_by_usercode_id( diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 8bb6b9fbe23..8c738d231d8 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -15,6 +15,7 @@ from sqlalchemy import Select from sqlalchemy import Table from sqlalchemy import func +from sqlalchemy import select from sqlalchemy.orm import Session from sqlalchemy.types import JSON from typing_extensions import TypeVar @@ -478,8 +479,10 @@ def remove_permission( permissions = permissions_or_err.ok() permissions.remove(permission.permission_string) - stmt = self.table.update(self.table.c.id == permission.uid).values( - permissions=list(permissions) + stmt = ( + self.table.update() + .where(self.table.c.id == permission.uid) + .values(permissions=list(permissions)) ) self.session.execute(stmt) self.session.commit() @@ -494,8 +497,10 @@ def remove_storage_permission( permissions = permissions_or_err.ok() permissions.pop(permission.permission_string) - stmt = self.table.update(self.table.c.id == permission.uid).values( - storage_permissions=list(permissions) + stmt = ( + self.table.update() + .where(self.table.c.id == permission.uid) + .values(storage_permissions=list(permissions)) ) self.session.execute(stmt) self.session.commit() @@ -578,16 +583,14 @@ def add_storage_permissions(self, permissions: list[StoragePermission]) -> None: self.add_storage_permission(permission) def _get_permissions_for_uid(self, uid: UID) -> Result[set[str], str]: - stmt = self.table.select(self.table.c.id, self.table.c.permissions).where( - self._get_field_filter("id", uid) - ) - result = self.session.execute(stmt).first() + stmt = select(self.table.c.permissions).where(self.table.c.id == uid) + result = self.session.execute(stmt).scalar_one_or_none() if result is None: return Err(f"No permissions found for uid: {uid}") - return Ok(set(result.permissions)) + return Ok(set(result)) def get_all_permissions(self) -> Result[dict[UID, set[str]], str]: - stmt = self.table.select(self.table.c.id, self.table.c.permissions) + stmt = select(self.table.c.id, self.table.c.permissions) results = self.session.execute(stmt).all() return Ok({row.id: set(row.permissions) for row in results}) From 7b31c9ab38c7c48bfb63f059223643c531ec8335 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 23 Aug 2024 19:31:19 +0200 Subject: [PATCH 041/197] fix sync --- packages/syft/src/syft/service/job/job_stash.py | 7 +++++++ .../src/syft/service/output/output_service.py | 4 ++-- .../syft/src/syft/service/sync/sync_service.py | 5 +++-- packages/syft/src/syft/store/db/stash.py | 16 ++++++++++------ 4 files changed, 22 insertions(+), 10 deletions(-) diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index dccfa54abe9..97e06741213 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -892,6 +892,13 @@ def get_by_parent_id( credentials=credentials, field_name="parent_job_id", field_value=str(uid) ) + def get_by_result_id( + self, credentials: SyftVerifyKey, uid: UID + ) -> Result[list[Job], str]: + return self.get_one_by_field( + credentials=credentials, field_name="result_id", field_value=str(uid) + ) + @serializable() class JobV1(SyncableSyftObject): diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index 6c6ce1ab400..93b78f06489 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -323,11 +323,11 @@ def has_output_read_permissions( roles=ADMIN_ROLE_LEVEL, ) def get_by_job_id( - self, context: AuthedServiceContext, user_code_id: UID + self, context: AuthedServiceContext, job_id: UID ) -> ExecutionOutput | None | SyftError: result = self.stash.get_by_job_id( credentials=context.server.verify_key, # type: ignore - user_code_id=user_code_id, + job_id=job_id, ) if result.is_ok(): return result.ok() diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index 32686cb75dc..952d3bf8c22 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -11,6 +11,7 @@ # relative from ...client.api import ServerIdentity from ...serde.serializable import serializable +from ...store.db.stash import ObjectStash from ...store.document_store import BaseStash from ...store.document_store import DocumentStore from ...store.linked_obj import LinkedObject @@ -40,12 +41,12 @@ logger = logging.getLogger(__name__) -def get_store(context: AuthedServiceContext, item: SyncableSyftObject) -> Any: +def get_store(context: AuthedServiceContext, item: SyncableSyftObject) -> ObjectStash: if isinstance(item, ActionObject): service = context.server.get_service("actionservice") # type: ignore return service.store # type: ignore service = context.server.get_service(TYPE_TO_SERVICE[type(item)]) # type: ignore - return service.stash.partition + return service.stash @instrument diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 8c738d231d8..81c9ce0b83f 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -507,18 +507,22 @@ def remove_storage_permission( return None def _get_storage_permissions_for_uid(self, uid: UID) -> Result[set[UID], str]: - stmt = self.table.select( - self.table.c.id, self.table.c.storage_permissions - ).where(self.table.c.id == uid) + stmt = select(self.table.c.id, self.table.c.storage_permissions).where( + self.table.c.id == uid + ) result = self.session.execute(stmt).first() if result is None: return Err(f"No storage permissions found for uid: {uid}") - return Ok(set(result.storage_permissions)) + return Ok({UID(uid) for uid in result.storage_permissions}) def get_all_storage_permissions(self) -> Result[dict[UID, set[UID]], str]: - stmt = self.table.select(self.table.c.id, self.table.c.storage_permissions) + stmt = select(self.table.c.id, self.table.c.storage_permissions) results = self.session.execute(stmt).all() - return Ok({row.id: set(row.storage_permissions) for row in results}) + + # make uid + return Ok( + {row.id: {(UID(uid) for uid in row.storage_permissions)} for row in results} + ) def has_permission(self, permission: ActionObjectPermission) -> bool: return self.has_permissions([permission]) From 0789e08e3f09e035c3529ee674154754d23ef074 Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Mon, 26 Aug 2024 12:14:59 +0800 Subject: [PATCH 042/197] Add postgres document store --- packages/syft/setup.cfg | 1 + .../syft/store/postgresql_document_store.py | 137 ++++++++++++++++++ 2 files changed, 138 insertions(+) create mode 100644 packages/syft/src/syft/store/postgresql_document_store.py diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 9d97a98589e..76a20100533 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -67,6 +67,7 @@ syft = jinja2==3.1.4 tenacity==8.3.0 nh3==0.2.17 + psycopg2-binary==2.9.9 install_requires = %(syft)s diff --git a/packages/syft/src/syft/store/postgresql_document_store.py b/packages/syft/src/syft/store/postgresql_document_store.py new file mode 100644 index 00000000000..bed00634f16 --- /dev/null +++ b/packages/syft/src/syft/store/postgresql_document_store.py @@ -0,0 +1,137 @@ +# stdlib +from collections import defaultdict +import logging + +# third party +import psycopg2 +from pydantic import Field + +# relative +from ..serde.serializable import serializable +from .document_store import DocumentStore +from .document_store import PartitionSettings +from .document_store import StoreClientConfig +from .document_store import StoreConfig +from .kv_document_store import KeyValueBackingStore +from .locks import LockingConfig +from .locks import NoLockingConfig +from .locks import SyftLock +from .sqlite_document_store import SQLiteBackingStore +from .sqlite_document_store import SQLiteStorePartition +from .sqlite_document_store import cache_key +from .sqlite_document_store import raise_exception + +logger = logging.getLogger(__name__) + +_CONNECTION_POOL_DB: dict[str, psycopg2.Connection] = {} +_CONNECTION_POOL_CUR: dict[str, psycopg2.Cursor] = {} +REF_COUNTS: dict[str, int] = defaultdict(int) + + +# https://www.psycopg.org/docs/module.html#psycopg2.connect +@serializable(canonical_name="PostgreSQLStoreClientConfig", version=1) +class PostgreSQLStoreClientConfig(StoreClientConfig): + dbname: str + user: str + password: str + host: str + port: int + + +@serializable(canonical_name="PostgreSQLStorePartition", version=1) +class PostgreSQLStorePartition(SQLiteStorePartition): + pass + + +@serializable(canonical_name="PostgreSQLDocumentStore", version=1) +class PostgreSQLDocumentStore(DocumentStore): + partition_type = PostgreSQLStorePartition + + +@serializable( + attrs=["index_name", "settings", "store_config"], + canonical_name="PostgreSQLBackingStore", + version=1, +) +class PostgreSQLBackingStore(SQLiteBackingStore): + def __init__( + self, + index_name: str, + settings: PartitionSettings, + store_config: StoreConfig, + ddtype: type | None = None, + ) -> None: + self.index_name = index_name + self.settings = settings + self.store_config = store_config + self._ddtype = ddtype + if self.store_config.client_config: + self.dbname = self.store_config.client_config.dbname + + self.lock = SyftLock(NoLockingConfig()) + self.create_table() + REF_COUNTS[cache_key(self.dbname)] += 1 + + def _connect(self) -> None: + if self.store_config.client_config: + connection = psycopg2.connect( + dbname=self.store_config.client_config.dbname, + user=self.store_config.client_config.user, + password=self.store_config.client_config.password, + host=self.store_config.client_config.host, + port=self.store_config.client_config.port, + ) + + _CONNECTION_POOL_DB[cache_key(self.dbname)] = connection + + def create_table(self) -> None: + try: + with self.lock: + self.cur.execute( + f"create table {self.table_name} (uid VARCHAR(32) NOT NULL PRIMARY KEY, " # nosec + + "repr TEXT NOT NULL, value BYTEA NOT NULL, " # nosec + + "sqltime TIMESTAMP NOT NULL DEFAULT NOW())" # nosec + ) + self.db.commit() + except Exception as e: + raise_exception(self.table_name, e) + + @property + def db(self) -> psycopg2.Connection: + if cache_key(self.dbname) not in _CONNECTION_POOL_DB: + self._connect() + return _CONNECTION_POOL_DB[cache_key(self.dbname)] + + @property + def cur(self) -> psycopg2.Cursor: + if cache_key(self.db_filename) not in _CONNECTION_POOL_CUR: + _CONNECTION_POOL_CUR[cache_key(self.dbname)] = self.db.cursor() + + return _CONNECTION_POOL_CUR[cache_key(self.dbname)] + + def _close(self) -> None: + self._commit() + REF_COUNTS[cache_key(self.db_filename)] -= 1 + if REF_COUNTS[cache_key(self.db_filename)] <= 0: + # once you close it seems like other object references can't re-use the + # same connection + + self.db.close() + db_key = cache_key(self.db_filename) + if db_key in _CONNECTION_POOL_CUR: + # NOTE if we don't remove the cursor, the cursor cache_key can clash with a future thread id + del _CONNECTION_POOL_CUR[db_key] + del _CONNECTION_POOL_DB[cache_key(self.db_filename)] + else: + # don't close yet because another SQLiteBackingStore is probably still open + pass + + +@serializable() +class PostgreSQLStoreConfig(StoreConfig): + __canonical_name__ = "PostgreSQLStorePartition" + + client_config: PostgreSQLStoreClientConfig + store_type: type[DocumentStore] = PostgreSQLDocumentStore + backing_store: type[KeyValueBackingStore] = PostgreSQLBackingStore + locking_config: LockingConfig = Field(default_factory=NoLockingConfig) From 439f400174cd0066b25e20ef29b01c63f034f37a Mon Sep 17 00:00:00 2001 From: Kien Dang Date: Mon, 26 Aug 2024 12:16:26 +0800 Subject: [PATCH 043/197] Add postgres action store --- packages/syft/src/syft/service/action/action_store.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index 250b3c5e9b5..2ddc36942d1 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -424,3 +424,8 @@ class MongoActionStore(KeyValueActionStore): """ pass + + +@serializable(canonical_name="PostgreSQLActionStore", version=1) +class PostgreSQLActionStore(KeyValueActionStore): + pass From b0be81a623c1da9243f385416014b10a51a0d301 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 26 Aug 2024 10:13:59 +0200 Subject: [PATCH 044/197] add project stash --- .../src/syft/service/project/project_stash.py | 32 +++++++------------ .../tests/syft/stores/store_fixtures_test.py | 3 -- 2 files changed, 12 insertions(+), 23 deletions(-) diff --git a/packages/syft/src/syft/service/project/project_stash.py b/packages/syft/src/syft/service/project/project_stash.py index dcd258938b3..268af4af911 100644 --- a/packages/syft/src/syft/service/project/project_stash.py +++ b/packages/syft/src/syft/service/project/project_stash.py @@ -6,12 +6,9 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash +from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from ...store.document_store import UIDPartitionKey -from ...types.uid import UID from ...util.telemetry import instrument from ..request.request import Request from ..response import SyftError @@ -22,32 +19,27 @@ @instrument -@serializable(canonical_name="ProjectStash", version=1) -class ProjectStash(BaseUIDStoreStash): +@serializable(canonical_name="ProjectSQLStash", version=1) +class ProjectStash(ObjectStash[Project]): object_type = Project settings: PartitionSettings = PartitionSettings( name=Project.__canonical_name__, object_type=Project ) def get_all_for_verify_key( - self, credentials: SyftVerifyKey, verify_key: VerifyKeyPartitionKey + self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> Result[list[Request], SyftError]: - if isinstance(verify_key, str): - verify_key = SyftVerifyKey.from_string(verify_key) - qks = QueryKeys(qks=[VerifyKeyPartitionKey.with_obj(verify_key)]) - return self.query_all( + return self.get_all_by_field( credentials=credentials, - qks=qks, + field_name="user_verify_key", + field_value=str(verify_key), ) - def get_by_uid( - self, credentials: SyftVerifyKey, uid: UID - ) -> Result[Project | None, str]: - qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)]) - return self.query_one(credentials=credentials, qks=qks) - def get_by_name( self, credentials: SyftVerifyKey, project_name: str ) -> Result[Project | None, str]: - qks = QueryKeys(qks=[NamePartitionKey.with_obj(project_name)]) - return self.query_one(credentials=credentials, qks=qks) + return self.get_one_by_field( + credentials=credentials, + field_name="name", + field_value=project_name, + ) diff --git a/packages/syft/tests/syft/stores/store_fixtures_test.py b/packages/syft/tests/syft/stores/store_fixtures_test.py index eb4aabd3cf4..cb4a5ecb64c 100644 --- a/packages/syft/tests/syft/stores/store_fixtures_test.py +++ b/packages/syft/tests/syft/stores/store_fixtures_test.py @@ -13,9 +13,6 @@ from syft.server.credentials import SyftVerifyKey from syft.service.action.action_permissions import ActionObjectPermission from syft.service.action.action_permissions import ActionPermission -from syft.service.action.action_store import DictActionStore -from syft.service.action.action_store import MongoActionStore -from syft.service.action.action_store import SQLiteActionStore from syft.service.queue.queue_stash import QueueStash from syft.service.user.user import User from syft.service.user.user import UserCreate From 94d80d409b429e18674b9140edd6d6aae00fe558 Mon Sep 17 00:00:00 2001 From: dk Date: Mon, 26 Aug 2024 15:41:09 +0700 Subject: [PATCH 045/197] [syft/stores] Update `server.py` to use `PostgreSQL` stores --- packages/syft/src/syft/server/server.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index eca1b56a8d8..0424bfc29d5 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -43,7 +43,7 @@ from ..service.action.action_object import ActionObject from ..service.action.action_store import ActionStore from ..service.action.action_store import DictActionStore -from ..service.action.action_store import MongoActionStore +from ..service.action.action_store import PostgreSQLActionStore from ..service.action.action_store import SQLiteActionStore from ..service.blob_storage.service import BlobStorageService from ..service.code.user_code_service import UserCodeService @@ -100,7 +100,7 @@ from ..store.dict_document_store import DictStoreConfig from ..store.document_store import StoreConfig from ..store.linked_obj import LinkedObject -from ..store.mongo_document_store import MongoStoreConfig +from ..store.postgresql_document_store import PostgreSQLStoreConfig from ..store.sqlite_document_store import SQLiteStoreClientConfig from ..store.sqlite_document_store import SQLiteStoreConfig from ..types.datetime import DATETIME_FORMAT @@ -842,13 +842,6 @@ def init_stores( document_store_config: StoreConfig, action_store_config: StoreConfig, ) -> None: - # We add the python id of the current server in order - # to create one connection per Server object in MongoClientCache - # so that we avoid closing the connection from a - # different thread through the garbage collection - if isinstance(document_store_config, MongoStoreConfig): - document_store_config.client_config.server_obj_python_id = id(self) - self.document_store_config = document_store_config self.document_store = document_store_config.store_type( server_uid=self.id, @@ -863,17 +856,11 @@ def init_stores( root_verify_key=self.verify_key, document_store=self.document_store, ) - elif isinstance(action_store_config, MongoStoreConfig): - # We add the python id of the current server in order - # to create one connection per Server object in MongoClientCache - # so that we avoid closing the connection from a - # different thread through the garbage collection - action_store_config.client_config.server_obj_python_id = id(self) - - self.action_store = MongoActionStore( + elif isinstance(action_store_config, PostgreSQLStoreConfig): + self.action_store = PostgreSQLActionStore( server_uid=self.id, - root_verify_key=self.verify_key, store_config=action_store_config, + root_verify_key=self.verify_key, document_store=self.document_store, ) else: From 42aced9d51f340d766c58d514327b59bef528f28 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 26 Aug 2024 17:52:53 +0200 Subject: [PATCH 046/197] move sqlite connection to Server --- packages/syft/src/syft/server/server.py | 59 ++++++++++++------- .../syft/src/syft/server/service_registry.py | 21 ++++++- .../src/syft/service/action/action_service.py | 2 + .../src/syft/service/action/action_store.py | 2 - .../syft/src/syft/service/api/api_stash.py | 4 -- .../service/blob_storage/remote_profile.py | 8 +-- .../src/syft/service/blob_storage/stash.py | 5 -- .../src/syft/service/code/status_service.py | 4 -- .../src/syft/service/code/user_code_stash.py | 4 -- .../code_history/code_history_stash.py | 4 -- .../src/syft/service/dataset/dataset_stash.py | 4 -- .../syft/src/syft/service/job/job_stash.py | 4 -- .../syft/src/syft/service/log/log_stash.py | 4 -- .../src/syft/service/output/output_service.py | 5 -- .../src/syft/service/request/request_stash.py | 1 - .../syft/service/settings/settings_stash.py | 4 -- .../syft/src/syft/service/sync/sync_stash.py | 5 +- .../syft/src/syft/service/user/user_stash.py | 11 ++-- packages/syft/src/syft/store/db/sqlite_db.py | 43 ++++++++++++-- packages/syft/src/syft/store/db/stash.py | 33 ++++++++--- .../syft/src/syft/store/document_store.py | 24 +------- 21 files changed, 133 insertions(+), 118 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index bea67abb5b9..09f6f437aca 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -23,6 +23,7 @@ from nacl.signing import SigningKey from result import Err from result import Result +from syft.store.db.sqlite_db import SQLiteDBConfig, SQLiteDBManager # relative from .. import __version__ @@ -386,6 +387,7 @@ def __init__( use_sqlite=use_sqlite, store_type="Action Store", ) + self.init_stores( action_store_config=action_store_config, document_store_config=document_store_config, @@ -393,6 +395,9 @@ def __init__( # construct services only after init stores self.services: ServiceRegistry = ServiceRegistry.for_server(self) + self.db.init_tables() + self.services.user.stash.init_root_user() + self.action_store = self.services.action.store create_admin_new( # nosec B106 name=root_username, @@ -852,28 +857,38 @@ def init_stores( store_config=document_store_config, ) - if isinstance(action_store_config, SQLiteStoreConfig): - self.action_store: ActionObjectStash = ActionObjectStash( - store=self.document_store, - ) - elif isinstance(action_store_config, MongoStoreConfig): - # We add the python id of the current server in order - # to create one connection per Server object in MongoClientCache - # so that we avoid closing the connection from a - # different thread through the garbage collection - action_store_config.client_config.server_obj_python_id = id(self) - - self.action_store = ActionObjectStash( - store=self.document_store, - ) - else: - self.action_store = ActionObjectStash( - store=self.document_store, - ) + # if isinstance(action_store_config, SQLiteStoreConfig): + # self.action_store: ActionObjectStash = ActionObjectStash( + # store=self.document_store, + # ) + # elif isinstance(action_store_config, MongoStoreConfig): + # # We add the python id of the current server in order + # # to create one connection per Server object in MongoClientCache + # # so that we avoid closing the connection from a + # # different thread through the garbage collection + # action_store_config.client_config.server_obj_python_id = id(self) + + # self.action_store = ActionObjectStash( + # store=self.document_store, + # ) + # else: + # self.action_store = ActionObjectStash( + # store=self.document_store, + # ) self.action_store_config = action_store_config self.queue_stash = QueueStash(store=self.document_store) + json_db_config = SQLiteDBConfig( + filename=f"{self.id}_json.db", + path=self.get_temp_dir("db"), + ) + self.db = SQLiteDBManager( + config=json_db_config, + server_uid=self.id, + root_verify_key=self.signing_key.verify_key, + ) + @property def job_stash(self) -> JobStash: return self.get_service("jobservice").stash @@ -954,7 +969,7 @@ def get_settings(self) -> ServerSettings | None: @property def settings(self) -> ServerSettings: - settings_stash = SettingsStash(store=self.document_store) + settings_stash = self.services.settings.stash if self.signing_key is None: raise ValueError(f"{self} has no signing key") settings = settings_stash.get_all(self.signing_key.verify_key) @@ -1592,7 +1607,7 @@ def get_unauthed_context( def create_initial_settings(self, admin_email: str) -> ServerSettings | None: try: - settings_stash = SettingsStash(store=self.document_store) + settings_stash = SettingsStash(store=self.services.settings.stash) if self.signing_key is None: logger.debug( "create_initial_settings failed as there is no signing key" @@ -1655,10 +1670,10 @@ def create_admin_new( name: str, email: str, password: str, - server: AbstractServer, + server: Server, ) -> User | None: try: - user_stash = UserStash(store=server.document_store) + user_stash = server.services.user.stash row_exists = user_stash.get_by_email( credentials=server.signing_key.verify_key, email=email ).ok() diff --git a/packages/syft/src/syft/server/service_registry.py b/packages/syft/src/syft/server/service_registry.py index 9bd3816162e..3cf81125c07 100644 --- a/packages/syft/src/syft/server/service_registry.py +++ b/packages/syft/src/syft/server/service_registry.py @@ -6,6 +6,8 @@ from typing import Any from typing import TYPE_CHECKING +from syft.store.db.stash import ObjectStash + # relative from ..serde.serializable import serializable from ..service.action.action_service import ActionService @@ -105,13 +107,30 @@ def get_service_classes( if issubclass(cls, AbstractService) } + @classmethod + def _uses_new_store(cls, service_cls: type[AbstractService]) -> bool: + stash_annotation = service_cls.__annotations__.get("stash") + try: + if issubclass(stash_annotation, ObjectStash): + return True + return False + except Exception as e: + return False + @classmethod def _construct_services(cls, server: "Server") -> dict[str, AbstractService]: service_dict = {} for field_name, service_cls in cls.get_service_classes().items(): svc_kwargs: dict[str, Any] = {} - if issubclass(service_cls.store_type, ActionObjectStash): + + # Use new DB + if cls._uses_new_store(service_cls): + svc_kwargs["store"] = server.db + + # Use old DB + elif issubclass(service_cls.store_type, ActionObjectStash): svc_kwargs["store"] = server.action_store + else: svc_kwargs["store"] = server.document_store diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 704de7890b4..8c179119de9 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -56,6 +56,8 @@ @serializable(canonical_name="ActionService", version=1) class ActionService(AbstractService): + stash: ActionObjectStash + def __init__(self, store: DocumentStore) -> None: self.store = ActionObjectStash(store) diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index 4ec3c7bdedd..c8afe5b7c33 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -25,8 +25,6 @@ @serializable(canonical_name="ActionObjectSQLStore", version=1) class ActionObjectStash(ObjectStash[ActionObject]): - object_type = ActionObject - def get( self, uid: UID, credentials: SyftVerifyKey, has_permission: bool = False ) -> Result[ActionObject, str]: diff --git a/packages/syft/src/syft/service/api/api_stash.py b/packages/syft/src/syft/service/api/api_stash.py index 94567d3ce05..b1c9bf37037 100644 --- a/packages/syft/src/syft/service/api/api_stash.py +++ b/packages/syft/src/syft/service/api/api_stash.py @@ -18,14 +18,10 @@ @serializable(canonical_name="TwinAPIEndpointSQLStash", version=1) class TwinAPIEndpointStash(ObjectStash[TwinAPIEndpoint]): - object_type = TwinAPIEndpoint settings: PartitionSettings = PartitionSettings( name=TwinAPIEndpoint.__canonical_name__, object_type=TwinAPIEndpoint ) - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - def get_by_path( self, credentials: SyftVerifyKey, path: str ) -> Result[TwinAPIEndpoint, str]: diff --git a/packages/syft/src/syft/service/blob_storage/remote_profile.py b/packages/syft/src/syft/service/blob_storage/remote_profile.py index d3e275625ae..dc8d52c9c26 100644 --- a/packages/syft/src/syft/service/blob_storage/remote_profile.py +++ b/packages/syft/src/syft/service/blob_storage/remote_profile.py @@ -5,6 +5,7 @@ from ...store.document_store import PartitionSettings from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SyftObject +from ...store.db.stash import ObjectStash @serializable() @@ -24,12 +25,9 @@ class AzureRemoteProfile(RemoteProfile): container_name: str -@serializable(canonical_name="RemoteProfileStash", version=1) -class RemoteProfileStash(BaseUIDStoreStash): +@serializable(canonical_name="RemoteProfileSQLStash", version=1) +class RemoteProfileStash(ObjectStash[RemoteProfile]): object_type = RemoteProfile settings: PartitionSettings = PartitionSettings( name=RemoteProfile.__canonical_name__, object_type=RemoteProfile ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) diff --git a/packages/syft/src/syft/service/blob_storage/stash.py b/packages/syft/src/syft/service/blob_storage/stash.py index 67ddd5f8ebb..5cc2dd57c56 100644 --- a/packages/syft/src/syft/service/blob_storage/stash.py +++ b/packages/syft/src/syft/service/blob_storage/stash.py @@ -1,17 +1,12 @@ # relative from ...serde.serializable import serializable from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...types.blob_storage import BlobStorageEntry @serializable(canonical_name="BlobStorageSQLStash", version=1) class BlobStorageStash(ObjectStash[BlobStorageEntry]): - object_type = BlobStorageEntry settings: PartitionSettings = PartitionSettings( name=BlobStorageEntry.__canonical_name__, object_type=BlobStorageEntry ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index 62de0a020fe..1c65e45849c 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -23,15 +23,11 @@ @instrument @serializable(canonical_name="StatusSQLStash", version=1) class StatusStash(ObjectStash[UserCodeStatusCollection]): - object_type = UserCodeStatusCollection settings: PartitionSettings = PartitionSettings( name=UserCodeStatusCollection.__canonical_name__, object_type=UserCodeStatusCollection, ) - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - @instrument @serializable(canonical_name="UserCodeStatusService", version=1) diff --git a/packages/syft/src/syft/service/code/user_code_stash.py b/packages/syft/src/syft/service/code/user_code_stash.py index 6d514f0a448..046bdece148 100644 --- a/packages/syft/src/syft/service/code/user_code_stash.py +++ b/packages/syft/src/syft/service/code/user_code_stash.py @@ -16,14 +16,10 @@ @instrument @serializable(canonical_name="UserCodeSQLStash", version=1) class UserCodeStash(ObjectStash[UserCode]): - object_type = UserCode settings: PartitionSettings = PartitionSettings( name=UserCode.__canonical_name__, object_type=UserCode ) - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - def get_by_code_hash( self, credentials: SyftVerifyKey, code_hash: str ) -> Result[UserCode | None, str]: diff --git a/packages/syft/src/syft/service/code_history/code_history_stash.py b/packages/syft/src/syft/service/code_history/code_history_stash.py index 01acd28b41f..60d733c2e8d 100644 --- a/packages/syft/src/syft/service/code_history/code_history_stash.py +++ b/packages/syft/src/syft/service/code_history/code_history_stash.py @@ -18,14 +18,10 @@ @serializable(canonical_name="CodeHistoryStashSQL", version=1) class CodeHistoryStash(ObjectStash[CodeHistory]): - object_type = CodeHistory settings: PartitionSettings = PartitionSettings( name=CodeHistory.__canonical_name__, object_type=CodeHistory ) - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - def get_by_service_func_name_and_verify_key( self, credentials: SyftVerifyKey, diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index 81f06098c06..63ad3f8bafa 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -19,14 +19,10 @@ @instrument @serializable(canonical_name="DatasetStashSQL", version=1) class DatasetStash(ObjectStash[Dataset]): - object_type = Dataset settings: PartitionSettings = PartitionSettings( name=Dataset.__canonical_name__, object_type=Dataset ) - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - def get_by_name( self, credentials: SyftVerifyKey, name: str ) -> Result[Dataset | None, str]: diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 97e06741213..4206ffb3272 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -842,14 +842,10 @@ def from_job( @instrument @serializable(canonical_name="JobStashSQL", version=1) class JobStash(ObjectStash[Job]): - object_type = Job settings: PartitionSettings = PartitionSettings( name=Job.__canonical_name__, object_type=Job ) - def __init__(self, store: DocumentStore) -> None: - super().__init__(store) - def set_result( self, credentials: SyftVerifyKey, diff --git a/packages/syft/src/syft/service/log/log_stash.py b/packages/syft/src/syft/service/log/log_stash.py index 54d7c9ba04b..360284ca0a3 100644 --- a/packages/syft/src/syft/service/log/log_stash.py +++ b/packages/syft/src/syft/service/log/log_stash.py @@ -10,10 +10,6 @@ @instrument @serializable(canonical_name="LogStash", version=1) class LogStash(ObjectStash[SyftLog]): - object_type = SyftLog settings: PartitionSettings = PartitionSettings( name=SyftLog.__canonical_name__, object_type=SyftLog ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index 93b78f06489..e489cfa719d 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -192,15 +192,10 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: @instrument @serializable(canonical_name="OutputStashSQL", version=1) class OutputStash(ObjectStash[ExecutionOutput]): - object_type = ExecutionOutput settings: PartitionSettings = PartitionSettings( name=ExecutionOutput.__canonical_name__, object_type=ExecutionOutput ) - def __init__(self, store: DocumentStore) -> None: - super().__init__(store) - self.store = store - def get_by_user_code_id( self, credentials: SyftVerifyKey, user_code_id: UID ) -> Result[list[ExecutionOutput], str]: diff --git a/packages/syft/src/syft/service/request/request_stash.py b/packages/syft/src/syft/service/request/request_stash.py index 7af8db6f119..1afa60addd1 100644 --- a/packages/syft/src/syft/service/request/request_stash.py +++ b/packages/syft/src/syft/service/request/request_stash.py @@ -24,7 +24,6 @@ @instrument @serializable(canonical_name="RequestStashSQL", version=1) class RequestStash(ObjectStash[Request]): - object_type = Request settings: PartitionSettings = PartitionSettings( name=Request.__canonical_name__, object_type=Request ) diff --git a/packages/syft/src/syft/service/settings/settings_stash.py b/packages/syft/src/syft/service/settings/settings_stash.py index 49adfe8979f..e7d6e260e16 100644 --- a/packages/syft/src/syft/service/settings/settings_stash.py +++ b/packages/syft/src/syft/service/settings/settings_stash.py @@ -19,10 +19,6 @@ @instrument @serializable(canonical_name="SettingsStashSQL", version=1) class SettingsStash(ObjectStash[ServerSettings]): - object_type = ServerSettings settings: PartitionSettings = PartitionSettings( name=ServerSettings.__canonical_name__, object_type=ServerSettings ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) diff --git a/packages/syft/src/syft/service/sync/sync_stash.py b/packages/syft/src/syft/service/sync/sync_stash.py index b3cf80a362f..1769917bf69 100644 --- a/packages/syft/src/syft/service/sync/sync_stash.py +++ b/packages/syft/src/syft/service/sync/sync_stash.py @@ -7,6 +7,7 @@ # third party from result import Ok from result import Result +from syft.store.db.sqlite_db import DBManager # relative from ...serde.serializable import serializable @@ -25,15 +26,13 @@ @instrument @serializable(canonical_name="SyncStash", version=1) class SyncStash(ObjectStash[SyncState]): - object_type = SyncState settings: PartitionSettings = PartitionSettings( name=SyncState.__canonical_name__, object_type=SyncState, ) - def __init__(self, store: DocumentStore): + def __init__(self, store: DBManager) -> None: super().__init__(store) - self.store = store self.last_state: SyncState | None = None def get_latest(self, credentials: SyftVerifyKey) -> Result[SyncState | None, str]: diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index faf4fff404b..38a918e5eb2 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -3,13 +3,13 @@ # third party from result import Ok from result import Result +from syft.store.db.sqlite_db import DBManager # relative from ...serde.serializable import serializable from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from ...util.telemetry import instrument @@ -27,18 +27,15 @@ @instrument @serializable(canonical_name="UserStashSQL", version=1) class UserStash(ObjectStash[User]): - object_type = User settings: PartitionSettings = PartitionSettings( name=User.__canonical_name__, object_type=User, ) - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) + def __init__(self, store: DBManager) -> None: + super().__init__(store) - self._init_root() - - def _init_root(self) -> None: + def init_root_user(self) -> None: # start a transaction users = self.get_all(self.root_verify_key, has_permission=True) if not users: diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index 578ca910c48..397547831ea 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -1,29 +1,64 @@ # stdlib import threading +from pathlib import Path # third party +from pydantic import BaseModel, Field from sqlalchemy import create_engine from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker +from syft.server.credentials import SyftVerifyKey # relative from ...types.uid import UID from .models import Base from .utils import dumps from .utils import loads +import tempfile -class SQLiteDBManager: - def __init__(self, server_uid: UID) -> None: +class DBConfig(BaseModel): + pass + + +class SQLiteDBConfig(DBConfig): + filename: str = "jsondb.sqlite" + path: Path = Field(default_factory=tempfile.gettempdir) + + +class DBManager: + def __init__( + self, + config: DBConfig, + server_uid: UID, + root_verify_key: SyftVerifyKey, + ) -> None: + self.config = config self.server_uid = server_uid - self.path = f"sqlite:////tmp/{str(server_uid)}.db" + self.root_verify_key = root_verify_key + + +class SQLiteDBManager(DBManager): + def __init__( + self, + config: SQLiteDBConfig, + server_uid: UID, + root_verify_key: SyftVerifyKey, + ) -> None: + self.config = config + self.root_verify_key = root_verify_key + self.server_uid = server_uid + + self.filepath = config.path / config.filename + self.path = f"sqlite:///{self.filepath.resolve()}" self.engine = create_engine( self.path, json_serializer=dumps, json_deserializer=loads ) - print(f"Connecting to {self.path}") self.Session = sessionmaker(bind=self.engine) + # TODO use AuthedServiceContext for session management instead of threading.local self.thread_local = threading.local() + def init_tables(self) -> None: Base.metadata.create_all(self.engine) # TODO remove diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 81c9ce0b83f..09c45018c4d 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -3,6 +3,7 @@ # stdlib from typing import Any from typing import Generic +from typing import get_args import uuid # third party @@ -35,25 +36,41 @@ from ...service.user.user_roles import ServiceRole from ...types.syft_object import SyftObject from ...types.uid import UID -from ..document_store import DocumentStore from .models import Base from .models import UIDTypeDecorator -from .sqlite_db import SQLiteDBManager +from .sqlite_db import DBManager SyftT = TypeVar("SyftT", bound=SyftObject) T = TypeVar("T") class ObjectStash(Generic[SyftT]): - object_type: type[SyftT] table: Table + object_type: type[SyftT] - def __init__(self, store: DocumentStore) -> None: - self.server_uid = store.server_uid - self.root_verify_key = store.root_verify_key - # is there a better way to init the table + def __init__(self, store: DBManager) -> None: + self.db = store + self.object_type = self.get_object_type() self.table = self._create_table() - self.db = SQLiteDBManager(self.server_uid) + + @classmethod + def get_object_type(cls) -> type[SyftT]: + generic_args = get_args(cls.__orig_bases__[0]) + if len(generic_args) != 1: + raise TypeError("ObjectStash must have a single generic argument") + elif not issubclass(generic_args[0], SyftObject): + raise TypeError( + "ObjectStash generic argument must be a subclass of SyftObject" + ) + return generic_args[0] + + @property + def server_uid(self) -> UID: + return self.db.server_uid + + @property + def root_verify_key(self) -> SyftVerifyKey: + return self.db.root_verify_key def check_type(self, obj: T, type_: type) -> Result[T, str]: return ( diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index 4012bc25ca5..5211cc0fed9 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -596,30 +596,8 @@ def __init__( def __has_admin_permissions( self, settings: PartitionSettings ) -> Callable[[SyftVerifyKey], bool]: - # relative - from ..service.user.user import User - from ..service.user.user_roles import ServiceRole - from ..service.user.user_stash import UserStash - - # leave out UserStash to avoid recursion - # TODO: pass the callback from BaseStash instead of DocumentStore - # so that this works with UserStash after the sqlite thread fix is merged - if settings.object_type is User: - return lambda credentials: False - - user_stash = UserStash(store=self) - def has_admin_permissions(credentials: SyftVerifyKey) -> bool: - res = user_stash.get_by_verify_key( - credentials=credentials, - verify_key=credentials, - ) - - return ( - res.is_ok() - and (user := res.ok()) is not None - and user.role in (ServiceRole.DATA_OWNER, ServiceRole.ADMIN) - ) + return credentials == self.root_verify_key return has_admin_permissions From 5762e6debe862c7db8034666f291199a8e78050e Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 26 Aug 2024 17:53:02 +0200 Subject: [PATCH 047/197] move sqlite connection to Server --- packages/syft/src/syft/server/server.py | 5 ++--- packages/syft/src/syft/server/service_registry.py | 5 ++--- packages/syft/src/syft/service/api/api_stash.py | 1 - .../syft/src/syft/service/blob_storage/remote_profile.py | 4 +--- packages/syft/src/syft/service/code/user_code_stash.py | 1 - .../src/syft/service/code_history/code_history_stash.py | 1 - packages/syft/src/syft/service/dataset/dataset_stash.py | 1 - packages/syft/src/syft/service/job/job_stash.py | 1 - packages/syft/src/syft/service/log/log_stash.py | 1 - .../syft/src/syft/service/settings/settings_stash.py | 1 - packages/syft/src/syft/service/sync/sync_stash.py | 3 +-- packages/syft/src/syft/service/user/user_stash.py | 2 +- packages/syft/src/syft/store/db/sqlite_db.py | 9 +++++---- 13 files changed, 12 insertions(+), 23 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 09f6f437aca..4cec2ffb21b 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -23,7 +23,6 @@ from nacl.signing import SigningKey from result import Err from result import Result -from syft.store.db.sqlite_db import SQLiteDBConfig, SQLiteDBManager # relative from .. import __version__ @@ -41,7 +40,6 @@ from ..protocol.data_protocol import get_data_protocol from ..service.action.action_object import Action from ..service.action.action_object import ActionObject -from ..service.action.action_store import ActionObjectStash from ..service.blob_storage.service import BlobStorageService from ..service.code.user_code_service import UserCodeService from ..service.code.user_code_stash import UserCodeStash @@ -81,7 +79,6 @@ from ..service.user.user import UserView from ..service.user.user_roles import ServiceRole from ..service.user.user_service import UserService -from ..service.user.user_stash import UserStash from ..service.worker.utils import DEFAULT_WORKER_IMAGE_TAG from ..service.worker.utils import DEFAULT_WORKER_POOL_NAME from ..service.worker.utils import create_default_image @@ -94,6 +91,8 @@ from ..store.blob_storage.on_disk import OnDiskBlobStorageClientConfig from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig from ..store.blob_storage.seaweedfs import SeaweedFSBlobDeposit +from ..store.db.sqlite_db import SQLiteDBConfig +from ..store.db.sqlite_db import SQLiteDBManager from ..store.dict_document_store import DictStoreConfig from ..store.document_store import StoreConfig from ..store.linked_obj import LinkedObject diff --git a/packages/syft/src/syft/server/service_registry.py b/packages/syft/src/syft/server/service_registry.py index 3cf81125c07..eed65ae8382 100644 --- a/packages/syft/src/syft/server/service_registry.py +++ b/packages/syft/src/syft/server/service_registry.py @@ -6,8 +6,6 @@ from typing import Any from typing import TYPE_CHECKING -from syft.store.db.stash import ObjectStash - # relative from ..serde.serializable import serializable from ..service.action.action_service import ActionService @@ -42,6 +40,7 @@ from ..service.worker.worker_image_service import SyftWorkerImageService from ..service.worker.worker_pool_service import SyftWorkerPoolService from ..service.worker.worker_service import WorkerService +from ..store.db.stash import ObjectStash if TYPE_CHECKING: # relative @@ -114,7 +113,7 @@ def _uses_new_store(cls, service_cls: type[AbstractService]) -> bool: if issubclass(stash_annotation, ObjectStash): return True return False - except Exception as e: + except Exception: return False @classmethod diff --git a/packages/syft/src/syft/service/api/api_stash.py b/packages/syft/src/syft/service/api/api_stash.py index b1c9bf37037..703b74f48b1 100644 --- a/packages/syft/src/syft/service/api/api_stash.py +++ b/packages/syft/src/syft/service/api/api_stash.py @@ -9,7 +9,6 @@ from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from .api import TwinAPIEndpoint diff --git a/packages/syft/src/syft/service/blob_storage/remote_profile.py b/packages/syft/src/syft/service/blob_storage/remote_profile.py index dc8d52c9c26..2c0cd38ee79 100644 --- a/packages/syft/src/syft/service/blob_storage/remote_profile.py +++ b/packages/syft/src/syft/service/blob_storage/remote_profile.py @@ -1,11 +1,9 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import BaseUIDStoreStash -from ...store.document_store import DocumentStore +from ...store.db.stash import ObjectStash from ...store.document_store import PartitionSettings from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SyftObject -from ...store.db.stash import ObjectStash @serializable() diff --git a/packages/syft/src/syft/service/code/user_code_stash.py b/packages/syft/src/syft/service/code/user_code_stash.py index 046bdece148..98c246a90dc 100644 --- a/packages/syft/src/syft/service/code/user_code_stash.py +++ b/packages/syft/src/syft/service/code/user_code_stash.py @@ -7,7 +7,6 @@ from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...util.telemetry import instrument from .user_code import UserCode diff --git a/packages/syft/src/syft/service/code_history/code_history_stash.py b/packages/syft/src/syft/service/code_history/code_history_stash.py index 60d733c2e8d..6a9db58ed35 100644 --- a/packages/syft/src/syft/service/code_history/code_history_stash.py +++ b/packages/syft/src/syft/service/code_history/code_history_stash.py @@ -7,7 +7,6 @@ from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from .code_history import CodeHistory diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index 63ad3f8bafa..fe934f10bf3 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -9,7 +9,6 @@ from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...types.uid import UID from ...util.telemetry import instrument diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 4206ffb3272..8f2d0ff78ff 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -24,7 +24,6 @@ from ...service.context import AuthedServiceContext from ...service.worker.worker_pool import SyftWorker from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...types.datetime import DateTime from ...types.datetime import format_timedelta diff --git a/packages/syft/src/syft/service/log/log_stash.py b/packages/syft/src/syft/service/log/log_stash.py index 360284ca0a3..5c1e89c31e9 100644 --- a/packages/syft/src/syft/service/log/log_stash.py +++ b/packages/syft/src/syft/service/log/log_stash.py @@ -1,7 +1,6 @@ # relative from ...serde.serializable import serializable from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...util.telemetry import instrument from .log import SyftLog diff --git a/packages/syft/src/syft/service/settings/settings_stash.py b/packages/syft/src/syft/service/settings/settings_stash.py index e7d6e260e16..5dd62576e8d 100644 --- a/packages/syft/src/syft/service/settings/settings_stash.py +++ b/packages/syft/src/syft/service/settings/settings_stash.py @@ -5,7 +5,6 @@ # relative from ...serde.serializable import serializable from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from ...types.uid import UID diff --git a/packages/syft/src/syft/service/sync/sync_stash.py b/packages/syft/src/syft/service/sync/sync_stash.py index 1769917bf69..aa34f809238 100644 --- a/packages/syft/src/syft/service/sync/sync_stash.py +++ b/packages/syft/src/syft/service/sync/sync_stash.py @@ -7,13 +7,12 @@ # third party from result import Ok from result import Result -from syft.store.db.sqlite_db import DBManager # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from ...types.datetime import DateTime diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index 38a918e5eb2..d4689dc182b 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -3,12 +3,12 @@ # third party from result import Ok from result import Result -from syft.store.db.sqlite_db import DBManager # relative from ...serde.serializable import serializable from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index 397547831ea..475b2dc8cdf 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -1,20 +1,21 @@ # stdlib -import threading from pathlib import Path +import tempfile +import threading # third party -from pydantic import BaseModel, Field +from pydantic import BaseModel +from pydantic import Field from sqlalchemy import create_engine from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker -from syft.server.credentials import SyftVerifyKey # relative +from ...server.credentials import SyftVerifyKey from ...types.uid import UID from .models import Base from .utils import dumps from .utils import loads -import tempfile class DBConfig(BaseModel): From 78fe0c760c9221be01780620fc5473f91967066d Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 26 Aug 2024 19:14:30 +0200 Subject: [PATCH 048/197] implement notification stash --- .vscode/launch.json | 7 ++ .../notification/notification_stash.py | 62 ++++++----------- .../syft/service/notifier/notifier_service.py | 14 ++-- .../syft/service/notifier/notifier_stash.py | 67 ++++--------------- packages/syft/src/syft/store/db/stash.py | 8 +-- 5 files changed, 50 insertions(+), 108 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index bb5d6e9c00a..7e30fe06537 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -4,6 +4,13 @@ // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + }, { "name": "Syft Debugger", "type": "debugpy", diff --git a/packages/syft/src/syft/service/notification/notification_stash.py b/packages/syft/src/syft/service/notification/notification_stash.py index 8521080fee5..fd94dff6ccc 100644 --- a/packages/syft/src/syft/service/notification/notification_stash.py +++ b/packages/syft/src/syft/service/notification/notification_stash.py @@ -6,11 +6,11 @@ from result import Result # relative +from ...serde.json_serde import serialize_json from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import BaseUIDStoreStash +from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings from ...store.document_store import QueryKeys from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime @@ -33,45 +33,29 @@ @instrument -@serializable(canonical_name="NotificationStash", version=1) -class NotificationStash(BaseUIDStoreStash): - object_type = Notification - settings: PartitionSettings = PartitionSettings( - name=Notification.__canonical_name__, - object_type=Notification, - ) - +@serializable(canonical_name="NotificationSQLStash", version=1) +class NotificationStash(ObjectStash[Notification]): def get_all_inbox_for_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> Result[list[Notification], str]: - qks = QueryKeys( - qks=[ - ToUserVerifyKeyPartitionKey.with_obj(verify_key), - ] - ) - return self.get_all_for_verify_key( - credentials=credentials, verify_key=verify_key, qks=qks + return self.get_all_by_field( + credentials, field_name="verify_key", field_value=str(verify_key) ) def get_all_sent_for_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> Result[list[Notification], str]: - qks = QueryKeys( - qks=[ - FromUserVerifyKeyPartitionKey.with_obj(verify_key), - ] + return self.get_all_by_field( + credentials, + field_name="from_user_verify_key", + field_value=str(verify_key), ) - return self.get_all_for_verify_key(credentials, verify_key=verify_key, qks=qks) def get_all_for_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey, qks: QueryKeys ) -> Result[list[Notification], str]: - if isinstance(verify_key, str): - verify_key = SyftVerifyKey.from_string(verify_key) - return self.query_all( - credentials, - qks=qks, - order_by=OrderByCreatedAtTimeStampPartitionKey, + return self.get_all_by_field( + credentials, field_name="verify_key", field_value=str(verify_key) ) def get_all_by_verify_key_for_status( @@ -80,16 +64,12 @@ def get_all_by_verify_key_for_status( verify_key: SyftVerifyKey, status: NotificationStatus, ) -> Result[list[Notification], str]: - qks = QueryKeys( - qks=[ - ToUserVerifyKeyPartitionKey.with_obj(verify_key), - StatusPartitionKey.with_obj(status), - ] - ) - return self.query_all( + return self.get_all_by_fields( credentials, - qks=qks, - order_by=OrderByCreatedAtTimeStampPartitionKey, + fields={ + "to_user_verify_key": str(verify_key), + "status": status.value, + }, ) def get_notification_for_linked_obj( @@ -97,12 +77,10 @@ def get_notification_for_linked_obj( credentials: SyftVerifyKey, linked_obj: LinkedObject, ) -> Result[Notification, str]: - qks = QueryKeys( - qks=[ - LinkedObjectPartitionKey.with_obj(linked_obj), - ] + # TODO does this work? + return self.get_one_by_fields( + credentials, fields={"linked_obj": serialize_json(linked_obj)} ) - return self.query_one(credentials=credentials, qks=qks) def update_notification_status( self, credentials: SyftVerifyKey, uid: UID, status: NotificationStatus diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index c8c09ba3d50..e0d87176878 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -80,7 +80,7 @@ def set_notifier_active_to_true( if notifier is None: return SyftError(message="Notifier settings not found.") notifier.active = True - result = self.stash.update(credentials=context.credentials, settings=notifier) + result = self.stash.update(credentials=context.credentials, obj=notifier) if result.is_err(): return SyftError(message=result.err()) return SyftSuccess(message="notifier.active set to true.") @@ -100,7 +100,7 @@ def set_notifier_active_to_false( return SyftError(message="Notifier settings not found.") notifier.active = False - result = self.stash.update(credentials=context.credentials, settings=notifier) + result = self.stash.update(credentials=context.credentials, obj=notifier) if result.is_err(): return SyftError(message=result.err()) return SyftSuccess(message="notifier.active set to false.") @@ -210,7 +210,7 @@ def turn_on( "Email credentials are valid. Updating the notifier settings in the db." ) - result = self.stash.update(credentials=context.credentials, settings=notifier) + result = self.stash.update(credentials=context.credentials, obj=notifier) if result.is_err(): return SyftError(message=result.err()) @@ -237,7 +237,7 @@ def turn_off( notifier = result.ok() notifier.active = False - result = self.stash.update(credentials=context.credentials, settings=notifier) + result = self.stash.update(credentials=context.credentials, obj=notifier) if result.is_err(): return SyftError(message=result.err()) @@ -294,7 +294,7 @@ def init_notifier( """ try: # Create a new NotifierStash since its a static method. - notifier_stash = NotifierStash(store=server.document_store) + notifier_stash = NotifierStash(store=server.db) result = notifier_stash.get(server.signing_key.verify_key) if result.is_err(): raise Exception(f"Could not create notifier: {result}") @@ -348,7 +348,7 @@ def set_email_rate_limit( notifier = notifier.ok() notifier.email_rate_limit[email_type.value] = daily_limit - result = self.stash.update(credentials=context.credentials, settings=notifier) + result = self.stash.update(credentials=context.credentials, obj=notifier) if result.is_err(): return SyftError(message="Couldn't update the notifier.") @@ -410,7 +410,7 @@ def dispatch_notification( ) } - result = self.stash.update(credentials=admin_key, settings=notifier) + result = self.stash.update(credentials=admin_key, obj=notifier) if result.is_err(): return SyftError(message="Couldn't update the notifier.") diff --git a/packages/syft/src/syft/service/notifier/notifier_stash.py b/packages/syft/src/syft/service/notifier/notifier_stash.py index 1d28c90a380..d20292a58b1 100644 --- a/packages/syft/src/syft/service/notifier/notifier_stash.py +++ b/packages/syft/src/syft/service/notifier/notifier_stash.py @@ -8,14 +8,11 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...service.response import SyftError -from ...store.document_store import BaseStash -from ...store.document_store import DocumentStore +from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey from ...store.document_store import PartitionSettings from ...types.uid import UID from ...util.telemetry import instrument -from ..action.action_permissions import ActionObjectPermission from .notifier import NotifierSettings NamePartitionKey = PartitionKey(key="name", type_=str) @@ -23,62 +20,22 @@ @instrument -@serializable(canonical_name="NotifierStash", version=1) -class NotifierStash(BaseStash): +@serializable(canonical_name="NotifierStashSQL", version=1) +class NotifierStash(ObjectStash[NotifierSettings]): object_type = NotifierSettings settings: PartitionSettings = PartitionSettings( name=NotifierSettings.__canonical_name__, object_type=NotifierSettings ) - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - - def admin_verify_key(self) -> SyftVerifyKey: - return self.partition.root_verify_key - # TODO: should this method behave like a singleton? def get(self, credentials: SyftVerifyKey) -> Result[NotifierSettings, Err]: """Get Settings""" - result = self.get_all(credentials) - if result.is_ok(): - settings = result.ok() - if len(settings) == 0: - return Ok( - None - ) # TODO: Stash shouldn't be empty after init. Return Err instead? - result = settings[ - 0 - ] # TODO: Should we check if theres more than one? => Report corruption - return Ok(result) - else: - return Err(SyftError(message=result.err())) - - def set( - self, - credentials: SyftVerifyKey, - settings: NotifierSettings, - add_permissions: list[ActionObjectPermission] | None = None, - add_storage_permission: bool = True, - ignore_duplicates: bool = False, - ) -> Result[NotifierSettings, Err]: - result = self.check_type(settings, self.object_type) - # we dont use and_then logic here as it is hard because of the order of the arguments - if result.is_err(): - return Err(SyftError(message=result.err())) - return super().set( - credentials=credentials, obj=result.ok() - ) # TODO check if result isInstance(Ok) - - def update( - self, - credentials: SyftVerifyKey, - settings: NotifierSettings, - has_permission: bool = False, - ) -> Result[NotifierSettings, Err]: - result = self.check_type(settings, self.object_type) - # we dont use and_then logic here as it is hard because of the order of the arguments - if result.is_err(): - return Err(SyftError(message=result.err())) - return super().update( - credentials=credentials, obj=result.ok() - ) # TODO check if result isInstance(Ok) + # actually get latest settings + results = self.get_all(credentials, limit=1) + match results: + case Ok(settings) if len(settings) > 0: + return Ok(settings[0]) + case Ok(_): + return Ok(None) + case Err(e): + return Err(e) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 09c45018c4d..e31411c2f0f 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -305,13 +305,13 @@ def row_as_obj(self, row: Row) -> SyftT: def get_role(self, credentials: SyftVerifyKey) -> ServiceRole: user_table = Table("User", Base.metadata) - stmt = user_table.select().where( + stmt = select(user_table.c.fields["role"]).where( self._get_field_filter("verify_key", str(credentials), table=user_table), ) - result = self.session.execute(stmt).first() - if result is None: + role = self.session.scalar(stmt) + if role is None: return ServiceRole.GUEST - return ServiceRole[result.fields["role"]] + return ServiceRole[role] def _get_permission_filter( self, From 662c5a6f91f089690430a983189a7e81d52add58 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 26 Aug 2024 19:27:18 +0200 Subject: [PATCH 049/197] signature handling improvements --- packages/syft/src/syft/client/api.py | 8 ++++---- packages/syft/src/syft/serde/signature.py | 9 +++++++++ packages/syft/src/syft/store/db/sqlite_db.py | 13 +++++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index 6e080b25178..ea4707a0ec1 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -31,8 +31,7 @@ from ..serde.serializable import serializable from ..serde.serialize import _serialize from ..serde.signature import Signature -from ..serde.signature import signature_remove_context -from ..serde.signature import signature_remove_self +from ..serde.signature import signature_remove from ..server.credentials import SyftSigningKey from ..server.credentials import SyftVerifyKey from ..service.context import AuthedServiceContext @@ -1111,9 +1110,10 @@ def build_endpoint_tree( api_module = APIModule(path="", refresh_callback=self.refresh_api_callback) for v in endpoints.values(): signature = v.signature + args_to_remove = ["context"] if not v.has_self: - signature = signature_remove_self(signature) - signature = signature_remove_context(signature) + args_to_remove.append("self") + signature = signature_remove(signature, args_to_remove) if isinstance(v, APIEndpoint): endpoint_function = generate_remote_function( self, diff --git a/packages/syft/src/syft/serde/signature.py b/packages/syft/src/syft/serde/signature.py index 23b0a556fca..0887d148367 100644 --- a/packages/syft/src/syft/serde/signature.py +++ b/packages/syft/src/syft/serde/signature.py @@ -86,6 +86,15 @@ def signature_remove_context(signature: Signature) -> Signature: ) +def signature_remove(signature: Signature, args: list[str]) -> Signature: + params = dict(signature.parameters) + for arg in args: + params.pop(arg, None) + return Signature( + list(params.values()), return_annotation=signature.return_annotation + ) + + def get_str_signature_from_docstring(doc: str, callable_name: str) -> str | None: if not doc or callable_name not in doc: return None diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index 475b2dc8cdf..64a5259582a 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -6,6 +6,7 @@ # third party from pydantic import BaseModel from pydantic import Field +import sqlalchemy as sa from sqlalchemy import create_engine from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker @@ -55,10 +56,22 @@ def __init__( self.engine = create_engine( self.path, json_serializer=dumps, json_deserializer=loads ) + print(f"Connecting to {self.path}") self.Session = sessionmaker(bind=self.engine) + # TODO use AuthedServiceContext for session management instead of threading.local self.thread_local = threading.local() + self.update_settings() + + def update_settings(self) -> None: + connection = self.engine.connect() + + connection.execute(sa.text("PRAGMA journal_mode = WAL")) + connection.execute(sa.text("PRAGMA busy_timeout = 5000")) + connection.execute(sa.text("PRAGMA temp_store = 2")) + connection.execute(sa.text("PRAGMA synchronous = 1")) + def init_tables(self) -> None: Base.metadata.create_all(self.engine) From ad4e38b19fd6930d9295327dc7aad7475ae0ad01 Mon Sep 17 00:00:00 2001 From: dk Date: Tue, 27 Aug 2024 15:34:07 +0700 Subject: [PATCH 050/197] [syft/server] replace mongo store config with postgres - add environment and settings variables for postgres - new error handling for postgresql store --- packages/grid/backend/grid/core/config.py | 9 +++--- packages/grid/backend/grid/core/server.py | 31 ++++++++++--------- packages/grid/default.env | 11 ++++++- .../syft/store/postgresql_document_store.py | 19 +++++++----- 4 files changed, 43 insertions(+), 27 deletions(-) diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index 0cd6e026e03..0d50d9c9078 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -126,10 +126,11 @@ def get_emails_enabled(self) -> Self: # NETWORK_CHECK_INTERVAL: int = int(os.getenv("NETWORK_CHECK_INTERVAL", 60)) # DATASITE_CHECK_INTERVAL: int = int(os.getenv("DATASITE_CHECK_INTERVAL", 60)) CONTAINER_HOST: str = str(os.getenv("CONTAINER_HOST", "docker")) - MONGO_HOST: str = str(os.getenv("MONGO_HOST", "")) - MONGO_PORT: int = int(os.getenv("MONGO_PORT", 27017)) - MONGO_USERNAME: str = str(os.getenv("MONGO_USERNAME", "")) - MONGO_PASSWORD: str = str(os.getenv("MONGO_PASSWORD", "")) + POSTGRESQL_DBNAME: str = str(os.getenv("POSTGRESQL_DBNAME", "")) + POSTGRESQL_HOST: str = str(os.getenv("POSTGRESQL_HOST", "")) + POSTGRESQL_PORT: int = int(os.getenv("POSTGRESQL_PORT", 27017)) + POSTGRESQL_USERNAME: str = str(os.getenv("POSTGRESQL_USERNAME", "")) + POSTGRESQL_PASSWORD: str = str(os.getenv("POSTGRESQL_PASSWORD", "")) DEV_MODE: bool = True if os.getenv("DEV_MODE", "false").lower() == "true" else False # ZMQ stuff QUEUE_PORT: int = int(os.getenv("QUEUE_PORT", 5556)) diff --git a/packages/grid/backend/grid/core/server.py b/packages/grid/backend/grid/core/server.py index 9802eece8e8..d284c8edb4e 100644 --- a/packages/grid/backend/grid/core/server.py +++ b/packages/grid/backend/grid/core/server.py @@ -14,8 +14,8 @@ from syft.service.queue.zmq_queue import ZMQQueueConfig from syft.store.blob_storage.seaweedfs import SeaweedFSClientConfig from syft.store.blob_storage.seaweedfs import SeaweedFSConfig -from syft.store.mongo_client import MongoStoreClientConfig -from syft.store.mongo_document_store import MongoStoreConfig +from syft.store.postgresql_document_store import PostgreSQLStoreClientConfig +from syft.store.postgresql_document_store import PostgreSQLStoreConfig from syft.store.sqlite_document_store import SQLiteStoreClientConfig from syft.store.sqlite_document_store import SQLiteStoreConfig from syft.types.uid import UID @@ -36,17 +36,6 @@ def queue_config() -> ZMQQueueConfig: return queue_config -def mongo_store_config() -> MongoStoreConfig: - mongo_client_config = MongoStoreClientConfig( - hostname=settings.MONGO_HOST, - port=settings.MONGO_PORT, - username=settings.MONGO_USERNAME, - password=settings.MONGO_PASSWORD, - ) - - return MongoStoreConfig(client_config=mongo_client_config) - - def sql_store_config() -> SQLiteStoreConfig: client_config = SQLiteStoreClientConfig( filename=f"{UID.from_string(get_server_uid_env())}.sqlite", @@ -55,6 +44,18 @@ def sql_store_config() -> SQLiteStoreConfig: return SQLiteStoreConfig(client_config=client_config) +def postgresql_store_config() -> PostgreSQLStoreConfig: + postgresql_client_config = PostgreSQLStoreClientConfig( + dbname=settings.POSTGRESQL_DBNAME, + host=settings.POSTGRESQL_HOST, + port=settings.POSTGRESQL_PORT, + username=settings.POSTGRESQL_USERNAME, + password=settings.POSTGRESQL_PASSWORD, + ) + + return PostgreSQLStoreConfig(client_config=postgresql_client_config) + + def seaweedfs_config() -> SeaweedFSConfig: seaweed_client_config = SeaweedFSClientConfig( host=settings.S3_ENDPOINT, @@ -87,7 +88,9 @@ def seaweedfs_config() -> SeaweedFSConfig: worker_class = worker_classes[server_type] single_container_mode = settings.SINGLE_CONTAINER_MODE -store_config = sql_store_config() if single_container_mode else mongo_store_config() +store_config = ( + sql_store_config() if single_container_mode else postgresql_store_config() +) blob_storage_config = None if single_container_mode else seaweedfs_config() queue_config = queue_config() diff --git a/packages/grid/default.env b/packages/grid/default.env index 3018a4c2ce2..791778206fa 100644 --- a/packages/grid/default.env +++ b/packages/grid/default.env @@ -115,4 +115,13 @@ ENABLE_SIGNUP=False DOCKER_IMAGE_ENCLAVE_ATTESTATION=openmined/syft-enclave-attestation # Rathole Config -RATHOLE_PORT=2333 \ No newline at end of file +RATHOLE_PORT=2333 + +# PostgresSQL Config +# POSTGRESQL_IMAGE=postgres +# export POSTGRESQL_VERSION="15" +POSTGRESQL_DBNAME=syftdb_postgres +POSTGRESQL_HOST=localhost +POSTGRESQL_PORT=5432 +POSTGRESQL_USERNAME=syft_postgres +POSTGRESQL_PASSWORD=changethis \ No newline at end of file diff --git a/packages/syft/src/syft/store/postgresql_document_store.py b/packages/syft/src/syft/store/postgresql_document_store.py index bed00634f16..c6e02e0d1dc 100644 --- a/packages/syft/src/syft/store/postgresql_document_store.py +++ b/packages/syft/src/syft/store/postgresql_document_store.py @@ -4,10 +4,13 @@ # third party import psycopg2 +from psycopg2.extensions import connection +from psycopg2.extensions import cursor from pydantic import Field # relative from ..serde.serializable import serializable +from ..types.errors import SyftException from .document_store import DocumentStore from .document_store import PartitionSettings from .document_store import StoreClientConfig @@ -19,12 +22,11 @@ from .sqlite_document_store import SQLiteBackingStore from .sqlite_document_store import SQLiteStorePartition from .sqlite_document_store import cache_key -from .sqlite_document_store import raise_exception +from .sqlite_document_store import special_exception_public_message logger = logging.getLogger(__name__) - -_CONNECTION_POOL_DB: dict[str, psycopg2.Connection] = {} -_CONNECTION_POOL_CUR: dict[str, psycopg2.Cursor] = {} +_CONNECTION_POOL_DB: dict[str, connection] = {} +_CONNECTION_POOL_CUR: dict[str, cursor] = {} REF_COUNTS: dict[str, int] = defaultdict(int) @@ -32,7 +34,7 @@ @serializable(canonical_name="PostgreSQLStoreClientConfig", version=1) class PostgreSQLStoreClientConfig(StoreClientConfig): dbname: str - user: str + username: str password: str host: str port: int @@ -94,16 +96,17 @@ def create_table(self) -> None: ) self.db.commit() except Exception as e: - raise_exception(self.table_name, e) + public_message = special_exception_public_message(self.table_name, e) + raise SyftException.from_exception(e, public_message=public_message) @property - def db(self) -> psycopg2.Connection: + def db(self) -> connection: if cache_key(self.dbname) not in _CONNECTION_POOL_DB: self._connect() return _CONNECTION_POOL_DB[cache_key(self.dbname)] @property - def cur(self) -> psycopg2.Cursor: + def cur(self) -> cursor: if cache_key(self.db_filename) not in _CONNECTION_POOL_CUR: _CONNECTION_POOL_CUR[cache_key(self.dbname)] = self.db.cursor() From 48cc473654dcdd66e693f264933078687c587552 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 27 Aug 2024 14:33:44 +0200 Subject: [PATCH 051/197] add stash --- packages/syft/src/syft/store/db/stash.py | 83 +++++++++++++----------- 1 file changed, 45 insertions(+), 38 deletions(-) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index c5e61941b0c..598edc42980 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -7,13 +7,10 @@ import uuid # third party -from result import Err -from result import Ok from result import Result import sqlalchemy as sa from sqlalchemy import Column from sqlalchemy import Row -from sqlalchemy import Select from sqlalchemy import Table from sqlalchemy import func from sqlalchemy import select @@ -113,8 +110,12 @@ def _print_query(self, stmt: sa.sql.select) -> None: ) ) + @property + def unique_fields(self) -> list[str]: + return getattr(self.object_type, "__attr_unique__", []) + def is_unique(self, obj: SyftT) -> bool: - unique_fields = self.object_type.__attr_unique__ + unique_fields = self.unique_fields if not unique_fields: return True filters = [] @@ -346,10 +347,10 @@ def _get_permission_filter( def _apply_limit_offset( self, - stmt: Select, + stmt: T, limit: int | None = None, offset: int | None = None, - ) -> Any: + ) -> T: if offset is not None: stmt = stmt.offset(offset) if limit is not None: @@ -358,10 +359,10 @@ def _apply_limit_offset( def _apply_order_by( self, - stmt: Select, + stmt: T, order_by: str | None = None, sort_order: str = "asc", - ) -> Any: + ) -> T: default_order_by = self.table.c.created_at default_order_by = ( default_order_by.desc() if sort_order == "desc" else default_order_by @@ -375,11 +376,11 @@ def _apply_order_by( def _apply_permission_filter( self, - stmt: Select, + stmt: T, credentials: SyftVerifyKey, permission: ActionPermission = ActionPermission.READ, has_permission: bool = False, - ) -> Any: + ) -> T: if not has_permission: stmt = stmt.where( self._get_permission_filter(credentials, permission=permission) @@ -424,22 +425,15 @@ def update( if not self.is_unique(obj): raise StashException(f"Some fields are not unique for {type(obj).__name__}") - # TODO error handling - has_permission_stmt = ( - self._get_permission_filter(credentials, ActionPermission.WRITE) - if has_permission - else sa.literal(True) - ) - stmt = ( - self.table.update() - .where( - sa.and_( - self._get_field_filter("id", obj.id), - has_permission_stmt, - ) - ) - .values(fields=serialize_json(obj)) + stmt = self.table.update().where(self._get_field_filter("id", obj.id)) + stmt = self._apply_permission_filter( + stmt, + credentials, + has_permission=has_permission, + permission=ActionPermission.WRITE, ) + stmt = stmt.values(fields=serialize_json(obj)) + self.session.execute(stmt) self.session.commit() @@ -463,8 +457,8 @@ def delete_by_uid( stmt = self._apply_permission_filter( stmt, credentials, - has_permission, permission=ActionPermission.WRITE, + has_permission=has_permission, ) self.session.execute(stmt) self.session.commit() @@ -472,11 +466,13 @@ def delete_by_uid( def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: # TODO: should do this in a single transaction + # TODO add error handling for permission in permissions: self.add_permission(permission) return None def add_permission(self, permission: ActionObjectPermission) -> None: + # TODO add error handling stmt = ( self.table.update() .where(self.table.c.id == permission.uid) @@ -493,8 +489,9 @@ def add_permission(self, permission: ActionObjectPermission) -> None: self.session.commit() def remove_permission(self, permission: ActionObjectPermission) -> None: + # TODO not threadsafe try: - permissions = self._get_permissions_for_uid(permission.uid) + permissions = self._get_permissions_for_uid(permission.uid).unwrap() permissions.remove(permission.permission_string) except (NotFoundException, KeyError): # TODO add error handling to permissions @@ -510,8 +507,9 @@ def remove_permission(self, permission: ActionObjectPermission) -> None: return None def remove_storage_permission(self, permission: StoragePermission) -> None: + # TODO not threadsafe try: - permissions = self._get_storage_permissions_for_uid(permission.uid) + permissions = self._get_storage_permissions_for_uid(permission.uid).unwrap() permissions.remove(permission.server_uid) except (NotFoundException, KeyError): # TODO add error handling to permissions @@ -526,6 +524,7 @@ def remove_storage_permission(self, permission: StoragePermission) -> None: self.session.commit() return None + @as_result(StashException) def _get_storage_permissions_for_uid(self, uid: UID) -> set[UID]: stmt = select(self.table.c.id, self.table.c.storage_permissions).where( self.table.c.id == uid @@ -535,13 +534,14 @@ def _get_storage_permissions_for_uid(self, uid: UID) -> set[UID]: raise NotFoundException(f"No storage permissions found for uid: {uid}") return {UID(uid) for uid in result.storage_permissions} - def get_all_storage_permissions(self) -> Result[dict[UID, set[UID]], str]: + @as_result(StashException) + def get_all_storage_permissions(self) -> dict[UID, set[UID]]: stmt = select(self.table.c.id, self.table.c.storage_permissions) results = self.session.execute(stmt).all() - # make uid return { - row.id: {(UID(uid) for uid in row.storage_permissions)} for row in results + UID(row.id): {(UID(uid) for uid in row.storage_permissions)} + for row in results } def has_permission(self, permission: ActionObjectPermission) -> bool: @@ -568,7 +568,7 @@ def has_storage_permissions(self, permissions: list[StoragePermission]) -> bool: return result def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: - # NOTE: maybe we should use a permissions table to check all permissions at once + # TODO: we should use a permissions table to check all permissions at once # TODO: should check for compound permissions permission_filters = [ sa.and_( @@ -606,6 +606,7 @@ def add_storage_permissions(self, permissions: list[StoragePermission]) -> None: for permission in permissions: self.add_storage_permission(permission) + @as_result(StashException) def _get_permissions_for_uid(self, uid: UID) -> set[str]: stmt = select(self.table.c.permissions).where(self.table.c.id == uid) result = self.session.execute(stmt).scalar_one_or_none() @@ -613,11 +614,13 @@ def _get_permissions_for_uid(self, uid: UID) -> set[str]: return NotFoundException(f"No permissions found for uid: {uid}") return set(result) - def get_all_permissions(self) -> Result[dict[UID, set[str]], str]: + @as_result(StashException) + def get_all_permissions(self) -> dict[UID, set[str]]: stmt = select(self.table.c.id, self.table.c.permissions) results = self.session.execute(stmt).all() - return Ok({row.id: set(row.permissions) for row in results}) + return {UID(row.id): set(row.permissions) for row in results} + @as_result(SyftException, StashException) def set( self, credentials: SyftVerifyKey, @@ -626,13 +629,17 @@ def set( add_storage_permission: bool = True, # TODO: check the default value ignore_duplicates: bool = False, ) -> Result[SyftT, str]: - # uid is unique by database constraint uid = obj.id + # check if the object already exists if self.exists(credentials, uid) or not self.is_unique(obj): if ignore_duplicates: - return Ok(obj) - return Err(f"Some fields are not unique for {type(obj).__name__}") + return obj + unique_fields_str = ", ".join(self.unique_fields) + raise SyftException( + public_message=f"Duplication Key Error for {obj}.\n" + f"The fields that should be unique are {unique_fields_str}." + ) permissions = self.get_ownership_permissions(uid, credentials) if add_permissions is not None: @@ -654,4 +661,4 @@ def set( ) self.session.execute(stmt) self.session.commit() - return Ok(obj) + return self.get_by_uid(credentials, uid) From 72c7c58ea75d387f6bb1726e3ce17bd953a90192 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 27 Aug 2024 14:43:19 +0200 Subject: [PATCH 052/197] fix bugs --- .../service/migration/migration_service.py | 2 +- .../syft/service/notifier/notifier_service.py | 2 +- .../syft/service/notifier/notifier_stash.py | 4 ++- .../syft/src/syft/service/user/user_stash.py | 5 +++- packages/syft/src/syft/store/db/stash.py | 27 ++++++++++++------- packages/syft/src/syft/types/result.py | 2 +- 6 files changed, 28 insertions(+), 14 deletions(-) diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index b43c192b428..e9933f95537 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -124,7 +124,7 @@ def _get_partition_from_type( self, context: AuthedServiceContext, object_type: type[SyftObject], - ) -> KeyValueActionStore | StorePartition: + ) -> StorePartition: object_partition: KeyValueActionStore | StorePartition | None = None if issubclass(object_type, ActionObject): object_partition = cast(KeyValueActionStore, context.server.action_store) diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index 5cc12e995cd..9c782cd712b 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -298,7 +298,7 @@ def init_notifier( if should_update: notifier_stash.update( - credentials=server.signing_key.verify_key, settings=notifier + credentials=server.signing_key.verify_key, obj=notifier ).unwrap() else: notifier_stash.set(server.signing_key.verify_key, notifier).unwrap() diff --git a/packages/syft/src/syft/service/notifier/notifier_stash.py b/packages/syft/src/syft/service/notifier/notifier_stash.py index fd7fc15e615..91271feca48 100644 --- a/packages/syft/src/syft/service/notifier/notifier_stash.py +++ b/packages/syft/src/syft/service/notifier/notifier_stash.py @@ -34,4 +34,6 @@ def get(self, credentials: SyftVerifyKey) -> NotifierSettings | None: result = self.get_all(credentials, limit=1).unwrap() if len(result) > 0: return result[0] - return None + raise NotFoundException( + public_message="No settings found for the current user." + ) diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index ed0f78239e6..0e2cf52d238 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -6,6 +6,7 @@ from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result +from ...types.uid import UID from ...util.telemetry import instrument from .user import User from .user_roles import ServiceRole @@ -14,14 +15,16 @@ @instrument @serializable(canonical_name="UserStashSQL", version=1) class UserStash(ObjectStash[User]): + @as_result(StashException) def init_root_user(self) -> None: # start a transaction - users = self.get_all(self.root_verify_key, has_permission=True) + users = self.get_all(self.root_verify_key, has_permission=True).unwrap() if not users: # NOTE this is not thread safe, should use a session and transaction super().set( self.root_verify_key, User( + id=UID(), email="_internal@root.com", role=ServiceRole.ADMIN, verify_key=self.root_verify_key, diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 598edc42980..bc3a6e45aaf 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -7,7 +7,6 @@ import uuid # third party -from result import Result import sqlalchemy as sa from sqlalchemy import Column from sqlalchemy import Row @@ -152,7 +151,9 @@ def get_by_uid( ) -> SyftT: stmt = self.table.select() stmt = stmt.where(self._get_field_filter("id", uid)) - stmt = self._apply_permission_filter(stmt, credentials, has_permission) + stmt = self._apply_permission_filter( + stmt, credentials=credentials, has_permission=has_permission + ) result = self.session.execute(stmt).first() if result is None: @@ -192,7 +193,7 @@ def _get_by_fields( stmt = table.select() stmt = stmt.where(sa.and_(*filters)) stmt = self._apply_permission_filter( - stmt, credentials, has_permission=has_permission + stmt, credentials=credentials, has_permission=has_permission ) stmt = self._apply_order_by(stmt, order_by, sort_order) stmt = self._apply_limit_offset(stmt, limit, offset) @@ -212,7 +213,7 @@ def get_one_by_field( credentials=credentials, fields={field_name: field_value}, has_permission=has_permission, - ) + ).unwrap() @as_result(SyftException, StashException, NotFoundException) def get_one_by_fields( @@ -292,7 +293,9 @@ def get_all_contains( stmt = self.table.select().where( self.table.c.fields[field_name].contains(func.json_quote(field_value)), ) - stmt = self._apply_permission_filter(stmt, credentials, has_permission) + stmt = self._apply_permission_filter( + stmt, credentials=credentials, has_permission=has_permission + ) stmt = self._apply_order_by(stmt, order_by, sort_order) stmt = self._apply_limit_offset(stmt, limit, offset) @@ -377,6 +380,7 @@ def _apply_order_by( def _apply_permission_filter( self, stmt: T, + *, credentials: SyftVerifyKey, permission: ActionPermission = ActionPermission.READ, has_permission: bool = False, @@ -399,7 +403,12 @@ def get_all( ) -> list[SyftT]: stmt = self.table.select() - stmt = self._apply_permission_filter(stmt, credentials, has_permission) + stmt = self._apply_permission_filter( + stmt, + credentials=credentials, + has_permission=has_permission, + permission=ActionPermission.READ, + ) stmt = self._apply_order_by(stmt, order_by, sort_order) stmt = self._apply_limit_offset(stmt, limit, offset) @@ -456,7 +465,7 @@ def delete_by_uid( stmt = self.table.delete().where(self._get_field_filter("id", uid)) stmt = self._apply_permission_filter( stmt, - credentials, + credentials=credentials, permission=ActionPermission.WRITE, has_permission=has_permission, ) @@ -628,7 +637,7 @@ def set( add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, # TODO: check the default value ignore_duplicates: bool = False, - ) -> Result[SyftT, str]: + ) -> SyftT: uid = obj.id # check if the object already exists @@ -661,4 +670,4 @@ def set( ) self.session.execute(stmt) self.session.commit() - return self.get_by_uid(credentials, uid) + return self.get_by_uid(credentials, uid).unwrap() diff --git a/packages/syft/src/syft/types/result.py b/packages/syft/src/syft/types/result.py index 020198bc8bc..b58c97d0ce9 100644 --- a/packages/syft/src/syft/types/result.py +++ b/packages/syft/src/syft/types/result.py @@ -112,7 +112,7 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> Result[T, BE]: if isinstance(output, Ok) or isinstance(output, Err): raise _AsResultError( f"Functions decorated with `as_result` should not return Result.\n" - f"Did you forget to unwrap() the result?\n" + f"Did you forget to unwrap() the result in {func.__name__}?\n" f"result: {output}" ) return Ok(output) From fee78f23133717e62a33e1fc578dc7f178c20fdb Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 27 Aug 2024 15:29:24 +0200 Subject: [PATCH 053/197] fix first notebook --- packages/syft/src/syft/server/server.py | 22 ++++++++++--------- .../syft/src/syft/server/service_registry.py | 4 ++-- .../src/syft/service/action/action_service.py | 2 +- .../src/syft/service/user/user_service.py | 4 +++- .../syft/src/syft/service/user/user_stash.py | 2 +- packages/syft/src/syft/store/db/stash.py | 14 +++++------- 6 files changed, 25 insertions(+), 23 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 83c894bc614..f37fbffacc2 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -71,7 +71,6 @@ from ..service.service import UserServiceConfigRegistry from ..service.settings.settings import ServerSettings from ..service.settings.settings import ServerSettingsUpdate -from ..service.settings.settings_stash import SettingsStash from ..service.user.user import User from ..service.user.user import UserCreate from ..service.user.user import UserView @@ -398,13 +397,13 @@ def __init__( # construct services only after init stores self.services: ServiceRegistry = ServiceRegistry.for_server(self) self.db.init_tables() - self.services.user.stash.init_root_user() + # self.services.user.stash.init_root_user() self.action_store = self.services.action.store - create_admin_new( # nosec B106 + create_admin_new( name=root_username, email=root_email, - password=root_password, + password=root_password, # nosec server=self, ) @@ -875,14 +874,15 @@ def init_stores( self.action_store_config = action_store_config self.queue_stash = QueueStash(store=self.document_store) + # TODO fix database filename + reset json_db_config = SQLiteDBConfig( - filename=f"{self.id}_json.db", + filename=f"{self.id}_{UID().hex}_json.db", path=self.get_temp_dir("db"), ) self.db = SQLiteDBManager( config=json_db_config, server_uid=self.id, - root_verify_key=self.signing_key.verify_key, + root_verify_key=self.verify_key, ) @property @@ -953,7 +953,7 @@ def get_settings(self) -> ServerSettings | None: if self.signing_key is None: raise ValueError(f"{self} has no signing key") - settings_stash = SettingsStash(store=self.document_store) + settings_stash = self.services.settings.stash try: settings = settings_stash.get_all(self.signing_key.verify_key).unwrap() @@ -1702,14 +1702,16 @@ def create_admin_new( # 🟡 TODO: change later but for now this gives the main user super user automatically user = create_user.to(User) user.signing_key = server.signing_key - user.verify_key = user.signing_key.verify_key + user.verify_key = server.verify_key new_user = user_stash.set( - credentials=server.signing_key.verify_key, + credentials=server.verify_key, obj=user, - ignore_duplicates=True, + ignore_duplicates=False, ).unwrap() + print(f"Created admin {new_user}") + logger.debug(f"Created admin {new_user.email}") return new_user diff --git a/packages/syft/src/syft/server/service_registry.py b/packages/syft/src/syft/server/service_registry.py index eed65ae8382..8a2df655d92 100644 --- a/packages/syft/src/syft/server/service_registry.py +++ b/packages/syft/src/syft/server/service_registry.py @@ -124,12 +124,12 @@ def _construct_services(cls, server: "Server") -> dict[str, AbstractService]: # Use new DB if cls._uses_new_store(service_cls): + print("Using new store:", service_cls) svc_kwargs["store"] = server.db # Use old DB elif issubclass(service_cls.store_type, ActionObjectStash): - svc_kwargs["store"] = server.action_store - + svc_kwargs["store"] = server.action_store # type: ignore else: svc_kwargs["store"] = server.document_store diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 84941661b83..5436d8c5712 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -331,7 +331,7 @@ def get_pointer( @service_method(path="action.get_mock", name="get_mock", roles=GUEST_ROLE_LEVEL) def get_mock(self, context: AuthedServiceContext, uid: UID) -> SyftObject: """Get a pointer from the action store""" - return self.store.get_mock(credentials=context.credentials, uid=uid) + return self.store.get_mock(credentials=context.credentials, uid=uid).unwrap() @service_method( path="action.has_storage_permission", diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index 54e87a63ccf..357697fd19e 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -507,7 +507,9 @@ def update( if user.role == ServiceRole.ADMIN: settings_stash = SettingsStash(store=self.store) - settings = settings_stash.get_all(context.credentials).unwrap() + settings = settings_stash.get_all( + context.credentials, limit=1, sort_order="desc" + ).unwrap() # TODO: Chance to refactor here in settings, as we're always doing get_att[0] if len(settings) > 0: diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index 0e2cf52d238..b6140a1735b 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -50,7 +50,7 @@ def get_by_reset_token(self, credentials: SyftVerifyKey, token: str) -> User: @as_result(StashException, NotFoundException) def get_by_email(self, credentials: SyftVerifyKey, email: str) -> User: - self.get_one_by_field( + return self.get_one_by_field( credentials=credentials, field_name="email", field_value=email ).unwrap() diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index bc3a6e45aaf..b3a984db336 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -157,9 +157,7 @@ def get_by_uid( result = self.session.execute(stmt).first() if result is None: - raise NotFoundException( - f"{type(self.object_type).__name__}: {uid} not found" - ) + raise NotFoundException(f"{self.object_type.__name__}: {uid} not found") return self.row_as_obj(result) def _get_field_filter( @@ -228,7 +226,7 @@ def get_one_by_fields( has_permission=has_permission, ).first() if result is None: - raise NotFoundException(f"{type(self.object_type).__name__}: not found") + raise NotFoundException(f"{self.object_type.__name__}: not found") return self.row_as_obj(result) @as_result(SyftException, StashException, NotFoundException) @@ -274,7 +272,7 @@ def get_all_by_field( limit=limit, offset=offset, has_permission=has_permission, - ) + ).unwrap() @as_result(SyftException, StashException, NotFoundException) def get_all_contains( @@ -364,7 +362,7 @@ def _apply_order_by( self, stmt: T, order_by: str | None = None, - sort_order: str = "asc", + sort_order: str = "desc", ) -> T: default_order_by = self.table.c.created_at default_order_by = ( @@ -437,9 +435,9 @@ def update( stmt = self.table.update().where(self._get_field_filter("id", obj.id)) stmt = self._apply_permission_filter( stmt, - credentials, - has_permission=has_permission, + credentials=credentials, permission=ActionPermission.WRITE, + has_permission=has_permission, ) stmt = stmt.values(fields=serialize_json(obj)) From d68e9864d5a8a8e1fb6eb28a31c7326cc533f743 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 27 Aug 2024 15:38:32 +0200 Subject: [PATCH 054/197] fix settingsstash --- .../syft/src/syft/service/settings/settings_service.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py index 4e552ba26cc..00bb25ccd48 100644 --- a/packages/syft/src/syft/service/settings/settings_service.py +++ b/packages/syft/src/syft/service/settings/settings_service.py @@ -109,13 +109,15 @@ def update( def _update( self, context: AuthedServiceContext, settings: ServerSettingsUpdate ) -> ServerSettings: - all_settings = self.stash.get_all(context.credentials).unwrap() + all_settings = self.stash.get_all( + context.credentials, limit=1, sort_order="desc" + ).unwrap() if len(all_settings) > 0: new_settings = all_settings[0].model_copy( update=settings.to_dict(exclude_empty=True) ) update_result = self.stash.update( - context.credentials, settings=new_settings + context.credentials, obj=new_settings ).unwrap() notifier_service = context.server.get_service("notifierservice") @@ -155,7 +157,9 @@ def set_server_side_type_dangerous( public_message=f"Not a valid server_side_type, please use one of the options from: {side_type_options}" ) - current_settings = self.stash.get_all(context.credentials).unwrap() + current_settings = self.stash.get_all( + context.credentials, limit=1, sort_order="desc" + ).unwrap() if len(current_settings) > 0: new_settings = current_settings[0] new_settings.server_side_type = server_side_type From 635afbed0896675631f08d140d7c8f461f75b582 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 27 Aug 2024 15:41:45 +0200 Subject: [PATCH 055/197] fix order --- packages/syft/src/syft/service/notifier/notifier_stash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/notifier/notifier_stash.py b/packages/syft/src/syft/service/notifier/notifier_stash.py index 91271feca48..1afdad6d6f9 100644 --- a/packages/syft/src/syft/service/notifier/notifier_stash.py +++ b/packages/syft/src/syft/service/notifier/notifier_stash.py @@ -31,7 +31,7 @@ class NotifierStash(ObjectStash[NotifierSettings]): def get(self, credentials: SyftVerifyKey) -> NotifierSettings | None: """Get Settings""" # actually get latest settings - result = self.get_all(credentials, limit=1).unwrap() + result = self.get_all(credentials, limit=1, sort_order="desc").unwrap() if len(result) > 0: return result[0] raise NotFoundException( From 4c6459514ebabfb3ecdd41deadbc62a2962e8278 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 27 Aug 2024 17:58:11 +0200 Subject: [PATCH 056/197] fix sorting, add get_index --- packages/syft/src/syft/client/api.py | 2 + packages/syft/src/syft/serde/json_serde.py | 61 +++++++++++++++---- packages/syft/src/syft/server/server.py | 2 - .../syft/src/syft/server/service_registry.py | 5 +- packages/syft/src/syft/service/user/user.py | 1 + .../src/syft/service/user/user_service.py | 21 ++++++- packages/syft/src/syft/store/db/stash.py | 60 +++++++++++++----- packages/syft/src/syft/types/syft_object.py | 2 +- 8 files changed, 120 insertions(+), 34 deletions(-) diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index ff7a43a0c6a..a3816a49b56 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -740,6 +740,8 @@ def __getattr__(self, name: str) -> Any: ) def __getitem__(self, key: str | int) -> Any: + if hasattr(self, "get_index"): + return self.get_index(key) if hasattr(self, "get_all"): return self.get_all()[key] raise NotImplementedError diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index 65793c52e09..3969c85c19e 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from enum import Enum import json +import typing from typing import Any from typing import Generic from typing import TypeVar @@ -21,6 +22,7 @@ from ..server.credentials import SyftSigningKey from ..server.credentials import SyftVerifyKey from ..types.datetime import DateTime +from ..types.errors import SyftException from ..types.syft_object import BaseDateTime from ..types.syft_object_registry import SyftObjectRegistry from ..types.uid import LineageID @@ -37,6 +39,10 @@ Json = str | int | float | bool | None | list["Json"] | dict[str, "Json"] +class JSONSerdeError(SyftException): + pass + + @dataclass class JSONSerde(Generic[T]): # TODO add json schema @@ -44,6 +50,10 @@ class JSONSerde(Generic[T]): serialize_fn: Callable[[T], Json] | None = None deserialize_fn: Callable[[Json], T] | None = None + def _check_type(self, obj: Any) -> None: + if not isinstance(obj, self.klass): + raise JSONSerdeError(f"Expected {self.klass}, got {type(obj)}") + def serialize(self, obj: T) -> Json: if self.serialize_fn is None: return obj # type: ignore @@ -53,7 +63,8 @@ def serialize(self, obj: T) -> Json: def deserialize(self, obj: Json) -> T: if self.deserialize_fn is None: return obj # type: ignore - return self.deserialize_fn(obj) + else: + return self.deserialize_fn(obj) # type: ignore JSON_SERDE_REGISTRY: dict[type[T], JSONSerde[T]] = {} @@ -101,11 +112,18 @@ def _validate_json(value: T) -> T: return value -def _is_optional_annotation(annotation: Any) -> Any: - return annotation | None == annotation +def _is_optional_annotation(annotation: Any) -> bool: + try: + return annotation | None == annotation + except TypeError: + return False + + +def _is_annotated_type(annotation: Any) -> bool: + return get_origin(annotation) == typing.Annotated -def _get_nonoptional_annotation(annotation: Any) -> Any: +def _unwrap_optional_annotation(annotation: Any) -> Any: """Return the type anntation with None type removed, if it is present. Args: @@ -120,6 +138,24 @@ def _get_nonoptional_annotation(annotation: Any) -> Any: return annotation +def _unwrap_annotated(annotation: Any) -> Any: + # Convert Annotated[T, ...] to T + return get_args(annotation)[0] + + +def _unwrap_type_annotation(annotation: Any) -> Any: + """ + recursively unwrap type annotations, removing Annotated and Optional types + """ + if _is_annotated_type(annotation): + res = _unwrap_annotated(annotation) + return _unwrap_type_annotation(res) + elif _is_optional_annotation(annotation): + res = _unwrap_optional_annotation(annotation) + return _unwrap_type_annotation(res) + return annotation + + def _annotation_issubclass(annotation: Any, cls: type) -> bool: # issubclass throws TypeError if annotation is not a valid type (eg Union) try: @@ -235,7 +271,7 @@ def _is_serializable_iterable(annotation: Any) -> bool: if len(args) != 1: return False - inner_type = _get_nonoptional_annotation(args[0]) + inner_type = _unwrap_type_annotation(args[0]) return inner_type in JSON_SERDE_REGISTRY or _annotation_issubclass( inner_type, pydantic.BaseModel ) @@ -250,12 +286,12 @@ def _deserialize_iterable_from_json(value: Json, annotation: Any) -> Any: if not isinstance(value, list): raise ValueError(f"Cannot deserialize {type(value)} to {annotation}") - annotation = _get_nonoptional_annotation(annotation) + annotation = _unwrap_type_annotation(annotation) if not _is_serializable_iterable(annotation): raise ValueError(f"Cannot deserialize {annotation} from JSON") - inner_type = _get_nonoptional_annotation(get_args(annotation)[0]) + inner_type = _unwrap_type_annotation(get_args(annotation)[0]) return [deserialize_json(v, inner_type) for v in value] @@ -279,7 +315,7 @@ def _is_serializable_mapping(annotation: Any) -> bool: return False # check if value type is serializable - value_type = _get_nonoptional_annotation(value_type) + value_type = _unwrap_type_annotation(value_type) return value_type in JSON_SERDE_REGISTRY or _annotation_issubclass( value_type, pydantic.BaseModel ) @@ -295,7 +331,7 @@ def _deserialize_mapping_from_json(value: Json, annotation: Any) -> Any: if not isinstance(value, dict): raise ValueError(f"Cannot deserialize {type(value)} to {annotation}") - annotation = _get_nonoptional_annotation(annotation) + annotation = _unwrap_type_annotation(annotation) if not _is_serializable_mapping(annotation): raise ValueError(f"Cannot deserialize {annotation} from JSON") @@ -347,7 +383,7 @@ def serialize_json(value: Any, annotation: Any = None, validate: bool = True) -> return None # Remove None type from annotation if it is present. - annotation = _get_nonoptional_annotation(annotation) + annotation = _unwrap_type_annotation(annotation) if annotation in JSON_SERDE_REGISTRY: result = JSON_SERDE_REGISTRY[annotation].serialize(value) @@ -394,7 +430,10 @@ def deserialize_json(value: Json, annotation: Any = None) -> Any: return None # Remove None type from annotation if it is present. - annotation = _get_nonoptional_annotation(annotation) + if annotation is None: + raise ValueError("Annotation is required for deserialization") + + annotation = _unwrap_type_annotation(annotation) if annotation in JSON_SERDE_REGISTRY: return JSON_SERDE_REGISTRY[annotation].deserialize(value) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index f37fbffacc2..93e23a7f23f 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -1710,8 +1710,6 @@ def create_admin_new( ignore_duplicates=False, ).unwrap() - print(f"Created admin {new_user}") - logger.debug(f"Created admin {new_user.email}") return new_user diff --git a/packages/syft/src/syft/server/service_registry.py b/packages/syft/src/syft/server/service_registry.py index 8a2df655d92..ad2b313e94e 100644 --- a/packages/syft/src/syft/server/service_registry.py +++ b/packages/syft/src/syft/server/service_registry.py @@ -9,7 +9,6 @@ # relative from ..serde.serializable import serializable from ..service.action.action_service import ActionService -from ..service.action.action_store import ActionObjectStash from ..service.api.api_service import APIService from ..service.attestation.attestation_service import AttestationService from ..service.blob_storage.service import BlobStorageService @@ -124,14 +123,12 @@ def _construct_services(cls, server: "Server") -> dict[str, AbstractService]: # Use new DB if cls._uses_new_store(service_cls): - print("Using new store:", service_cls) svc_kwargs["store"] = server.db # Use old DB - elif issubclass(service_cls.store_type, ActionObjectStash): - svc_kwargs["store"] = server.action_store # type: ignore else: svc_kwargs["store"] = server.document_store + print("Using old store:", service_cls) service = service_cls(**svc_kwargs) service_dict[field_name] = service diff --git a/packages/syft/src/syft/service/user/user.py b/packages/syft/src/syft/service/user/user.py index 0fb8a87e7fd..6c24554478f 100644 --- a/packages/syft/src/syft/service/user/user.py +++ b/packages/syft/src/syft/service/user/user.py @@ -76,6 +76,7 @@ class User(SyftObject): # version __canonical_name__ = "User" __version__ = SYFT_OBJECT_VERSION_2 + __order_by__ = ("email", "asc") id: UID | None = None # type: ignore[assignment] diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index 357697fd19e..2145b520759 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -321,18 +321,37 @@ def view(self, context: AuthedServiceContext, uid: UID) -> UserView: def get_all( self, context: AuthedServiceContext, + order_by: str | None = None, + sort_order: str | None = None, page_size: int | None = 0, page_index: int | None = 0, ) -> list[UserView]: if context.role in [ServiceRole.DATA_OWNER, ServiceRole.ADMIN]: users = self.stash.get_all( - context.credentials, has_permission=True + context.credentials, + has_permission=True, + order_by=order_by, + sort_order=sort_order, ).unwrap() else: users = self.stash.get_all(context.credentials).unwrap() users = [user.to(UserView) for user in users] return _paginate(users, page_size, page_index) + @service_method( + path="user.get_index", name="get_index", roles=DATA_OWNER_ROLE_LEVEL + ) + def get_index( + self, + context: AuthedServiceContext, + index: int, + ) -> UserView: + return ( + self.stash.get_index(credentials=context.credentials, index=index) + .unwrap() + .to(UserView) + ) + def signing_key_for_verify_key(self, verify_key: SyftVerifyKey) -> UserPrivateKey: user = self.stash.get_by_verify_key( credentials=self.stash.admin_verify_key(), verify_key=verify_key diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index b3a984db336..1f44b187da9 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -45,7 +45,7 @@ class ObjectStash(Generic[SyftT]): table: Table - object_type: type + object_type: type[SyftObject] def __init__(self, store: DBManager) -> None: self.db = store @@ -177,7 +177,7 @@ def _get_by_fields( fields: dict[str, str], table: Table | None = None, order_by: str | None = None, - sort_order: str = "asc", + sort_order: str | None = None, limit: int | None = None, offset: int | None = None, has_permission: bool = False, @@ -235,7 +235,7 @@ def get_all_by_fields( credentials: SyftVerifyKey, fields: dict[str, str], order_by: str | None = None, - sort_order: str = "asc", + sort_order: str | None = None, limit: int | None = None, offset: int | None = None, has_permission: bool = False, @@ -259,7 +259,7 @@ def get_all_by_field( field_name: str, field_value: str, order_by: str | None = None, - sort_order: str = "asc", + sort_order: str | None = None, limit: int | None = None, offset: int | None = None, has_permission: bool = False, @@ -281,7 +281,7 @@ def get_all_contains( field_name: str, field_value: str, order_by: str | None = None, - sort_order: str = "asc", + sort_order: str | None = None, limit: int | None = None, offset: int | None = None, has_permission: bool = False, @@ -300,6 +300,21 @@ def get_all_contains( result = self.session.execute(stmt).all() return [self.row_as_obj(row) for row in result] + @as_result(SyftException, StashException, NotFoundException) + def get_index( + self, credentials: SyftVerifyKey, index: int, has_permission: bool = False + ) -> SyftT: + items = self.get_all( + credentials, + has_permission=has_permission, + limit=1, + offset=index, + ).unwrap() + + if len(items) == 0: + raise NotFoundException(f"No item found at index {index}") + return items[0] + def row_as_obj(self, row: Row) -> SyftT: # TODO make unwrappable serde return deserialize_json(row.fields) @@ -358,22 +373,37 @@ def _apply_limit_offset( stmt = stmt.limit(limit) return stmt + def _get_order_by_col(self, order_by: str, sort_order: str | None = None) -> Column: + # TODO connect+rename created_date to created_at + if sort_order is None: + sort_order = "asc" + + if order_by == "id": + col = self.table.c.id + if order_by == "created_date" or order_by == "created_at": + col = self.table.c.created_at + else: + col = self.table.c.fields[order_by] + + return col.desc() if sort_order.lower() == "desc" else col.asc() + def _apply_order_by( self, stmt: T, order_by: str | None = None, - sort_order: str = "desc", + sort_order: str | None = None, ) -> T: - default_order_by = self.table.c.created_at - default_order_by = ( - default_order_by.desc() if sort_order == "desc" else default_order_by - ) if order_by is None: - return stmt.order_by(default_order_by) + order_by, default_sort_order = self.object_type.__order_by__ + sort_order = sort_order or default_sort_order + + order_by_col = self._get_order_by_col(order_by, sort_order) + + if order_by == "id": + return stmt.order_by(order_by_col) else: - order_by_col = self.table.c.fields[order_by] - order_by = order_by_col.desc() if sort_order == "desc" else order_by_col - return stmt.order_by(order_by, default_order_by) + secondary_order_by = self._get_order_by_col("id", sort_order) + return stmt.order_by(order_by_col, secondary_order_by) def _apply_permission_filter( self, @@ -395,7 +425,7 @@ def get_all( credentials: SyftVerifyKey, has_permission: bool = False, order_by: str | None = None, - sort_order: str = "asc", + sort_order: str | None = None, limit: int | None = None, offset: int | None = None, ) -> list[SyftT]: diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index 09b8139b9bd..516de5974c7 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -395,7 +395,6 @@ class SyftObject(SyftObjectVersioned): # all objects have a UID id: UID - created_date: BaseDateTime | None = None updated_date: BaseDateTime | None = None deleted_date: BaseDateTime | None = None @@ -410,6 +409,7 @@ def make_id(cls, values: Any) -> Any: values["id"] = id_field.annotation() return values + __order_by__: ClassVar[tuple[str, str]] = ("created_date", "desc") __attr_searchable__: ClassVar[ list[str] ] = [] # keys which can be searched in the ORM From 2541fb8f4f255d9969eb91d818a16d1f47081d2e Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 28 Aug 2024 09:09:04 +0200 Subject: [PATCH 057/197] fix bugs --- packages/syft/src/syft/server/server.py | 2 +- .../src/syft/service/action/action_service.py | 5 +++- .../syft/service/code/user_code_service.py | 24 +++++++------------ .../syft/src/syft/service/dataset/dataset.py | 5 ++-- 4 files changed, 17 insertions(+), 19 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index f37fbffacc2..69820caefba 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -876,7 +876,7 @@ def init_stores( # TODO fix database filename + reset json_db_config = SQLiteDBConfig( - filename=f"{self.id}_{UID().hex}_json.db", + filename=f"{self.id}_json.db", path=self.get_temp_dir("db"), ) self.db = SQLiteDBManager( diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 5436d8c5712..434245503c3 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -46,6 +46,7 @@ from .action_permissions import ActionObjectPermission from .action_permissions import ActionObjectREAD from .action_permissions import ActionPermission +from .action_permissions import StoragePermission from .action_store import ActionObjectStash from .action_types import action_type_for_type from .numpy import NumpyArrayObject @@ -339,7 +340,9 @@ def get_mock(self, context: AuthedServiceContext, uid: UID) -> SyftObject: roles=GUEST_ROLE_LEVEL, ) def has_storage_permission(self, context: AuthedServiceContext, uid: UID) -> bool: - return self.store.has_storage_permission(uid) + return self.store.has_storage_permission( + StoragePermission(uid=uid, server_uid=context.server.id) + ) def has_read_permission(self, context: AuthedServiceContext, uid: UID) -> bool: return self.store.has_permissions( diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index f6182339e33..c61dc2065ad 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -86,6 +86,7 @@ def submit( message="User Code Submitted", require_api_update=True, value=user_code ) + @as_result(SyftException) def _submit( self, context: AuthedServiceContext, @@ -111,14 +112,13 @@ def _submit( context.credentials, code_hash=get_code_hash(submit_code.code, context.credentials), ).unwrap() - - if not exists_ok: - raise SyftException( - public_message="The code to be submitted already exists" - ) - return existing_code except NotFoundException: - pass + existing_code = None + + if not exists_ok and existing_code is not None: + raise SyftException( + public_message="UserCode with this code already exists", + ) code = submit_code.to(UserCode, context=context) @@ -286,13 +286,7 @@ def _get_or_submit_user_code( - If the code is a SubmitUserCode and the code hash does not exist, submit the code """ if isinstance(code, UserCode): - # Get existing UserCode - try: - return self.stash.get_by_uid(context.credentials, code.id).unwrap() - except NotFoundException as exc: - raise NotFoundException.from_exception( - exc, public_message=f"UserCode {code.id} not found on this server" - ) + return self.stash.get_by_uid(context.credentials, code.id).unwrap() else: # code: SubmitUserCode # Submit new UserCode, or get existing UserCode with the same code hash # TODO: Why is this tagged as unreachable? @@ -310,7 +304,7 @@ def request_code_execution( reason: str | None = "", ) -> Request: """Request Code execution on user code""" - user_code = self._get_or_submit_user_code(context, code).unwrap() + user_code = self._submit(context, code, exists_ok=False).unwrap() result = self._request_code_execution( context, diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index fb2d9cd6b00..a6712ee7832 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -7,6 +7,7 @@ from typing import Any # third party +from IPython.display import display import markdown import pandas as pd from pydantic import ConfigDict @@ -292,8 +293,8 @@ def _private_data(self) -> Any: def data(self) -> Any: try: return self._private_data().unwrap() - except SyftException as e: - print(e) + except SyftException: + display(SyftError(message="You have no access to the private data")) return None From 366e2672f0ffdac3b1d05450adbb1c64d9e939b7 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 28 Aug 2024 10:52:07 +0200 Subject: [PATCH 058/197] bug fix --- packages/syft/src/syft/service/sync/sync_stash.py | 2 +- packages/syft/src/syft/service/user/user_service.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/service/sync/sync_stash.py b/packages/syft/src/syft/service/sync/sync_stash.py index 5e9d118634c..3a7c6824281 100644 --- a/packages/syft/src/syft/service/sync/sync_stash.py +++ b/packages/syft/src/syft/service/sync/sync_stash.py @@ -41,7 +41,7 @@ def get_latest(self, credentials: SyftVerifyKey) -> SyncState | None: credentials=credentials, sort_order="desc", limit=1, - ) + ).unwrap() if len(states) > 0: return states[0] diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index 2145b520759..cd9d9b8c729 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -367,9 +367,8 @@ def get_role_for_credentials( # they could be different # TODO: This fn is cryptic -- when does each situation occur? if isinstance(credentials, SyftVerifyKey): - user = self.stash.get_by_verify_key( - credentials=credentials, verify_key=credentials - ).unwrap() + role = self.stash.get_role(credentials=credentials) + return role elif isinstance(credentials, SyftSigningKey): user = self.stash.get_by_signing_key( credentials=credentials, From f6d21a6d046640d164aecea770e8583d2e082435 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 28 Aug 2024 13:50:33 +0200 Subject: [PATCH 059/197] add basic postgres support --- .../syft/service/action/action_permissions.py | 17 +++ packages/syft/src/syft/store/db/sqlite_db.py | 42 ++++++-- packages/syft/src/syft/store/db/stash.py | 101 ++++++++++++------ 3 files changed, 117 insertions(+), 43 deletions(-) diff --git a/packages/syft/src/syft/service/action/action_permissions.py b/packages/syft/src/syft/service/action/action_permissions.py index 8177222bd1a..477b3b02bee 100644 --- a/packages/syft/src/syft/service/action/action_permissions.py +++ b/packages/syft/src/syft/service/action/action_permissions.py @@ -18,6 +18,19 @@ class ActionPermission(Enum): EXECUTE = 64 ALL_EXECUTE = 128 + @property + def as_compound(self) -> "ActionPermission": + if self in COMPOUND_ACTION_PERMISSION: + return self + elif self == ActionPermission.READ: + return ActionPermission.ALL_READ + elif self == ActionPermission.WRITE: + return ActionPermission.ALL_WRITE + elif self == ActionPermission.EXECUTE: + return ActionPermission.ALL_EXECUTE + else: + raise Exception(f"Invalid compound permission {self}") + COMPOUND_ACTION_PERMISSION = { ActionPermission.ALL_READ, @@ -64,6 +77,10 @@ def permission_string(self) -> str: return f"{self.credentials.verify}_{self.permission.name}" return f"{self.permission.name}" + @property + def compound_permission_string(self) -> str: + return self.permission.as_compound.name + def _coll_repr_(self) -> dict[str, Any]: return { "uid": str(self.uid), diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index 64a5259582a..1f4e6d4b56d 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -20,13 +20,34 @@ class DBConfig(BaseModel): - pass + reset: bool = False + + @property + def connection_string(self) -> str: + raise NotImplementedError("Subclasses must implement this method.") class SQLiteDBConfig(DBConfig): filename: str = "jsondb.sqlite" path: Path = Field(default_factory=tempfile.gettempdir) + @property + def connection_string(self) -> str: + filepath = self.path / self.filename + return f"sqlite:///{filepath.resolve()}" + + +class PostgresDBConfig(DBConfig): + host: str = "localhost" + port: int = 5432 + user: str = "postgres" + password: str = "postgres" + database: str = "postgres" + + @property + def connection_string(self) -> str: + return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" + class DBManager: def __init__( @@ -50,13 +71,10 @@ def __init__( self.config = config self.root_verify_key = root_verify_key self.server_uid = server_uid - - self.filepath = config.path / config.filename - self.path = f"sqlite:///{self.filepath.resolve()}" self.engine = create_engine( - self.path, json_serializer=dumps, json_deserializer=loads + config.connection_string, json_serializer=dumps, json_deserializer=loads ) - print(f"Connecting to {self.path}") + print(f"Connecting to {config.connection_string}") self.Session = sessionmaker(bind=self.engine) # TODO use AuthedServiceContext for session management instead of threading.local @@ -67,12 +85,16 @@ def __init__( def update_settings(self) -> None: connection = self.engine.connect() - connection.execute(sa.text("PRAGMA journal_mode = WAL")) - connection.execute(sa.text("PRAGMA busy_timeout = 5000")) - connection.execute(sa.text("PRAGMA temp_store = 2")) - connection.execute(sa.text("PRAGMA synchronous = 1")) + if self.engine.dialect.name == "sqlite": + connection.execute(sa.text("PRAGMA journal_mode = WAL")) + connection.execute(sa.text("PRAGMA busy_timeout = 5000")) + connection.execute(sa.text("PRAGMA temp_store = 2")) + connection.execute(sa.text("PRAGMA synchronous = 1")) def init_tables(self) -> None: + if self.config.reset: + # drop all tables that we know about + Base.metadata.drop_all(bind=self.engine) Base.metadata.create_all(self.engine) # TODO remove diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 1f44b187da9..de6d84a1985 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -1,6 +1,7 @@ # stdlib # stdlib +from functools import cache from typing import Any from typing import Generic from typing import get_args @@ -13,6 +14,7 @@ from sqlalchemy import Table from sqlalchemy import func from sqlalchemy import select +from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session from sqlalchemy.types import JSON from typing_extensions import TypeVar @@ -84,14 +86,32 @@ def session(self) -> Session: def _create_table(self) -> Table: # need to call Base.metadata.create_all(engine) to create the table table_name = self.object_type.__canonical_name__ + + fields_type = ( + JSON if self.db.engine.dialect.name == "sqlite" else postgresql.JSONB + ) + permissons_type = ( + JSON + if self.db.engine.dialect.name == "sqlite" + else postgresql.ARRAY(sa.String) + ) + storage_permissions_type = ( + JSON + if self.db.engine.dialect.name == "sqlite" + else postgresql.ARRAY(UIDTypeDecorator) + ) if table_name not in Base.metadata.tables: Table( self.object_type.__canonical_name__, Base.metadata, Column("id", UIDTypeDecorator, primary_key=True, default=uuid.uuid4), - Column("fields", JSON, default={}), - Column("permissions", JSON, default=[]), - Column("storage_permissions", JSON, default=[]), + Column("fields", fields_type, default={}), + Column("permissions", permissons_type, default=[]), + Column( + "storage_permissions", + storage_permissions_type, + default=[], + ), # TODO rename and use on SyftObject fields Column( "created_at", sa.DateTime, server_default=sa.func.now(), index=True @@ -101,6 +121,13 @@ def _create_table(self) -> Table: ) return Base.metadata.tables[table_name] + def _drop_table(self) -> None: + table_name = self.object_type.__canonical_name__ + if table_name in Base.metadata.tables: + Base.metadata.tables[table_name].drop(self.db.engine) + else: + raise StashException(f"Table {table_name} does not exist") + def _print_query(self, stmt: sa.sql.select) -> None: print( stmt.compile( @@ -154,6 +181,7 @@ def get_by_uid( stmt = self._apply_permission_filter( stmt, credentials=credentials, has_permission=has_permission ) + result = self.session.execute(stmt).first() if result is None: @@ -169,7 +197,11 @@ def _get_field_filter( table = table if table is not None else self.table if field_name == "id": return table.c.id == field_value - return table.c.fields[field_name] == func.json_quote(field_value) + + if self.db.engine.dialect.name == "sqlite": + return table.c.fields[field_name] == func.json_quote(field_value) + elif self.db.engine.dialect.name == "postgresql": + return sa.cast(table.c.fields[field_name], sa.String) == field_value def _get_by_fields( self, @@ -319,6 +351,8 @@ def row_as_obj(self, row: Row) -> SyftT: # TODO make unwrappable serde return deserialize_json(row.fields) + # TODO add cache invalidation, ignore B019 + @cache # noqa: B019 def get_role(self, credentials: SyftVerifyKey) -> ServiceRole: # TODO error handling user_table = Table("User", Base.metadata) @@ -330,32 +364,16 @@ def get_role(self, credentials: SyftVerifyKey) -> ServiceRole: return ServiceRole.GUEST return ServiceRole[role] - def _get_permission_filter( + def _get_permission_filter_from_permisson( self, - credentials: SyftVerifyKey, - permission: ActionPermission = ActionPermission.READ, + permission: ActionObjectPermission, ) -> sa.sql.elements.BinaryExpression: - role = self.get_role(credentials) - if role in (ServiceRole.ADMIN, ServiceRole.DATA_OWNER): - return sa.literal(True) - - permission_string = ActionObjectPermission( - uid=UID(), # dummy uid, we just need the permission string - credentials=credentials, - permission=permission, - ).permission_string - - compound_permission_map = { - ActionPermission.READ: ActionPermission.ALL_READ, - ActionPermission.WRITE: ActionPermission.ALL_WRITE, - ActionPermission.EXECUTE: ActionPermission.ALL_EXECUTE, - } - compound_permission_string = ActionObjectPermission( - uid=UID(), # dummy uid, we just need the permission string - credentials=None, # no credentials for compound permissions - permission=compound_permission_map[permission], - ).permission_string + permission_string = permission.permission_string + compound_permission_string = permission.compound_permission_string + if self.session.bind.dialect.name == "postgresql": + permission_string = [permission_string] # type: ignore + compound_permission_string = [compound_permission_string] # type: ignore return sa.or_( self.table.c.permissions.contains(permission_string), self.table.c.permissions.contains(compound_permission_string), @@ -413,10 +431,26 @@ def _apply_permission_filter( permission: ActionPermission = ActionPermission.READ, has_permission: bool = False, ) -> T: - if not has_permission: - stmt = stmt.where( - self._get_permission_filter(credentials, permission=permission) + if has_permission: + # ignoring permissions + return stmt + + role = self.get_role(credentials) + if role in (ServiceRole.ADMIN, ServiceRole.DATA_OWNER): + # admins and data owners have all permissions + return stmt + + action_object_permission = ActionObjectPermission( + uid=UID(), # dummy uid, we just need the permission string + credentials=credentials, + permission=permission, + ) + + stmt = stmt.where( + self._get_permission_filter_from_permisson( + permission=action_object_permission ) + ) return stmt @as_result(StashException) @@ -602,15 +636,16 @@ def has_storage_permissions(self, permissions: list[StoragePermission]) -> bool: ) ) result = self.session.execute(stmt).first() - return result + return result is not None def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: # TODO: we should use a permissions table to check all permissions at once # TODO: should check for compound permissions + permission_filters = [ sa.and_( self._get_field_filter("id", p.uid), - self.table.c.permissions.contains(p.permission_string), + self._get_permission_filter_from_permisson(permission=p), ) for p in permissions ] @@ -618,7 +653,7 @@ def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: stmt = self.table.select().where( sa.and_( *permission_filters, - ) + ), ) result = self.session.execute(stmt).first() return result is not None From eccdb6bbb2380a60f91fc728f60ee44629147111 Mon Sep 17 00:00:00 2001 From: Madhava Jay Date: Wed, 28 Aug 2024 21:59:35 +1000 Subject: [PATCH 060/197] Made some progress on postgreql --- notebooks/api/0.8/00-load-data.ipynb | 90 ++++++++++++++++++- packages/syft/setup.cfg | 2 +- packages/syft/src/syft/orchestra.py | 4 + .../src/syft/protocol/protocol_version.json | 7 ++ packages/syft/src/syft/server/server.py | 44 +++++++-- packages/syft/src/syft/server/uvicorn.py | 5 ++ .../syft/store/postgresql_document_store.py | 69 +++++++++++--- .../src/syft/store/sqlite_document_store.py | 41 +++++---- 8 files changed, 218 insertions(+), 44 deletions(-) diff --git a/notebooks/api/0.8/00-load-data.ipynb b/notebooks/api/0.8/00-load-data.ipynb index 8c3bb05b93b..65aca80f4f5 100644 --- a/notebooks/api/0.8/00-load-data.ipynb +++ b/notebooks/api/0.8/00-load-data.ipynb @@ -1,5 +1,72 @@ { "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# !pip install psycopg[binary]==3.1.19" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# !docker run --name postgres-latest -e POSTGRES_USER=admin \\\n", + "# -e POSTGRES_PASSWORD=adminpassword -p 5432:5432 -d postgres:latest" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# import os\n", + "# os.environ[\"POSTGRESQL_DBNAME\"] = \"postgres\"\n", + "# os.environ[\"POSTGRESQL_HOST\"] = \"localhost\"\n", + "# os.environ[\"POSTGRESQL_PORT\"] = \"5432\"\n", + "# os.environ[\"POSTGRESQL_USERNAME\"] = \"admin\"\n", + "# os.environ[\"POSTGRESQL_PASSWORD\"] = \"adminpassword\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# import psycopg\n", + "# connection = psycopg.connect(\n", + "# dbname=\"postgres\",\n", + "# user=\"admin\",\n", + "# password=\"adminpassword\",\n", + "# host=\"localhost\",\n", + "# port=\"5432\",\n", + "# )\n", + "# cursor = connection.cursor()\n", + "# sql = \"select uid from User_unique_keys where uid = %s\"\n", + "# args = ['email']\n", + "# res = cursor.execute(sql, args)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + }, { "cell_type": "markdown", "metadata": {}, @@ -67,7 +134,21 @@ "outputs": [], "source": [ "# Launch a fresh datasite server named \"test-datasite-1\" in dev mode on the local machine\n", - "server = sy.orchestra.launch(name=\"test-datasite-1\", dev_mode=True, reset=True)" + "store_client_config = {\n", + " \"TYPE\": \"PostgreSQLStoreConfig\",\n", + " \"POSTGRESQL_DBNAME\": \"postgres\",\n", + " \"POSTGRESQL_HOST\": \"localhost\",\n", + " \"POSTGRESQL_PORT\": \"5432\",\n", + " \"POSTGRESQL_USERNAME\": \"admin\",\n", + " \"POSTGRESQL_PASSWORD\": \"adminpassword\",\n", + "}\n", + "\n", + "server = sy.orchestra.launch(\n", + " name=\"test-datasite-1\",\n", + " dev_mode=True,\n", + " reset=True,\n", + " store_client_config=store_client_config,\n", + ")" ] }, { @@ -711,6 +792,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -721,7 +807,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.12.2" }, "toc": { "base_numbering": 1, diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index a19a1566259..ed207617a62 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -66,7 +66,7 @@ syft = jinja2==3.1.4 tenacity==8.3.0 nh3==0.2.17 - psycopg2-binary==2.9.9 + psycopg[binary]==3.1.19 install_requires = %(syft)s diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 54bf299ec07..14ce315e1f5 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -186,6 +186,7 @@ def deploy_to_python( background_tasks: bool = False, debug: bool = False, migrate: bool = False, + store_client_config: dict | None = None, ) -> ServerHandle: worker_classes = { ServerType.DATASITE: Datasite, @@ -215,6 +216,7 @@ def deploy_to_python( "background_tasks": background_tasks, "debug": debug, "migrate": migrate, + "store_client_config": store_client_config, } if port: @@ -325,6 +327,7 @@ def launch( background_tasks: bool = False, debug: bool = False, migrate: bool = False, + store_client_config: dict | None = None, ) -> ServerHandle: if dev_mode is True: thread_workers = True @@ -363,6 +366,7 @@ def launch( background_tasks=background_tasks, debug=debug, migrate=migrate, + store_client_config=store_client_config, ) display( SyftInfo( diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 047d8571c04..68ac8d861ba 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -1157,6 +1157,13 @@ "hash": "8d2dfafa01ec909c080a790cf15a8fc78e00382d3bfe6207098ceb25a60b9c53", "action": "add" } + }, + "PostgreSQLStorePartition": { + "1": { + "version": 1, + "hash": "1a807dcf54f969c53e6f46d62443d4dd83e5f6ff47fb4e9f6381c3374601c818", + "action": "add" + } } } } diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index e48c35affe8..ac9b01453e5 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -296,6 +296,23 @@ def auth_context_for_user( return cls.__server_context_registry__.get(key) +def get_external_storage_config( + store_client_config: dict | None = None, +) -> PostgreSQLStoreConfig | None: + if not store_client_config: + store_client_config_json = os.environ.get("SYFT_STORE_CLIENT_CONFIG", "{}") + store_client_config = json.loads(store_client_config_json) + + if ( + store_client_config + and "TYPE" in store_client_config + and store_client_config["TYPE"] == "PostgreSQLStoreConfig" + ): + return PostgreSQLStoreConfig.from_dict(store_client_config) + + return None + + @instrument class Server(AbstractServer): signing_key: SyftSigningKey | None @@ -336,6 +353,7 @@ def __init__( smtp_host: str | None = None, association_request_auto_approval: bool = False, background_tasks: bool = False, + store_client_config: dict | None = None, ): # 🟡 TODO 22: change our ENV variable format and default init args to make this # less horrible or add some convenience functions @@ -383,15 +401,21 @@ def __init__( if reset: self.remove_temp_dir() - use_sqlite = local_db or (processes > 0 and not is_subprocess) - document_store_config = document_store_config or self.get_default_store( - use_sqlite=use_sqlite, - store_type="Document Store", - ) - action_store_config = action_store_config or self.get_default_store( - use_sqlite=use_sqlite, - store_type="Action Store", - ) + # get from python constructors or env variables + external_config = get_external_storage_config(store_client_config) + if external_config: + document_store_config = external_config + action_store_config = external_config + else: + use_sqlite = local_db or (processes > 0 and not is_subprocess) + document_store_config = document_store_config or self.get_default_store( + use_sqlite=use_sqlite, + store_type="Document Store", + ) + action_store_config = action_store_config or self.get_default_store( + use_sqlite=use_sqlite, + store_type="Action Store", + ) self.init_stores( action_store_config=action_store_config, document_store_config=document_store_config, @@ -683,6 +707,7 @@ def named( in_memory_workers: bool = True, association_request_auto_approval: bool = False, background_tasks: bool = False, + store_client_config: dict | None = None, ) -> Server: uid = get_named_server_uid(name) name_hash = hashlib.sha256(name.encode("utf8")).digest() @@ -712,6 +737,7 @@ def named( reset=reset, association_request_auto_approval=association_request_auto_approval, background_tasks=background_tasks, + store_client_config=store_client_config, ) def is_root(self, credentials: SyftVerifyKey) -> bool: diff --git a/packages/syft/src/syft/server/uvicorn.py b/packages/syft/src/syft/server/uvicorn.py index 3549d7a5987..982e69f0c1e 100644 --- a/packages/syft/src/syft/server/uvicorn.py +++ b/packages/syft/src/syft/server/uvicorn.py @@ -1,6 +1,7 @@ # stdlib from collections.abc import Callable from contextlib import asynccontextmanager +import json import logging import multiprocessing import multiprocessing.synchronize @@ -170,6 +171,8 @@ def run_uvicorn( env_prefix = AppSettings.model_config.get("env_prefix", "") for key, value in kwargs.items(): key_with_prefix = f"{env_prefix}{key.upper()}" + if isinstance(value, dict): + value = json.dumps(value) os.environ[key_with_prefix] = str(value) # The `serve_server` function calls `run_uvicorn` in a separate process using `multiprocessing.Process`. @@ -213,6 +216,7 @@ def serve_server( association_request_auto_approval: bool = False, background_tasks: bool = False, debug: bool = False, + store_client_config: dict | None = None, ) -> tuple[Callable, Callable]: starting_uvicorn_event = multiprocessing.Event() @@ -240,6 +244,7 @@ def serve_server( "background_tasks": background_tasks, "debug": debug, "starting_uvicorn_event": starting_uvicorn_event, + "store_client_config": store_client_config, }, ) diff --git a/packages/syft/src/syft/store/postgresql_document_store.py b/packages/syft/src/syft/store/postgresql_document_store.py index c6e02e0d1dc..f8de84876b6 100644 --- a/packages/syft/src/syft/store/postgresql_document_store.py +++ b/packages/syft/src/syft/store/postgresql_document_store.py @@ -1,16 +1,21 @@ # stdlib from collections import defaultdict import logging +from typing import Any +from typing import Self # third party -import psycopg2 -from psycopg2.extensions import connection -from psycopg2.extensions import cursor +import psycopg +from psycopg import Connection +from psycopg import Cursor +from psycopg.errors import DuplicateTable +from psycopg.errors import InFailedSqlTransaction from pydantic import Field # relative from ..serde.serializable import serializable from ..types.errors import SyftException +from ..types.result import as_result from .document_store import DocumentStore from .document_store import PartitionSettings from .document_store import StoreClientConfig @@ -25,8 +30,8 @@ from .sqlite_document_store import special_exception_public_message logger = logging.getLogger(__name__) -_CONNECTION_POOL_DB: dict[str, connection] = {} -_CONNECTION_POOL_CUR: dict[str, cursor] = {} +_CONNECTION_POOL_DB: dict[str, Connection] = {} +_CONNECTION_POOL_CUR: dict[str, Cursor] = {} REF_COUNTS: dict[str, int] = defaultdict(int) @@ -39,6 +44,10 @@ class PostgreSQLStoreClientConfig(StoreClientConfig): host: str port: int + # makes hashabel + class Config: + frozen = True + @serializable(canonical_name="PostgreSQLStorePartition", version=1) class PostgreSQLStorePartition(SQLiteStorePartition): @@ -66,6 +75,7 @@ def __init__( self.index_name = index_name self.settings = settings self.store_config = store_config + self.store_config_hash = hash(store_config.client_config) self._ddtype = ddtype if self.store_config.client_config: self.dbname = self.store_config.client_config.dbname @@ -73,12 +83,13 @@ def __init__( self.lock = SyftLock(NoLockingConfig()) self.create_table() REF_COUNTS[cache_key(self.dbname)] += 1 + self.subs_char = r"%s" # thanks postgresql def _connect(self) -> None: if self.store_config.client_config: - connection = psycopg2.connect( + connection = psycopg.connect( dbname=self.store_config.client_config.dbname, - user=self.store_config.client_config.user, + user=self.store_config.client_config.username, password=self.store_config.client_config.password, host=self.store_config.client_config.host, port=self.store_config.client_config.port, @@ -95,40 +106,58 @@ def create_table(self) -> None: + "sqltime TIMESTAMP NOT NULL DEFAULT NOW())" # nosec ) self.db.commit() + except DuplicateTable: + pass + except InFailedSqlTransaction: + self.db.rollback() except Exception as e: public_message = special_exception_public_message(self.table_name, e) raise SyftException.from_exception(e, public_message=public_message) @property - def db(self) -> connection: + def db(self) -> Connection: if cache_key(self.dbname) not in _CONNECTION_POOL_DB: self._connect() return _CONNECTION_POOL_DB[cache_key(self.dbname)] @property - def cur(self) -> cursor: - if cache_key(self.db_filename) not in _CONNECTION_POOL_CUR: + def cur(self) -> Cursor: + if cache_key(self.store_config_hash) not in _CONNECTION_POOL_CUR: _CONNECTION_POOL_CUR[cache_key(self.dbname)] = self.db.cursor() return _CONNECTION_POOL_CUR[cache_key(self.dbname)] def _close(self) -> None: self._commit() - REF_COUNTS[cache_key(self.db_filename)] -= 1 - if REF_COUNTS[cache_key(self.db_filename)] <= 0: + REF_COUNTS[cache_key(self.store_config_hash)] -= 1 + if REF_COUNTS[cache_key(self.store_config_hash)] <= 0: # once you close it seems like other object references can't re-use the # same connection self.db.close() - db_key = cache_key(self.db_filename) + db_key = cache_key(self.store_config_hash) if db_key in _CONNECTION_POOL_CUR: # NOTE if we don't remove the cursor, the cursor cache_key can clash with a future thread id del _CONNECTION_POOL_CUR[db_key] - del _CONNECTION_POOL_DB[cache_key(self.db_filename)] + del _CONNECTION_POOL_DB[cache_key(self.store_config_hash)] else: # don't close yet because another SQLiteBackingStore is probably still open pass + @as_result(SyftException) + def _execute(self, sql: str, args: list[Any] | None) -> psycopg.Cursor: + with self.lock: + cursor: psycopg.Cursor | None = None + try: + cursor = self.cur.execute(sql, args) + except InFailedSqlTransaction: + self.db.rollback() + except Exception as e: + public_message = special_exception_public_message(self.table_name, e) + raise SyftException.from_exception(e, public_message=public_message) + self.db.commit() # Commit if everything went ok + return cursor + @serializable() class PostgreSQLStoreConfig(StoreConfig): @@ -138,3 +167,15 @@ class PostgreSQLStoreConfig(StoreConfig): store_type: type[DocumentStore] = PostgreSQLDocumentStore backing_store: type[KeyValueBackingStore] = PostgreSQLBackingStore locking_config: LockingConfig = Field(default_factory=NoLockingConfig) + + @classmethod + def from_dict(cls, client_config_dict: dict) -> Self: + postgresql_client_config = PostgreSQLStoreClientConfig( + dbname=client_config_dict["POSTGRESQL_DBNAME"], + host=client_config_dict["POSTGRESQL_HOST"], + port=client_config_dict["POSTGRESQL_PORT"], + username=client_config_dict["POSTGRESQL_USERNAME"], + password=client_config_dict["POSTGRESQL_PASSWORD"], + ) + + return PostgreSQLStoreConfig(client_config=postgresql_client_config) diff --git a/packages/syft/src/syft/store/sqlite_document_store.py b/packages/syft/src/syft/store/sqlite_document_store.py index 2cbef952862..c2a329d0cae 100644 --- a/packages/syft/src/syft/store/sqlite_document_store.py +++ b/packages/syft/src/syft/store/sqlite_document_store.py @@ -114,6 +114,7 @@ def __init__( self.lock = SyftLock(NoLockingConfig()) self.create_table() REF_COUNTS[cache_key(self.db_filename)] += 1 + self.subs_char = r"?" @property def table_name(self) -> str: @@ -134,7 +135,7 @@ def _connect(self) -> None: connection = sqlite3.connect( self.file_path, timeout=self.store_config.client_config.timeout, - check_same_thread=False, # do we need this if we use the lock? + check_same_thread=False, # do we need this if we use the lock # check_same_thread=self.store_config.client_config.check_same_thread, ) # Set journal mode to WAL. @@ -192,17 +193,17 @@ def _commit(self) -> None: self.db.commit() @as_result(SyftException) - def _execute(self, sql: str, *args: list[Any] | None) -> sqlite3.Cursor: + def _execute(self, sql: str, args: list[Any] | None) -> sqlite3.Cursor: with self.lock: cursor: sqlite3.Cursor | None = None # err = None try: - cursor = self.cur.execute(sql, *args) + cursor = self.cur.execute(sql, args) except Exception as e: public_message = special_exception_public_message(self.table_name, e) raise SyftException.from_exception(e, public_message=public_message) - # TODO: Which exception is safe to rollback on? + # TODO: Which exception is safe to rollback on # we should map out some more clear exceptions that can be returned # rather than halting the program like disk I/O error etc # self.db.rollback() # Roll back all changes if an exception occurs. @@ -215,22 +216,28 @@ def _set(self, key: UID, value: Any) -> None: self._update(key, value) else: insert_sql = ( - f"insert into {self.table_name} (uid, repr, value) VALUES (?, ?, ?)" # nosec - ) + f"insert into {self.table_name} (uid, repr, value) VALUES " + f"({self.subs_char}, {self.subs_char}, {self.subs_char})" + ) # nosec data = _serialize(value, to_bytes=True) self._execute(insert_sql, [str(key), _repr_debug_(value), data]).unwrap() def _update(self, key: UID, value: Any) -> None: insert_sql = ( - f"update {self.table_name} set uid = ?, repr = ?, value = ? where uid = ?" # nosec - ) + f"update {self.table_name} set uid = {self.subs_char}, " + f"repr = {self.subs_char}, value = {self.subs_char} " + f"where uid = {self.subs_char}" + ) # nosec data = _serialize(value, to_bytes=True) self._execute( insert_sql, [str(key), _repr_debug_(value), data, str(key)] ).unwrap() def _get(self, key: UID) -> Any: - select_sql = f"select * from {self.table_name} where uid = ? order by sqltime" # nosec + select_sql = ( + f"select * from {self.table_name} where uid = {self.subs_char} " + "order by sqltime" + ) # nosec cursor = self._execute(select_sql, [str(key)]).unwrap( public_message=f"Query {select_sql} failed" ) @@ -241,13 +248,11 @@ def _get(self, key: UID) -> Any: return _deserialize(data, from_bytes=True) def _exists(self, key: UID) -> bool: - select_sql = f"select uid from {self.table_name} where uid = ?" # nosec - + select_sql = f"select uid from {self.table_name} where uid = {self.subs_char}" # nosec res = self._execute(select_sql, [str(key)]) if res.is_err(): return False cursor = res.ok() - row = cursor.fetchone() # type: ignore if row is None: return False @@ -259,7 +264,7 @@ def _get_all(self) -> Any: keys = [] data = [] - res = self._execute(select_sql) + res = self._execute(select_sql, []) if res.is_err(): return {} cursor = res.ok() @@ -276,7 +281,7 @@ def _get_all(self) -> Any: def _get_all_keys(self) -> Any: select_sql = f"select uid from {self.table_name} order by sqltime" # nosec - res = self._execute(select_sql) + res = self._execute(select_sql, []) if res.is_err(): return [] cursor = res.ok() @@ -289,16 +294,16 @@ def _get_all_keys(self) -> Any: return keys def _delete(self, key: UID) -> None: - select_sql = f"delete from {self.table_name} where uid = ?" # nosec + select_sql = f"delete from {self.table_name} where uid = {self.subs_char}" # nosec self._execute(select_sql, [str(key)]).unwrap() def _delete_all(self) -> None: select_sql = f"delete from {self.table_name}" # nosec - self._execute(select_sql).unwrap() + self._execute(select_sql, []).unwrap() def _len(self) -> int: select_sql = f"select count(uid) from {self.table_name}" # nosec - cursor = self._execute(select_sql).unwrap() + cursor = self._execute(select_sql, []).unwrap() cnt = cursor.fetchone()[0] return cnt @@ -369,7 +374,7 @@ class SQLiteStorePartition(KeyValueStorePartition): def close(self) -> None: self.lock.acquire() try: - # I think we don't want these now, because of the REF_COUNT? + # I think we don't want these now, because of the REF_COUNT # self.data._close() # self.unique_keys._close() # self.searchable_keys._close() From 1fc12a20f0a3d7b24d97e63ee58f11566c022bdd Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 28 Aug 2024 15:34:50 +0200 Subject: [PATCH 061/197] fix sync --- .../syft/src/syft/service/job/job_stash.py | 13 ++++++++----- .../src/syft/service/sync/sync_service.py | 19 ++++++++++++++++--- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 8b71c86f734..27a2452a68f 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -644,11 +644,14 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: # if self.user_code_id is not None: dependencies.append(self.user_code_id) - output = context.server.get_service("outputservice").get_by_job_id( # type: ignore - context, self.id - ) - if output is not None: - dependencies.append(output.id) + try: + output = context.server.get_service("outputservice").get_by_job_id( # type: ignore + context, self.id + ) + if output is not None: + dependencies.append(output.id) + except NotFoundException: + pass return dependencies diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index 8ac08512c25..7d9915fea2d 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -3,6 +3,8 @@ import logging from typing import Any +from syft.store.document_store_errors import NotFoundException + # relative from ...client.api import ServerIdentity from ...serde.serializable import serializable @@ -154,7 +156,12 @@ def set_object( stash = self.get_stash_for_item(context, item) creds = context.credentials - exists = stash.get_by_uid(context.credentials, item.id).ok() is not None + try: + obj = stash.get_by_uid(context.credentials, item.id).unwrap() + except (SyftException, KeyError): + obj = None + + exists = obj is not None if isinstance(item, TwinAPIEndpoint): # we need the side effect of set function @@ -300,7 +307,10 @@ def _get_job_batch( job_batch.append(log) output_service = context.server.get_service("outputservice") - output = output_service.get_by_job_id(context, job.id) + try: + output = output_service.get_by_job_id(context, job.id) + except NotFoundException: + output = None if output is not None: job_batch.append(output) @@ -370,7 +380,10 @@ def build_current_state( permissions = {} storage_permissions = {} - previous_state = self.stash.get_latest(context=context).unwrap() + try: + previous_state = self.stash.get_latest(context=context).unwrap() + except NotFoundException: + previous_state = None if previous_state is not None: previous_state_link = LinkedObject.from_obj( From ca9896a21da954a64dd6d015ede760ff17bcd042 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 29 Aug 2024 14:36:34 +0530 Subject: [PATCH 062/197] Add helm templates for postgres - add postgres service to devspace - pass postgres creds to backend Co-authored-by: khoaguin --- packages/grid/backend/backend.dockerfile | 9 +-- packages/grid/backend/grid/core/server.py | 4 ++ packages/grid/devspace.yaml | 14 ++-- packages/grid/helm/examples/dev/base.yaml | 7 ++ .../backend/backend-statefulset.yaml | 20 +++--- .../postgres/postgres-headless-service.yaml | 15 ++++ .../templates/postgres/postgres-secret.yaml | 17 +++++ .../templates/postgres/postgres-service.yaml | 17 +++++ .../postgres/postgres-statefuleset.yaml | 72 +++++++++++++++++++ packages/grid/helm/syft/values.yaml | 32 +++++++++ .../syft/store/postgresql_document_store.py | 9 +++ 11 files changed, 196 insertions(+), 20 deletions(-) create mode 100644 packages/grid/helm/syft/templates/postgres/postgres-headless-service.yaml create mode 100644 packages/grid/helm/syft/templates/postgres/postgres-secret.yaml create mode 100644 packages/grid/helm/syft/templates/postgres/postgres-service.yaml create mode 100644 packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml diff --git a/packages/grid/backend/backend.dockerfile b/packages/grid/backend/backend.dockerfile index 256fdfd447b..0989cc4f4f2 100644 --- a/packages/grid/backend/backend.dockerfile +++ b/packages/grid/backend/backend.dockerfile @@ -83,9 +83,10 @@ ENV \ DEFAULT_ROOT_EMAIL="info@openmined.org" \ DEFAULT_ROOT_PASSWORD="changethis" \ STACK_API_KEY="changeme" \ - MONGO_HOST="localhost" \ - MONGO_PORT="27017" \ - MONGO_USERNAME="root" \ - MONGO_PASSWORD="example" + POSTGRESQL_DBNAME="syftdb_postgres" \ + POSTGRESQL_HOST="localhost" \ + POSTGRESQL_PORT="5432" \ + POSTGRESQL_USERNAME="syft_postgres" \ + POSTGRESQL_PASSWORD="example" CMD ["bash", "./grid/start.sh"] diff --git a/packages/grid/backend/grid/core/server.py b/packages/grid/backend/grid/core/server.py index d284c8edb4e..a447d56ed14 100644 --- a/packages/grid/backend/grid/core/server.py +++ b/packages/grid/backend/grid/core/server.py @@ -91,6 +91,10 @@ def seaweedfs_config() -> SeaweedFSConfig: store_config = ( sql_store_config() if single_container_mode else postgresql_store_config() ) + +print("----------------------------Store Config----------------------------\n") +print(store_config.model_dump()) +print("\n----------------------------Store Config----------------------------") blob_storage_config = None if single_container_mode else seaweedfs_config() queue_config = queue_config() diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 8a35cdef7ee..06efe99f4a1 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -80,12 +80,12 @@ deployments: - ./helm/examples/dev/base.yaml dev: - mongo: + postgres: labelSelector: app.kubernetes.io/name: syft - app.kubernetes.io/component: mongo + app.kubernetes.io/component: postgres ports: - - port: "27017" + - port: "5432" seaweedfs: labelSelector: app.kubernetes.io/name: syft @@ -188,8 +188,8 @@ profiles: path: dev.seaweedfs # Port Re-Mapping - op: replace - path: dev.mongo.ports[0].port - value: 27018:27017 + path: dev.postgres.ports[0].port + value: 5433:5432 - op: replace path: dev.backend.ports[0].port value: 5679:5678 @@ -251,8 +251,8 @@ profiles: value: ./helm/examples/dev/enclave.yaml # Port Re-Mapping - op: replace - path: dev.mongo.ports[0].port - value: 27019:27017 + path: dev.postgres.ports[0].port + value: 5434:5432 - op: replace path: dev.backend.ports[0].port value: 5680:5678 diff --git a/packages/grid/helm/examples/dev/base.yaml b/packages/grid/helm/examples/dev/base.yaml index 9b14fbe29ed..e173e9ac539 100644 --- a/packages/grid/helm/examples/dev/base.yaml +++ b/packages/grid/helm/examples/dev/base.yaml @@ -26,6 +26,13 @@ mongo: secret: rootPassword: example +postgres: + resourcesPreset: null + resources: null + + secret: + rootPassword: example + seaweedfs: resourcesPreset: null resources: null diff --git a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml index 3dcefcd0f6b..693bb6820ba 100644 --- a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml @@ -99,18 +99,20 @@ spec: - name: REVERSE_TUNNEL_ENABLED value: "true" {{- end }} - # MongoDB - - name: MONGO_PORT - value: {{ .Values.mongo.port | quote }} - - name: MONGO_HOST - value: "mongo" - - name: MONGO_USERNAME - value: {{ .Values.mongo.username | quote }} - - name: MONGO_PASSWORD + # Postgres + - name: POSTGRESQL_PORT + value: {{ .Values.postgres.port | quote }} + - name: POSTGRESQL_HOST + value: "postgres" + - name: POSTGRESQL_USERNAME + value: {{ .Values.postgres.username | quote }} + - name: POSTGRESQL_PASSWORD valueFrom: secretKeyRef: - name: {{ .Values.mongo.secretKeyName | required "mongo.secretKeyName is required" }} + name: {{ .Values.postgres.secretKeyName | required "postgres.secretKeyName is required" }} key: rootPassword + - name: POSTGRESQL_DBNAME + value: {{ .Values.postgres.dbname | quote }} # SMTP - name: SMTP_HOST value: {{ .Values.server.smtp.host | quote }} diff --git a/packages/grid/helm/syft/templates/postgres/postgres-headless-service.yaml b/packages/grid/helm/syft/templates/postgres/postgres-headless-service.yaml new file mode 100644 index 00000000000..4855a7868ff --- /dev/null +++ b/packages/grid/helm/syft/templates/postgres/postgres-headless-service.yaml @@ -0,0 +1,15 @@ +apiVersion: v1 +kind: Service +metadata: + name: postgres-headless + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: postgres +spec: + clusterIP: None + ports: + - name: postgres + port: 5432 + selector: + {{- include "common.selectorLabels" . | nindent 4 }} + app.kubernetes.io/component: postgres \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/postgres/postgres-secret.yaml b/packages/grid/helm/syft/templates/postgres/postgres-secret.yaml new file mode 100644 index 00000000000..63a990c0d9a --- /dev/null +++ b/packages/grid/helm/syft/templates/postgres/postgres-secret.yaml @@ -0,0 +1,17 @@ +{{- $secretName := "postgres-secret" }} +apiVersion: v1 +kind: Secret +metadata: + name: {{ $secretName }} + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: postgres +type: Opaque +data: + rootPassword: {{ include "common.secrets.set" (dict + "secret" $secretName + "key" "rootPassword" + "randomDefault" .Values.global.randomizedSecrets + "default" .Values.postgres.secret.rootPassword + "context" $) + }} \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/postgres/postgres-service.yaml b/packages/grid/helm/syft/templates/postgres/postgres-service.yaml new file mode 100644 index 00000000000..9cd8b156bdd --- /dev/null +++ b/packages/grid/helm/syft/templates/postgres/postgres-service.yaml @@ -0,0 +1,17 @@ +apiVersion: v1 +kind: Service +metadata: + name: postgres + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: postgres +spec: + type: ClusterIP + selector: + {{- include "common.selectorLabels" . | nindent 4 }} + app.kubernetes.io/component: postgres + ports: + - name: postgres + port: 5432 + protocol: TCP + targetPort: 5432 \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml b/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml new file mode 100644 index 00000000000..425f9e88770 --- /dev/null +++ b/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml @@ -0,0 +1,72 @@ +apiVersion: apps/v1 +kind: StatefulSet +metadata: + name: postgres + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: postgres +spec: + replicas: 1 + updateStrategy: + type: RollingUpdate + selector: + matchLabels: + {{- include "common.selectorLabels" . | nindent 6 }} + app.kubernetes.io/component: postgres + serviceName: postgres-headless + podManagementPolicy: OrderedReady + template: + metadata: + labels: + {{- include "common.labels" . | nindent 8 }} + app.kubernetes.io/component: postgres + {{- if .Values.postgres.podLabels }} + {{- toYaml .Values.postgres.podLabels | nindent 8 }} + {{- end }} + {{- if .Values.postgres.podAnnotations }} + annotations: {{- toYaml .Values.postgres.podAnnotations | nindent 8 }} + {{- end }} + spec: + {{- if .Values.postgres.nodeSelector }} + nodeSelector: {{- .Values.postgres.nodeSelector | toYaml | nindent 8 }} + {{- end }} + containers: + - name: postgres-container + image: postgres:13 + imagePullPolicy: Always + resources: {{ include "common.resources.set" (dict "resources" .Values.postgres.resources "preset" .Values.postgres.resourcesPreset) | nindent 12 }} + env: + - name: POSTGRES_USER + value: {{ .Values.postgres.username | required "postgres.username is required" | quote }} + - name: POSTGRES_PASSWORD + valueFrom: + secretKeyRef: + name: {{ .Values.postgres.secretKeyName | required "postgres.secretKeyName is required" }} + key: rootPassword + - name: POSTGRES_DB + value: {{ .Values.postgres.dbname | required "postgres.dbname is required" | quote }} + {{- if .Values.postgres.env }} + {{- toYaml .Values.postgres.env | nindent 12 }} + {{- end }} + volumeMounts: + - mountPath: /data/db + name: postgres-data + readOnly: false + subPath: '' + ports: + - name: postgres-port + containerPort: 5432 + terminationGracePeriodSeconds: 5 + volumeClaimTemplates: + - metadata: + name: postgres-data + labels: + {{- include "common.volumeLabels" . | nindent 8 }} + app.kubernetes.io/component: postgres + spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.postgres.storageSize | quote }} + diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index d43c721203e..0e7619e6791 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -44,6 +44,38 @@ mongo: # ================================================================================= +postgres: +# Postgres config + port: 5432 + username: syft_postgres + dbname: syftdb_postgres + host: postgres + + # Extra environment vars + env: null + + # Pod labels & annotations + podLabels: null + podAnnotations: null + + # Node selector for pods + nodeSelector: null + + # Pod Resource Limits + resourcesPreset: large + resources: null + + # PVC storage size + storageSize: 5Gi + + # Mongo secret name. Override this if you want to use a self-managed secret. + secretKeyName: postgres-secret + + # default/custom secret raw values + secret: + rootPassword: null +# ================================================================================= + frontend: # Extra environment vars env: null diff --git a/packages/syft/src/syft/store/postgresql_document_store.py b/packages/syft/src/syft/store/postgresql_document_store.py index f8de84876b6..c0750dfdedb 100644 --- a/packages/syft/src/syft/store/postgresql_document_store.py +++ b/packages/syft/src/syft/store/postgresql_document_store.py @@ -48,6 +48,12 @@ class PostgreSQLStoreClientConfig(StoreClientConfig): class Config: frozen = True + def __hash__(self) -> int: + return hash((self.dbname, self.username, self.password, self.host, self.port)) + + def __str__(self) -> str: + return f"dbname={self.dbname} user={self.username} password={self.password} host={self.host} port={self.port}" + @serializable(canonical_name="PostgreSQLStorePartition", version=1) class PostgreSQLStorePartition(SQLiteStorePartition): @@ -95,6 +101,9 @@ def _connect(self) -> None: port=self.store_config.client_config.port, ) + print(f"Connected to {self.store_config.client_config.dbname}") + print("PostgreSQL database connection:", connection._check_connection_ok()) + _CONNECTION_POOL_DB[cache_key(self.dbname)] = connection def create_table(self) -> None: From bad8e5a0bbb0b26ce01813a25b846f39648256a2 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Fri, 30 Aug 2024 11:42:27 +0700 Subject: [PATCH 063/197] [syft/stores] new exception handlings for some methods in `SQLiteBackingStore` [tox] trying to delete a datasite cluster before launching --- .../src/syft/store/postgresql_document_store.py | 11 +++++++++-- .../src/syft/store/sqlite_document_store.py | 17 +++-------------- tox.ini | 2 ++ 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/packages/syft/src/syft/store/postgresql_document_store.py b/packages/syft/src/syft/store/postgresql_document_store.py index c0750dfdedb..c7dad244f71 100644 --- a/packages/syft/src/syft/store/postgresql_document_store.py +++ b/packages/syft/src/syft/store/postgresql_document_store.py @@ -158,10 +158,17 @@ def _execute(self, sql: str, args: list[Any] | None) -> psycopg.Cursor: with self.lock: cursor: psycopg.Cursor | None = None try: - cursor = self.cur.execute(sql, args) + # Ensure self.cur is a psycopg cursor object + cursor = self.cur # Assuming self.cur is already set as psycopg.Cursor + cursor.execute(sql, args) # Execute the SQL with arguments + # cursor = self.cur.execute(sql, args) except InFailedSqlTransaction: - self.db.rollback() + self.db.rollback() # Rollback if something went wrong + raise SyftException( + public_message=f"Transaction {sql} failed and was rolled back." + ) except Exception as e: + self.db.rollback() # Rollback on any other exception to maintain clean state public_message = special_exception_public_message(self.table_name, e) raise SyftException.from_exception(e, public_message=public_message) self.db.commit() # Commit if everything went ok diff --git a/packages/syft/src/syft/store/sqlite_document_store.py b/packages/syft/src/syft/store/sqlite_document_store.py index c2a329d0cae..1154952def7 100644 --- a/packages/syft/src/syft/store/sqlite_document_store.py +++ b/packages/syft/src/syft/store/sqlite_document_store.py @@ -249,10 +249,7 @@ def _get(self, key: UID) -> Any: def _exists(self, key: UID) -> bool: select_sql = f"select uid from {self.table_name} where uid = {self.subs_char}" # nosec - res = self._execute(select_sql, [str(key)]) - if res.is_err(): - return False - cursor = res.ok() + cursor = self._execute(select_sql, [str(key)]).unwrap() row = cursor.fetchone() # type: ignore if row is None: return False @@ -264,11 +261,7 @@ def _get_all(self) -> Any: keys = [] data = [] - res = self._execute(select_sql, []) - if res.is_err(): - return {} - cursor = res.ok() - + cursor = self._execute(select_sql, []).unwrap() rows = cursor.fetchall() # type: ignore if rows is None: return {} @@ -281,11 +274,7 @@ def _get_all(self) -> Any: def _get_all_keys(self) -> Any: select_sql = f"select uid from {self.table_name} order by sqltime" # nosec - res = self._execute(select_sql, []) - if res.is_err(): - return [] - cursor = res.ok() - + cursor = self._execute(select_sql, []).unwrap() rows = cursor.fetchall() # type: ignore if rows is None: return [] diff --git a/tox.ini b/tox.ini index c9b2569fea8..78a3c541202 100644 --- a/tox.ini +++ b/tox.ini @@ -1089,7 +1089,9 @@ setenv= CLUSTER_HTTP_PORT={env:CLUSTER_HTTP_PORT:9082} allowlist_externals = tox + bash commands = + bash -c "CLUSTER_NAME=${CLUSTER_NAME} tox -e dev.k8s.destroy" tox -e dev.k8s.start tox -e dev.k8s.{posargs:deploy} From 85f72f4afaa4ca5c36d51fef7cb4d8f8014e3668 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Fri, 30 Aug 2024 13:58:16 +0700 Subject: [PATCH 064/197] [syft/store] fix wrong imports and ports for postgres store --- packages/grid/backend/backend.dockerfile | 2 +- packages/grid/backend/grid/core/config.py | 2 +- packages/syft/src/syft/store/postgresql_document_store.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/grid/backend/backend.dockerfile b/packages/grid/backend/backend.dockerfile index 0989cc4f4f2..1a6ccbf356e 100644 --- a/packages/grid/backend/backend.dockerfile +++ b/packages/grid/backend/backend.dockerfile @@ -87,6 +87,6 @@ ENV \ POSTGRESQL_HOST="localhost" \ POSTGRESQL_PORT="5432" \ POSTGRESQL_USERNAME="syft_postgres" \ - POSTGRESQL_PASSWORD="example" + POSTGRESQL_PASSWORD="changethis" CMD ["bash", "./grid/start.sh"] diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index 0d50d9c9078..6a201629b3b 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -128,7 +128,7 @@ def get_emails_enabled(self) -> Self: CONTAINER_HOST: str = str(os.getenv("CONTAINER_HOST", "docker")) POSTGRESQL_DBNAME: str = str(os.getenv("POSTGRESQL_DBNAME", "")) POSTGRESQL_HOST: str = str(os.getenv("POSTGRESQL_HOST", "")) - POSTGRESQL_PORT: int = int(os.getenv("POSTGRESQL_PORT", 27017)) + POSTGRESQL_PORT: int = int(os.getenv("POSTGRESQL_PORT", 5432)) POSTGRESQL_USERNAME: str = str(os.getenv("POSTGRESQL_USERNAME", "")) POSTGRESQL_PASSWORD: str = str(os.getenv("POSTGRESQL_PASSWORD", "")) DEV_MODE: bool = True if os.getenv("DEV_MODE", "false").lower() == "true" else False diff --git a/packages/syft/src/syft/store/postgresql_document_store.py b/packages/syft/src/syft/store/postgresql_document_store.py index c7dad244f71..aca2976938b 100644 --- a/packages/syft/src/syft/store/postgresql_document_store.py +++ b/packages/syft/src/syft/store/postgresql_document_store.py @@ -2,7 +2,6 @@ from collections import defaultdict import logging from typing import Any -from typing import Self # third party import psycopg @@ -11,6 +10,7 @@ from psycopg.errors import DuplicateTable from psycopg.errors import InFailedSqlTransaction from pydantic import Field +from typing_extensions import Self # relative from ..serde.serializable import serializable From abfc2f1e77223aea581be0859a24d488e1cb7659 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 2 Sep 2024 11:28:55 +0200 Subject: [PATCH 065/197] fix --- packages/syft/src/syft/store/db/stash.py | 34 ++++++++++++------------ 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 210c47fbeb4..6b8724919be 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -39,11 +39,11 @@ from .models import UIDTypeDecorator from .sqlite_db import DBManager -SyftT = TypeVar("SyftT", bound=SyftObject) +StashT = TypeVar("StashT", bound=SyftObject) T = TypeVar("T") -class ObjectStash(Generic[SyftT]): +class ObjectStash(Generic[StashT]): table: Table object_type: type[SyftObject] @@ -53,7 +53,7 @@ def __init__(self, store: DBManager) -> None: self.table = self._create_table() @classmethod - def get_object_type(cls) -> type[SyftT]: + def get_object_type(cls) -> type[StashT]: generic_args = get_args(cls.__orig_bases__[0]) if len(generic_args) != 1: raise TypeError("ObjectStash must have a single generic argument") @@ -113,7 +113,7 @@ def _print_query(self, stmt: sa.sql.select) -> None: def unique_fields(self) -> list[str]: return getattr(self.object_type, "__attr_unique__", []) - def is_unique(self, obj: SyftT) -> bool: + def is_unique(self, obj: StashT) -> bool: unique_fields = self.unique_fields if not unique_fields: return True @@ -148,7 +148,7 @@ def exists(self, credentials: SyftVerifyKey, uid: UID) -> bool: @as_result(SyftException, StashException, NotFoundException) def get_by_uid( self, credentials: SyftVerifyKey, uid: UID, has_permission: bool = False - ) -> SyftT: + ) -> StashT: stmt = self.table.select() stmt = stmt.where(self._get_field_filter("id", uid)) stmt = self._apply_permission_filter( @@ -206,7 +206,7 @@ def get_one_by_field( field_name: str, field_value: str, has_permission: bool = False, - ) -> SyftT: + ) -> StashT: return self.get_one_by_fields( credentials=credentials, fields={field_name: field_value}, @@ -219,7 +219,7 @@ def get_one_by_fields( credentials: SyftVerifyKey, fields: dict[str, str], has_permission: bool = False, - ) -> SyftT: + ) -> StashT: result = self._get_by_fields( credentials=credentials, fields=fields, @@ -239,7 +239,7 @@ def get_all_by_fields( limit: int | None = None, offset: int | None = None, has_permission: bool = False, - ) -> list[SyftT]: + ) -> list[StashT]: result = self._get_by_fields( credentials=credentials, fields=fields, @@ -263,7 +263,7 @@ def get_all_by_field( limit: int | None = None, offset: int | None = None, has_permission: bool = False, - ) -> list[SyftT]: + ) -> list[StashT]: return self.get_all_by_fields( credentials=credentials, fields={field_name: field_value}, @@ -285,7 +285,7 @@ def get_all_contains( limit: int | None = None, offset: int | None = None, has_permission: bool = False, - ) -> list[SyftT]: + ) -> list[StashT]: # TODO write filter logic, merge with get_all stmt = self.table.select().where( @@ -303,7 +303,7 @@ def get_all_contains( @as_result(SyftException, StashException, NotFoundException) def get_index( self, credentials: SyftVerifyKey, index: int, has_permission: bool = False - ) -> SyftT: + ) -> StashT: items = self.get_all( credentials, has_permission=has_permission, @@ -315,7 +315,7 @@ def get_index( raise NotFoundException(f"No item found at index {index}") return items[0] - def row_as_obj(self, row: Row) -> SyftT: + def row_as_obj(self, row: Row) -> StashT: # TODO make unwrappable serde return deserialize_json(row.fields) @@ -428,7 +428,7 @@ def get_all( sort_order: str | None = None, limit: int | None = None, offset: int | None = None, - ) -> list[SyftT]: + ) -> list[StashT]: stmt = self.table.select() stmt = self._apply_permission_filter( @@ -447,9 +447,9 @@ def get_all( def update( self, credentials: SyftVerifyKey, - obj: SyftT, + obj: StashT, has_permission: bool = False, - ) -> SyftT: + ) -> StashT: """ NOTE: We cannot do partial updates on the database, because we are using computed fields that are not known to the DB or ORM: @@ -661,11 +661,11 @@ def get_all_permissions(self) -> dict[UID, set[str]]: def set( self, credentials: SyftVerifyKey, - obj: SyftT, + obj: StashT, add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, # TODO: check the default value ignore_duplicates: bool = False, - ) -> SyftT: + ) -> StashT: uid = obj.id # check if the object already exists From ea3114c149d6d5debf6975f7d163f9f6dd944b26 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 3 Sep 2024 09:15:09 +0200 Subject: [PATCH 066/197] wip --- packages/syft/src/syft/server/server.py | 1 + packages/syft/src/syft/service/sync/sync_service.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 7efe4ab01da..ce07c3c35ae 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -879,6 +879,7 @@ def init_stores( filename=f"{self.id}_json.db", path=self.get_temp_dir("db"), ) + # json_db_config = PostgresDBConfig(reset=False) self.db = SQLiteDBManager( config=json_db_config, server_uid=self.id, diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index 861b5db0a8f..bfa0a5ae042 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -3,14 +3,13 @@ import logging from typing import Any -from syft.store.document_store_errors import NotFoundException - # relative from ...client.api import ServerIdentity from ...serde.serializable import serializable from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import NewBaseStash +from ...store.document_store_errors import NotFoundException from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime from ...types.errors import SyftException From f0a3d192c644e680c8570b6c306b13cf87c5a847 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 4 Sep 2024 11:45:28 +0200 Subject: [PATCH 067/197] remove dict document store --- packages/syft/setup.cfg | 1 + packages/syft/src/syft/server/server.py | 30 +- .../syft/service/action/action_permissions.py | 4 + .../src/syft/service/user/user_service.py | 4 +- packages/syft/src/syft/store/db/sqlite_db.py | 3 +- .../src/syft/store/dict_document_store.py | 106 ------ packages/syft/tests/conftest.py | 8 - .../tests/syft/dataset/dataset_stash_test.py | 42 -- .../notifications/notification_stash_test.py | 86 +---- .../tests/syft/request/request_stash_test.py | 4 +- .../tests/syft/stores/action_store_test.py | 2 +- .../syft/tests/syft/stores/base_stash_test.py | 17 +- .../syft/stores/dict_document_store_test.py | 358 ------------------ .../syft/stores/mongo_document_store_test.py | 2 +- .../tests/syft/stores/store_fixtures_test.py | 98 +---- packages/syft/tests/syft/worker_test.py | 3 +- 16 files changed, 44 insertions(+), 724 deletions(-) delete mode 100644 packages/syft/src/syft/store/dict_document_store.py delete mode 100644 packages/syft/tests/syft/stores/dict_document_store_test.py diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 399a06e69f5..da0552250eb 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -102,6 +102,7 @@ dev = ruff==0.4.7 safety>=2.4.0b2 aiosmtpd==1.4.6 + dynaconf==3.2.5 telemetry = opentelemetry-api==1.27.0 diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 5719dd3c78b..8ab77c2256b 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -90,7 +90,6 @@ from ..store.blob_storage.seaweedfs import SeaweedFSBlobDeposit from ..store.db.sqlite_db import SQLiteDBConfig from ..store.db.sqlite_db import SQLiteDBManager -from ..store.dict_document_store import DictStoreConfig from ..store.document_store import StoreConfig from ..store.document_store_errors import NotFoundException from ..store.document_store_errors import StashException @@ -378,13 +377,10 @@ def __init__( if reset: self.remove_temp_dir() - use_sqlite = local_db or (processes > 0 and not is_subprocess) document_store_config = document_store_config or self.get_default_store( - use_sqlite=use_sqlite, store_type="Document Store", ) action_store_config = action_store_config or self.get_default_store( - use_sqlite=use_sqlite, store_type="Action Store", ) @@ -452,21 +448,19 @@ def runs_in_docker(self) -> bool: and any("docker" in line for line in open(path)) ) - def get_default_store(self, use_sqlite: bool, store_type: str) -> StoreConfig: - if use_sqlite: - path = self.get_temp_dir("db") - file_name: str = f"{self.id}.sqlite" - if self.dev_mode: - # leave this until the logger shows this in the notebook - print(f"{store_type}'s SQLite DB path: {path/file_name}") - logger.debug(f"{store_type}'s SQLite DB path: {path/file_name}") - return SQLiteStoreConfig( - client_config=SQLiteStoreClientConfig( - filename=file_name, - path=path, - ) + def get_default_store(self, store_type: str) -> StoreConfig: + path = self.get_temp_dir("db") + file_name: str = f"{self.id}.sqlite" + if self.dev_mode: + # leave this until the logger shows this in the notebook + print(f"{store_type}'s SQLite DB path: {path/file_name}") + logger.debug(f"{store_type}'s SQLite DB path: {path/file_name}") + return SQLiteStoreConfig( + client_config=SQLiteStoreClientConfig( + filename=file_name, + path=path, ) - return DictStoreConfig() + ) def init_blob_storage(self, config: BlobStorageConfig | None = None) -> None: if config is None: diff --git a/packages/syft/src/syft/service/action/action_permissions.py b/packages/syft/src/syft/service/action/action_permissions.py index 477b3b02bee..ab6f9b7ce9a 100644 --- a/packages/syft/src/syft/service/action/action_permissions.py +++ b/packages/syft/src/syft/service/action/action_permissions.py @@ -17,6 +17,7 @@ class ActionPermission(Enum): ALL_WRITE = 32 EXECUTE = 64 ALL_EXECUTE = 128 + ALL_OWNER = 256 @property def as_compound(self) -> "ActionPermission": @@ -28,6 +29,8 @@ def as_compound(self) -> "ActionPermission": return ActionPermission.ALL_WRITE elif self == ActionPermission.EXECUTE: return ActionPermission.ALL_EXECUTE + elif self == ActionPermission.OWNER: + return ActionPermission.ALL_OWNER else: raise Exception(f"Invalid compound permission {self}") @@ -36,6 +39,7 @@ def as_compound(self) -> "ActionPermission": ActionPermission.ALL_READ, ActionPermission.ALL_WRITE, ActionPermission.ALL_EXECUTE, + ActionPermission.ALL_OWNER, } diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index 4c9a365a16a..ec23aa6e36f 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -395,7 +395,9 @@ def search( if len(kwargs) == 0: raise SyftException(public_message="Invalid search parameters") - users = self.stash.find_all(credentials=context.credentials, **kwargs).unwrap() + users = self.stash.get_all_by_fields( + credentials=context.credentials, fields=kwargs + ).unwrap() users = [user.to(UserView) for user in users] if users is not None else [] return _paginate(users, page_size, page_index) diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index 1f4e6d4b56d..a28d5f8f752 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -2,6 +2,7 @@ from pathlib import Path import tempfile import threading +import uuid # third party from pydantic import BaseModel @@ -28,7 +29,7 @@ def connection_string(self) -> str: class SQLiteDBConfig(DBConfig): - filename: str = "jsondb.sqlite" + filename: str = Field(default_factory=lambda: f"{uuid.uuid4()}.db") path: Path = Field(default_factory=tempfile.gettempdir) @property diff --git a/packages/syft/src/syft/store/dict_document_store.py b/packages/syft/src/syft/store/dict_document_store.py deleted file mode 100644 index ca0f3e1f33a..00000000000 --- a/packages/syft/src/syft/store/dict_document_store.py +++ /dev/null @@ -1,106 +0,0 @@ -# future -from __future__ import annotations - -# stdlib -from typing import Any - -# third party -from pydantic import Field - -# relative -from ..serde.serializable import serializable -from ..server.credentials import SyftVerifyKey -from ..types import uid -from .document_store import DocumentStore -from .document_store import StoreConfig -from .kv_document_store import KeyValueBackingStore -from .kv_document_store import KeyValueStorePartition -from .locks import LockingConfig -from .locks import ThreadingLockingConfig - - -@serializable(canonical_name="DictBackingStore", version=1) -class DictBackingStore(dict, KeyValueBackingStore): # type: ignore[misc] - # TODO: fix the mypy issue - """Dictionary-based Store core logic""" - - def __init__(self, *args: Any, **kwargs: Any) -> None: - super().__init__() - self._ddtype = kwargs.get("ddtype", None) - - def __getitem__(self, key: Any) -> Any: - try: - value = super().__getitem__(key) - return value - except KeyError as e: - if self._ddtype: - return self._ddtype() - raise e - - -@serializable(canonical_name="DictStorePartition", version=1) -class DictStorePartition(KeyValueStorePartition): - """Dictionary-based StorePartition - - Parameters: - `settings`: PartitionSettings - PySyft specific settings, used for indexing and partitioning - `store_config`: DictStoreConfig - DictStore specific configuration - """ - - def prune(self) -> None: - self.init_store().unwrap() - - -# the base document store is already a dict but we can change it later -@serializable(canonical_name="DictDocumentStore", version=1) -class DictDocumentStore(DocumentStore): - """Dictionary-based Document Store - - Parameters: - `store_config`: DictStoreConfig - Dictionary Store specific configuration, containing the store type and the backing store type - """ - - partition_type = DictStorePartition - - def __init__( - self, - server_uid: uid, - root_verify_key: SyftVerifyKey | None, - store_config: DictStoreConfig | None = None, - ) -> None: - if store_config is None: - store_config = DictStoreConfig() - super().__init__( - server_uid=server_uid, - root_verify_key=root_verify_key, - store_config=store_config, - ) - - def reset(self) -> None: - for partition in self.partitions.values(): - partition.prune() - - -@serializable() -class DictStoreConfig(StoreConfig): - __canonical_name__ = "DictStoreConfig" - """Dictionary-based configuration - - Parameters: - `store_type`: Type[DocumentStore] - The Document type used. Default: DictDocumentStore - `backing_store`: Type[KeyValueBackingStore] - The backend type used. Default: DictBackingStore - locking_config: LockingConfig - The config used for store locking. Available options: - * NoLockingConfig: no locking, ideal for single-thread stores. - * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. - Defaults to ThreadingLockingConfig. - """ - - store_type: type[DocumentStore] = DictDocumentStore - backing_store: type[KeyValueBackingStore] = DictBackingStore - locking_config: LockingConfig = Field(default_factory=ThreadingLockingConfig) diff --git a/packages/syft/tests/conftest.py b/packages/syft/tests/conftest.py index eca68d13b12..0bd3426f2b9 100644 --- a/packages/syft/tests/conftest.py +++ b/packages/syft/tests/conftest.py @@ -27,10 +27,6 @@ # relative # our version of mongomock that has a fix for CodecOptions and custom TypeRegistry Support from .mongomock.mongo_client import MongoClient -from .syft.stores.store_fixtures_test import dict_action_store -from .syft.stores.store_fixtures_test import dict_document_store -from .syft.stores.store_fixtures_test import dict_queue_stash -from .syft.stores.store_fixtures_test import dict_store_partition from .syft.stores.store_fixtures_test import mongo_action_store from .syft.stores.store_fixtures_test import mongo_document_store from .syft.stores.store_fixtures_test import mongo_queue_stash @@ -317,10 +313,6 @@ def big_dataset() -> Dataset: "sqlite_document_store", "sqlite_queue_stash", "sqlite_action_store", - "dict_store_partition", - "dict_action_store", - "dict_document_store", - "dict_queue_stash", ] pytest_plugins = [ diff --git a/packages/syft/tests/syft/dataset/dataset_stash_test.py b/packages/syft/tests/syft/dataset/dataset_stash_test.py index d177aaa508e..86c44de2319 100644 --- a/packages/syft/tests/syft/dataset/dataset_stash_test.py +++ b/packages/syft/tests/syft/dataset/dataset_stash_test.py @@ -1,54 +1,12 @@ # third party import pytest -from typeguard import TypeCheckError # syft absolute from syft.service.dataset.dataset import Dataset -from syft.service.dataset.dataset_stash import ActionIDsPartitionKey -from syft.service.dataset.dataset_stash import NamePartitionKey -from syft.store.document_store import QueryKey from syft.store.document_store_errors import NotFoundException from syft.types.uid import UID -def test_dataset_namepartitionkey() -> None: - mock_obj = "dummy_name_key" - - assert NamePartitionKey.key == "name" - assert NamePartitionKey.type_ == str - - name_partition_key = NamePartitionKey.with_obj(obj=mock_obj) - - assert isinstance(name_partition_key, QueryKey) - assert name_partition_key.key == "name" - assert name_partition_key.type_ == str - assert name_partition_key.value == mock_obj - - with pytest.raises(AttributeError): - NamePartitionKey.with_obj(obj=[UID()]) - - -def test_dataset_actionidpartitionkey() -> None: - mock_obj = [UID() for _ in range(3)] - - assert ActionIDsPartitionKey.key == "action_ids" - assert ActionIDsPartitionKey.type_ == list[UID] - - action_ids_partition_key = ActionIDsPartitionKey.with_obj(obj=mock_obj) - - assert isinstance(action_ids_partition_key, QueryKey) - assert action_ids_partition_key.key == "action_ids" - assert action_ids_partition_key.type_ == list[UID] - assert action_ids_partition_key.value == mock_obj - - with pytest.raises(AttributeError): - ActionIDsPartitionKey.with_obj(obj="dummy_str") - - # Not sure what Exception should be raised here, Type or Attibute - with pytest.raises(TypeCheckError): - ActionIDsPartitionKey.with_obj(obj=["first_str", "second_str"]) - - def test_dataset_get_by_name(root_verify_key, mock_dataset_stash, mock_dataset) -> None: # retrieving existing dataset result = mock_dataset_stash.get_by_name(root_verify_key, mock_dataset.name) diff --git a/packages/syft/tests/syft/notifications/notification_stash_test.py b/packages/syft/tests/syft/notifications/notification_stash_test.py index b848324a2b7..9c871182f93 100644 --- a/packages/syft/tests/syft/notifications/notification_stash_test.py +++ b/packages/syft/tests/syft/notifications/notification_stash_test.py @@ -8,13 +8,7 @@ # syft absolute from syft.server.credentials import SyftSigningKey from syft.server.credentials import SyftVerifyKey -from syft.service.notification.notification_stash import ( - OrderByCreatedAtTimeStampPartitionKey, -) -from syft.service.notification.notification_stash import FromUserVerifyKeyPartitionKey from syft.service.notification.notification_stash import NotificationStash -from syft.service.notification.notification_stash import StatusPartitionKey -from syft.service.notification.notification_stash import ToUserVerifyKeyPartitionKey from syft.service.notification.notifications import Notification from syft.service.notification.notifications import NotificationExpiryStatus from syft.service.notification.notifications import NotificationStatus @@ -60,74 +54,6 @@ def add_mock_notification( return mock_notification -def test_fromuserverifykey_partitionkey() -> None: - random_verify_key = SyftSigningKey.generate().verify_key - - assert FromUserVerifyKeyPartitionKey.type_ == SyftVerifyKey - assert FromUserVerifyKeyPartitionKey.key == "from_user_verify_key" - - result = FromUserVerifyKeyPartitionKey.with_obj(random_verify_key) - - assert result.type_ == SyftVerifyKey - assert result.key == "from_user_verify_key" - - assert result.value == random_verify_key - - signing_key = SyftSigningKey.from_string(test_signing_key_string) - with pytest.raises(AttributeError): - FromUserVerifyKeyPartitionKey.with_obj(signing_key) - - -def test_touserverifykey_partitionkey() -> None: - random_verify_key = SyftSigningKey.generate().verify_key - - assert ToUserVerifyKeyPartitionKey.type_ == SyftVerifyKey - assert ToUserVerifyKeyPartitionKey.key == "to_user_verify_key" - - result = ToUserVerifyKeyPartitionKey.with_obj(random_verify_key) - - assert result.type_ == SyftVerifyKey - assert result.key == "to_user_verify_key" - assert result.value == random_verify_key - - signing_key = SyftSigningKey.from_string(test_signing_key_string) - with pytest.raises(AttributeError): - ToUserVerifyKeyPartitionKey.with_obj(signing_key) - - -def test_status_partitionkey() -> None: - assert StatusPartitionKey.key == "status" - assert StatusPartitionKey.type_ == NotificationStatus - - result1 = StatusPartitionKey.with_obj(NotificationStatus.UNREAD) - result2 = StatusPartitionKey.with_obj(NotificationStatus.READ) - - assert result1.type_ == NotificationStatus - assert result1.key == "status" - assert result1.value == NotificationStatus.UNREAD - assert result2.type_ == NotificationStatus - assert result2.key == "status" - assert result2.value == NotificationStatus.READ - - notification_expiry_status_auto = NotificationExpiryStatus(0) - - with pytest.raises(AttributeError): - StatusPartitionKey.with_obj(notification_expiry_status_auto) - - -def test_orderbycreatedattimestamp_partitionkey() -> None: - random_datetime = DateTime.now() - - assert OrderByCreatedAtTimeStampPartitionKey.key == "created_at" - assert OrderByCreatedAtTimeStampPartitionKey.type_ == DateTime - - result = OrderByCreatedAtTimeStampPartitionKey.with_obj(random_datetime) - - assert result.type_ == DateTime - assert result.key == "created_at" - assert result.value == random_datetime - - def test_get_all_inbox_for_verify_key(root_verify_key, document_store) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key @@ -205,12 +131,9 @@ def test_get_all_sent_for_verify_key(root_verify_key, document_store) -> None: def test_get_all_for_verify_key(root_verify_key, document_store) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key - query_key = FromUserVerifyKeyPartitionKey.with_obj(test_verify_key) test_stash = NotificationStash(store=document_store) - response = test_stash.get_all_for_verify_key( - root_verify_key, random_verify_key, query_key - ) + response = test_stash.get_all_for_verify_key(root_verify_key, random_verify_key) assert response.is_ok() @@ -221,11 +144,8 @@ def test_get_all_for_verify_key(root_verify_key, document_store) -> None: root_verify_key, test_stash, test_verify_key, random_verify_key ) - query_key2 = FromUserVerifyKeyPartitionKey.with_obj( - mock_notification.from_user_verify_key - ) response_from_verify_key = test_stash.get_all_for_verify_key( - root_verify_key, mock_notification.from_user_verify_key, query_key2 + root_verify_key, mock_notification.from_user_verify_key ) assert response_from_verify_key.is_ok() @@ -235,7 +155,7 @@ def test_get_all_for_verify_key(root_verify_key, document_store) -> None: assert result[0] == mock_notification response_from_verify_key_string = test_stash.get_all_for_verify_key( - root_verify_key, test_verify_key_string, query_key2 + root_verify_key, test_verify_key_string ) assert response_from_verify_key_string.is_ok() diff --git a/packages/syft/tests/syft/request/request_stash_test.py b/packages/syft/tests/syft/request/request_stash_test.py index a492c2f6b9f..869725e4feb 100644 --- a/packages/syft/tests/syft/request/request_stash_test.py +++ b/packages/syft/tests/syft/request/request_stash_test.py @@ -12,7 +12,6 @@ from syft.service.request.request import Request from syft.service.request.request import SubmitRequest from syft.service.request.request_stash import RequestStash -from syft.service.request.request_stash import RequestingUserVerifyKeyPartitionKey from syft.store.document_store import PartitionKey from syft.store.document_store import QueryKeys from syft.types.errors import SyftException @@ -111,9 +110,8 @@ def test_requeststash_get_all_for_verify_key_find_index_fail( guest_datasite_client: SyftClient, ) -> None: verify_key: SyftVerifyKey = guest_datasite_client.credentials.verify_key - qks = QueryKeys(qks=[RequestingUserVerifyKeyPartitionKey.with_obj(verify_key)]) - mock_error_message = f"Failed to query index or search with {qks.all[0]}" + mock_error_message = "Failed search with" def mock_find_index_or_search_keys_error( credentials: SyftVerifyKey, diff --git a/packages/syft/tests/syft/stores/action_store_test.py b/packages/syft/tests/syft/stores/action_store_test.py index 375204908c1..204235a421c 100644 --- a/packages/syft/tests/syft/stores/action_store_test.py +++ b/packages/syft/tests/syft/stores/action_store_test.py @@ -7,8 +7,8 @@ # syft absolute from syft.server.credentials import SyftVerifyKey +from syft.service.action.action_permissions import ActionObjectOWNER from syft.service.action.action_store import ActionObjectEXECUTE -from syft.service.action.action_store import ActionObjectOWNER from syft.service.action.action_store import ActionObjectREAD from syft.service.action.action_store import ActionObjectWRITE from syft.types.uid import UID diff --git a/packages/syft/tests/syft/stores/base_stash_test.py b/packages/syft/tests/syft/stores/base_stash_test.py index c61806fb31b..c6a6e24714c 100644 --- a/packages/syft/tests/syft/stores/base_stash_test.py +++ b/packages/syft/tests/syft/stores/base_stash_test.py @@ -12,10 +12,10 @@ # syft absolute from syft.serde.serializable import serializable -from syft.store.dict_document_store import DictDocumentStore -from syft.store.document_store import NewBaseUIDStoreStash +from syft.store.db.sqlite_db import DBManager +from syft.store.db.sqlite_db import SQLiteDBConfig +from syft.store.db.stash import ObjectStash from syft.store.document_store import PartitionKey -from syft.store.document_store import PartitionSettings from syft.store.document_store import QueryKey from syft.store.document_store import QueryKeys from syft.store.document_store import UIDPartitionKey @@ -45,11 +45,8 @@ class MockObject(SyftObject): ImportancePartitionKey = PartitionKey(key="importance", type_=int) -class MockStash(NewBaseUIDStoreStash): - object_type = MockObject - settings = PartitionSettings( - name=MockObject.__canonical_name__, object_type=MockObject - ) +class MockStash(ObjectStash[MockObject]): + pass def get_object_values(obj: SyftObject) -> tuple[Any]: @@ -80,7 +77,9 @@ def create_unique( @pytest.fixture def base_stash(root_verify_key) -> MockStash: - yield MockStash(store=DictDocumentStore(UID(), root_verify_key)) + config = SQLiteDBConfig() + db_manager = DBManager(config, UID(), root_verify_key) + yield MockStash(store=db_manager) def random_sentence(faker: Faker) -> str: diff --git a/packages/syft/tests/syft/stores/dict_document_store_test.py b/packages/syft/tests/syft/stores/dict_document_store_test.py deleted file mode 100644 index e04414d666c..00000000000 --- a/packages/syft/tests/syft/stores/dict_document_store_test.py +++ /dev/null @@ -1,358 +0,0 @@ -# stdlib -from threading import Thread - -# syft absolute -from syft.store.dict_document_store import DictStorePartition -from syft.store.document_store import QueryKeys -from syft.types.uid import UID - -# relative -from .store_mocks_test import MockObjectType -from .store_mocks_test import MockSyftObject - - -def test_dict_store_partition_sanity(dict_store_partition: DictStorePartition) -> None: - res = dict_store_partition.init_store() - assert res.is_ok() - - assert hasattr(dict_store_partition, "data") - assert hasattr(dict_store_partition, "unique_keys") - assert hasattr(dict_store_partition, "searchable_keys") - - -def test_dict_store_partition_set( - root_verify_key, dict_store_partition: DictStorePartition -) -> None: - res = dict_store_partition.init_store() - assert res.is_ok() - - obj = MockSyftObject(id=UID(), data=1) - res = dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - - assert res.is_ok() - assert res.ok() == obj - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - res = dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_err() - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - res = dict_store_partition.set(root_verify_key, obj, ignore_duplicates=True) - assert res.is_ok() - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - obj2 = MockSyftObject(data=2) - res = dict_store_partition.set(root_verify_key, obj2, ignore_duplicates=False) - assert res.is_ok() - assert res.ok() == obj2 - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == 2 - ) - - repeats = 5 - for idx in range(repeats): - obj = MockSyftObject(data=idx) - res = dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_ok() - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == 3 + idx - ) - - -def test_dict_store_partition_delete( - root_verify_key, dict_store_partition: DictStorePartition -) -> None: - res = dict_store_partition.init_store() - assert res.is_ok() - - objs = [] - repeats = 5 - for v in range(repeats): - obj = MockSyftObject(data=v) - dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - objs.append(obj) - - assert len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) == len(objs) - - # random object - obj = MockSyftObject(data="bogus") - key = dict_store_partition.settings.store_key.with_obj(obj) - res = dict_store_partition.delete(root_verify_key, key) - assert res.is_err() - assert len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) == len(objs) - - # cleanup store - for idx, v in enumerate(objs): - key = dict_store_partition.settings.store_key.with_obj(v) - res = dict_store_partition.delete(root_verify_key, key) - assert res.is_ok() - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == len(objs) - idx - 1 - ) - - res = dict_store_partition.delete(root_verify_key, key) - assert res.is_err() - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == len(objs) - idx - 1 - ) - - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == 0 - ) - - -def test_dict_store_partition_update( - root_verify_key, dict_store_partition: DictStorePartition -) -> None: - dict_store_partition.init_store() - - # add item - obj = MockSyftObject(data=1) - dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert len(dict_store_partition.all(root_verify_key).ok()) == 1 - - # fail to update missing keys - rand_obj = MockSyftObject(data="bogus") - key = dict_store_partition.settings.store_key.with_obj(rand_obj) - res = dict_store_partition.update(root_verify_key, key, obj) - assert res.is_err() - - # update the key multiple times - repeats = 5 - for v in range(repeats): - key = dict_store_partition.settings.store_key.with_obj(obj) - obj_new = MockSyftObject(data=v) - - res = dict_store_partition.update(root_verify_key, key, obj_new) - assert res.is_ok() - - # The ID should stay the same on update, unly the values are updated. - assert ( - len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - assert ( - dict_store_partition.all( - root_verify_key, - ) - .ok()[0] - .id - == obj.id - ) - assert ( - dict_store_partition.all( - root_verify_key, - ) - .ok()[0] - .id - != obj_new.id - ) - assert ( - dict_store_partition.all( - root_verify_key, - ) - .ok()[0] - .data - == v - ) - - stored = dict_store_partition.get_all_from_store( - root_verify_key, QueryKeys(qks=[key]) - ) - assert stored.ok()[0].data == v - - -def test_dict_store_partition_set_multithreaded( - root_verify_key, - dict_store_partition: DictStorePartition, -) -> None: - thread_cnt = 3 - repeats = 5 - - dict_store_partition.init_store() - - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - for idx in range(repeats): - obj = MockObjectType(data=idx) - - for _ in range(10): - res = dict_store_partition.set( - root_verify_key, obj, ignore_duplicates=False - ) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - stored_cnt = len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - assert stored_cnt == repeats * thread_cnt - - -def test_dict_store_partition_update_multithreaded( - root_verify_key, - dict_store_partition: DictStorePartition, -) -> None: - thread_cnt = 3 - repeats = 5 - dict_store_partition.init_store() - - obj = MockSyftObject(data=0) - key = dict_store_partition.settings.store_key.with_obj(obj) - dict_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - for repeat in range(repeats): - obj = MockSyftObject(data=repeat) - - for _ in range(10): - res = dict_store_partition.update(root_verify_key, key, obj) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - -def test_dict_store_partition_set_delete_multithreaded( - root_verify_key, - dict_store_partition: DictStorePartition, -) -> None: - dict_store_partition.init_store() - - thread_cnt = 3 - repeats = 5 - - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - for idx in range(repeats): - obj = MockSyftObject(data=idx) - - for _ in range(10): - res = dict_store_partition.set( - root_verify_key, obj, ignore_duplicates=False - ) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - key = dict_store_partition.settings.store_key.with_obj(obj) - - res = dict_store_partition.delete(root_verify_key, key) - if res.is_err(): - execution_err = res - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - stored_cnt = len( - dict_store_partition.all( - root_verify_key, - ).ok() - ) - assert stored_cnt == 0 diff --git a/packages/syft/tests/syft/stores/mongo_document_store_test.py b/packages/syft/tests/syft/stores/mongo_document_store_test.py index 95df806c189..5af4c52d626 100644 --- a/packages/syft/tests/syft/stores/mongo_document_store_test.py +++ b/packages/syft/tests/syft/stores/mongo_document_store_test.py @@ -7,11 +7,11 @@ # syft absolute from syft.server.credentials import SyftVerifyKey +from syft.service.action.action_permissions import ActionObjectOWNER from syft.service.action.action_permissions import ActionObjectPermission from syft.service.action.action_permissions import ActionPermission from syft.service.action.action_permissions import StoragePermission from syft.service.action.action_store import ActionObjectEXECUTE -from syft.service.action.action_store import ActionObjectOWNER from syft.service.action.action_store import ActionObjectREAD from syft.service.action.action_store import ActionObjectWRITE from syft.store.document_store import PartitionSettings diff --git a/packages/syft/tests/syft/stores/store_fixtures_test.py b/packages/syft/tests/syft/stores/store_fixtures_test.py index e462c584b19..614b35b4b67 100644 --- a/packages/syft/tests/syft/stores/store_fixtures_test.py +++ b/packages/syft/tests/syft/stores/store_fixtures_test.py @@ -13,14 +13,14 @@ from syft.server.credentials import SyftVerifyKey from syft.service.action.action_permissions import ActionObjectPermission from syft.service.action.action_permissions import ActionPermission +from syft.service.action.action_store import ActionObjectStash from syft.service.queue.queue_stash import QueueStash from syft.service.user.user import User from syft.service.user.user import UserCreate from syft.service.user.user_roles import ServiceRole from syft.service.user.user_stash import UserStash -from syft.store.dict_document_store import DictDocumentStore -from syft.store.dict_document_store import DictStoreConfig -from syft.store.dict_document_store import DictStorePartition +from syft.store.db.sqlite_db import DBManager +from syft.store.db.sqlite_db import SQLiteDBConfig from syft.store.document_store import DocumentStore from syft.store.document_store import PartitionSettings from syft.store.locks import LockingConfig @@ -62,8 +62,9 @@ def str_to_locking_config(conf: str) -> LockingConfig: def document_store_with_admin( server_uid: UID, verify_key: SyftVerifyKey ) -> DocumentStore: - document_store = DictDocumentStore( - server_uid=server_uid, root_verify_key=verify_key + config = SQLiteDBConfig() + document_store = DBManager( + server_uid=server_uid, root_verify_key=verify_key, config=config ) password = uuid.uuid4().hex @@ -215,7 +216,7 @@ def sqlite_action_store(sqlite_workspace: tuple[Path, str], request): server_uid = UID() document_store = document_store_with_admin(server_uid, ver_key) - yield SQLiteActionStore( + yield ActionObjectStash( server_uid=server_uid, store_config=store_config, root_verify_key=ver_key, @@ -310,88 +311,3 @@ def mongo_queue_stash(root_verify_key, mongo_client, request): locking_config_name=locking_config_name, ) yield mongo_queue_stash_fn(store) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def mongo_action_store(mongo_client, request): - mongo_db_name = token_hex(8) - locking_config_name = request.param - locking_config = str_to_locking_config(locking_config_name) - - mongo_config = MongoStoreClientConfig(client=mongo_client) - store_config = MongoStoreConfig( - client_config=mongo_config, db_name=mongo_db_name, locking_config=locking_config - ) - ver_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT) - server_uid = UID() - document_store = document_store_with_admin(server_uid, ver_key) - mongo_action_store = MongoActionStore( - server_uid=server_uid, - store_config=store_config, - root_verify_key=ver_key, - document_store=document_store, - ) - - yield mongo_action_store - - -def dict_store_partition_fn( - root_verify_key, - locking_config_name: str = "nop", -): - locking_config = str_to_locking_config(locking_config_name) - store_config = DictStoreConfig(locking_config=locking_config) - settings = PartitionSettings(name="test", object_type=MockObjectType) - - return DictStorePartition( - UID(), root_verify_key, settings=settings, store_config=store_config - ) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def dict_store_partition(root_verify_key, request): - locking_config_name = request.param - yield dict_store_partition_fn( - root_verify_key, locking_config_name=locking_config_name - ) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def dict_action_store(request): - locking_config_name = request.param - locking_config = str_to_locking_config(locking_config_name) - - store_config = DictStoreConfig(locking_config=locking_config) - ver_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT) - server_uid = UID() - document_store = document_store_with_admin(server_uid, ver_key) - - yield DictActionStore( - server_uid=server_uid, - store_config=store_config, - root_verify_key=ver_key, - document_store=document_store, - ) - - -def dict_document_store_fn(root_verify_key, locking_config_name: str = "nop"): - locking_config = str_to_locking_config(locking_config_name) - store_config = DictStoreConfig(locking_config=locking_config) - return DictDocumentStore(UID(), root_verify_key, store_config=store_config) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def dict_document_store(root_verify_key, request): - locking_config_name = request.param - yield dict_document_store_fn( - root_verify_key, locking_config_name=locking_config_name - ) - - -def dict_queue_stash_fn(dict_document_store): - return QueueStash(store=dict_document_store) - - -@pytest.fixture(scope="function") -def dict_queue_stash(dict_document_store): - yield dict_queue_stash_fn(dict_document_store) diff --git a/packages/syft/tests/syft/worker_test.py b/packages/syft/tests/syft/worker_test.py index 3f2e22eea31..37ae7b7da24 100644 --- a/packages/syft/tests/syft/worker_test.py +++ b/packages/syft/tests/syft/worker_test.py @@ -16,7 +16,6 @@ from syft.server.credentials import SyftVerifyKey from syft.server.worker import Worker from syft.service.action.action_object import ActionObject -from syft.service.action.action_store import DictActionStore from syft.service.context import AuthedServiceContext from syft.service.queue.queue_stash import QueueItem from syft.service.response import SyftError @@ -78,7 +77,7 @@ def test_signing_key() -> None: def test_action_store() -> None: test_signing_key = SyftSigningKey.from_string(test_signing_key_string) - action_store = DictActionStore(server_uid=UID()) + action_store = ... uid = UID() raw_data = np.array([1, 2, 3]) test_object = ActionObject.from_obj(raw_data) From ffa9f59720061b3b1368cbb090ba9326393f846c Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 4 Sep 2024 11:58:11 +0200 Subject: [PATCH 068/197] fix settings stash test --- packages/syft/src/syft/store/db/sqlite_db.py | 4 ++ packages/syft/tests/conftest.py | 5 +-- .../syft/settings/settings_stash_test.py | 43 ++++--------------- .../tests/syft/stores/store_fixtures_test.py | 4 +- 4 files changed, 15 insertions(+), 41 deletions(-) diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index a28d5f8f752..ce4b7f14fa1 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -98,6 +98,10 @@ def init_tables(self) -> None: Base.metadata.drop_all(bind=self.engine) Base.metadata.create_all(self.engine) + def reset(self) -> None: + Base.metadata.drop_all(bind=self.engine) + Base.metadata.create_all(self.engine) + # TODO remove def get_session_threading_local(self) -> Session: if not hasattr(self.thread_local, "session"): diff --git a/packages/syft/tests/conftest.py b/packages/syft/tests/conftest.py index 0bd3426f2b9..64d91d78151 100644 --- a/packages/syft/tests/conftest.py +++ b/packages/syft/tests/conftest.py @@ -27,7 +27,6 @@ # relative # our version of mongomock that has a fix for CodecOptions and custom TypeRegistry Support from .mongomock.mongo_client import MongoClient -from .syft.stores.store_fixtures_test import mongo_action_store from .syft.stores.store_fixtures_test import mongo_document_store from .syft.stores.store_fixtures_test import mongo_queue_stash from .syft.stores.store_fixtures_test import mongo_store_partition @@ -208,8 +207,7 @@ def ds_verify_key(ds_client: DatasiteClient): @pytest.fixture def document_store(worker): - yield worker.document_store - worker.document_store.reset() + yield worker.db @pytest.fixture @@ -307,7 +305,6 @@ def big_dataset() -> Dataset: "mongo_store_partition", "mongo_document_store", "mongo_queue_stash", - "mongo_action_store", "sqlite_store_partition", "sqlite_workspace", "sqlite_document_store", diff --git a/packages/syft/tests/syft/settings/settings_stash_test.py b/packages/syft/tests/syft/settings/settings_stash_test.py index 3fbbd28f9e9..bb003757c11 100644 --- a/packages/syft/tests/syft/settings/settings_stash_test.py +++ b/packages/syft/tests/syft/settings/settings_stash_test.py @@ -1,54 +1,27 @@ -# third party - # syft absolute from syft.service.settings.settings import ServerSettings from syft.service.settings.settings import ServerSettingsUpdate from syft.service.settings.settings_stash import SettingsStash -def add_mock_settings( - root_verify_key, settings_stash: SettingsStash, settings: ServerSettings -) -> ServerSettings: - # prepare: add mock settings - result = settings_stash.partition.set(root_verify_key, settings) - assert result.is_ok() - - created_settings = result.ok() - assert created_settings is not None - - return created_settings - - +# NOTE: Is this test necessary? It is just testing set and update methods def test_settingsstash_set( - root_verify_key, settings_stash: SettingsStash, settings: ServerSettings -) -> None: - result = settings_stash.set(root_verify_key, settings) - assert result.is_ok() - - created_settings = result.ok() - assert isinstance(created_settings, ServerSettings) - assert created_settings == settings - assert settings.id in settings_stash.partition.data - - -def test_settingsstash_update( root_verify_key, settings_stash: SettingsStash, settings: ServerSettings, update_settings: ServerSettingsUpdate, ) -> None: - # prepare: add a mock settings - mock_settings = add_mock_settings(root_verify_key, settings_stash, settings) + created_settings = settings_stash.set(root_verify_key, settings).unwrap() + assert isinstance(created_settings, ServerSettings) + assert created_settings == settings + assert settings_stash.exists(root_verify_key, settings.id) # update mock_settings according to update_settings update_kwargs = update_settings.to_dict(exclude_empty=True).items() for field_name, value in update_kwargs: - setattr(mock_settings, field_name, value) + setattr(settings, field_name, value) # update the settings in the stash - result = settings_stash.update(root_verify_key, settings=mock_settings) - - assert result.is_ok() - updated_settings = result.ok() + updated_settings = settings_stash.update(root_verify_key, obj=settings).unwrap() assert isinstance(updated_settings, ServerSettings) - assert mock_settings == updated_settings + assert settings == updated_settings diff --git a/packages/syft/tests/syft/stores/store_fixtures_test.py b/packages/syft/tests/syft/stores/store_fixtures_test.py index 614b35b4b67..b049a0a58de 100644 --- a/packages/syft/tests/syft/stores/store_fixtures_test.py +++ b/packages/syft/tests/syft/stores/store_fixtures_test.py @@ -19,8 +19,8 @@ from syft.service.user.user import UserCreate from syft.service.user.user_roles import ServiceRole from syft.service.user.user_stash import UserStash -from syft.store.db.sqlite_db import DBManager from syft.store.db.sqlite_db import SQLiteDBConfig +from syft.store.db.sqlite_db import SQLiteDBManager from syft.store.document_store import DocumentStore from syft.store.document_store import PartitionSettings from syft.store.locks import LockingConfig @@ -63,7 +63,7 @@ def document_store_with_admin( server_uid: UID, verify_key: SyftVerifyKey ) -> DocumentStore: config = SQLiteDBConfig() - document_store = DBManager( + document_store = SQLiteDBManager( server_uid=server_uid, root_verify_key=verify_key, config=config ) From 83a6ab7014233df57773f5dbe4929b84c3d52611 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 4 Sep 2024 12:00:08 +0200 Subject: [PATCH 069/197] add remaining stashes --- .../src/syft/protocol/protocol_version.json | 5 -- .../syft/src/syft/server/service_registry.py | 10 +++- .../data_subject_member_service.py | 31 ++++------- .../data_subject/data_subject_service.py | 36 +++++++------ .../syft/service/metadata/metadata_service.py | 6 --- .../service/migration/migration_service.py | 4 +- .../migration/object_migration_state.py | 46 +++-------------- .../syft/service/network/network_service.py | 26 ++++------ .../syft/service/policy/user_policy_stash.py | 24 +++------ .../src/syft/service/queue/queue_stash.py | 51 +++++++------------ .../service/worker/image_registry_stash.py | 44 ++++------------ .../syft/service/worker/worker_image_stash.py | 34 ++++--------- .../syft/service/worker/worker_pool_stash.py | 36 +++++-------- .../src/syft/service/worker/worker_stash.py | 23 +++------ packages/syft/src/syft/store/db/stash.py | 1 + 15 files changed, 123 insertions(+), 254 deletions(-) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index c1eba4bfc63..3121dc36d9f 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -487,11 +487,6 @@ } }, "JobItem": { - "1": { - "version": 1, - "hash": "0b32277b7d3b9bdc14a2a51cc9005f8254e7f7b6ec059ddcccbcd681a807afb6", - "action": "add" - }, "2": { "version": 2, "hash": "b087d0c62b7d304c6ca80e4fb0e8a7f2a444be8f8cba57490dc09aeb98033105", diff --git a/packages/syft/src/syft/server/service_registry.py b/packages/syft/src/syft/server/service_registry.py index ad2b313e94e..30878648707 100644 --- a/packages/syft/src/syft/server/service_registry.py +++ b/packages/syft/src/syft/server/service_registry.py @@ -128,7 +128,15 @@ def _construct_services(cls, server: "Server") -> dict[str, AbstractService]: # Use old DB else: svc_kwargs["store"] = server.document_store - print("Using old store:", service_cls) + + # TODO remove after all services are migrated + services_without_stash = [ + AttestationService, + MetadataService, + EnclaveService, + ] + if service_cls not in services_without_stash: + print("Using old store:", service_cls) service = service_cls(**svc_kwargs) service_dict[field_name] = service diff --git a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py index c54e8bc67ea..b9f8c851b78 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py @@ -3,10 +3,8 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey +from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys from ...store.document_store_errors import StashException from ...types.result import as_result from ..context import AuthedServiceContext @@ -14,35 +12,28 @@ from ..service import AbstractService from ..service import SERVICE_TO_TYPES from ..service import TYPE_TO_SERVICE -from .data_subject_member import ChildPartitionKey from .data_subject_member import DataSubjectMemberRelationship -from .data_subject_member import ParentPartitionKey -@serializable(canonical_name="DataSubjectMemberStash", version=1) -class DataSubjectMemberStash(NewBaseUIDStoreStash): - object_type = DataSubjectMemberRelationship - settings: PartitionSettings = PartitionSettings( - name=DataSubjectMemberRelationship.__canonical_name__, - object_type=DataSubjectMemberRelationship, - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - +@serializable(canonical_name="DataSubjectMemberSQLStash", version=1) +class DataSubjectMemberStash(ObjectStash[DataSubjectMemberRelationship]): @as_result(StashException) def get_all_for_parent( self, credentials: SyftVerifyKey, name: str ) -> list[DataSubjectMemberRelationship]: - qks = QueryKeys(qks=[ParentPartitionKey.with_obj(name)]) - return self.query_all(credentials=credentials, qks=qks).unwrap() + return self.get_all_by_fields( + credentials=credentials, + fields={"parent": name}, + ).unwrap() @as_result(StashException) def get_all_for_child( self, credentials: SyftVerifyKey, name: str ) -> list[DataSubjectMemberRelationship]: - qks = QueryKeys(qks=[ChildPartitionKey.with_obj(name)]) - return self.query_all(credentials=credentials, qks=qks).unwrap() + return self.get_all_by_fields( + credentials=credentials, + fields={"child": name}, + ).unwrap() @serializable(canonical_name="DataSubjectMemberService", version=1) diff --git a/packages/syft/src/syft/service/data_subject/data_subject_service.py b/packages/syft/src/syft/service/data_subject/data_subject_service.py index b386fd1ec8d..b66de42e349 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_service.py @@ -5,10 +5,8 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey +from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys from ...store.document_store_errors import StashException from ...types.result import as_result from ..context import AuthedServiceContext @@ -19,24 +17,17 @@ from ..service import service_method from .data_subject import DataSubject from .data_subject import DataSubjectCreate -from .data_subject import NamePartitionKey from .data_subject_member_service import DataSubjectMemberService -@serializable(canonical_name="DataSubjectStash", version=1) -class DataSubjectStash(NewBaseUIDStoreStash): - object_type = DataSubject - settings: PartitionSettings = PartitionSettings( - name=DataSubject.__canonical_name__, object_type=DataSubject - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - +@serializable(canonical_name="DataSubjectSQLStash", version=1) +class DataSubjectStash(ObjectStash[DataSubject]): @as_result(StashException) def get_by_name(self, credentials: SyftVerifyKey, name: str) -> DataSubject: - qks = QueryKeys(qks=[NamePartitionKey.with_obj(name)]) - return self.query_one(credentials, qks=qks).unwrap() + return self.get_one_by_fields( + credentials=credentials, + fields={"name": name}, + ).unwrap() @as_result(StashException) def update( @@ -45,9 +36,16 @@ def update( data_subject: DataSubject, has_permission: bool = False, ) -> DataSubject: - res = self.check_type(data_subject, DataSubject).unwrap() - # we dont use and_then logic here as it is hard because of the order of the arguments - return super().update(credentials=credentials, obj=res).unwrap() + self.check_type(data_subject, DataSubject).unwrap() + return ( + super() + .update( + credentials=credentials, + obj=data_subject, + has_permission=has_permission, + ) + .unwrap() + ) @serializable(canonical_name="DataSubjectService", version=1) diff --git a/packages/syft/src/syft/service/metadata/metadata_service.py b/packages/syft/src/syft/service/metadata/metadata_service.py index 4e4e84d1364..70453d9b084 100644 --- a/packages/syft/src/syft/service/metadata/metadata_service.py +++ b/packages/syft/src/syft/service/metadata/metadata_service.py @@ -21,12 +21,6 @@ def __init__(self, store: DocumentStore) -> None: def get_metadata(self, context: AuthedServiceContext) -> ServerMetadata: return context.server.metadata # type: ignore - # @service_method(path="metadata.get_admin", name="get_admin", roles=GUEST_ROLE_LEVEL) - # def get_admin(self, context: AuthedServiceContext): - # user_service = context.server.get_service("userservice") - # admin_user = user_service.get_all(context=context)[0] - # return admin_user - @service_method(path="metadata.get_env", name="get_env", roles=GUEST_ROLE_LEVEL) def get_env(self, context: AuthedServiceContext) -> str: return context.server.packages diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 3ee8f294c5b..da2fc107d32 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -75,9 +75,7 @@ def register_migration_state( obj = SyftObjectMigrationState( current_version=current_version, canonical_name=canonical_name ) - return self.stash.set( - migration_state=obj, credentials=context.credentials - ).unwrap() + return self.stash.set(obj=obj, credentials=context.credentials).unwrap() @as_result(SyftException, NotFoundException) def _find_klasses_pending_for_migration( diff --git a/packages/syft/src/syft/service/migration/object_migration_state.py b/packages/syft/src/syft/service/migration/object_migration_state.py index b7fa2bb7cbd..e6f744c3b0e 100644 --- a/packages/syft/src/syft/service/migration/object_migration_state.py +++ b/packages/syft/src/syft/service/migration/object_migration_state.py @@ -15,10 +15,8 @@ from ...serde.serialize import _serialize from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseStash +from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings from ...store.document_store_errors import NotFoundException from ...types.blob_storage import BlobStorageEntry from ...types.blob_storage import CreateBlobStorageEntry @@ -34,7 +32,6 @@ from ...types.transforms import make_set_default from ...types.uid import UID from ...util.util import prompt_warning_message -from ..action.action_permissions import ActionObjectPermission from ..response import SyftSuccess from ..worker.utils import DEFAULT_WORKER_POOL_NAME from ..worker.worker_image import SyftWorkerImage @@ -70,45 +67,16 @@ def supported_versions(self) -> list: KlassNamePartitionKey = PartitionKey(key="canonical_name", type_=str) -@serializable(canonical_name="SyftMigrationStateStash", version=1) -class SyftMigrationStateStash(NewBaseStash): - object_type = SyftObjectMigrationState - settings: PartitionSettings = PartitionSettings( - name=SyftObjectMigrationState.__canonical_name__, - object_type=SyftObjectMigrationState, - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - - @as_result(SyftException) - def set( # type: ignore [override] - self, - credentials: SyftVerifyKey, - migration_state: SyftObjectMigrationState, - add_permissions: list[ActionObjectPermission] | None = None, - add_storage_permission: bool = True, - ignore_duplicates: bool = False, - ) -> SyftObjectMigrationState: - obj = self.check_type(migration_state, self.object_type).unwrap() - return ( - super() - .set( - credentials=credentials, - obj=obj, - add_permissions=add_permissions, - add_storage_permission=add_storage_permission, - ignore_duplicates=ignore_duplicates, - ) - .unwrap() - ) - +@serializable(canonical_name="SyftMigrationStateSQLStash", version=1) +class SyftMigrationStateStash(ObjectStash[SyftObjectMigrationState]): @as_result(SyftException, NotFoundException) def get_by_name( self, canonical_name: str, credentials: SyftVerifyKey ) -> SyftObjectMigrationState: - qks = KlassNamePartitionKey.with_obj(canonical_name) - return self.query_one(credentials=credentials, qks=qks).unwrap() + return self.get_one_by_fields( + credentials=credentials, + fields={"canonical_name": canonical_name}, + ).unwrap() @serializable() diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 5e4dfdbadf4..a2285901796 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -15,10 +15,9 @@ from ...server.credentials import SyftVerifyKey from ...server.worker_settings import WorkerSettings from ...service.settings.settings import ServerSettings +from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings from ...store.document_store import QueryKeys from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException @@ -36,7 +35,6 @@ from ...util.util import prompt_warning_message from ...util.util import str_to_bool from ..context import AuthedServiceContext -from ..data_subject.data_subject import NamePartitionKey from ..metadata.server_metadata import ServerMetadata from ..request.request import Request from ..request.request import RequestStatus @@ -80,24 +78,18 @@ class ServerPeerAssociationStatus(Enum): PEER_NOT_FOUND = "PEER_NOT_FOUND" -@serializable(canonical_name="NetworkStash", version=1) -class NetworkStash(NewBaseUIDStoreStash): - object_type = ServerPeer - settings: PartitionSettings = PartitionSettings( - name=ServerPeer.__canonical_name__, object_type=ServerPeer - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - +@serializable(canonical_name="NetworkSQLStash", version=1) +class NetworkStash(ObjectStash[ServerPeer]): @as_result(StashException, NotFoundException) def get_by_name(self, credentials: SyftVerifyKey, name: str) -> ServerPeer: - qks = QueryKeys(qks=[NamePartitionKey.with_obj(name)]) try: - return self.query_one(credentials=credentials, qks=qks).unwrap() - except NotFoundException as exc: + return self.get_one_by_fields( + credentials=credentials, + fields={"name": name}, + ).unwrap() + except NotFoundException as e: raise NotFoundException.from_exception( - exc, public_message=f"ServerPeer with {name} not found" + e, public_message=f"ServerPeer with {name} not found" ) @as_result(StashException) diff --git a/packages/syft/src/syft/service/policy/user_policy_stash.py b/packages/syft/src/syft/service/policy/user_policy_stash.py index 38ab7f54c06..643e7e4e84a 100644 --- a/packages/syft/src/syft/service/policy/user_policy_stash.py +++ b/packages/syft/src/syft/service/policy/user_policy_stash.py @@ -3,30 +3,20 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result -from .policy import PolicyUserVerifyKeyPartitionKey from .policy import UserPolicy -@serializable(canonical_name="UserPolicyStash", version=1) -class UserPolicyStash(NewBaseUIDStoreStash): - object_type = UserPolicy - settings: PartitionSettings = PartitionSettings( - name=UserPolicy.__canonical_name__, object_type=UserPolicy - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - +@serializable(canonical_name="UserPolicySQLStash", version=1) +class UserPolicyStash(ObjectStash[UserPolicy]): @as_result(StashException, NotFoundException) def get_all_by_user_verify_key( self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey ) -> list[UserPolicy]: - qks = QueryKeys(qks=[PolicyUserVerifyKeyPartitionKey.with_obj(user_verify_key)]) - return self.query_one(credentials=credentials, qks=qks).unwrap() + return self.get_all_by_fields( + credentials=credentials, + fields={"user_verify_key": str(user_verify_key)}, + ).unwrap() diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index 251f4a9fb63..6ddc2594abe 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -6,12 +6,8 @@ from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...server.worker_settings import WorkerSettings -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseStash +from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys -from ...store.document_store import UIDPartitionKey from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject @@ -41,7 +37,7 @@ class QueueItem(SyftObject): __canonical_name__ = "QueueItem" __version__ = SYFT_OBJECT_VERSION_1 - __attr_searchable__ = ["status", "worker_pool"] + __attr_searchable__ = ["status", "worker_pool_id"] id: UID server_uid: UID @@ -64,6 +60,10 @@ def __repr__(self) -> str: def _repr_markdown_(self, wrap_as_python: bool = True, indent: int = 0) -> str: return f": {self.status}" + @property + def worker_pool_id(self) -> UID: + return self.worker_pool.object_uid + @property def is_action(self) -> bool: return self.service_path == "Action" and self.method_name == "execute" @@ -93,16 +93,8 @@ class APIEndpointQueueItem(QueueItem): service: str = "apiservice" -@serializable(canonical_name="QueueStash", version=1) -class QueueStash(NewBaseStash): - object_type = QueueItem - settings: PartitionSettings = PartitionSettings( - name=QueueItem.__canonical_name__, object_type=QueueItem - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - +@serializable(canonical_name="QueueSQLStash", version=1) +class QueueStash(ObjectStash[QueueItem]): # FIX: Check the return value for None. set_result is used extensively @as_result(StashException) def set_result( @@ -133,11 +125,6 @@ def set_placeholder( return super().set(credentials, item, add_permissions).unwrap() return item - @as_result(StashException) - def get_by_uid(self, credentials: SyftVerifyKey, uid: UID) -> QueueItem: - qks = QueryKeys(qks=[UIDPartitionKey.with_obj(uid)]) - return self.query_one(credentials=credentials, qks=qks).unwrap() - @as_result(StashException) def pop(self, credentials: SyftVerifyKey, uid: UID) -> QueueItem | None: try: @@ -156,23 +143,23 @@ def pop_on_complete(self, credentials: SyftVerifyKey, uid: UID) -> QueueItem: self.delete_by_uid(credentials=credentials, uid=uid) return queue_item - @as_result(StashException) - def delete_by_uid(self, credentials: SyftVerifyKey, uid: UID) -> UID: - qk = UIDPartitionKey.with_obj(uid) - super().delete(credentials=credentials, qk=qk).unwrap() - return uid - @as_result(StashException) def get_by_status( self, credentials: SyftVerifyKey, status: Status ) -> list[QueueItem]: - qks = QueryKeys(qks=StatusPartitionKey.with_obj(status)) - - return self.query_all(credentials=credentials, qks=qks).unwrap() + # TODO do we need json serialization for Status? + return self.get_all_by_fields( + credentials=credentials, + fields={"status": status}, + ).unwrap() @as_result(StashException) def _get_by_worker_pool( self, credentials: SyftVerifyKey, worker_pool: LinkedObject ) -> list[QueueItem]: - qks = QueryKeys(qks=_WorkerPoolPartitionKey.with_obj(worker_pool)) - return self.query_all(credentials=credentials, qks=qks).unwrap() + worker_pool_id = worker_pool.object_uid + + return self.get_all_by_fields( + credentials=credentials, + fields={"worker_pool_id": worker_pool_id}, + ).unwrap() diff --git a/packages/syft/src/syft/service/worker/image_registry_stash.py b/packages/syft/src/syft/service/worker/image_registry_stash.py index 5c469da0825..145270339ab 100644 --- a/packages/syft/src/syft/service/worker/image_registry_stash.py +++ b/packages/syft/src/syft/service/worker/image_registry_stash.py @@ -1,55 +1,33 @@ -# stdlib - -# third party - -# stdlib - # stdlib from typing import Literal # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException from ...types.result import as_result from .image_registry import SyftImageRegistry -__all__ = ["SyftImageRegistryStash"] - - -URLPartitionKey = PartitionKey(key="url", type_=str) - - -@serializable(canonical_name="SyftImageRegistryStash", version=1) -class SyftImageRegistryStash(NewBaseUIDStoreStash): - object_type = SyftImageRegistry - settings: PartitionSettings = PartitionSettings( - name=SyftImageRegistry.__canonical_name__, - object_type=SyftImageRegistry, - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) +@serializable(canonical_name="SyftImageRegistrySQLStash", version=1) +class SyftImageRegistryStash(ObjectStash[SyftImageRegistry]): @as_result(SyftException, StashException, NotFoundException) def get_by_url( self, credentials: SyftVerifyKey, url: str, ) -> SyftImageRegistry | None: - qks = QueryKeys(qks=[URLPartitionKey.with_obj(url)]) - return self.query_one(credentials=credentials, qks=qks).unwrap( - public_message=f"Image Registry with url {url} not found" - ) + return self.get_one_by_fields( + credentials=credentials, fields={"url": url} + ).unwrap() @as_result(SyftException, StashException) def delete_by_url(self, credentials: SyftVerifyKey, url: str) -> Literal[True]: - qk = URLPartitionKey.with_obj(url) - return super().delete(credentials=credentials, qk=qk).unwrap() + item = self.get_by_url(credentials=credentials, url=url).unwrap() + self.delete_by_uid(credentials=credentials, uid=item.id).unwrap() + + # TODO standardize delete return type + return True diff --git a/packages/syft/src/syft/service/worker/worker_image_stash.py b/packages/syft/src/syft/service/worker/worker_image_stash.py index 983dfe9a8d6..211f6072831 100644 --- a/packages/syft/src/syft/service/worker/worker_image_stash.py +++ b/packages/syft/src/syft/service/worker/worker_image_stash.py @@ -7,11 +7,7 @@ from ...custom_worker.config import WorkerConfig from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException @@ -20,22 +16,11 @@ from ..action.action_permissions import ActionPermission from .worker_image import SyftWorkerImage -WorkerConfigPK = PartitionKey(key="config", type_=WorkerConfig) - - -@serializable(canonical_name="SyftWorkerImageStash", version=1) -class SyftWorkerImageStash(NewBaseUIDStoreStash): - object_type = SyftWorkerImage - settings: PartitionSettings = PartitionSettings( - name=SyftWorkerImage.__canonical_name__, - object_type=SyftWorkerImage, - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) +@serializable(canonical_name="SyftWorkerImageSQLStash", version=1) +class SyftWorkerImageStash(ObjectStash[SyftWorkerImage]): @as_result(SyftException, StashException, NotFoundException) - def set( # type: ignore + def set( self, credentials: SyftVerifyKey, obj: SyftWorkerImage, @@ -43,9 +28,8 @@ def set( # type: ignore add_storage_permission: bool = True, ignore_duplicates: bool = False, ) -> SyftWorkerImage: - add_permissions = [] if add_permissions is None else add_permissions - # By default syft images have all read permission + add_permissions = [] if add_permissions is None else add_permissions add_permissions.append( ActionObjectPermission(uid=obj.id, permission=ActionPermission.ALL_READ) ) @@ -85,7 +69,11 @@ def worker_config_exists( def get_by_worker_config( self, credentials: SyftVerifyKey, config: WorkerConfig ) -> SyftWorkerImage: - qks = QueryKeys(qks=[WorkerConfigPK.with_obj(config)]) - return self.query_one(credentials=credentials, qks=qks).unwrap( + # TODO cannot search on fields containing objects + all_images = self.get_all(credentials=credentials).unwrap() + for image in all_images: + if image.config == config: + return image + raise NotFoundException( public_message=f"Worker Image with config {config} not found" ) diff --git a/packages/syft/src/syft/service/worker/worker_pool_stash.py b/packages/syft/src/syft/service/worker/worker_pool_stash.py index 94a4a0b8fab..36c0799296a 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_stash.py +++ b/packages/syft/src/syft/service/worker/worker_pool_stash.py @@ -5,11 +5,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash -from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys +from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result @@ -18,25 +14,17 @@ from ..action.action_permissions import ActionPermission from .worker_pool import WorkerPool -PoolNamePartitionKey = PartitionKey(key="name", type_=str) -PoolImageIDPartitionKey = PartitionKey(key="image_id", type_=UID) - - -@serializable(canonical_name="SyftWorkerPoolStash", version=1) -class SyftWorkerPoolStash(NewBaseUIDStoreStash): - object_type = WorkerPool - settings: PartitionSettings = PartitionSettings( - name=WorkerPool.__canonical_name__, - object_type=WorkerPool, - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) +@serializable(canonical_name="SyftWorkerPoolSQLStash", version=1) +class SyftWorkerPoolStash(ObjectStash[WorkerPool]): @as_result(StashException, NotFoundException) def get_by_name(self, credentials: SyftVerifyKey, pool_name: str) -> WorkerPool: - qks = QueryKeys(qks=[PoolNamePartitionKey.with_obj(pool_name)]) - return self.query_one(credentials=credentials, qks=qks).unwrap( + result = self.get_one_by_fields( + credentials=credentials, + fields={"name": pool_name}, + ) + + return result.unwrap( public_message=f"WorkerPool with name {pool_name} not found" ) @@ -70,5 +58,7 @@ def set( def get_by_image_uid( self, credentials: SyftVerifyKey, image_uid: UID ) -> list[WorkerPool]: - qks = QueryKeys(qks=[PoolImageIDPartitionKey.with_obj(image_uid)]) - return self.query_all(credentials=credentials, qks=qks).unwrap() + return self.get_by_fields( + credentials=credentials, + fields={"image_id": image_uid.no_dash}, + ).unwrap() diff --git a/packages/syft/src/syft/service/worker/worker_stash.py b/packages/syft/src/syft/service/worker/worker_stash.py index b2b059ffec5..b9dfa1d3e5a 100644 --- a/packages/syft/src/syft/service/worker/worker_stash.py +++ b/packages/syft/src/syft/service/worker/worker_stash.py @@ -5,11 +5,8 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore -from ...store.document_store import NewBaseUIDStoreStash +from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey -from ...store.document_store import PartitionSettings -from ...store.document_store import QueryKeys from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result @@ -22,16 +19,8 @@ WorkerContainerNamePartitionKey = PartitionKey(key="container_name", type_=str) -@serializable(canonical_name="WorkerStash", version=1) -class WorkerStash(NewBaseUIDStoreStash): - object_type = SyftWorker - settings: PartitionSettings = PartitionSettings( - name=SyftWorker.__canonical_name__, object_type=SyftWorker - ) - - def __init__(self, store: DocumentStore) -> None: - super().__init__(store=store) - +@serializable(canonical_name="WorkerSQLStash", version=1) +class WorkerStash(ObjectStash[SyftWorker]): @as_result(StashException) def set( self, @@ -62,8 +51,10 @@ def set( def get_worker_by_name( self, credentials: SyftVerifyKey, worker_name: str ) -> SyftWorker: - qks = QueryKeys(qks=[WorkerContainerNamePartitionKey.with_obj(worker_name)]) - return self.query_one(credentials=credentials, qks=qks).unwrap() + self.get_one_by_fields( + credentials=credentials, + fields={"container_name": worker_name}, + ).unwrap(public_message=f"Worker with name {worker_name} not found") @as_result(StashException, NotFoundException) def update_consumer_state( diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 97fb0b869d5..c47e4bfdd34 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -701,6 +701,7 @@ def set( add_storage_permission: bool = True, # TODO: check the default value ignore_duplicates: bool = False, ) -> StashT: + self.check_type(obj, self.object_type).unwrap() uid = obj.id # check if the object already exists From a7bfef204b86df45f1334a61fb7185a3c98a2a1f Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 4 Sep 2024 12:07:11 +0200 Subject: [PATCH 070/197] fix --- packages/syft/src/syft/service/worker/worker_stash.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/packages/syft/src/syft/service/worker/worker_stash.py b/packages/syft/src/syft/service/worker/worker_stash.py index b9dfa1d3e5a..01202eb9ba2 100644 --- a/packages/syft/src/syft/service/worker/worker_stash.py +++ b/packages/syft/src/syft/service/worker/worker_stash.py @@ -47,15 +47,6 @@ def set( .unwrap() ) - @as_result(StashException, NotFoundException) - def get_worker_by_name( - self, credentials: SyftVerifyKey, worker_name: str - ) -> SyftWorker: - self.get_one_by_fields( - credentials=credentials, - fields={"container_name": worker_name}, - ).unwrap(public_message=f"Worker with name {worker_name} not found") - @as_result(StashException, NotFoundException) def update_consumer_state( self, credentials: SyftVerifyKey, worker_uid: UID, consumer_state: ConsumerState From 6bf895dd49a5047e1f2e93c82c8821292dbfbec0 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 4 Sep 2024 12:32:49 +0200 Subject: [PATCH 071/197] fixes --- packages/syft/src/syft/server/server.py | 64 ++++--------------- .../syft/src/syft/server/service_registry.py | 33 +--------- .../src/syft/service/action/action_store.py | 2 + .../src/syft/service/dataset/dataset_stash.py | 2 + packages/syft/src/syft/store/db/stash.py | 4 +- 5 files changed, 20 insertions(+), 85 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 40045ee6f9f..4bd97a27a30 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -96,7 +96,6 @@ from ..store.document_store_errors import NotFoundException from ..store.document_store_errors import StashException from ..store.linked_obj import LinkedObject -from ..store.mongo_document_store import MongoStoreConfig from ..store.sqlite_document_store import SQLiteStoreClientConfig from ..store.sqlite_document_store import SQLiteStoreConfig from ..types.datetime import DATETIME_FORMAT @@ -389,20 +388,17 @@ def __init__( if reset: self.remove_temp_dir() - use_sqlite = local_db or (processes > 0 and not is_subprocess) - document_store_config = document_store_config or self.get_default_store( - use_sqlite=use_sqlite, - store_type="Document Store", - ) - action_store_config = action_store_config or self.get_default_store( - use_sqlite=use_sqlite, - store_type="Action Store", - ) + # use_sqlite = local_db or (processes > 0 and not is_subprocess) + # document_store_config = document_store_config or self.get_default_store( + # use_sqlite=use_sqlite, + # store_type="Document Store", + # ) + # action_store_config = action_store_config or self.get_default_store( + # use_sqlite=use_sqlite, + # store_type="Action Store", + # ) - self.init_stores( - action_store_config=action_store_config, - document_store_config=document_store_config, - ) + self.init_stores() # construct services only after init stores self.services: ServiceRegistry = ServiceRegistry.for_server(self) @@ -880,45 +876,7 @@ def reload_user_code() -> None: def init_stores( self, - document_store_config: StoreConfig, - action_store_config: StoreConfig, ) -> None: - # We add the python id of the current server in order - # to create one connection per Server object in MongoClientCache - # so that we avoid closing the connection from a - # different thread through the garbage collection - if isinstance(document_store_config, MongoStoreConfig): - document_store_config.client_config.server_obj_python_id = id(self) - - self.document_store_config = document_store_config - self.document_store = document_store_config.store_type( - server_uid=self.id, - root_verify_key=self.verify_key, - store_config=document_store_config, - ) - - # if isinstance(action_store_config, SQLiteStoreConfig): - # self.action_store: ActionObjectStash = ActionObjectStash( - # store=self.document_store, - # ) - # elif isinstance(action_store_config, MongoStoreConfig): - # # We add the python id of the current server in order - # # to create one connection per Server object in MongoClientCache - # # so that we avoid closing the connection from a - # # different thread through the garbage collection - # action_store_config.client_config.server_obj_python_id = id(self) - - # self.action_store = ActionObjectStash( - # store=self.document_store, - # ) - # else: - # self.action_store = ActionObjectStash( - # store=self.document_store, - # ) - - self.action_store_config = action_store_config - self.queue_stash = QueueStash(store=self.document_store) - # TODO fix database filename + reset json_db_config = SQLiteDBConfig( filename=f"{self.id}_json.db", @@ -931,6 +889,8 @@ def init_stores( root_verify_key=self.verify_key, ) + self.queue_stash = QueueStash(store=self.db) + @property def job_stash(self) -> JobStash: return self.get_service("jobservice").stash diff --git a/packages/syft/src/syft/server/service_registry.py b/packages/syft/src/syft/server/service_registry.py index 30878648707..8bd877f0f3b 100644 --- a/packages/syft/src/syft/server/service_registry.py +++ b/packages/syft/src/syft/server/service_registry.py @@ -3,7 +3,6 @@ from dataclasses import dataclass from dataclasses import field import typing -from typing import Any from typing import TYPE_CHECKING # relative @@ -39,7 +38,6 @@ from ..service.worker.worker_image_service import SyftWorkerImageService from ..service.worker.worker_pool_service import SyftWorkerPoolService from ..service.worker.worker_service import WorkerService -from ..store.db.stash import ObjectStash if TYPE_CHECKING: # relative @@ -105,40 +103,11 @@ def get_service_classes( if issubclass(cls, AbstractService) } - @classmethod - def _uses_new_store(cls, service_cls: type[AbstractService]) -> bool: - stash_annotation = service_cls.__annotations__.get("stash") - try: - if issubclass(stash_annotation, ObjectStash): - return True - return False - except Exception: - return False - @classmethod def _construct_services(cls, server: "Server") -> dict[str, AbstractService]: service_dict = {} for field_name, service_cls in cls.get_service_classes().items(): - svc_kwargs: dict[str, Any] = {} - - # Use new DB - if cls._uses_new_store(service_cls): - svc_kwargs["store"] = server.db - - # Use old DB - else: - svc_kwargs["store"] = server.document_store - - # TODO remove after all services are migrated - services_without_stash = [ - AttestationService, - MetadataService, - EnclaveService, - ] - if service_cls not in services_without_stash: - print("Using old store:", service_cls) - - service = service_cls(**svc_kwargs) + service = service_cls(store=server.db) service_dict[field_name] = service return service_dict diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index b03a2760b40..c90492ed4b8 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -23,6 +23,8 @@ @serializable(canonical_name="ActionObjectSQLStore", version=1) class ActionObjectStash(ObjectStash[ActionObject]): + allow_any_type = True + @as_result(NotFoundException, SyftException) def get( self, uid: UID, credentials: SyftVerifyKey, has_permission: bool = False diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index 3b50c91fb1f..99dfa06d64c 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -13,6 +13,8 @@ @instrument @serializable(canonical_name="DatasetStashSQL", version=1) class DatasetStash(ObjectStash[Dataset]): + allow_set_any_type = True + @as_result(StashException, NotFoundException) def get_by_name(self, credentials: SyftVerifyKey, name: str) -> Dataset: return self.get_one_by_field( diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index c47e4bfdd34..7a7f3c34d22 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -48,6 +48,7 @@ class ObjectStash(Generic[StashT]): table: Table object_type: type[SyftObject] + allow_any_type: bool = False def __init__(self, store: DBManager) -> None: self.db = store @@ -701,7 +702,8 @@ def set( add_storage_permission: bool = True, # TODO: check the default value ignore_duplicates: bool = False, ) -> StashT: - self.check_type(obj, self.object_type).unwrap() + if not self.allow_any_type: + self.check_type(obj, self.object_type).unwrap() uid = obj.id # check if the object already exists From c615b637ad004d1d6221c896e7a353733f84a2e3 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 4 Sep 2024 13:15:42 +0200 Subject: [PATCH 072/197] fix notifcation tests --- .../notification/notification_stash.py | 20 +++-- .../service/notification/notifications.py | 1 + packages/syft/src/syft/store/db/stash.py | 21 ++++-- .../notifications/notification_stash_test.py | 75 +++++++------------ 4 files changed, 58 insertions(+), 59 deletions(-) diff --git a/packages/syft/src/syft/service/notification/notification_stash.py b/packages/syft/src/syft/service/notification/notification_stash.py index 32051937f0f..e97dc253359 100644 --- a/packages/syft/src/syft/service/notification/notification_stash.py +++ b/packages/syft/src/syft/service/notification/notification_stash.py @@ -1,7 +1,3 @@ -# stdlib - -# third party - # relative from ...serde.json_serde import serialize_json from ...serde.serializable import serializable @@ -22,14 +18,18 @@ class NotificationStash(ObjectStash[Notification]): def get_all_inbox_for_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> list[Notification]: + if not isinstance(verify_key, SyftVerifyKey | str): + raise AttributeError("verify_key must be of type SyftVerifyKey or str") return self.get_all_by_field( - credentials, field_name="verify_key", field_value=str(verify_key) + credentials, field_name="to_user_verify_key", field_value=str(verify_key) ).unwrap() @as_result(StashException) def get_all_sent_for_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> list[Notification]: + if not isinstance(verify_key, SyftVerifyKey | str): + raise AttributeError("verify_key must be of type SyftVerifyKey or str") return self.get_all_by_field( credentials, field_name="from_user_verify_key", @@ -40,8 +40,10 @@ def get_all_sent_for_verify_key( def get_all_for_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> list[Notification]: + if not isinstance(verify_key, SyftVerifyKey | str): + raise AttributeError("verify_key must be of type SyftVerifyKey or str") return self.get_all_by_field( - credentials, field_name="verify_key", field_value=str(verify_key) + credentials, field_name="from_user_verify_key", field_value=str(verify_key) ).unwrap() @as_result(StashException) @@ -51,11 +53,13 @@ def get_all_by_verify_key_for_status( verify_key: SyftVerifyKey, status: NotificationStatus, ) -> list[Notification]: + if not isinstance(verify_key, SyftVerifyKey | str): + raise AttributeError("verify_key must be of type SyftVerifyKey or str") return self.get_all_by_fields( credentials, fields={ "to_user_verify_key": str(verify_key), - "status": status.value, + "status": status.name, }, ).unwrap() @@ -82,6 +86,8 @@ def update_notification_status( def delete_all_for_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> bool: + if not isinstance(verify_key, SyftVerifyKey | str): + raise AttributeError("verify_key must be of type SyftVerifyKey or str") notifications = self.get_all_inbox_for_verify_key( credentials, verify_key=verify_key, diff --git a/packages/syft/src/syft/service/notification/notifications.py b/packages/syft/src/syft/service/notification/notifications.py index dffad4e414e..2b176a65af3 100644 --- a/packages/syft/src/syft/service/notification/notifications.py +++ b/packages/syft/src/syft/service/notification/notifications.py @@ -70,6 +70,7 @@ class Notification(SyftObject): ] __repr_attrs__ = ["subject", "status", "created_at", "linked_obj"] __table_sort_attr__ = "Created at" + __order_by__ = ("created_at", "asc") def _repr_html_(self) -> str: return f""" diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 97fb0b869d5..d24da68da1d 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -114,10 +114,10 @@ def _create_table(self) -> Table: ), # TODO rename and use on SyftObject fields Column( - "created_at", sa.DateTime, server_default=sa.func.now(), index=True + "_created_at", sa.DateTime, server_default=sa.func.now(), index=True ), - Column("updated_at", sa.DateTime, server_onupdate=sa.func.now()), - Column("deleted_at", sa.DateTime, index=True), + Column("_updated_at", sa.DateTime, server_onupdate=sa.func.now()), + Column("_deleted_at", sa.DateTime, index=True), ) return Base.metadata.tables[table_name] @@ -398,8 +398,8 @@ def _get_order_by_col(self, order_by: str, sort_order: str | None = None) -> Col if order_by == "id": col = self.table.c.id - if order_by == "created_date" or order_by == "created_at": - col = self.table.c.created_at + if order_by == "created_date" or order_by == "_created_at": + col = self.table.c._created_at else: col = self.table.c.fields[order_by] @@ -492,6 +492,8 @@ def update( - To fix, we either need db-supported computed fields, or know in our ORM which fields should be re-computed. """ + self.check_type(obj, self.object_type).unwrap() + # TODO has_permission is not used if not self.is_unique(obj): raise StashException(f"Some fields are not unique for {type(obj).__name__}") @@ -503,7 +505,14 @@ def update( permission=ActionPermission.WRITE, has_permission=has_permission, ) - stmt = stmt.values(fields=serialize_json(obj)) + fields = serialize_json(obj) + try: + deserialize_json(fields) + except Exception as e: + raise StashException( + f"Error serializing object: {e}. Some fields are invalid." + ) + stmt = stmt.values(fields=fields) self.session.execute(stmt) self.session.commit() diff --git a/packages/syft/tests/syft/notifications/notification_stash_test.py b/packages/syft/tests/syft/notifications/notification_stash_test.py index 9c871182f93..7864fb10d19 100644 --- a/packages/syft/tests/syft/notifications/notification_stash_test.py +++ b/packages/syft/tests/syft/notifications/notification_stash_test.py @@ -59,13 +59,9 @@ def test_get_all_inbox_for_verify_key(root_verify_key, document_store) -> None: random_verify_key = random_signing_key.verify_key test_stash = NotificationStash(store=document_store) - response = test_stash.get_all_inbox_for_verify_key( + result = test_stash.get_all_inbox_for_verify_key( root_verify_key, random_verify_key - ) - - assert response.is_ok() - - result = response.ok() + ).unwrap() assert len(result) == 0 # list of mock notifications @@ -78,14 +74,11 @@ def test_get_all_inbox_for_verify_key(root_verify_key, document_store) -> None: notification_list.append(mock_notification) # returned list of notifications from stash that's sorted by created_at - response2 = test_stash.get_all_inbox_for_verify_key( + result = test_stash.get_all_inbox_for_verify_key( root_verify_key, random_verify_key - ) + ).unwrap() - assert response2.is_ok() - - result = response2.ok() - assert len(response2.value) == 5 + assert len(result) == 5 for notification in notification_list: # check if all notifications are present in the result @@ -169,28 +162,21 @@ def test_get_all_by_verify_key_for_status(root_verify_key, document_store) -> No random_verify_key = random_signing_key.verify_key test_stash = NotificationStash(store=document_store) - response = test_stash.get_all_by_verify_key_for_status( + result = test_stash.get_all_by_verify_key_for_status( root_verify_key, random_verify_key, NotificationStatus.READ - ) - - assert response.is_ok() - - result = response.ok() + ).unwrap() assert len(result) == 0 mock_notification = add_mock_notification( root_verify_key, test_stash, test_verify_key, random_verify_key ) - response2 = test_stash.get_all_by_verify_key_for_status( + result2 = test_stash.get_all_by_verify_key_for_status( root_verify_key, mock_notification.to_user_verify_key, NotificationStatus.UNREAD - ) - assert response2.is_ok() + ).unwrap() + assert len(result2) == 1 - result = response2.ok() - assert len(result) == 1 - - assert result[0] == mock_notification + assert result2[0] == mock_notification with pytest.raises(AttributeError): test_stash.get_all_by_verify_key_for_status( @@ -208,7 +194,7 @@ def test_update_notification_status(root_verify_key, document_store) -> None: root_verify_key, uid=random_uid, status=NotificationStatus.READ ).unwrap() - assert exc.type is SyftException + assert issubclass(exc.type, SyftException) assert exc.value.public_message mock_notification = add_mock_notification( @@ -234,7 +220,7 @@ def test_update_notification_status(root_verify_key, document_store) -> None: status=notification_expiry_status_auto, ).unwrap() - assert exc.type is SyftException + assert issubclass(exc.type, SyftException) assert exc.value.public_message @@ -246,6 +232,10 @@ def test_update_notification_status_error_on_get_by_uid( test_stash = NotificationStash(store=document_store) expected_error_msg = f"No notification exists for id: {random_verify_key}" + add_mock_notification( + root_verify_key, test_stash, test_verify_key, random_verify_key + ) + @as_result(StashException) def mock_get_by_uid(root_verify_key: SyftVerifyKey, uid: UID) -> NoReturn: raise StashException(public_message=f"No notification exists for id: {uid}") @@ -255,11 +245,6 @@ def mock_get_by_uid(root_verify_key: SyftVerifyKey, uid: UID) -> NoReturn: "get_by_uid", mock_get_by_uid, ) - - add_mock_notification( - root_verify_key, test_stash, test_verify_key, random_verify_key - ) - with pytest.raises(StashException) as exc: test_stash.update_notification_status( root_verify_key, random_verify_key, NotificationStatus.READ @@ -274,11 +259,9 @@ def test_delete_all_for_verify_key(root_verify_key, document_store) -> None: random_verify_key = random_signing_key.verify_key test_stash = NotificationStash(store=document_store) - response = test_stash.delete_all_for_verify_key(root_verify_key, test_verify_key) - - assert response.is_ok() - - result = response.ok() + result = test_stash.delete_all_for_verify_key( + root_verify_key, test_verify_key + ).unwrap() assert result is True add_mock_notification( @@ -287,23 +270,23 @@ def test_delete_all_for_verify_key(root_verify_key, document_store) -> None: inbox_before = test_stash.get_all_inbox_for_verify_key( root_verify_key, random_verify_key - ).value + ).unwrap() assert len(inbox_before) == 1 - response2 = test_stash.delete_all_for_verify_key(root_verify_key, random_verify_key) - - assert response2.is_ok() - - result = response2.ok() - assert result is True + result2 = test_stash.delete_all_for_verify_key( + root_verify_key, random_verify_key + ).unwrap() + assert result2 is True inbox_after = test_stash.get_all_inbox_for_verify_key( root_verify_key, random_verify_key - ).value + ).unwrap() assert len(inbox_after) == 0 with pytest.raises(AttributeError): - test_stash.delete_all_for_verify_key(root_verify_key, random_signing_key) + test_stash.delete_all_for_verify_key( + root_verify_key, random_signing_key + ).unwrap() def test_delete_all_for_verify_key_error_on_get_all_inbox_for_verify_key( From 03b4e114875d3b65c4a6b8a50b0eeef37c2667db Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 4 Sep 2024 13:18:20 +0200 Subject: [PATCH 073/197] fix request stash tests --- .../tests/syft/request/request_stash_test.py | 65 ------------------- 1 file changed, 65 deletions(-) diff --git a/packages/syft/tests/syft/request/request_stash_test.py b/packages/syft/tests/syft/request/request_stash_test.py index 869725e4feb..a9115d5c934 100644 --- a/packages/syft/tests/syft/request/request_stash_test.py +++ b/packages/syft/tests/syft/request/request_stash_test.py @@ -1,9 +1,4 @@ -# stdlib -from typing import NoReturn - # third party -import pytest -from pytest import MonkeyPatch # syft absolute from syft.client.client import SyftClient @@ -12,9 +7,6 @@ from syft.service.request.request import Request from syft.service.request.request import SubmitRequest from syft.service.request.request_stash import RequestStash -from syft.store.document_store import PartitionKey -from syft.store.document_store import QueryKeys -from syft.types.errors import SyftException def test_requeststash_get_all_for_verify_key_no_requests( @@ -32,7 +24,6 @@ def test_requeststash_get_all_for_verify_key_no_requests( assert len(requests.ok()) == 0 -@pytest.mark.xfail def test_requeststash_get_all_for_verify_key_success( root_verify_key, request_stash: RequestStash, @@ -76,59 +67,3 @@ def test_requeststash_get_all_for_verify_key_success( requests.ok()[1] == stash_set_result_2.ok() or requests.ok()[0] == stash_set_result_2.ok() ) - - -def test_requeststash_get_all_for_verify_key_fail( - root_verify_key, - request_stash: RequestStash, - monkeypatch: MonkeyPatch, - guest_datasite_client: SyftClient, -) -> None: - verify_key: SyftVerifyKey = guest_datasite_client.credentials.verify_key - mock_error_message = ( - "verify key not in the document store's unique or searchable keys" - ) - - def mock_query_all_error( - credentials: SyftVerifyKey, qks: QueryKeys, order_by: PartitionKey | None - ) -> NoReturn: - raise SyftException(public_message=mock_error_message) - - monkeypatch.setattr(request_stash, "query_all", mock_query_all_error) - - with pytest.raises(SyftException) as exc: - request_stash.get_all_for_verify_key(root_verify_key, verify_key).unwrap() - - assert exc.type is SyftException - assert exc.value.public_message == mock_error_message - - -def test_requeststash_get_all_for_verify_key_find_index_fail( - root_verify_key, - request_stash: RequestStash, - monkeypatch: MonkeyPatch, - guest_datasite_client: SyftClient, -) -> None: - verify_key: SyftVerifyKey = guest_datasite_client.credentials.verify_key - - mock_error_message = "Failed search with" - - def mock_find_index_or_search_keys_error( - credentials: SyftVerifyKey, - index_qks: QueryKeys, - search_qks: QueryKeys, - order_by: PartitionKey | None, - ) -> NoReturn: - raise SyftException(public_message=mock_error_message) - - monkeypatch.setattr( - request_stash.partition, - "find_index_or_search_keys", - mock_find_index_or_search_keys_error, - ) - - with pytest.raises(SyftException) as exc: - request_stash.get_all_for_verify_key(root_verify_key, verify_key).unwrap() - - assert exc.type == SyftException - assert exc.value.public_message == mock_error_message From cc1b4002ba8f2cf9e1b839a406812092a1be21ff Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 4 Sep 2024 14:05:37 +0200 Subject: [PATCH 074/197] fix user stash tests --- .../syft/src/syft/service/user/user_stash.py | 2 +- packages/syft/src/syft/store/db/stash.py | 5 +++- .../syft/tests/syft/users/user_stash_test.py | 29 +++++++++---------- 3 files changed, 19 insertions(+), 17 deletions(-) diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index cd634f52404..f12d3f12476 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -66,7 +66,7 @@ def email_exists(self, email: str) -> bool: def get_by_role(self, credentials: SyftVerifyKey, role: ServiceRole) -> User: try: return self.get_one_by_field( - credentials=credentials, field_name="role", field_value=role + credentials=credentials, field_name="role", field_value=role.name ).unwrap() except NotFoundException as exc: private_msg = f"User with role {role} not found" diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index d24da68da1d..ec52d4f9eaf 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -73,6 +73,10 @@ def server_uid(self) -> UID: def root_verify_key(self) -> SyftVerifyKey: return self.db.root_verify_key + @property + def _data(self) -> list[StashT]: + return self.get_all(self.root_verify_key, has_permission=True).unwrap() + @as_result(StashException) def check_type(self, obj: T, type_: type) -> T: if not isinstance(obj, type_): @@ -434,7 +438,6 @@ def _apply_permission_filter( if has_permission: # ignoring permissions return stmt - role = self.get_role(credentials) if role in (ServiceRole.ADMIN, ServiceRole.DATA_OWNER): # admins and data owners have all permissions diff --git a/packages/syft/tests/syft/users/user_stash_test.py b/packages/syft/tests/syft/users/user_stash_test.py index ee8e4b1edc9..9592284c2f8 100644 --- a/packages/syft/tests/syft/users/user_stash_test.py +++ b/packages/syft/tests/syft/users/user_stash_test.py @@ -14,7 +14,7 @@ def add_mock_user(root_datasite_client, user_stash: UserStash, user: User) -> User: # prepare: add mock data - result = user_stash.partition.set(root_datasite_client.credentials.verify_key, user) + result = user_stash.set(root_datasite_client.credentials.verify_key, user) assert result.is_ok() user = result.ok() @@ -26,22 +26,23 @@ def add_mock_user(root_datasite_client, user_stash: UserStash, user: User) -> Us def test_userstash_set( root_datasite_client, user_stash: UserStash, guest_user: User ) -> None: - result = user_stash.set(root_datasite_client.credentials.verify_key, guest_user) - assert result.is_ok() - - created_user = result.ok() + created_user = user_stash.set( + root_datasite_client.credentials.verify_key, guest_user + ).unwrap() assert isinstance(created_user, User) assert guest_user == created_user - assert guest_user.id in user_stash.partition.data + assert user_stash.exists( + root_datasite_client.credentials.verify_key, created_user.id + ) def test_userstash_set_duplicate( root_datasite_client, user_stash: UserStash, guest_user: User ) -> None: - result = user_stash.set(root_datasite_client.credentials.verify_key, guest_user) - assert result.is_ok() - - original_count = len(user_stash.partition.data) + result = user_stash.set( + root_datasite_client.credentials.verify_key, guest_user + ).unwrap() + original_count = len(user_stash._data) result = user_stash.set(root_datasite_client.credentials.verify_key, guest_user) assert result.is_err() @@ -49,7 +50,7 @@ def test_userstash_set_duplicate( assert type(exc) == SyftException assert exc.public_message - assert len(user_stash.partition.data) == original_count + assert len(user_stash._data) == original_count def test_userstash_get_by_uid( @@ -171,11 +172,9 @@ def test_userstash_get_by_role( # prepare: add mock data user = add_mock_user(root_datasite_client, user_stash, guest_user) - result = user_stash.get_by_role( + searched_user = user_stash.get_by_role( root_datasite_client.credentials.verify_key, role=ServiceRole.GUEST - ) - assert result.is_ok() - searched_user = result.ok() + ).unwrap() assert user == searched_user From 8163587337c16260439fc98c64a1ab6ae2b62f5f Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 4 Sep 2024 14:10:17 +0200 Subject: [PATCH 075/197] fix dataset stash tests --- .../syft/tests/syft/dataset/dataset_stash_test.py | 15 +++++++-------- packages/syft/tests/syft/dataset/fixtures.py | 6 ++++-- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/packages/syft/tests/syft/dataset/dataset_stash_test.py b/packages/syft/tests/syft/dataset/dataset_stash_test.py index 86c44de2319..394b03e05c3 100644 --- a/packages/syft/tests/syft/dataset/dataset_stash_test.py +++ b/packages/syft/tests/syft/dataset/dataset_stash_test.py @@ -3,11 +3,14 @@ # syft absolute from syft.service.dataset.dataset import Dataset +from syft.service.dataset.dataset_stash import DatasetStash from syft.store.document_store_errors import NotFoundException from syft.types.uid import UID -def test_dataset_get_by_name(root_verify_key, mock_dataset_stash, mock_dataset) -> None: +def test_dataset_get_by_name( + root_verify_key, mock_dataset_stash: DatasetStash, mock_dataset: Dataset +) -> None: # retrieving existing dataset result = mock_dataset_stash.get_by_name(root_verify_key, mock_dataset.name) assert result.is_ok() @@ -21,7 +24,9 @@ def test_dataset_get_by_name(root_verify_key, mock_dataset_stash, mock_dataset) assert type(result.err()) is NotFoundException -def test_dataset_search_action_ids(root_verify_key, mock_dataset_stash, mock_dataset): +def test_dataset_search_action_ids( + root_verify_key, mock_dataset_stash: DatasetStash, mock_dataset +): action_id = mock_dataset.assets[0].action_id result = mock_dataset_stash.search_action_ids(root_verify_key, uid=action_id) @@ -30,12 +35,6 @@ def test_dataset_search_action_ids(root_verify_key, mock_dataset_stash, mock_dat assert isinstance(result.ok()[0], Dataset) assert result.ok()[0].id == mock_dataset.id - # retrieving dataset by list of action_ids - result = mock_dataset_stash.search_action_ids(root_verify_key, uid=[action_id]) - assert result.is_ok() - assert isinstance(result.ok()[0], Dataset) - assert result.ok()[0].id == mock_dataset.id - # retrieving dataset by non-existing action_id other_action_id = UID() result = mock_dataset_stash.search_action_ids(root_verify_key, uid=other_action_id) diff --git a/packages/syft/tests/syft/dataset/fixtures.py b/packages/syft/tests/syft/dataset/fixtures.py index 9c062e756bc..bcb26bff262 100644 --- a/packages/syft/tests/syft/dataset/fixtures.py +++ b/packages/syft/tests/syft/dataset/fixtures.py @@ -60,7 +60,9 @@ def mock_asset(worker, root_datasite_client) -> Asset: @pytest.fixture -def mock_dataset(root_verify_key, mock_dataset_stash, mock_asset) -> Dataset: +def mock_dataset( + root_verify_key, mock_dataset_stash: DatasetStash, mock_asset +) -> Dataset: uploader = Contributor( role=str(Roles.UPLOADER), name="test", @@ -70,7 +72,7 @@ def mock_dataset(root_verify_key, mock_dataset_stash, mock_asset) -> Dataset: id=UID(), name="test_dataset", uploader=uploader, contributors=[uploader] ) mock_dataset.asset_list.append(mock_asset) - result = mock_dataset_stash.partition.set(root_verify_key, mock_dataset) + result = mock_dataset_stash.set(root_verify_key, mock_dataset) mock_dataset = result.ok() yield mock_dataset From 1aabb41aa5dc761778f717f3619a190d4a14db79 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 4 Sep 2024 14:46:39 +0200 Subject: [PATCH 076/197] fix base stash tests --- packages/syft/src/syft/store/db/sqlite_db.py | 2 +- packages/syft/src/syft/store/db/stash.py | 22 +- .../syft/tests/syft/stores/base_stash_test.py | 227 +++++------------- 3 files changed, 80 insertions(+), 171 deletions(-) diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index ce4b7f14fa1..da3877632bf 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -30,7 +30,7 @@ def connection_string(self) -> str: class SQLiteDBConfig(DBConfig): filename: str = Field(default_factory=lambda: f"{uuid.uuid4()}.db") - path: Path = Field(default_factory=tempfile.gettempdir) + path: Path = Field(default_factory=lambda: Path(tempfile.gettempdir())) @property def connection_string(self) -> str: diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index ec52d4f9eaf..4d8069646ea 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -200,7 +200,8 @@ def _get_field_filter( ) -> sa.sql.elements.BinaryExpression: table = table if table is not None else self.table if field_name == "id": - return table.c.id == field_value + uid_field_value = UID(field_value) + return table.c.id == uid_field_value if self.db.engine.dialect.name == "sqlite": return table.c.fields[field_name] == func.json_quote(field_value) @@ -480,7 +481,7 @@ def get_all( result = self.session.execute(stmt).all() return [self.row_as_obj(row) for row in result] - @as_result(StashException) + @as_result(StashException, NotFoundException) def update( self, credentials: SyftVerifyKey, @@ -517,9 +518,12 @@ def update( ) stmt = stmt.values(fields=fields) - self.session.execute(stmt) + result = self.session.execute(stmt) self.session.commit() - + if result.rowcount == 0: + raise NotFoundException( + f"{self.object_type.__name__}: {obj.id} not found or no permission to update." + ) return self.get_by_uid(credentials, obj.id).unwrap() def get_ownership_permissions( @@ -532,7 +536,7 @@ def get_ownership_permissions( ActionObjectEXECUTE(uid=uid, credentials=credentials).permission_string, ] - @as_result(StashException) + @as_result(StashException, NotFoundException) def delete_by_uid( self, credentials: SyftVerifyKey, uid: UID, has_permission: bool = False ) -> UID: @@ -543,8 +547,12 @@ def delete_by_uid( permission=ActionPermission.WRITE, has_permission=has_permission, ) - self.session.execute(stmt) + result = self.session.execute(stmt) self.session.commit() + if result.rowcount == 0: + raise NotFoundException( + f"{self.object_type.__name__}: {uid} not found or no permission to delete." + ) return uid def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: @@ -720,7 +728,7 @@ def set( if ignore_duplicates: return obj unique_fields_str = ", ".join(self.unique_fields) - raise SyftException( + raise StashException( public_message=f"Duplication Key Error for {obj}.\n" f"The fields that should be unique are {unique_fields_str}." ) diff --git a/packages/syft/tests/syft/stores/base_stash_test.py b/packages/syft/tests/syft/stores/base_stash_test.py index c6a6e24714c..3cdc920bd72 100644 --- a/packages/syft/tests/syft/stores/base_stash_test.py +++ b/packages/syft/tests/syft/stores/base_stash_test.py @@ -12,13 +12,10 @@ # syft absolute from syft.serde.serializable import serializable -from syft.store.db.sqlite_db import DBManager from syft.store.db.sqlite_db import SQLiteDBConfig +from syft.store.db.sqlite_db import SQLiteDBManager from syft.store.db.stash import ObjectStash from syft.store.document_store import PartitionKey -from syft.store.document_store import QueryKey -from syft.store.document_store import QueryKeys -from syft.store.document_store import UIDPartitionKey from syft.store.document_store_errors import NotFoundException from syft.store.document_store_errors import StashException from syft.types.errors import SyftException @@ -53,13 +50,6 @@ def get_object_values(obj: SyftObject) -> tuple[Any]: return tuple(obj.to_dict().values()) -def add_mock_object(root_verify_key, stash: MockStash, obj: MockObject) -> MockObject: - result = stash.set(root_verify_key, obj) - assert result.is_ok() - - return result.ok() - - T = TypeVar("T") P = ParamSpec("P") @@ -78,8 +68,10 @@ def create_unique( @pytest.fixture def base_stash(root_verify_key) -> MockStash: config = SQLiteDBConfig() - db_manager = DBManager(config, UID(), root_verify_key) - yield MockStash(store=db_manager) + db_manager = SQLiteDBManager(config, UID(), root_verify_key) + mock_stash = MockStash(store=db_manager) + db_manager.init_tables() + yield mock_stash def random_sentence(faker: Faker) -> str: @@ -118,8 +110,7 @@ def mock_objects(faker: Faker) -> list[MockObject]: def test_basestash_set( root_verify_key, base_stash: MockStash, mock_object: MockObject ) -> None: - result = add_mock_object(root_verify_key, base_stash, mock_object) - + result = base_stash.set(root_verify_key, mock_object).unwrap() assert result is not None assert result == mock_object @@ -131,11 +122,10 @@ def test_basestash_set_duplicate( MockObject(**kwargs) for kwargs in multiple_object_kwargs(faker, n=2, same=True) ) - result = base_stash.set(root_verify_key, original) - assert result.is_ok() + base_stash.set(root_verify_key, original).unwrap() - result = base_stash.set(root_verify_key, duplicate) - assert result.is_err() + with pytest.raises(StashException): + base_stash.set(root_verify_key, duplicate).unwrap() def test_basestash_set_duplicate_unique_key( @@ -156,28 +146,19 @@ def test_basestash_set_duplicate_unique_key( def test_basestash_delete( root_verify_key, base_stash: MockStash, mock_object: MockObject ) -> None: - add_mock_object(root_verify_key, base_stash, mock_object) - - result = base_stash.delete( - root_verify_key, UIDPartitionKey.with_obj(mock_object.id) - ) - assert result.is_ok() - - assert len(base_stash.get_all(root_verify_key).ok()) == 0 + base_stash.set(root_verify_key, mock_object).unwrap() + base_stash.delete_by_uid(root_verify_key, mock_object.id).unwrap() + assert len(base_stash.get_all(root_verify_key).unwrap()) == 0 def test_basestash_cannot_delete_non_existent( root_verify_key, base_stash: MockStash, mock_object: MockObject ) -> None: - add_mock_object(root_verify_key, base_stash, mock_object) + result = base_stash.set(root_verify_key, mock_object).unwrap() random_uid = create_unique(UID, [mock_object.id]) - for result in [ - base_stash.delete(root_verify_key, UIDPartitionKey.with_obj(random_uid)), - base_stash.delete_by_uid(root_verify_key, random_uid), - ]: - result = base_stash.delete(root_verify_key, UIDPartitionKey.with_obj(UID())) - assert result.is_err() + result = base_stash.delete_by_uid(root_verify_key, random_uid) + assert result.is_err() assert ( len( @@ -192,7 +173,7 @@ def test_basestash_cannot_delete_non_existent( def test_basestash_update( root_verify_key, base_stash: MockStash, mock_object: MockObject, faker: Faker ) -> None: - add_mock_object(root_verify_key, base_stash, mock_object) + result = base_stash.set(root_verify_key, mock_object).unwrap() updated_obj = mock_object.copy() updated_obj.name = faker.name() @@ -207,7 +188,7 @@ def test_basestash_update( def test_basestash_cannot_update_non_existent( root_verify_key, base_stash: MockStash, mock_object: MockObject, faker: Faker ) -> None: - add_mock_object(root_verify_key, base_stash, mock_object) + result = base_stash.set(root_verify_key, mock_object).unwrap() updated_obj = mock_object.copy() updated_obj.id = create_unique(UID, [mock_object.id]) @@ -226,10 +207,7 @@ def test_basestash_set_get_all( stored_objects = base_stash.get_all( root_verify_key, - ) - assert stored_objects.is_ok() - - stored_objects = stored_objects.ok() + ).unwrap() assert len(stored_objects) == len(mock_objects) stored_objects_values = {get_object_values(obj) for obj in stored_objects} @@ -240,11 +218,10 @@ def test_basestash_set_get_all( def test_basestash_get_by_uid( root_verify_key, base_stash: MockStash, mock_object: MockObject ) -> None: - add_mock_object(root_verify_key, base_stash, mock_object) + result = base_stash.set(root_verify_key, mock_object).unwrap() - result = base_stash.get_by_uid(root_verify_key, mock_object.id) - assert result.is_ok() - assert result.ok() == mock_object + result = base_stash.get_by_uid(root_verify_key, mock_object.id).unwrap() + assert result == mock_object random_uid = create_unique(UID, [mock_object.id]) bad_uid = base_stash.get_by_uid(root_verify_key, random_uid) @@ -261,12 +238,9 @@ def test_basestash_get_by_uid( def test_basestash_delete_by_uid( root_verify_key, base_stash: MockStash, mock_object: MockObject ) -> None: - add_mock_object(root_verify_key, base_stash, mock_object) - - result = base_stash.delete_by_uid(root_verify_key, mock_object.id) - assert result.is_ok() + result = base_stash.set(root_verify_key, mock_object).unwrap() - response = result.ok() + response = base_stash.delete_by_uid(root_verify_key, mock_object.id).unwrap() assert isinstance(response, UID) result = base_stash.get_by_uid(root_verify_key, mock_object.id) @@ -287,43 +261,27 @@ def test_basestash_query_one( base_stash.set(root_verify_key, obj) obj = random.choice(mock_objects) + result = base_stash.get_one_by_fields( + root_verify_key, fields={"name": obj.name} + ).unwrap() - for result in ( - base_stash.query_one_kwargs(root_verify_key, name=obj.name), - base_stash.query_one( - root_verify_key, QueryKey.from_obj(NamePartitionKey, obj.name) - ), - ): - assert result.is_ok() - assert result.ok() == obj + assert result == obj existing_names = {obj.name for obj in mock_objects} random_name = create_unique(faker.name, existing_names) - for result in ( - base_stash.query_one_kwargs(root_verify_key, name=random_name), - base_stash.query_one( - root_verify_key, QueryKey.from_obj(NamePartitionKey, random_name) - ), - ): - assert result.is_err() - assert isinstance(result.err(), NotFoundException) + with pytest.raises(NotFoundException): + result = base_stash.get_one_by_fields( + root_verify_key, fields={"name": random_name} + ).unwrap() params = {"name": obj.name, "desc": obj.desc} - for result in [ - base_stash.query_one_kwargs(root_verify_key, **params), - base_stash.query_one(root_verify_key, QueryKeys.from_dict(params)), - ]: - assert result.is_ok() - assert result.ok() == obj + result = base_stash.get_one_by_fields(root_verify_key, fields=params).unwrap() + assert result == obj params = {"name": random_name, "desc": random_sentence(faker)} - for result in [ - base_stash.query_one_kwargs(root_verify_key, **params), - base_stash.query_one(root_verify_key, QueryKeys.from_dict(params)), - ]: - assert result.is_err() - assert isinstance(result.err(), NotFoundException) + with pytest.raises(NotFoundException): + result = base_stash.get_one_by_fields(root_verify_key, fields=params).unwrap() def test_basestash_query_all( @@ -338,46 +296,32 @@ def test_basestash_query_all( for obj in all_objects: base_stash.set(root_verify_key, obj) - for result in [ - base_stash.query_all_kwargs(root_verify_key, desc=desc), - base_stash.query_all( - root_verify_key, QueryKey.from_obj(DescPartitionKey, desc) - ), - ]: - assert result.is_ok() - objects = result.ok() - assert len(objects) == n_same - assert all(obj.desc == desc for obj in objects) - original_object_values = {get_object_values(obj) for obj in similar_objects} - retrived_objects_values = {get_object_values(obj) for obj in objects} - assert original_object_values == retrived_objects_values + objects = base_stash.get_all_by_fields( + root_verify_key, fields={"desc": desc} + ).unwrap() + assert len(objects) == n_same + assert all(obj.desc == desc for obj in objects) + original_object_values = {get_object_values(obj) for obj in similar_objects} + retrived_objects_values = {get_object_values(obj) for obj in objects} + assert original_object_values == retrived_objects_values random_desc = create_unique( random_sentence, [obj.desc for obj in all_objects], faker ) - for result in [ - base_stash.query_all_kwargs(root_verify_key, desc=random_desc), - base_stash.query_all( - root_verify_key, QueryKey.from_obj(DescPartitionKey, random_desc) - ), - ]: - assert result.is_ok() - objects = result.ok() - assert len(objects) == 0 + + objects = base_stash.get_all_by_fields( + root_verify_key, fields={"desc": random_desc} + ).unwrap() + assert len(objects) == 0 obj = random.choice(similar_objects) params = {"name": obj.name, "desc": obj.desc} - for result in [ - base_stash.query_all_kwargs(root_verify_key, **params), - base_stash.query_all(root_verify_key, QueryKeys.from_dict(params)), - ]: - assert result.is_ok() - objects = result.ok() - assert len(objects) == sum( - 1 for obj_ in all_objects if (obj_.name, obj_.desc) == (obj.name, obj.desc) - ) - assert objects[0] == obj + objects = base_stash.get_all_by_fields(root_verify_key, fields=params).unwrap() + assert len(objects) == sum( + 1 for obj_ in all_objects if (obj_.name, obj_.desc) == (obj.name, obj.desc) + ) + assert objects[0] == obj def test_basestash_query_all_kwargs_multiple_params( @@ -396,66 +340,23 @@ def test_basestash_query_all_kwargs_multiple_params( base_stash.set(root_verify_key, obj) params = {"importance": importance, "desc": desc} - for result in [ - base_stash.query_all_kwargs(root_verify_key, **params), - base_stash.query_all(root_verify_key, QueryKeys.from_dict(params)), - ]: - assert result.is_ok() - objects = result.ok() - assert len(objects) == n_same - assert all(obj.desc == desc for obj in objects) - original_object_values = {get_object_values(obj) for obj in similar_objects} - retrived_objects_values = {get_object_values(obj) for obj in objects} - assert original_object_values == retrived_objects_values + objects = base_stash.get_all_by_fields(root_verify_key, fields=params).unwrap() + assert len(objects) == n_same + assert all(obj.desc == desc for obj in objects) + original_object_values = {get_object_values(obj) for obj in similar_objects} + retrived_objects_values = {get_object_values(obj) for obj in objects} + assert original_object_values == retrived_objects_values params = { "name": create_unique(faker.name, [obj.name for obj in all_objects]), "desc": random_sentence(faker), } - for result in [ - base_stash.query_all_kwargs(root_verify_key, **params), - base_stash.query_all(root_verify_key, QueryKeys.from_dict(params)), - ]: - assert result.is_ok() - objects = result.ok() - assert len(objects) == 0 + objects = base_stash.get_all_by_fields(root_verify_key, fields=params).unwrap() + assert len(objects) == 0 obj = random.choice(similar_objects) params = {"id": obj.id, "name": obj.name, "desc": obj.desc} - for result in [ - base_stash.query_all_kwargs(root_verify_key, **params), - base_stash.query_all(root_verify_key, QueryKeys.from_dict(params)), - ]: - assert result.is_ok() - objects = result.ok() - assert len(objects) == 1 - assert objects[0] == obj - - -def test_basestash_cannot_query_non_searchable( - root_verify_key, base_stash: MockStash, mock_objects: list[MockObject] -) -> None: - for obj in mock_objects: - base_stash.set(root_verify_key, obj) - - obj = random.choice(mock_objects) - - assert base_stash.query_one_kwargs(root_verify_key, value=10).is_err() - assert base_stash.query_all_kwargs(root_verify_key, value=10).is_err() - assert base_stash.query_one_kwargs( - root_verify_key, value=10, name=obj.name - ).is_err() - assert base_stash.query_all_kwargs( - root_verify_key, value=10, name=obj.name - ).is_err() - - ValuePartitionKey = PartitionKey(key="value", type_=int) - qk = ValuePartitionKey.with_obj(10) - - assert base_stash.query_one(root_verify_key, qk).is_err() - assert base_stash.query_all(root_verify_key, qk).is_err() - assert base_stash.query_all(root_verify_key, QueryKeys(qks=[qk])).is_err() - assert base_stash.query_all( - root_verify_key, QueryKeys(qks=[qk, UIDPartitionKey.with_obj(obj.id)]) - ).is_err() + objects = base_stash.get_all_by_fields(root_verify_key, fields=params).unwrap() + assert len(objects) == 1 + assert objects[0] == obj From 37e6a744eb0d16df056a913954635717659748df Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 4 Sep 2024 15:42:59 +0200 Subject: [PATCH 077/197] json serialize filter --- packages/syft/src/syft/store/db/stash.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 7a7f3c34d22..a4ef564bf15 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -2,6 +2,7 @@ # stdlib from functools import cache +from typing import Any from typing import Generic from typing import cast from typing import get_args @@ -192,17 +193,18 @@ def get_by_uid( def _get_field_filter( self, field_name: str, - field_value: str, + field_value: Any, table: Table | None = None, ) -> sa.sql.elements.BinaryExpression: table = table if table is not None else self.table if field_name == "id": return table.c.id == field_value + json_value = serialize_json(field_value) if self.db.engine.dialect.name == "sqlite": - return table.c.fields[field_name] == func.json_quote(field_value) + return table.c.fields[field_name] == func.json_quote(json_value) elif self.db.engine.dialect.name == "postgresql": - return sa.cast(table.c.fields[field_name], sa.String) == field_value + return sa.cast(table.c.fields[field_name], sa.String) == json_value def _get_by_fields( self, From 87b7d1a271758cd9c7a2546cec2d3caa9881aa99 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 4 Sep 2024 16:28:13 +0200 Subject: [PATCH 078/197] fix notification service tests --- packages/syft/src/syft/serde/json_serde.py | 7 ++- packages/syft/src/syft/store/db/stash.py | 18 +++++-- packages/syft/tests/syft/action_test.py | 4 +- packages/syft/tests/syft/eager_test.py | 8 +-- .../notification_service_test.py | 50 +++---------------- .../syft/service/action/action_object_test.py | 4 +- .../syft/tests/syft/users/user_stash_test.py | 15 +++--- 7 files changed, 44 insertions(+), 62 deletions(-) diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index 3969c85c19e..af4569844aa 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -24,6 +24,7 @@ from ..types.datetime import DateTime from ..types.errors import SyftException from ..types.syft_object import BaseDateTime +from ..types.syft_object import DYNAMIC_SYFT_ATTRIBUTES from ..types.syft_object_registry import SyftObjectRegistry from ..types.uid import LineageID from ..types.uid import UID @@ -174,8 +175,12 @@ def _serialize_pydantic_to_json(obj: pydantic.BaseModel) -> dict[str, Json]: JSON_VERSION_FIELD: version, } + all_exclude_attrs = ( + set(exclude_attrs) | DEFAULT_EXCLUDE_ATTRS | set(DYNAMIC_SYFT_ATTRIBUTES) + ) + for key, type_ in obj.model_fields.items(): - if key in exclude_attrs or key in DEFAULT_EXCLUDE_ATTRS: + if key in all_exclude_attrs: continue try: result[key] = serialize_json(getattr(obj, key), type_.annotation) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 4d8069646ea..8b06f544b9b 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -555,13 +555,15 @@ def delete_by_uid( ) return uid + @as_result(NotFoundException) def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: # TODO: should do this in a single transaction # TODO add error handling for permission in permissions: - self.add_permission(permission) + self.add_permission(permission).unwrap() return None + @as_result(NotFoundException) def add_permission(self, permission: ActionObjectPermission) -> None: # TODO add error handling stmt = ( @@ -576,8 +578,12 @@ def add_permission(self, permission: ActionObjectPermission) -> None: ) ) - self.session.execute(stmt) + result = self.session.execute(stmt) self.session.commit() + if result.rowcount == 0: + raise NotFoundException( + f"{self.object_type.__name__}: {permission.uid} not found or no permission to change." + ) def remove_permission(self, permission: ActionObjectPermission) -> None: # TODO not threadsafe @@ -678,6 +684,7 @@ def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: result = self.session.execute(stmt).first() return result is not None + @as_result(NotFoundException) def add_storage_permission(self, permission: StoragePermission) -> None: stmt = ( self.table.update() @@ -690,10 +697,15 @@ def add_storage_permission(self, permission: StoragePermission) -> None: ) ) ) - self.session.execute(stmt) + result = self.session.execute(stmt) self.session.commit() + if result.rowcount == 0: + raise NotFoundException( + f"{self.object_type.__name__}: {permission.uid} not found or no permission to change." + ) return None + @as_result(NotFoundException) def add_storage_permissions(self, permissions: list[StoragePermission]) -> None: for permission in permissions: self.add_storage_permission(permission) diff --git a/packages/syft/tests/syft/action_test.py b/packages/syft/tests/syft/action_test.py index c7829daa9d3..e94222a4cc0 100644 --- a/packages/syft/tests/syft/action_test.py +++ b/packages/syft/tests/syft/action_test.py @@ -26,9 +26,9 @@ def test_actionobject_method(worker): action_store = worker.get_service("actionservice").store obj = ActionObject.from_obj("abc") pointer = obj.send(root_datasite_client) - assert len(action_store.data) == 1 + assert len(action_store._data) == 1 res = pointer.capitalize() - assert len(action_store.data) == 2 + assert len(action_store._data) == 2 assert res[0] == "A" diff --git a/packages/syft/tests/syft/eager_test.py b/packages/syft/tests/syft/eager_test.py index 18fd4b85394..243a18130a2 100644 --- a/packages/syft/tests/syft/eager_test.py +++ b/packages/syft/tests/syft/eager_test.py @@ -174,7 +174,7 @@ def test_setattribute(worker, guest_client): obj_pointer.dtype = np.int32 # local object is updated - assert obj_pointer.id.id in worker.action_store.data + assert obj_pointer.id.id in worker.action_store._data assert obj_pointer.id != original_id res = root_datasite_client.api.services.action.get(obj_pointer.id) @@ -206,7 +206,7 @@ def test_getattribute(worker, guest_client): size_pointer = obj_pointer.size # check result - assert size_pointer.id.id in worker.action_store.data + assert size_pointer.id.id in worker.action_store._data assert root_datasite_client.api.services.action.get(size_pointer.id) == 6 @@ -226,7 +226,7 @@ def test_eager_method(worker, guest_client): flat_pointer = obj_pointer.flatten() - assert flat_pointer.id.id in worker.action_store.data + assert flat_pointer.id.id in worker.action_store._data # check result assert all( root_datasite_client.api.services.action.get(flat_pointer.id) @@ -250,7 +250,7 @@ def test_eager_dunder_method(worker, guest_client): first_row_pointer = obj_pointer[0] - assert first_row_pointer.id.id in worker.action_store.data + assert first_row_pointer.id.id in worker.action_store._data # check result assert all( root_datasite_client.api.services.action.get(first_row_pointer.id) diff --git a/packages/syft/tests/syft/notifications/notification_service_test.py b/packages/syft/tests/syft/notifications/notification_service_test.py index f48e77ab97d..a8319d32a80 100644 --- a/packages/syft/tests/syft/notifications/notification_service_test.py +++ b/packages/syft/tests/syft/notifications/notification_service_test.py @@ -144,20 +144,12 @@ def test_get_all_success( NotificationStatus.UNREAD, ) - @as_result(StashException) - def mock_get_all_inbox_for_verify_key(*args, **kwargs) -> list[Notification]: - return [expected_message] - - monkeypatch.setattr( - notification_service.stash, - "get_all_inbox_for_verify_key", - mock_get_all_inbox_for_verify_key, - ) - response = test_notification_service.get_all(authed_context) assert len(response) == 1 assert isinstance(response[0], Notification) + response[0].syft_client_verify_key = None + response[0].syft_server_location = None assert response[0] == expected_message @@ -188,9 +180,6 @@ def mock_get_all_inbox_for_verify_key( def test_get_sent_success( - root_verify_key, - monkeypatch: MonkeyPatch, - notification_service: NotificationService, authed_context: AuthedServiceContext, document_store: DocumentStore, ) -> None: @@ -207,20 +196,12 @@ def test_get_sent_success( NotificationStatus.UNREAD, ) - @as_result(StashException) - def mock_get_all_sent_for_verify_key(credentials, verify_key) -> list[Notification]: - return [expected_message] - - monkeypatch.setattr( - notification_service.stash, - "get_all_sent_for_verify_key", - mock_get_all_sent_for_verify_key, - ) - response = test_notification_service.get_all_sent(authed_context) assert len(response) == 1 assert isinstance(response[0], Notification) + response[0].syft_server_location = None + response[0].syft_client_verify_key = None assert response[0] == expected_message @@ -340,19 +321,12 @@ def test_get_all_read_success( NotificationStatus.READ, ) - def mock_get_all_by_verify_key_for_status() -> list[Notification]: - return [expected_message] - - monkeypatch.setattr( - notification_service.stash, - "get_all_by_verify_key_for_status", - mock_get_all_by_verify_key_for_status, - ) - response = test_notification_service.get_all_read(authed_context) assert len(response) == 1 assert isinstance(response[0], Notification) + response[0].syft_server_location = None + response[0].syft_client_verify_key = None assert response[0] == expected_message @@ -404,19 +378,11 @@ def test_get_all_unread_success( NotificationStatus.UNREAD, ) - @as_result(StashException) - def mock_get_all_by_verify_key_for_status() -> list[Notification]: - return [expected_message] - - monkeypatch.setattr( - notification_service.stash, - "get_all_by_verify_key_for_status", - mock_get_all_by_verify_key_for_status, - ) - response = test_notification_service.get_all_unread(authed_context) assert len(response) == 1 assert isinstance(response[0], Notification) + response[0].syft_server_location = None + response[0].syft_client_verify_key = None assert response[0] == expected_message diff --git a/packages/syft/tests/syft/service/action/action_object_test.py b/packages/syft/tests/syft/service/action/action_object_test.py index e53412d7b84..5252ec1eb33 100644 --- a/packages/syft/tests/syft/service/action/action_object_test.py +++ b/packages/syft/tests/syft/service/action/action_object_test.py @@ -511,10 +511,10 @@ def test_actionobject_syft_send_get(worker, testcase): orig_obj = testcase obj = helper_make_action_obj(orig_obj) - assert len(action_store.data) == 0 + assert len(action_store._data) == 0 ptr = obj.send(root_datasite_client) - assert len(action_store.data) == 1 + assert len(action_store._data) == 1 retrieved = ptr.get() assert obj.syft_action_data == retrieved diff --git a/packages/syft/tests/syft/users/user_stash_test.py b/packages/syft/tests/syft/users/user_stash_test.py index 9592284c2f8..584e616d093 100644 --- a/packages/syft/tests/syft/users/user_stash_test.py +++ b/packages/syft/tests/syft/users/user_stash_test.py @@ -1,5 +1,6 @@ # third party from faker import Faker +import pytest # syft absolute from syft.server.credentials import SyftSigningKey @@ -39,16 +40,14 @@ def test_userstash_set( def test_userstash_set_duplicate( root_datasite_client, user_stash: UserStash, guest_user: User ) -> None: - result = user_stash.set( - root_datasite_client.credentials.verify_key, guest_user - ).unwrap() + _ = user_stash.set(root_datasite_client.credentials.verify_key, guest_user).unwrap() original_count = len(user_stash._data) - result = user_stash.set(root_datasite_client.credentials.verify_key, guest_user) - assert result.is_err() - exc = result.err() - assert type(exc) == SyftException - assert exc.public_message + with pytest.raises(SyftException) as exc: + _ = user_stash.set( + root_datasite_client.credentials.verify_key, guest_user + ).unwrap() + assert exc.public_message assert len(user_stash._data) == original_count From 2bd627a92428760f0755db088cba86ee1bd9768e Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 4 Sep 2024 17:25:55 +0200 Subject: [PATCH 079/197] fix request service --- .../src/syft/service/blob_storage/service.py | 5 ++--- .../syft/service/code/user_code_service.py | 19 ++++++++++--------- .../syft/src/syft/service/request/request.py | 1 + packages/syft/src/syft/store/db/stash.py | 3 --- packages/syft/tests/syft/action_test.py | 4 ++-- .../service/action/action_service_test.py | 2 +- .../tests/syft/service_permission_test.py | 6 +----- .../tests/syft/stores/queue_stash_test.py | 6 ------ .../syft/tests/syft/users/user_code_test.py | 8 ++++++-- 9 files changed, 23 insertions(+), 31 deletions(-) diff --git a/packages/syft/src/syft/service/blob_storage/service.py b/packages/syft/src/syft/service/blob_storage/service.py index c4bc955e29d..ec335433edc 100644 --- a/packages/syft/src/syft/service/blob_storage/service.py +++ b/packages/syft/src/syft/service/blob_storage/service.py @@ -12,7 +12,6 @@ from ...store.blob_storage.on_disk import OnDiskBlobDeposit from ...store.blob_storage.seaweedfs import SeaweedFSBlobDeposit from ...store.document_store import DocumentStore -from ...store.document_store import UIDPartitionKey from ...types.blob_storage import AzureSecureFilePathLocation from ...types.blob_storage import BlobFileType from ...types.blob_storage import BlobStorageEntry @@ -323,8 +322,8 @@ def delete(self, context: AuthedServiceContext, uid: UID) -> SyftSuccess: public_message=f"Failed to delete blob file with id '{uid}'. Error: {e}" ) - self.stash.delete( - context.credentials, UIDPartitionKey.with_obj(uid), has_permission=True + self.stash.delete_by_uid( + context.credentials, uid, has_permission=True ).unwrap() except Exception as e: raise SyftException( diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 86763616860..5a0d369420c 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -78,7 +78,7 @@ def submit( self, context: AuthedServiceContext, code: SubmitUserCode ) -> SyftSuccess: """Add User Code""" - user_code = self._submit(context, code, exists_ok=False) + user_code = self._submit(context, code, exists_ok=False).unwrap() return SyftSuccess( message="User Code Submitted", require_api_update=True, value=user_code ) @@ -109,16 +109,17 @@ def _submit( context.credentials, code_hash=get_code_hash(submit_code.code, context.credentials), ).unwrap() + # no exception, code exists + if exists_ok: + return existing_code + else: + raise SyftException( + public_message="UserCode with this code already exists" + ) except NotFoundException: - existing_code = None - - if not exists_ok and existing_code is not None: - raise SyftException( - public_message="UserCode with this code already exists", - ) + pass code = submit_code.to(UserCode, context=context) - result = self._post_user_code_transform_ops(context, code) if result.is_err(): @@ -301,7 +302,7 @@ def request_code_execution( reason: str | None = "", ) -> Request: """Request Code execution on user code""" - user_code = self._submit(context, code, exists_ok=False).unwrap() + user_code = self._get_or_submit_user_code(context, code).unwrap() result = self._request_code_execution( context, diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 4d48f745927..d978a7c2c75 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -367,6 +367,7 @@ class Request(SyncableSyftObject): __attr_searchable__ = [ "requesting_user_verify_key", "approving_user_verify_key", + "code_id", ] __attr_unique__ = ["request_hash"] __repr_attrs__ = [ diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index b5110b42e94..35f2537410d 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -1,7 +1,6 @@ # stdlib # stdlib -from functools import cache from typing import Any from typing import Generic from typing import cast @@ -359,8 +358,6 @@ def row_as_obj(self, row: Row) -> StashT: # TODO make unwrappable serde return deserialize_json(row.fields) - # TODO add cache invalidation, ignore B019 - @cache # noqa: B019 def get_role(self, credentials: SyftVerifyKey) -> ServiceRole: # TODO error handling user_table = Table("User", Base.metadata) diff --git a/packages/syft/tests/syft/action_test.py b/packages/syft/tests/syft/action_test.py index e94222a4cc0..d94712fa4fe 100644 --- a/packages/syft/tests/syft/action_test.py +++ b/packages/syft/tests/syft/action_test.py @@ -75,7 +75,7 @@ def test_lib_function_action(worker): assert isinstance(res, ActionObject) assert all(res == np.array([0, 0, 0])) - assert len(worker.get_service("actionservice").store.data) > 0 + assert len(worker.get_service("actionservice").store._data) > 0 def test_call_lib_function_action2(worker): @@ -90,7 +90,7 @@ def test_lib_class_init_action(worker): assert isinstance(res, ActionObject) assert res == np.float32(4.0) - assert len(worker.get_service("actionservice").store.data) > 0 + assert len(worker.get_service("actionservice").store._data) > 0 def test_call_lib_wo_permission(worker): diff --git a/packages/syft/tests/syft/service/action/action_service_test.py b/packages/syft/tests/syft/service/action/action_service_test.py index bb8057d4ee3..88e44cd8c55 100644 --- a/packages/syft/tests/syft/service/action/action_service_test.py +++ b/packages/syft/tests/syft/service/action/action_service_test.py @@ -22,6 +22,6 @@ def test_action_service_sanity(worker): obj = ActionObject.from_obj("abc") pointer = obj.send(root_datasite_client) - assert len(service.store.data) == 1 + assert len(service.store._data) == 1 res = pointer.capitalize() assert res[0] == "A" diff --git a/packages/syft/tests/syft/service_permission_test.py b/packages/syft/tests/syft/service_permission_test.py index afc266005f3..ceb6d63923c 100644 --- a/packages/syft/tests/syft/service_permission_test.py +++ b/packages/syft/tests/syft/service_permission_test.py @@ -9,12 +9,8 @@ @pytest.fixture def guest_mock_user(root_verify_key, user_stash, guest_user): - result = user_stash.partition.set(root_verify_key, guest_user) - assert result.is_ok() - - user = result.ok() + user = user_stash.set(root_verify_key, guest_user).unwrap() assert user is not None - yield user diff --git a/packages/syft/tests/syft/stores/queue_stash_test.py b/packages/syft/tests/syft/stores/queue_stash_test.py index 312766e7c4e..9809d4a8774 100644 --- a/packages/syft/tests/syft/stores/queue_stash_test.py +++ b/packages/syft/tests/syft/stores/queue_stash_test.py @@ -47,7 +47,6 @@ def mock_queue_object(): @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("dict_queue_stash"), pytest.lazy_fixture("sqlite_queue_stash"), pytest.lazy_fixture("mongo_queue_stash"), ], @@ -61,7 +60,6 @@ def test_queue_stash_sanity(queue: Any) -> None: @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("dict_queue_stash"), pytest.lazy_fixture("sqlite_queue_stash"), pytest.lazy_fixture("mongo_queue_stash"), ], @@ -102,7 +100,6 @@ def test_queue_stash_set_get(root_verify_key, queue: Any) -> None: @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("dict_queue_stash"), pytest.lazy_fixture("sqlite_queue_stash"), pytest.lazy_fixture("mongo_queue_stash"), ], @@ -133,7 +130,6 @@ def test_queue_stash_update(root_verify_key, queue: Any) -> None: @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("dict_queue_stash"), pytest.lazy_fixture("sqlite_queue_stash"), pytest.lazy_fixture("mongo_queue_stash"), ], @@ -176,7 +172,6 @@ def _kv_cbk(tid: int) -> None: @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("dict_queue_stash"), pytest.lazy_fixture("sqlite_queue_stash"), pytest.lazy_fixture("mongo_queue_stash"), ], @@ -220,7 +215,6 @@ def _kv_cbk(tid: int) -> None: @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("dict_queue_stash"), pytest.lazy_fixture("sqlite_queue_stash"), pytest.lazy_fixture("mongo_queue_stash"), ], diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index a1fa1a3925d..9ed439b55d1 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -16,6 +16,7 @@ from syft.service.response import SyftError from syft.service.response import SyftSuccess from syft.service.user.user import User +from syft.service.user.user import UserView from syft.service.user.user_roles import ServiceRole from syft.types.errors import SyftException @@ -66,14 +67,17 @@ def test_new_admin_can_list_user_code( admin = root_client.login(email=email, password=pw) - root_client.api.services.user.update(uid=admin.account.id, role=ServiceRole.ADMIN) + result: UserView = root_client.api.services.user.update( + uid=admin.account.id, role=ServiceRole.ADMIN + ) + assert result.role == ServiceRole.ADMIN if delete_original_admin: res = root_client.api.services.user.delete(root_client.account.id) assert not isinstance(res, SyftError) user_code_stash = worker.get_service("usercodeservice").stash - user_code = user_code_stash.get_all(user_code_stash.store.root_verify_key).ok() + user_code = user_code_stash.get_all(user_code_stash.root_verify_key).ok() assert len(user_code) == len(admin.code.get_all()) assert {c.id for c in user_code} == {c.id for c in admin.code} From d3360e576d29d157424bcee395532b155f38eb9b Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 5 Sep 2024 09:56:38 +0200 Subject: [PATCH 080/197] fix protocol --- .../src/syft/protocol/protocol_version.json | 443 +++++++++--------- 1 file changed, 218 insertions(+), 225 deletions(-) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 3121dc36d9f..a0428c76ce3 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -36,217 +36,175 @@ "action": "add" } }, - "StoreConfig": { - "1": { - "version": 1, - "hash": "a9997fce6a8a0ed2884c58b8eb9382f8554bdd18fff61f8bf0451945bcff12c7", - "action": "add" - } - }, - "MongoDict": { - "1": { - "version": 1, - "hash": "57e36f57eed75e62b29e2bac1295035a9bf2c0e3c56719dac24cb6cc685be00b", - "action": "add" - } - }, - "MongoStoreConfig": { - "1": { - "version": 1, - "hash": "53342b27d34165b7e2699f8e7ad70d13d125875e6a75e8fa18f5796428f41036", - "action": "add" - } - }, - "LinkedObject": { - "1": { - "version": 1, - "hash": "d80f5ac7f51a9383be1a3cb334d56ae50e49733ed3199f3b6b5d6febd9de410b", - "action": "add" - } - }, - "BaseConfig": { + "User": { "1": { "version": 1, - "hash": "10bd7566041d0f0a3aa295367785fdcc2c5bbf0ded984ac9230754f37496a6a7", + "hash": "2df4b68182c558dba5485a8a6867acf2a5c341b249ad67373a504098aa8c4343", "action": "add" }, "2": { "version": 2, - "hash": "890d2879ac44611db9b88ba9334a721130d0ac3aa18a303fa9e4081f14b9b8c7", + "hash": "af6fb5b2e1606e97838f4a60f0536ad95db606d455e94acbd1977df866608a2c", "action": "add" } }, - "ServiceConfig": { + "UserUpdate": { "1": { "version": 1, - "hash": "28af8a296f5ff63de50438277eaa1f4380682e6aca9f2ca28320d7a444825e88", - "action": "add" - }, - "2": { - "version": 2, - "hash": "93dfab144e0b0884c602358b3a9ce889bb29ab96e3b4adcfe3cef47a31694a9a", + "hash": "1bf6707c69b809c804fb939c7c37d787c2f6889508a4bec37d24221af2eb777a", "action": "add" } }, - "LibConfig": { + "UserCreate": { "1": { "version": 1, - "hash": "ee8f0e3f6aae81948d72e30226645e8eb5d312a6770411a1edca748168c467c0", - "action": "add" - }, - "2": { - "version": 2, - "hash": "a8a78a8d726ee9e79f95614f3d0fa5b85edc6fce7be7651715208669be93e0e3", + "hash": "49d6087e2309ba59987f3126e286e74b3a66492a08ad82fa507ea17d52ce78e3", "action": "add" } }, - "APIEndpoint": { + "UserSearch": { "1": { "version": 1, - "hash": "faa1cf9336a0d1233868c8c57745ff38c0be60399dc1acd0c0e8dd440e405dbd", + "hash": "9ac946338cca68d00d1696a57943442f062628ec3daf53077d0bdd3f72cd9fa0", "action": "add" } }, - "LibEndpoint": { + "UserView": { "1": { "version": 1, - "hash": "a585c83a33a019d363ae5a0c6d4197193654307c19a4829dfbf8a8cfd2c1842a", + "hash": "0b52d758e31d5889c9cd88afb467aae4a74e34a5276924e07012243c34d300fe", "action": "add" } }, - "SignedSyftAPICall": { + "UserViewPage": { "1": { "version": 1, - "hash": "2f959455f7130f4e59360b8aa58f19785b76eaa0f8a5a9188a6cbf32b31311ca", + "hash": "1cd6528d02ec180f080d5c35f0da760d8a59af9da7baaa9c17c1c7cedcc858fa", "action": "add" } }, - "SyftAPICall": { + "UserPrivateKey": { "1": { "version": 1, - "hash": "59e89e7b9ea30deaed64d1ffd9bc0769b999d3082b305428432c1f5be36c6343", + "hash": "4817d8147aba94373f320dcd90e65f097cf6e5a2ef353aa8520e23128d522b5d", "action": "add" } }, - "SyftAPIData": { + "StoreConfig": { "1": { "version": 1, - "hash": "820b279c581cafd9bb5009702d4e3db22ec3a3156676426304b9038dad260a24", + "hash": "a9997fce6a8a0ed2884c58b8eb9382f8554bdd18fff61f8bf0451945bcff12c7", "action": "add" } }, - "SyftAPI": { + "MongoDict": { "1": { "version": 1, - "hash": "cc13ab058ee36748c14b0d4bd9b9e894c7566fff09cfa4170b3eece520169f15", + "hash": "57e36f57eed75e62b29e2bac1295035a9bf2c0e3c56719dac24cb6cc685be00b", "action": "add" } }, - "User": { + "MongoStoreConfig": { "1": { "version": 1, - "hash": "2df4b68182c558dba5485a8a6867acf2a5c341b249ad67373a504098aa8c4343", - "action": "add" - }, - "2": { - "version": 2, - "hash": "af6fb5b2e1606e97838f4a60f0536ad95db606d455e94acbd1977df866608a2c", + "hash": "53342b27d34165b7e2699f8e7ad70d13d125875e6a75e8fa18f5796428f41036", "action": "add" } }, - "UserUpdate": { + "LinkedObject": { "1": { "version": 1, - "hash": "1bf6707c69b809c804fb939c7c37d787c2f6889508a4bec37d24221af2eb777a", + "hash": "d80f5ac7f51a9383be1a3cb334d56ae50e49733ed3199f3b6b5d6febd9de410b", "action": "add" } }, - "UserCreate": { + "DateTime": { "1": { "version": 1, - "hash": "49d6087e2309ba59987f3126e286e74b3a66492a08ad82fa507ea17d52ce78e3", + "hash": "394abb554114ead4d63c36e3fe83ac018dead4b21a8465174009577c46d54c58", "action": "add" } }, - "UserSearch": { + "ReplyNotification": { "1": { "version": 1, - "hash": "9ac946338cca68d00d1696a57943442f062628ec3daf53077d0bdd3f72cd9fa0", + "hash": "84102dfc59d711b03c2f3d3a6ecaca000b6835f1bbdd9af801057f7aacb5f1d0", "action": "add" } }, - "UserView": { + "Notification": { "1": { "version": 1, - "hash": "0b52d758e31d5889c9cd88afb467aae4a74e34a5276924e07012243c34d300fe", + "hash": "af4cb232bff390c431e399975f048b34da7e940ace8b23b940a3b398c91c5326", "action": "add" } }, - "UserViewPage": { + "CreateNotification": { "1": { "version": 1, - "hash": "1cd6528d02ec180f080d5c35f0da760d8a59af9da7baaa9c17c1c7cedcc858fa", + "hash": "7e426c946b7d5db6f9427960ec16042f3018091d835ca5966f3568c324a2ab53", "action": "add" } }, - "UserPrivateKey": { + "UserNotificationActivity": { "1": { "version": 1, - "hash": "4817d8147aba94373f320dcd90e65f097cf6e5a2ef353aa8520e23128d522b5d", + "hash": "422fd01c6d9af38688a9982abd34e80794a1f6ddd444cca225d77f49189847a9", "action": "add" } }, - "DateTime": { + "NotificationPreferences": { "1": { "version": 1, - "hash": "394abb554114ead4d63c36e3fe83ac018dead4b21a8465174009577c46d54c58", + "hash": "a42f06b367e7c6cbabcbf3cfcc84d1ca0873e457d972ebd060e87c9d6185f62b", "action": "add" } }, - "ReplyNotification": { + "NotifierSettings": { "1": { "version": 1, - "hash": "84102dfc59d711b03c2f3d3a6ecaca000b6835f1bbdd9af801057f7aacb5f1d0", + "hash": "65c8ab814d35fac32f68d3000756692592cc59940f30e3af3dcdfa2328755b9d", "action": "add" - } - }, - "Notification": { - "1": { - "version": 1, - "hash": "af4cb232bff390c431e399975f048b34da7e940ace8b23b940a3b398c91c5326", + }, + "2": { + "version": 2, + "hash": "be8b52597fc628d1b7cd22b776ee81416e1adbb04a45188778eb0e32ed1416b4", "action": "add" } }, - "CreateNotification": { + "BaseConfig": { "1": { "version": 1, - "hash": "7e426c946b7d5db6f9427960ec16042f3018091d835ca5966f3568c324a2ab53", + "hash": "10bd7566041d0f0a3aa295367785fdcc2c5bbf0ded984ac9230754f37496a6a7", "action": "add" - } - }, - "UserNotificationActivity": { - "1": { - "version": 1, - "hash": "422fd01c6d9af38688a9982abd34e80794a1f6ddd444cca225d77f49189847a9", + }, + "2": { + "version": 2, + "hash": "890d2879ac44611db9b88ba9334a721130d0ac3aa18a303fa9e4081f14b9b8c7", "action": "add" } }, - "NotificationPreferences": { + "ServiceConfig": { "1": { "version": 1, - "hash": "a42f06b367e7c6cbabcbf3cfcc84d1ca0873e457d972ebd060e87c9d6185f62b", + "hash": "28af8a296f5ff63de50438277eaa1f4380682e6aca9f2ca28320d7a444825e88", + "action": "add" + }, + "2": { + "version": 2, + "hash": "93dfab144e0b0884c602358b3a9ce889bb29ab96e3b4adcfe3cef47a31694a9a", "action": "add" } }, - "NotifierSettings": { + "LibConfig": { "1": { "version": 1, - "hash": "65c8ab814d35fac32f68d3000756692592cc59940f30e3af3dcdfa2328755b9d", + "hash": "ee8f0e3f6aae81948d72e30226645e8eb5d312a6770411a1edca748168c467c0", "action": "add" }, "2": { "version": 2, - "hash": "be8b52597fc628d1b7cd22b776ee81416e1adbb04a45188778eb0e32ed1416b4", + "hash": "a8a78a8d726ee9e79f95614f3d0fa5b85edc6fce7be7651715208669be93e0e3", "action": "add" } }, @@ -343,6 +301,48 @@ "action": "add" } }, + "APIEndpoint": { + "1": { + "version": 1, + "hash": "faa1cf9336a0d1233868c8c57745ff38c0be60399dc1acd0c0e8dd440e405dbd", + "action": "add" + } + }, + "LibEndpoint": { + "1": { + "version": 1, + "hash": "a585c83a33a019d363ae5a0c6d4197193654307c19a4829dfbf8a8cfd2c1842a", + "action": "add" + } + }, + "SignedSyftAPICall": { + "1": { + "version": 1, + "hash": "2f959455f7130f4e59360b8aa58f19785b76eaa0f8a5a9188a6cbf32b31311ca", + "action": "add" + } + }, + "SyftAPICall": { + "1": { + "version": 1, + "hash": "59e89e7b9ea30deaed64d1ffd9bc0769b999d3082b305428432c1f5be36c6343", + "action": "add" + } + }, + "SyftAPIData": { + "1": { + "version": 1, + "hash": "820b279c581cafd9bb5009702d4e3db22ec3a3156676426304b9038dad260a24", + "action": "add" + } + }, + "SyftAPI": { + "1": { + "version": 1, + "hash": "cc13ab058ee36748c14b0d4bd9b9e894c7566fff09cfa4170b3eece520169f15", + "action": "add" + } + }, "HTTPConnection": { "1": { "version": 1, @@ -399,6 +399,125 @@ "action": "add" } }, + "BlobFile": { + "1": { + "version": 1, + "hash": "d99239100f1cb0b73c69b2ad7cab01a06909cc3a4976ba2b3b67cf6fe5e2f516", + "action": "add" + } + }, + "BlobFileOBject": { + "1": { + "version": 1, + "hash": "6c40dab2c8d2220d4fff7cc653d76cc026a856db7e2b5713b6341e255adc7ea2", + "action": "add" + } + }, + "SecureFilePathLocation": { + "1": { + "version": 1, + "hash": "ea5978b98d7773d221665b450454c9130c103a5c850669a0acd620607cd614b7", + "action": "add" + } + }, + "SeaweedSecureFilePathLocation": { + "1": { + "version": 1, + "hash": "3fc9bfc8c1b1cf660c9747e8c1fe3eb2220e78d4e3b5d6b5c5f29a07a77ebf3e", + "action": "add" + } + }, + "AzureSecureFilePathLocation": { + "1": { + "version": 1, + "hash": "090a9e962eeb655586ee966c5651d8996363969818a38f9a486fd64d33047e05", + "action": "add" + } + }, + "BlobStorageEntry": { + "1": { + "version": 1, + "hash": "afdc6a1d8a24b1ee1ed9d3e79f5bac64b4f0d9d36800f07f10be0b896470345f", + "action": "add" + } + }, + "BlobStorageMetadata": { + "1": { + "version": 1, + "hash": "9d4b61ac4ea1910c2f7c767a50a6a52544a24663548f069e79bd906f11b538e4", + "action": "add" + } + }, + "CreateBlobStorageEntry": { + "1": { + "version": 1, + "hash": "ffc3cbfeade67d074dc5bf7d655a1eb8c83630076028a72b3cc4548f3b413e14", + "action": "add" + } + }, + "BlobRetrieval": { + "1": { + "version": 1, + "hash": "c422c74b89a9349742acaa848566fe18bfef1a83333458b858c074baed37a859", + "action": "add" + } + }, + "SyftObjectRetrieval": { + "1": { + "version": 1, + "hash": "b2b62447445adc4cd0b77ab59d6fa56624dd316fb50281e570daad07556b6db2", + "action": "add" + } + }, + "BlobRetrievalByURL": { + "1": { + "version": 1, + "hash": "4db0e3b7a6334d3835356d8393866711e243e360af25a95f3cc4066f032404b5", + "action": "add" + } + }, + "BlobDeposit": { + "1": { + "version": 1, + "hash": "6eb5cc57dc763126bfc6ec5a2b79d02e77eadf9d9efb1888a5c366b7799c1c24", + "action": "add" + } + }, + "WorkerSettings": { + "1": { + "version": 1, + "hash": "dca33003904a71688e5b07db65f8833eb4de8135aade7154076b8eafbb94d26b", + "action": "add" + } + }, + "HTTPServerRoute": { + "1": { + "version": 1, + "hash": "938245604a9c7e50001299afff5b669b2548364e356fed22a22780497831bf81", + "action": "add" + } + }, + "PythonServerRoute": { + "1": { + "version": 1, + "hash": "a068d8f942d55ecb6d45af88a27c6ebf208584275bf589cbc308df3f774ab9a9", + "action": "add" + } + }, + "VeilidServerRoute": { + "1": { + "version": 1, + "hash": "e676bc165601d2ede69707a4b6168ed4674f3f98887026d098a2dd4da4dfd097", + "action": "add" + } + }, + "EnclaveMetadata": { + "1": { + "version": 1, + "hash": "8d2dfafa01ec909c080a790cf15a8fc78e00382d3bfe6207098ceb25a60b9c53", + "action": "add" + } + }, "CustomEndpointActionObject": { "1": { "version": 1, @@ -673,62 +792,6 @@ "action": "add" } }, - "BlobFile": { - "1": { - "version": 1, - "hash": "d99239100f1cb0b73c69b2ad7cab01a06909cc3a4976ba2b3b67cf6fe5e2f516", - "action": "add" - } - }, - "BlobFileOBject": { - "1": { - "version": 1, - "hash": "6c40dab2c8d2220d4fff7cc653d76cc026a856db7e2b5713b6341e255adc7ea2", - "action": "add" - } - }, - "SecureFilePathLocation": { - "1": { - "version": 1, - "hash": "ea5978b98d7773d221665b450454c9130c103a5c850669a0acd620607cd614b7", - "action": "add" - } - }, - "SeaweedSecureFilePathLocation": { - "1": { - "version": 1, - "hash": "3fc9bfc8c1b1cf660c9747e8c1fe3eb2220e78d4e3b5d6b5c5f29a07a77ebf3e", - "action": "add" - } - }, - "AzureSecureFilePathLocation": { - "1": { - "version": 1, - "hash": "090a9e962eeb655586ee966c5651d8996363969818a38f9a486fd64d33047e05", - "action": "add" - } - }, - "BlobStorageEntry": { - "1": { - "version": 1, - "hash": "afdc6a1d8a24b1ee1ed9d3e79f5bac64b4f0d9d36800f07f10be0b896470345f", - "action": "add" - } - }, - "BlobStorageMetadata": { - "1": { - "version": 1, - "hash": "9d4b61ac4ea1910c2f7c767a50a6a52544a24663548f069e79bd906f11b538e4", - "action": "add" - } - }, - "CreateBlobStorageEntry": { - "1": { - "version": 1, - "hash": "ffc3cbfeade67d074dc5bf7d655a1eb8c83630076028a72b3cc4548f3b413e14", - "action": "add" - } - }, "SyftObjectMigrationState": { "1": { "version": 1, @@ -755,34 +818,6 @@ "action": "add" } }, - "BlobRetrieval": { - "1": { - "version": 1, - "hash": "c422c74b89a9349742acaa848566fe18bfef1a83333458b858c074baed37a859", - "action": "add" - } - }, - "SyftObjectRetrieval": { - "1": { - "version": 1, - "hash": "b2b62447445adc4cd0b77ab59d6fa56624dd316fb50281e570daad07556b6db2", - "action": "add" - } - }, - "BlobRetrievalByURL": { - "1": { - "version": 1, - "hash": "4db0e3b7a6334d3835356d8393866711e243e360af25a95f3cc4066f032404b5", - "action": "add" - } - }, - "BlobDeposit": { - "1": { - "version": 1, - "hash": "6eb5cc57dc763126bfc6ec5a2b79d02e77eadf9d9efb1888a5c366b7799c1c24", - "action": "add" - } - }, "OnDiskBlobDeposit": { "1": { "version": 1, @@ -811,13 +846,6 @@ "action": "add" } }, - "DictStoreConfig": { - "1": { - "version": 1, - "hash": "2e1365c5535fa51c22eef79f67dd6444789bc829c27881367e3050e06e2ffbfe", - "action": "add" - } - }, "NumpyArrayObject": { "1": { "version": 1, @@ -1007,34 +1035,6 @@ "action": "add" } }, - "WorkerSettings": { - "1": { - "version": 1, - "hash": "dca33003904a71688e5b07db65f8833eb4de8135aade7154076b8eafbb94d26b", - "action": "add" - } - }, - "HTTPServerRoute": { - "1": { - "version": 1, - "hash": "938245604a9c7e50001299afff5b669b2548364e356fed22a22780497831bf81", - "action": "add" - } - }, - "PythonServerRoute": { - "1": { - "version": 1, - "hash": "a068d8f942d55ecb6d45af88a27c6ebf208584275bf589cbc308df3f774ab9a9", - "action": "add" - } - }, - "VeilidServerRoute": { - "1": { - "version": 1, - "hash": "e676bc165601d2ede69707a4b6168ed4674f3f98887026d098a2dd4da4dfd097", - "action": "add" - } - }, "ServerPeer": { "1": { "version": 1, @@ -1160,13 +1160,6 @@ "hash": "ed05cb87aec832098fc464ac36cd6bceaab705463d0d2fa1b2d8e1ccc510018c", "action": "add" } - }, - "EnclaveMetadata": { - "1": { - "version": 1, - "hash": "8d2dfafa01ec909c080a790cf15a8fc78e00382d3bfe6207098ceb25a60b9c53", - "action": "add" - } } } } From 0fdc2b8e71945f1fc85134453a0422f3cc8c1075 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 5 Sep 2024 09:59:52 +0200 Subject: [PATCH 081/197] remove dataset stash any type --- packages/syft/src/syft/service/dataset/dataset_stash.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index 99dfa06d64c..3b50c91fb1f 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -13,8 +13,6 @@ @instrument @serializable(canonical_name="DatasetStashSQL", version=1) class DatasetStash(ObjectStash[Dataset]): - allow_set_any_type = True - @as_result(StashException, NotFoundException) def get_by_name(self, credentials: SyftVerifyKey, name: str) -> Dataset: return self.get_one_by_field( From e1c4e855661838bac356dc55ad5f7428b10865ca Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 5 Sep 2024 10:11:33 +0200 Subject: [PATCH 082/197] fix dependencies --- packages/syft/setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index a4015b1d895..010bc788433 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -68,6 +68,7 @@ syft = nh3==0.2.17 ipython<8.27.0 dynaconf==3.2.6 + sqlalchemy==2.0.32 install_requires = %(syft)s @@ -103,7 +104,6 @@ dev = ruff==0.4.7 safety>=2.4.0b2 aiosmtpd==1.4.6 - dynaconf==3.2.5 telemetry = opentelemetry-api==1.27.0 From 6ca11f8026d302ea2e0eff832b1b23d1835a25dd Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 5 Sep 2024 10:16:50 +0200 Subject: [PATCH 083/197] fix project tests --- .../src/syft/protocol/protocol_version.json | 1159 ++++++++++++++++- .../syft/src/syft/server/worker_settings.py | 7 +- packages/syft/src/syft/store/db/stash.py | 12 +- .../tests/syft/stores/queue_stash_test.py | 12 +- .../syft/tests/syft/users/user_code_test.py | 4 +- 5 files changed, 1168 insertions(+), 26 deletions(-) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index daee5706ccf..ca79d2ac467 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -1,21 +1,1164 @@ { - "1": { - "release_name": "0.9.1.json" - }, "dev": { "object_versions": { + "SyftObjectVersioned": { + "1": { + "version": 1, + "hash": "7c842dcdbb57e2528ffa690ea18c19fff3c8a591811d40cad2b19be3100e2ff4", + "action": "add" + } + }, + "BaseDateTime": { + "1": { + "version": 1, + "hash": "614db484b1950be729902b1861bd3a7b33899176507c61cef11dc0d44611cfd3", + "action": "add" + } + }, + "SyftObject": { + "1": { + "version": 1, + "hash": "bb70d874355988908d3a92a3941d6613a6995a4850be3b6a0147f4d387724406", + "action": "add" + } + }, + "PartialSyftObject": { + "1": { + "version": 1, + "hash": "19a995fcc2833f4fab24584fd99b71a80c2ef1f13c06f83af79e4482846b1656", + "action": "add" + } + }, + "ServerMetadata": { + "1": { + "version": 1, + "hash": "1691c7667eca86b20c4189e90ce4e643dd41fd3682cdb69c6308878f2a6f135c", + "action": "add" + } + }, + "User": { + "1": { + "version": 1, + "hash": "2df4b68182c558dba5485a8a6867acf2a5c341b249ad67373a504098aa8c4343", + "action": "add" + }, + "2": { + "version": 2, + "hash": "af6fb5b2e1606e97838f4a60f0536ad95db606d455e94acbd1977df866608a2c", + "action": "add" + } + }, + "UserUpdate": { + "1": { + "version": 1, + "hash": "1bf6707c69b809c804fb939c7c37d787c2f6889508a4bec37d24221af2eb777a", + "action": "add" + } + }, + "UserCreate": { + "1": { + "version": 1, + "hash": "49d6087e2309ba59987f3126e286e74b3a66492a08ad82fa507ea17d52ce78e3", + "action": "add" + } + }, + "UserSearch": { + "1": { + "version": 1, + "hash": "9ac946338cca68d00d1696a57943442f062628ec3daf53077d0bdd3f72cd9fa0", + "action": "add" + } + }, + "UserView": { + "1": { + "version": 1, + "hash": "0b52d758e31d5889c9cd88afb467aae4a74e34a5276924e07012243c34d300fe", + "action": "add" + } + }, + "UserViewPage": { + "1": { + "version": 1, + "hash": "1cd6528d02ec180f080d5c35f0da760d8a59af9da7baaa9c17c1c7cedcc858fa", + "action": "add" + } + }, + "UserPrivateKey": { + "1": { + "version": 1, + "hash": "4817d8147aba94373f320dcd90e65f097cf6e5a2ef353aa8520e23128d522b5d", + "action": "add" + } + }, + "StoreConfig": { + "1": { + "version": 1, + "hash": "a9997fce6a8a0ed2884c58b8eb9382f8554bdd18fff61f8bf0451945bcff12c7", + "action": "add" + } + }, + "MongoDict": { + "1": { + "version": 1, + "hash": "57e36f57eed75e62b29e2bac1295035a9bf2c0e3c56719dac24cb6cc685be00b", + "action": "add" + } + }, + "MongoStoreConfig": { + "1": { + "version": 1, + "hash": "53342b27d34165b7e2699f8e7ad70d13d125875e6a75e8fa18f5796428f41036", + "action": "add" + } + }, + "LinkedObject": { + "1": { + "version": 1, + "hash": "d80f5ac7f51a9383be1a3cb334d56ae50e49733ed3199f3b6b5d6febd9de410b", + "action": "add" + } + }, + "DateTime": { + "1": { + "version": 1, + "hash": "394abb554114ead4d63c36e3fe83ac018dead4b21a8465174009577c46d54c58", + "action": "add" + } + }, + "ReplyNotification": { + "1": { + "version": 1, + "hash": "84102dfc59d711b03c2f3d3a6ecaca000b6835f1bbdd9af801057f7aacb5f1d0", + "action": "add" + } + }, + "Notification": { + "1": { + "version": 1, + "hash": "af4cb232bff390c431e399975f048b34da7e940ace8b23b940a3b398c91c5326", + "action": "add" + } + }, + "CreateNotification": { + "1": { + "version": 1, + "hash": "7e426c946b7d5db6f9427960ec16042f3018091d835ca5966f3568c324a2ab53", + "action": "add" + } + }, + "UserNotificationActivity": { + "1": { + "version": 1, + "hash": "422fd01c6d9af38688a9982abd34e80794a1f6ddd444cca225d77f49189847a9", + "action": "add" + } + }, + "NotificationPreferences": { + "1": { + "version": 1, + "hash": "a42f06b367e7c6cbabcbf3cfcc84d1ca0873e457d972ebd060e87c9d6185f62b", + "action": "add" + } + }, + "NotifierSettings": { + "1": { + "version": 1, + "hash": "65c8ab814d35fac32f68d3000756692592cc59940f30e3af3dcdfa2328755b9d", + "action": "add" + }, + "2": { + "version": 2, + "hash": "be8b52597fc628d1b7cd22b776ee81416e1adbb04a45188778eb0e32ed1416b4", + "action": "add" + } + }, + "BaseConfig": { + "1": { + "version": 1, + "hash": "10bd7566041d0f0a3aa295367785fdcc2c5bbf0ded984ac9230754f37496a6a7", + "action": "add" + }, + "2": { + "version": 2, + "hash": "890d2879ac44611db9b88ba9334a721130d0ac3aa18a303fa9e4081f14b9b8c7", + "action": "add" + } + }, + "ServiceConfig": { + "1": { + "version": 1, + "hash": "28af8a296f5ff63de50438277eaa1f4380682e6aca9f2ca28320d7a444825e88", + "action": "add" + }, + "2": { + "version": 2, + "hash": "93dfab144e0b0884c602358b3a9ce889bb29ab96e3b4adcfe3cef47a31694a9a", + "action": "add" + } + }, + "LibConfig": { + "1": { + "version": 1, + "hash": "ee8f0e3f6aae81948d72e30226645e8eb5d312a6770411a1edca748168c467c0", + "action": "add" + }, + "2": { + "version": 2, + "hash": "a8a78a8d726ee9e79f95614f3d0fa5b85edc6fce7be7651715208669be93e0e3", + "action": "add" + } + }, + "SyftImageRegistry": { + "1": { + "version": 1, + "hash": "67e18903e41cba1afe136adf29d404b63ec04fea6e928abb2533ec4fa52b246b", + "action": "add" + } + }, + "SyftWorkerImage": { + "1": { + "version": 1, + "hash": "44da7badfbe573d5403d3ab78c077f17dbefc560b81fdf927b671815be047441", + "action": "add" + } + }, + "SyftWorker": { + "1": { + "version": 1, + "hash": "9d897f6039eabe48dfa8e8d5c5cdcb283b0375b4c64571b457777eaaf3fb1920", + "action": "add" + } + }, + "WorkerPool": { + "1": { + "version": 1, + "hash": "16efc5dd2596ae744fd611c8f46af9eaec1bd5729eb20e85e9fd2f31df402564", + "action": "add" + } + }, + "MarkdownDescription": { + "1": { + "version": 1, + "hash": "31a73f8824cad1636a55d14b6a1074cdb071d0d4e16e86baaa3d4f63a7e80134", + "action": "add" + } + }, + "HTMLObject": { + "1": { + "version": 1, + "hash": "97f2e93f5ceaa88015047186f66a17ff13df2a6b7925b41331f9e19d5a515a9f", + "action": "add" + } + }, + "PwdTokenResetConfig": { + "1": { + "version": 1, + "hash": "0415a272428f22add4896c64aa9f29c8c1d35619e2433da6564eb5f1faff39ac", + "action": "add" + } + }, + "ServerSettingsUpdate": { + "1": { + "version": 1, + "hash": "1e4260ad879ae80728c3ffae2cd1d48759abd51f9d0960d4b25855cdbb4c506b", + "action": "add" + }, + "2": { + "version": 2, + "hash": "23b2716e9dceca667e228408e2416c82f11821e322e5bccf1f83406f3d09abdc", + "action": "add" + }, + "3": { + "version": 3, + "hash": "335c7946f2e52d09c7b26f511120cd340717c74c5cca9107e84f839da993c55c", + "action": "add" + }, + "4": { + "version": 4, + "hash": "8d7a41992c39c287fcb46383bed429ce75d3c9524ced8c86b88c26dd0232e2fe", + "action": "add" + } + }, + "ServerSettings": { + "1": { + "version": 1, + "hash": "5a1e7470cbeaaae5b80ac9beecb743734f7e4e42d429a09ea8defa569a5ddff1", + "action": "add" + }, + "2": { + "version": 2, + "hash": "7727ea54e494dc9deaa0d1bd38ac8a6180bc192b74eec5659adbc338a19e21f5", + "action": "add" + }, + "3": { + "version": 3, + "hash": "997667e1cba22d151857aacc2caba6b1ca73c1648adbd03461dc74a0c0c372b3", + "action": "add" + }, + "4": { + "version": 4, + "hash": "b8067777967a0e06733433e179e549caaf501419d62f7e8474ee33b839e3890d", + "action": "add" + } + }, + "APIEndpoint": { + "1": { + "version": 1, + "hash": "faa1cf9336a0d1233868c8c57745ff38c0be60399dc1acd0c0e8dd440e405dbd", + "action": "add" + } + }, + "LibEndpoint": { + "1": { + "version": 1, + "hash": "a585c83a33a019d363ae5a0c6d4197193654307c19a4829dfbf8a8cfd2c1842a", + "action": "add" + } + }, + "SignedSyftAPICall": { + "1": { + "version": 1, + "hash": "2f959455f7130f4e59360b8aa58f19785b76eaa0f8a5a9188a6cbf32b31311ca", + "action": "add" + } + }, + "SyftAPICall": { + "1": { + "version": 1, + "hash": "59e89e7b9ea30deaed64d1ffd9bc0769b999d3082b305428432c1f5be36c6343", + "action": "add" + } + }, + "SyftAPIData": { + "1": { + "version": 1, + "hash": "820b279c581cafd9bb5009702d4e3db22ec3a3156676426304b9038dad260a24", + "action": "add" + } + }, + "SyftAPI": { + "1": { + "version": 1, + "hash": "cc13ab058ee36748c14b0d4bd9b9e894c7566fff09cfa4170b3eece520169f15", + "action": "add" + } + }, + "HTTPConnection": { + "1": { + "version": 1, + "hash": "bf10f81646c71069c76292b1237b4a3de1e507264392c5c591d067636ce6fb46", + "action": "add" + } + }, + "PythonConnection": { + "1": { + "version": 1, + "hash": "28010778b5e3463ff6960a0e2224818de00bc7b5e6f892192e02e399ccbe18b5", + "action": "add" + } + }, + "ActionDataEmpty": { + "1": { + "version": 1, + "hash": "e0e4a5cf18d05b6b747addc048515c6f2a5f35f0766ebaee96d898cb971e1c5b", + "action": "add" + } + }, + "ObjectNotReady": { + "1": { + "version": 1, + "hash": "8cf471e205cd0893d6aae5f0227d14db7df1c9698da08a3ab991f59132d17fe9", + "action": "add" + } + }, + "ActionDataLink": { + "1": { + "version": 1, + "hash": "3469478343439e411b761c270eec63eb3d533e459ad72d0965158c3a6cdf3b9a", + "action": "add" + } + }, + "Action": { + "1": { + "version": 1, + "hash": "021826d7c6f69bd0283d025d40661f3ffbeba8810ca94de01344f6afbdae62cd", + "action": "add" + } + }, + "ActionObject": { + "1": { + "version": 1, + "hash": "0a5f4bc343cb114a251f06686ecdbb59d74bfb3d29a098b176699deb35a1e683", + "action": "add" + } + }, + "AnyActionObject": { + "1": { + "version": 1, + "hash": "b3c44c7788c59c03fa1baeec656c2ca6e633f4cbd4b23ff7ece6ee94c38449f0", + "action": "add" + } + }, + "BlobFile": { + "1": { + "version": 1, + "hash": "d99239100f1cb0b73c69b2ad7cab01a06909cc3a4976ba2b3b67cf6fe5e2f516", + "action": "add" + } + }, + "BlobFileOBject": { + "1": { + "version": 1, + "hash": "6c40dab2c8d2220d4fff7cc653d76cc026a856db7e2b5713b6341e255adc7ea2", + "action": "add" + } + }, + "SecureFilePathLocation": { + "1": { + "version": 1, + "hash": "ea5978b98d7773d221665b450454c9130c103a5c850669a0acd620607cd614b7", + "action": "add" + } + }, + "SeaweedSecureFilePathLocation": { + "1": { + "version": 1, + "hash": "3fc9bfc8c1b1cf660c9747e8c1fe3eb2220e78d4e3b5d6b5c5f29a07a77ebf3e", + "action": "add" + } + }, + "AzureSecureFilePathLocation": { + "1": { + "version": 1, + "hash": "090a9e962eeb655586ee966c5651d8996363969818a38f9a486fd64d33047e05", + "action": "add" + } + }, + "BlobStorageEntry": { + "1": { + "version": 1, + "hash": "afdc6a1d8a24b1ee1ed9d3e79f5bac64b4f0d9d36800f07f10be0b896470345f", + "action": "add" + } + }, + "BlobStorageMetadata": { + "1": { + "version": 1, + "hash": "9d4b61ac4ea1910c2f7c767a50a6a52544a24663548f069e79bd906f11b538e4", + "action": "add" + } + }, + "CreateBlobStorageEntry": { + "1": { + "version": 1, + "hash": "ffc3cbfeade67d074dc5bf7d655a1eb8c83630076028a72b3cc4548f3b413e14", + "action": "add" + } + }, + "BlobRetrieval": { + "1": { + "version": 1, + "hash": "c422c74b89a9349742acaa848566fe18bfef1a83333458b858c074baed37a859", + "action": "add" + } + }, + "SyftObjectRetrieval": { + "1": { + "version": 1, + "hash": "b2b62447445adc4cd0b77ab59d6fa56624dd316fb50281e570daad07556b6db2", + "action": "add" + } + }, + "BlobRetrievalByURL": { + "1": { + "version": 1, + "hash": "4db0e3b7a6334d3835356d8393866711e243e360af25a95f3cc4066f032404b5", + "action": "add" + } + }, + "BlobDeposit": { + "1": { + "version": 1, + "hash": "6eb5cc57dc763126bfc6ec5a2b79d02e77eadf9d9efb1888a5c366b7799c1c24", + "action": "add" + } + }, + "WorkerSettings": { + "1": { + "version": 1, + "hash": "b18b575d0e633fa4adbe88f08ea39f056e6882aff5ede0334911b8309d2ef489", + "action": "add" + } + }, + "HTTPServerRoute": { + "1": { + "version": 1, + "hash": "938245604a9c7e50001299afff5b669b2548364e356fed22a22780497831bf81", + "action": "add" + } + }, + "PythonServerRoute": { + "1": { + "version": 1, + "hash": "a068d8f942d55ecb6d45af88a27c6ebf208584275bf589cbc308df3f774ab9a9", + "action": "add" + } + }, + "VeilidServerRoute": { + "1": { + "version": 1, + "hash": "e676bc165601d2ede69707a4b6168ed4674f3f98887026d098a2dd4da4dfd097", + "action": "add" + } + }, + "EnclaveMetadata": { + "1": { + "version": 1, + "hash": "8d2dfafa01ec909c080a790cf15a8fc78e00382d3bfe6207098ceb25a60b9c53", + "action": "add" + } + }, + "CustomEndpointActionObject": { + "1": { + "version": 1, + "hash": "c7addbaf2777707f3e91e5c1e092343476cd22efc4ec8617f39ccf76e61a5a14", + "action": "add" + }, + "2": { + "version": 2, + "hash": "846ba36e8737a1bec16853c9de54c4948450009278e0b76fe7e3355ef9e70089", + "action": "add" + } + }, + "DataSubject": { + "1": { + "version": 1, + "hash": "582cdf9e82b5d6915b7f09f7c0d5f08328b11a2ce9b0198e5083f1672c2e2bf5", + "action": "add" + } + }, + "DataSubjectCreate": { + "1": { + "version": 1, + "hash": "5a8423c2690d55f425bfeecc87cd4a797a75d88ebb5fbda754d4f269b62d2ceb", + "action": "add" + } + }, + "DataSubjectMemberRelationship": { + "1": { + "version": 1, + "hash": "0810483ea76ea10c8f286c6035dc0b2085291f345183be50c179f3a05a577110", + "action": "add" + } + }, + "Contributor": { + "1": { + "version": 1, + "hash": "30c32bd44098f00e0b15496be441763b6e50af8b12d3d2bef33aca6287193876", + "action": "add" + } + }, + "Asset": { + "1": { + "version": 1, + "hash": "000abc78719611c106295cf12b1690b7e5411dc1bb9db9d4afd22956da90d1f4", + "action": "add" + } + }, + "CreateAsset": { + "1": { + "version": 1, + "hash": "357d52576cb12b24fb3980342bb49a562b065c0e4419e87d34176340628c7309", + "action": "add" + } + }, + "Dataset": { + "1": { + "version": 1, + "hash": "0ca6b0b4a3aebb2c8f351668075b44951bb20d1e23a779b82109124f334ce3a4", + "action": "add" + } + }, + "DatasetPageView": { + "1": { + "version": 1, + "hash": "aa0dd69637281b80d5523b4409a2c7e89db114c9fe79c858063c6dadff8977d1", + "action": "add" + }, + "2": { + "version": 2, + "hash": "be1ca6dcd0b3aa0481ce5dce737e78432d06a78ad0c701aaf136be407c798352", + "action": "add" + } + }, + "CreateDataset": { + "1": { + "version": 1, + "hash": "7e02dfa89540c3dbebacbb13810d95cdc4e36db31d56cffed7ab54abe25716c9", + "action": "add" + } + }, + "SyftLog": { + "1": { + "version": 1, + "hash": "1bcd71e5bf3f0db3bba0996f33b6b2bde3489b9c71f11e6b30c3495c76a8f53f", + "action": "add" + } + }, "JobItem": { + "2": { + "version": 2, + "hash": "b087d0c62b7d304c6ca80e4fb0e8a7f2a444be8f8cba57490dc09aeb98033105", + "action": "add" + } + }, + "ExecutionOutput": { + "1": { + "version": 1, + "hash": "e36c71685edf5276a3427cb6749550486d3a177c1dcf73dd337ab2a73c0ce6b5", + "action": "add" + } + }, + "TwinObject": { + "1": { + "version": 1, + "hash": "4f31243fb348dbb083579afd6f638d75af010cb53d19bfba59b74afff41ccbbb", + "action": "add" + } + }, + "PolicyRule": { + "1": { + "version": 1, + "hash": "44d1ca1db97be46f66558aa1a729ff31bf8e113c6a913b11aedf9d6b6ad5b7b5", + "action": "add" + } + }, + "CreatePolicyRule": { + "1": { + "version": 1, + "hash": "342bb723526d445151a0435f57d251f4c1219f8ae7cca3e8e9fce52e2ee1b8b1", + "action": "add" + } + }, + "CreatePolicyRuleConstant": { + "1": { + "version": 1, + "hash": "78b54832cb0468a87013bc36bc11d4759874ca1b5065a1b711f1e5ef5d94c2df", + "action": "add" + } + }, + "Matches": { + "1": { + "version": 1, + "hash": "dd6d91ddb2ec5eaf60be2b0899ecfdb9a15f7904aa39d2f4d9bb2d7b793040e6", + "action": "add" + } + }, + "PreFill": { + "1": { + "version": 1, + "hash": "c7aefb11dc4c4569dcd1e6988371047a32a8be1b32ad46d12adba419a19769ad", + "action": "add" + } + }, + "UserOwned": { + "1": { + "version": 1, + "hash": "c8738dc3d8c2a5ef461b85a0467c3dff53dab16b54a4d12b44b1477906aef51d", + "action": "add" + } + }, + "MixedInputPolicy": { + "1": { + "version": 1, + "hash": "37bb12d950518d9579c8ec7c4cc22ac731ea82caf8c1370dd0b0a82b46462dde", + "action": "add" + } + }, + "ExactMatch": { + "1": { + "version": 1, + "hash": "5eb37edbf5e451d942e599247f3eaed923c1fe9d91eefdba02bf06503f6cc08d", + "action": "add" + } + }, + "OutputHistory": { + "1": { + "version": 1, + "hash": "9366db79d131f8c65e5a4ff12c90e2aa0c11e302debe06e46eeb93b26e2aaf61", + "action": "add" + } + }, + "OutputPolicyExecuteCount": { + "1": { + "version": 1, + "hash": "2a77e5ed5c7b0391147562651ad4061e20b11745c191fbc34cb549da37ba72dd", + "action": "add" + } + }, + "OutputPolicyExecuteOnce": { + "1": { + "version": 1, + "hash": "5589c00d127d9eb1f5ccf3a16def8219737784d57bb3bf9be5cb6d83325ef436", + "action": "add" + } + }, + "EmptyInputPolicy": { + "1": { + "version": 1, + "hash": "7ef81cfd223be0064600e1503f8b04bafc16385e27730e9319466e68a077c68b", + "action": "add" + } + }, + "UserPolicy": { + "1": { + "version": 1, + "hash": "74373bb71a334f4dcf77623ae10ff5b1c7e5b3006f65f2051ffb1e01f422f982", + "action": "add" + } + }, + "SubmitUserPolicy": { + "1": { + "version": 1, + "hash": "ec4e808eb39613bcdbbbf9ffb3267612084a9d99880a2f3bee3ef32d46329c02", + "action": "add" + } + }, + "UserCodeStatusCollection": { + "1": { + "version": 1, + "hash": "735ecf2d4abb1e7d19b2e751d880f32b01ce267ba10e417ef1b440be3d94d8f1", + "action": "add" + } + }, + "UserCode": { + "1": { + "version": 1, + "hash": "3bcd14413b9c4fbde7c5612c2ed713518340280b5cff89cf2aaaf1c77c4037a8", + "action": "add" + } + }, + "SubmitUserCode": { + "1": { + "version": 1, + "hash": "d2bb8cfe12f070b4adafded78ce01900c5409bd83f055f94b1e285745ef65a76", + "action": "add" + } + }, + "UserCodeExecutionResult": { + "1": { + "version": 1, + "hash": "1f4cbc62caac4dd193f427306405dc7a099ae744bea5830cf57149ce71c1e589", + "action": "add" + } + }, + "UserCodeExecutionOutput": { + "1": { + "version": 1, + "hash": "c1d53300a39dbbb437d7d5a1257bd175a067b1065f4099a0938fac7540035258", + "action": "add" + }, + "2": { + "version": 2, + "hash": "3e104e39b4ab53c950e61e4f7e92ce935cf96a5100de301de9bf297eb7e5787e", + "action": "add" + } + }, + "CodeHistory": { + "1": { + "version": 1, + "hash": "e3ef5346f108257828f364d22b12d9311812c9cf843200afef5dc4d9302f9b21", + "action": "add" + } + }, + "CodeHistoryView": { + "1": { + "version": 1, + "hash": "8b8b97d334b51d1ce0a9efab722411ff25caa3f12be319105954497e0a306eb2", + "action": "add" + } + }, + "CodeHistoriesDict": { + "1": { + "version": 1, + "hash": "01d7dcd4b21525a06e4484d8699a4a34a5c84f1f6026ec55e32eb30412742601", + "action": "add" + } + }, + "UsersCodeHistoriesDict": { + "1": { + "version": 1, + "hash": "4ed8b83973258ea19a1f91feb2590ff73b801be86f4296cc3db48f6929ff784c", + "action": "add" + } + }, + "SyftObjectMigrationState": { + "1": { + "version": 1, + "hash": "ee83315828551f18904bab18e0cac48896493620561215b04cc448e6ce5834af", + "action": "add" + } + }, + "StoreMetadata": { + "1": { + "version": 1, + "hash": "8de9a22a2765ef976bc161cb0704347d30350c085da8c8ffa876065cfca3e5fd", + "action": "add" + } + }, + "MigrationData": { + "1": { + "version": 1, + "hash": "cb96b8c8413609e1224341d1b0dd1efb08387c0ff7b0ff65eba36c0b104c9ed1", + "action": "add" + }, + "2": { + "version": 2, + "hash": "1d1b14c196221ecf6d644d7dcaa32ac9e90361b2687fa83161ff399ebc6df1bd", + "action": "add" + } + }, + "OnDiskBlobDeposit": { + "1": { + "version": 1, + "hash": "817bf1bee4a35bfa1cd25d6779a10d8d180b1b3f1e837952f81f48b9411d1970", + "action": "add" + } + }, + "RemoteConfig": { + "1": { + "version": 1, + "hash": "179d067099a178d748c6d9a0477e8de7c3b55577439669eca7150258f2409567", + "action": "add" + } + }, + "AzureRemoteConfig": { + "1": { + "version": 1, + "hash": "a143811fec0da5fd881e927643ef667c91c78a2c90519cf88da7da20738bd187", + "action": "add" + } + }, + "SeaweedFSBlobDeposit": { + "1": { + "version": 1, + "hash": "febeb2a2ce81aa2c512e4c6b611b582984042aafa0541403d4584662273a166c", + "action": "add" + } + }, + "NumpyArrayObject": { + "1": { + "version": 1, + "hash": "05dd2917b7692b3daf4e7ad083a46fa7ec7a2be8faac8d4a654809189c986443", + "action": "add" + } + }, + "NumpyScalarObject": { + "1": { + "version": 1, + "hash": "8753e5c78270a5cacbf0439447724772f4765351a4a8b58b0a5c416a6b2c8b6e", + "action": "add" + } + }, + "NumpyBoolObject": { + "1": { + "version": 1, + "hash": "331c44f8fa3d0a077f1aaad7313bae2c43b386d04def7b8bedae9fdf7690134d", + "action": "add" + } + }, + "PandasDataframeObject": { + "1": { + "version": 1, + "hash": "5e8018364cea31d5f185a901da4ab89846b02153ee7d041ee8a6d305ece31f90", + "action": "add" + } + }, + "PandasSeriesObject": { + "1": { + "version": 1, + "hash": "b8bd482bf16fc7177e9778292cd42f8835b6ced2ce8dc88908b4b8e6d7c7528f", + "action": "add" + } + }, + "Change": { + "1": { + "version": 1, + "hash": "75fb9a5cd4e76b189ebe130a421d3921a0c251947a48bbb92a2ef1c315dc3c16", + "action": "add" + } + }, + "ChangeStatus": { + "1": { + "version": 1, + "hash": "c914a6f7637b555a51b71e8e197e591f7a2e28121e29b5dd586f87e0383d179d", + "action": "add" + } + }, + "ActionStoreChange": { + "1": { + "version": 1, + "hash": "1a803bb08924b49f3114fd46e0e132f819d4d56be5e03a27e9fe90947ca26e85", + "action": "add" + } + }, + "CreateCustomImageChange": { + "1": { + "version": 1, + "hash": "c3dbea3f49979fdcc517c0d13cd02739ca2fe86b370c42496a224f142ae31562", + "action": "add" + } + }, + "CreateCustomWorkerPoolChange": { + "1": { + "version": 1, + "hash": "0355793dd58b364dcb84fff29714b6a26446bead3ba95c6d75e3200008e580f4", + "action": "add" + } + }, + "Request": { + "1": { + "version": 1, + "hash": "1d69f5f0074114f99aa29c5ee77cb20b9151e5b50e77b026f11c3632a12efadf", + "action": "add" + } + }, + "RequestInfo": { + "1": { + "version": 1, + "hash": "779562547744ebed64548f8021647292604fdf4256bf79685dfa14a1e56cc27b", + "action": "add" + } + }, + "RequestInfoFilter": { + "1": { + "version": 1, + "hash": "bb881a003032f4676321218d7cd09580f4d64fccaa1cf9e118fdcd5c73c3d3a8", + "action": "add" + } + }, + "SubmitRequest": { + "1": { + "version": 1, + "hash": "6c38b6ffd0a6f7442746e68b9ace7b21cb1dca7d2031929db5f9a302a280403f", + "action": "add" + } + }, + "ObjectMutation": { + "1": { + "version": 1, + "hash": "ce88096760ce9334599c8194ec97b0a1470651ad680d9d21b8826a0df0af2a36", + "action": "add" + } + }, + "EnumMutation": { + "1": { + "version": 1, + "hash": "5173fda73df17a344eb663b7692cca48bd46bf1773455439836b852cd165448c", + "action": "add" + } + }, + "UserCodeStatusChange": { + "1": { + "version": 1, + "hash": "89aaf7f1368c782e3a1b9e79988877f6eaa05ab84365f7d321b757fde7fe86e7", + "action": "add" + } + }, + "SyncedUserCodeStatusChange": { + "1": { + "version": 1, + "hash": "d9ad2d341eb645bd50d06330cd30fd4c266f93e37b9f5391d58b78365fc440e6", + "action": "add" + } + }, + "TwinAPIContextView": { + "1": { + "version": 1, + "hash": "e099eef32cb3a8a806cbdc54cc7fca96bed3d60344bd571163ec049db407938b", + "action": "add" + } + }, + "CustomAPIView": { + "1": { + "version": 1, + "hash": "769e96bebd05736ab860591670fb6da19406239b0104ddc71bd092a134335146", + "action": "add" + } + }, + "CustomApiEndpoint": { + "1": { + "version": 1, + "hash": "ec4a217585336d1b59c93c18570443a63f4fbb24d2c088fbacf80bcf389d23e8", + "action": "add" + } + }, + "PrivateAPIEndpoint": { + "1": { + "version": 1, + "hash": "6d7d143432c2811c520ab6dade005ba40173b590e5c676be04f5921b970ef938", + "action": "add" + } + }, + "PublicAPIEndpoint": { + "1": { + "version": 1, + "hash": "3bf51fc33aa8feb1abc9d0ef792e8889da31a57050430e0bd8e17f2065ff8734", + "action": "add" + } + }, + "UpdateTwinAPIEndpoint": { + "1": { + "version": 1, + "hash": "851e59412716e73c7f70a696619e0b375ce136b43f6fe2ea784747091caba5d8", + "action": "add" + } + }, + "CreateTwinAPIEndpoint": { + "1": { + "version": 1, + "hash": "3d0b84dae95ebcc6647b5aabe54e65b3c6bf957665fde57d8037806a4aac13be", + "action": "add" + } + }, + "TwinAPIEndpoint": { + "1": { + "version": 1, + "hash": "d1947b8f9c80d6c9b443e5a9f0758afa8849a5f12b9a511feefd7e4f82c374f4", + "action": "add" + } + }, + "SyncState": { + "1": { + "version": 1, + "hash": "9a3f0bb973858b55bc766c9770c4d9abcc817898f797d94a89938650c0c67868", + "action": "add" + } + }, + "ServerPeer": { + "1": { + "version": 1, + "hash": "0d5f252018e324ea0d2dcb5c2ad8bd15707220565fce4f14de7f63a8f9e4391b", + "action": "add" + } + }, + "ServerPeerUpdate": { + "1": { + "version": 1, + "hash": "0b854b57db7a18118c1fd8f31495b2ba4eeb9fbe4f24c631ff112418a94570d3", + "action": "add" + } + }, + "AssociationRequestChange": { + "1": { + "version": 1, + "hash": "0134ac0002879c85fc9ddb06bed6306a8905c8434b0a40d3a96ce24a7bd4da90", + "action": "add" + } + }, + "QueueItem": { + "1": { + "version": 1, + "hash": "1db212c46b6c56ccc5579cfe2141b693f0cd9286e2ede71210393e8455379bf1", + "action": "add" + } + }, + "ActionQueueItem": { + "1": { + "version": 1, + "hash": "396d579dfc2e2b36b9fbed2f204bffcca1bea7ee2db7175045dd3328ebf08718", + "action": "add" + } + }, + "APIEndpointQueueItem": { + "1": { + "version": 1, + "hash": "f04b3990a8d29c116d301e70df54d58f188895307a411dc13a666ff764ffd8dd", + "action": "add" + } + }, + "ZMQClientConfig": { + "1": { + "version": 1, + "hash": "36ee8f75067d5144f0ed062cdc79466caae16b7a128231d89b6b430174843bde", + "action": "add" + } + }, + "SQLiteStoreConfig": { + "1": { + "version": 1, + "hash": "ad062a5f863ae84683867d2a6a5e1d4420c010a64b88bc7b392106e33d71ac03", + "action": "add" + } + }, + "ProjectEvent": { + "1": { + "version": 1, + "hash": "dc0486c52daebd5e98c2b3b03ffd9a9a14bc3d86d8dc0c23e41ebf6c31fe2ffb", + "action": "add" + } + }, + "ProjectThreadMessage": { + "1": { + "version": 1, + "hash": "99256d7592577d1e37df94a06eabc0a287f2d79e144c51fd719315e278edb46d", + "action": "add" + } + }, + "ProjectMessage": { + "1": { + "version": 1, + "hash": "b5004b6354f71b19c81dd5f4b20bf446e0b959f5608a22707e96b944dd8175b0", + "action": "add" + } + }, + "ProjectRequestResponse": { + "1": { + "version": 1, + "hash": "52162a8a779a4a301d8755691bf4cf994c86b9f650f9e8c8a923b44e635b1bc0", + "action": "add" + } + }, + "ProjectRequest": { + "1": { + "version": 1, + "hash": "dc684135d5a5a48e5fc7988598c1e6e0de76cf1c5995f1c283fcf63d0eb4d24f", + "action": "add" + } + }, + "AnswerProjectPoll": { + "1": { + "version": 1, + "hash": "c83d83a5ba6cc034d5061df200b3f1d029aa770b1e13dbef959bb1790323dc6e", + "action": "add" + } + }, + "ProjectPoll": { + "1": { + "version": 1, + "hash": "ecf69b3b324e0bee9c82295796d44c4e8f796496cdc9db6d4302c2f160566466", + "action": "add" + } + }, + "Project": { + "1": { + "version": 1, + "hash": "de86a1163ddbcd1cc3cc2b1b5dfcb85a8ad9f9d4bbc759c2b1f92a0d0a2ff184", + "action": "add" + } + }, + "ProjectSubmit": { "1": { "version": 1, - "hash": "0b32277b7d3b9bdc14a2a51cc9005f8254e7f7b6ec059ddcccbcd681a807afb6", - "action": "remove" + "hash": "7555ba11ee5a814dcd9c45647300020f7359efc1081559940990cbd745936cac", + "action": "add" } }, - "DictStoreConfig": { + "Plan": { "1": { "version": 1, - "hash": "2e1365c5535fa51c22eef79f67dd6444789bc829c27881367e3050e06e2ffbfe", - "action": "remove" + "hash": "ed05cb87aec832098fc464ac36cd6bceaab705463d0d2fa1b2d8e1ccc510018c", + "action": "add" } } } diff --git a/packages/syft/src/syft/server/worker_settings.py b/packages/syft/src/syft/server/worker_settings.py index 57a69d2a4eb..48b3f737590 100644 --- a/packages/syft/src/syft/server/worker_settings.py +++ b/packages/syft/src/syft/server/worker_settings.py @@ -30,8 +30,9 @@ class WorkerSettings(SyftObject): server_side_type: ServerSideType deployment_type: DeploymentType = DeploymentType.REMOTE signing_key: SyftSigningKey - document_store_config: StoreConfig - action_store_config: StoreConfig + # PR NOTE: Update these fields to use new config classes + document_store_config: StoreConfig | None = None + action_store_config: StoreConfig | None = None blob_store_config: BlobStorageConfig | None = None queue_config: QueueConfig | None = None log_level: int | None = None @@ -47,8 +48,6 @@ def from_server(cls, server: AbstractServer) -> Self: name=server.name, server_type=server.server_type, signing_key=server.signing_key, - document_store_config=server.document_store_config, - action_store_config=server.action_store_config, server_side_type=server_side_type, blob_store_config=server.blob_store_config, queue_config=server.queue_config, diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 35f2537410d..614921752ae 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -758,10 +758,20 @@ def set( self.server_uid, ) + fields = serialize_json(obj) + try: + # check if the fields are deserializable + # PR NOTE: Is this too much extra work? + deserialize_json(fields) + except Exception as e: + raise StashException( + f"Error serializing object: {e}. Some fields are invalid." + ) + # create the object with the permissions stmt = self.table.insert().values( id=uid, - fields=serialize_json(obj), + fields=fields, permissions=permissions, storage_permissions=storage_permissions, ) diff --git a/packages/syft/tests/syft/stores/queue_stash_test.py b/packages/syft/tests/syft/stores/queue_stash_test.py index 9809d4a8774..68ebd47e999 100644 --- a/packages/syft/tests/syft/stores/queue_stash_test.py +++ b/packages/syft/tests/syft/stores/queue_stash_test.py @@ -64,7 +64,7 @@ def test_queue_stash_sanity(queue: Any) -> None: pytest.lazy_fixture("mongo_queue_stash"), ], ) -# @pytest.mark.flaky(reruns=3, reruns_delay=3) +# def test_queue_stash_set_get(root_verify_key, queue: Any) -> None: objs = [] repeats = 5 @@ -104,7 +104,6 @@ def test_queue_stash_set_get(root_verify_key, queue: Any) -> None: pytest.lazy_fixture("mongo_queue_stash"), ], ) -@pytest.mark.flaky(reruns=3, reruns_delay=3) def test_queue_stash_update(root_verify_key, queue: Any) -> None: obj = mock_queue_object() res = queue.set(root_verify_key, obj, ignore_duplicates=False) @@ -134,7 +133,6 @@ def test_queue_stash_update(root_verify_key, queue: Any) -> None: pytest.lazy_fixture("mongo_queue_stash"), ], ) -@pytest.mark.flaky(reruns=3, reruns_delay=3) def test_queue_set_existing_queue_threading(root_verify_key, queue: Any) -> None: thread_cnt = 3 repeats = 5 @@ -176,7 +174,6 @@ def _kv_cbk(tid: int) -> None: pytest.lazy_fixture("mongo_queue_stash"), ], ) -@pytest.mark.flaky(reruns=3, reruns_delay=3) def test_queue_update_existing_queue_threading(root_verify_key, queue: Any) -> None: thread_cnt = 3 repeats = 5 @@ -219,7 +216,6 @@ def _kv_cbk(tid: int) -> None: pytest.lazy_fixture("mongo_queue_stash"), ], ) -@pytest.mark.flaky(reruns=3, reruns_delay=3) def test_queue_set_delete_existing_queue_threading( root_verify_key, queue: Any, @@ -305,7 +301,6 @@ def _kv_cbk(tid: int) -> None: assert len(queue) == thread_cnt * repeats -@pytest.mark.flaky(reruns=3, reruns_delay=3) def test_queue_set_sqlite(root_verify_key, sqlite_workspace): def create_queue_cbk(): return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace) @@ -313,7 +308,6 @@ def create_queue_cbk(): helper_queue_set_threading(root_verify_key, create_queue_cbk) -@pytest.mark.flaky(reruns=3, reruns_delay=3) def test_queue_set_threading_mongo(root_verify_key, mongo_document_store): def create_queue_cbk(): return mongo_queue_stash_fn(mongo_document_store) @@ -363,7 +357,6 @@ def _kv_cbk(tid: int) -> None: assert execution_err is None -@pytest.mark.flaky(reruns=3, reruns_delay=3) def test_queue_update_threading_sqlite(root_verify_key, sqlite_workspace): def create_queue_cbk(): return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace) @@ -371,7 +364,6 @@ def create_queue_cbk(): helper_queue_update_threading(root_verify_key, create_queue_cbk) -@pytest.mark.flaky(reruns=3, reruns_delay=3) def test_queue_update_threading_mongo(root_verify_key, mongo_document_store): def create_queue_cbk(): return mongo_queue_stash_fn(mongo_document_store) @@ -429,7 +421,6 @@ def _kv_cbk(tid: int) -> None: assert len(queue) == 0 -@pytest.mark.flaky(reruns=3, reruns_delay=3) def test_queue_delete_threading_sqlite(root_verify_key, sqlite_workspace): def create_queue_cbk(): return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace) @@ -437,7 +428,6 @@ def create_queue_cbk(): helper_queue_set_delete_threading(root_verify_key, create_queue_cbk) -@pytest.mark.flaky(reruns=3, reruns_delay=3) def test_queue_delete_threading_mongo(root_verify_key, mongo_document_store): def create_queue_cbk(): return mongo_queue_stash_fn(mongo_document_store) diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index 9ed439b55d1..46287f1b6b5 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -77,9 +77,9 @@ def test_new_admin_can_list_user_code( assert not isinstance(res, SyftError) user_code_stash = worker.get_service("usercodeservice").stash - user_code = user_code_stash.get_all(user_code_stash.root_verify_key).ok() + user_code = user_code_stash._data - assert len(user_code) == len(admin.code.get_all()) + assert 1 == len(admin.code.get_all()) assert {c.id for c in user_code} == {c.id for c in admin.code} From 38212c7f7a36037ad6c4b32944b40f393350499a Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 5 Sep 2024 10:53:13 +0200 Subject: [PATCH 084/197] fix job unittests --- packages/syft/src/syft/server/server.py | 26 +++++++++++-------- .../syft/src/syft/server/worker_settings.py | 12 ++++----- .../syft/src/syft/service/network/routes.py | 3 +-- packages/syft/src/syft/store/db/sqlite_db.py | 4 +++ 4 files changed, 25 insertions(+), 20 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index e978c79dc1c..17137079fac 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -89,6 +89,7 @@ from ..store.blob_storage.on_disk import OnDiskBlobStorageClientConfig from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig from ..store.blob_storage.seaweedfs import SeaweedFSBlobDeposit +from ..store.db.sqlite_db import DBConfig from ..store.db.sqlite_db import SQLiteDBConfig from ..store.db.sqlite_db import SQLiteDBManager from ..store.document_store import StoreConfig @@ -295,6 +296,7 @@ class Server(AbstractServer): signing_key: SyftSigningKey | None required_signed_calls: bool = True packages: str + db_config: DBConfig def __init__( self, @@ -304,6 +306,7 @@ def __init__( signing_key: SyftSigningKey | SigningKey | None = None, action_store_config: StoreConfig | None = None, document_store_config: StoreConfig | None = None, + db_config: DBConfig | None = None, root_email: str | None = default_root_email, root_username: str | None = default_root_username, root_password: str | None = default_root_password, @@ -394,7 +397,16 @@ def __init__( store_type="Action Store", ) - self.init_stores() + if db_config is None: + db_config = SQLiteDBConfig( + filename=f"{self.id}_json.db", + path=self.get_temp_dir("db"), + ) + # json_db_config = PostgresDBConfig(reset=False) + + self.db_config = db_config + + self.init_stores(db_config=self.db_config) # construct services only after init stores self.services: ServiceRegistry = ServiceRegistry.for_server(self) @@ -868,17 +880,9 @@ def reload_user_code() -> None: if ti is not None: CODE_RELOADER[ti] = reload_user_code - def init_stores( - self, - ) -> None: - # TODO fix database filename + reset - json_db_config = SQLiteDBConfig( - filename=f"{self.id}_json.db", - path=self.get_temp_dir("db"), - ) - # json_db_config = PostgresDBConfig(reset=False) + def init_stores(self, db_config: DBConfig) -> None: self.db = SQLiteDBManager( - config=json_db_config, + config=db_config, server_uid=self.id, root_verify_key=self.verify_key, ) diff --git a/packages/syft/src/syft/server/worker_settings.py b/packages/syft/src/syft/server/worker_settings.py index 57a69d2a4eb..70a6ddf76ca 100644 --- a/packages/syft/src/syft/server/worker_settings.py +++ b/packages/syft/src/syft/server/worker_settings.py @@ -13,8 +13,8 @@ from ..server.credentials import SyftSigningKey from ..service.queue.base_queue import QueueConfig from ..store.blob_storage import BlobStorageConfig -from ..store.document_store import StoreConfig -from ..types.syft_object import SYFT_OBJECT_VERSION_1 +from ..store.db.sqlite_db import DBConfig +from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.syft_object import SyftObject from ..types.uid import UID @@ -22,7 +22,7 @@ @serializable() class WorkerSettings(SyftObject): __canonical_name__ = "WorkerSettings" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 id: UID name: str @@ -30,8 +30,7 @@ class WorkerSettings(SyftObject): server_side_type: ServerSideType deployment_type: DeploymentType = DeploymentType.REMOTE signing_key: SyftSigningKey - document_store_config: StoreConfig - action_store_config: StoreConfig + db_config: DBConfig blob_store_config: BlobStorageConfig | None = None queue_config: QueueConfig | None = None log_level: int | None = None @@ -47,8 +46,7 @@ def from_server(cls, server: AbstractServer) -> Self: name=server.name, server_type=server.server_type, signing_key=server.signing_key, - document_store_config=server.document_store_config, - action_store_config=server.action_store_config, + db_config=server.db_config, server_side_type=server_side_type, blob_store_config=server.blob_store_config, queue_config=server.queue_config, diff --git a/packages/syft/src/syft/service/network/routes.py b/packages/syft/src/syft/service/network/routes.py index 5cd7a5f2136..f6de35f75fd 100644 --- a/packages/syft/src/syft/service/network/routes.py +++ b/packages/syft/src/syft/service/network/routes.py @@ -130,8 +130,7 @@ def server(self) -> AbstractServer | None: server_type=self.worker_settings.server_type, server_side_type=self.worker_settings.server_side_type, signing_key=self.worker_settings.signing_key, - document_store_config=self.worker_settings.document_store_config, - action_store_config=self.worker_settings.action_store_config, + db_config=self.worker_settings.db_config, processes=1, ) return server diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index da3877632bf..4ee509c9dac 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -13,6 +13,7 @@ from sqlalchemy.orm import sessionmaker # relative +from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...types.uid import UID from .models import Base @@ -20,6 +21,7 @@ from .utils import loads +@serializable(canonical_name="DBConfig", version=1) class DBConfig(BaseModel): reset: bool = False @@ -28,6 +30,7 @@ def connection_string(self) -> str: raise NotImplementedError("Subclasses must implement this method.") +@serializable(canonical_name="SQLiteDBConfig", version=1) class SQLiteDBConfig(DBConfig): filename: str = Field(default_factory=lambda: f"{uuid.uuid4()}.db") path: Path = Field(default_factory=lambda: Path(tempfile.gettempdir())) @@ -38,6 +41,7 @@ def connection_string(self) -> str: return f"sqlite:///{filepath.resolve()}" +@serializable(canonical_name="PostgresDBConfig", version=1) class PostgresDBConfig(DBConfig): host: str = "localhost" port: int = 5432 From 668bc3ba30c56f293fdd76e106c5f713b088177f Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 5 Sep 2024 11:01:48 +0200 Subject: [PATCH 085/197] simple way to create stashes --- packages/syft/src/syft/store/db/sqlite_db.py | 24 ++++++++++++ packages/syft/src/syft/store/db/stash.py | 4 ++ packages/syft/tests/syft/worker_test.py | 39 ++++++++++++-------- 3 files changed, 51 insertions(+), 16 deletions(-) diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index da3877632bf..3199ef92131 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -13,6 +13,7 @@ from sqlalchemy.orm import sessionmaker # relative +from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey from ...types.uid import UID from .models import Base @@ -61,6 +62,12 @@ def __init__( self.server_uid = server_uid self.root_verify_key = root_verify_key + def init_tables(self) -> None: + pass + + def reset(self) -> None: + pass + class SQLiteDBManager(DBManager): def __init__( @@ -112,3 +119,20 @@ def get_session_threading_local(self) -> Session: @property def session(self) -> Session: return self.get_session_threading_local() + + @classmethod + def random( + cls, + *, + config: SQLiteDBConfig | None = None, + server_uid: UID | None = None, + root_verify_key: SyftVerifyKey | None = None, + ) -> "SQLiteDBManager": + root_verify_key = root_verify_key or SyftSigningKey.generate().verify_key + server_uid = server_uid or UID() + config = config or SQLiteDBConfig() + return SQLiteDBManager( + config=config, + server_uid=server_uid, + root_verify_key=root_verify_key, + ) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 614921752ae..c8c9faa9847 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -360,6 +360,10 @@ def row_as_obj(self, row: Row) -> StashT: def get_role(self, credentials: SyftVerifyKey) -> ServiceRole: # TODO error handling + if Base.metadata.tables.get("User") is None: + # if User table does not exist, we assume the user is a guest + # this happens when we create stashes in tests + return ServiceRole.GUEST user_table = Table("User", Base.metadata) stmt = select(user_table.c.fields["role"]).where( self._get_field_filter("verify_key", str(credentials), table=user_table), diff --git a/packages/syft/tests/syft/worker_test.py b/packages/syft/tests/syft/worker_test.py index 37ae7b7da24..d189cb2d025 100644 --- a/packages/syft/tests/syft/worker_test.py +++ b/packages/syft/tests/syft/worker_test.py @@ -16,6 +16,7 @@ from syft.server.credentials import SyftVerifyKey from syft.server.worker import Worker from syft.service.action.action_object import ActionObject +from syft.service.action.action_store import ActionObjectStash from syft.service.context import AuthedServiceContext from syft.service.queue.queue_stash import QueueItem from syft.service.response import SyftError @@ -23,9 +24,9 @@ from syft.service.user.user import UserCreate from syft.service.user.user import UserView from syft.service.user.user_service import UserService +from syft.store.db.sqlite_db import SQLiteDBManager from syft.types.errors import SyftException from syft.types.result import Ok -from syft.types.uid import UID test_signing_key_string = ( "b7803e90a6f3f4330afbd943cef3451c716b338b17a9cf40a0a309bc38bc366d" @@ -75,30 +76,36 @@ def test_signing_key() -> None: assert test_verify_key == test_verify_key_2 -def test_action_store() -> None: +@pytest.fixture +def action_object_stash() -> ActionObjectStash: + root_verify_key = SyftVerifyKey.from_string(test_verify_key_string) + db_manager = SQLiteDBManager.random(root_verify_key=root_verify_key) + stash = ActionObjectStash(store=db_manager) + stash.db.init_tables() + yield stash + + +def test_action_store(action_object_stash: ActionObjectStash) -> None: test_signing_key = SyftSigningKey.from_string(test_signing_key_string) - action_store = ... - uid = UID() + test_verify_key = test_signing_key.verify_key raw_data = np.array([1, 2, 3]) test_object = ActionObject.from_obj(raw_data) + # PR NOTE: Why `uid` was not `uid = test_object.id`? + uid = test_object.id - set_result = action_store.set( + action_object_stash.set_or_update( uid=uid, - credentials=test_signing_key, + credentials=test_verify_key, syft_object=test_object, has_result_read_permission=True, - ) - assert set_result.is_ok() - test_object_result = action_store.get(uid=uid, credentials=test_signing_key) - assert test_object_result.is_ok() - assert (test_object == test_object_result.ok()).all() + ).unwrap() + from_stash = action_object_stash.get(uid=uid, credentials=test_verify_key).unwrap() + assert (test_object == from_stash).all() test_verift_key_2 = SyftVerifyKey.from_string(test_verify_key_string_2) - test_object_result_fail = action_store.get(uid=uid, credentials=test_verift_key_2) - assert test_object_result_fail.is_err() - exc = test_object_result_fail.err() - assert type(exc) == SyftException - assert "denied" in exc.public_message + with pytest.raises(SyftException) as exc: + action_object_stash.get(uid=uid, credentials=test_verift_key_2).unwrap() + assert "denied" in exc.public_message def test_user_transform() -> None: From 61564561324a7fa6aa4f00a450e351ce1ed238f6 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 5 Sep 2024 12:04:02 +0200 Subject: [PATCH 086/197] fix queue --- packages/syft/src/syft/protocol/protocol_version.json | 6 +++--- packages/syft/src/syft/serde/json_serde.py | 5 +---- packages/syft/src/syft/server/worker_settings.py | 5 +---- packages/syft/src/syft/service/queue/queue.py | 11 ++++++----- packages/syft/src/syft/service/queue/zmq_producer.py | 8 ++++---- packages/syft/tests/syft/worker_test.py | 8 -------- 6 files changed, 15 insertions(+), 28 deletions(-) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index ca79d2ac467..d7f1a607404 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -484,9 +484,9 @@ } }, "WorkerSettings": { - "1": { - "version": 1, - "hash": "b18b575d0e633fa4adbe88f08ea39f056e6882aff5ede0334911b8309d2ef489", + "2": { + "version": 2, + "hash": "13c6e022b939778ab37b594dbc5094aba9f54564c90d3cb0c21115382b155bfe", "action": "add" } }, diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index af4569844aa..301788ff849 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -24,7 +24,6 @@ from ..types.datetime import DateTime from ..types.errors import SyftException from ..types.syft_object import BaseDateTime -from ..types.syft_object import DYNAMIC_SYFT_ATTRIBUTES from ..types.syft_object_registry import SyftObjectRegistry from ..types.uid import LineageID from ..types.uid import UID @@ -175,9 +174,7 @@ def _serialize_pydantic_to_json(obj: pydantic.BaseModel) -> dict[str, Json]: JSON_VERSION_FIELD: version, } - all_exclude_attrs = ( - set(exclude_attrs) | DEFAULT_EXCLUDE_ATTRS | set(DYNAMIC_SYFT_ATTRIBUTES) - ) + all_exclude_attrs = set(exclude_attrs) | DEFAULT_EXCLUDE_ATTRS for key, type_ in obj.model_fields.items(): if key in all_exclude_attrs: diff --git a/packages/syft/src/syft/server/worker_settings.py b/packages/syft/src/syft/server/worker_settings.py index 70a6ddf76ca..c76e5617d0d 100644 --- a/packages/syft/src/syft/server/worker_settings.py +++ b/packages/syft/src/syft/server/worker_settings.py @@ -37,10 +37,7 @@ class WorkerSettings(SyftObject): @classmethod def from_server(cls, server: AbstractServer) -> Self: - if server.server_side_type: - server_side_type: str = server.server_side_type.value - else: - server_side_type = ServerSideType.HIGH_SIDE + server_side_type = server.server_side_type or ServerSideType.HIGH_SIDE return cls( id=server.id, name=server.name, diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index da1ded8bd70..8855a3c8bcb 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -5,6 +5,7 @@ from threading import Thread import time from typing import Any +from typing import cast # third party import psutil @@ -165,8 +166,7 @@ def handle_message_multiprocessing( id=worker_settings.id, name=worker_settings.name, signing_key=worker_settings.signing_key, - document_store_config=worker_settings.document_store_config, - action_store_config=worker_settings.action_store_config, + db_config=worker_settings.db_config, blob_storage_config=worker_settings.blob_store_config, server_side_type=worker_settings.server_side_type, queue_config=queue_config, @@ -251,7 +251,10 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: from ...server.server import Server queue_item = deserialize(message, from_bytes=True) + queue_item = cast(QueueItem, queue_item) worker_settings = queue_item.worker_settings + if worker_settings is None: + raise ValueError("Worker settings are missing in the queue item.") queue_config = worker_settings.queue_config queue_config.client_config.create_producer = False @@ -261,9 +264,7 @@ def handle_message(message: bytes, syft_worker_id: UID) -> None: id=worker_settings.id, name=worker_settings.name, signing_key=worker_settings.signing_key, - document_store_config=worker_settings.document_store_config, - action_store_config=worker_settings.action_store_config, - blob_storage_config=worker_settings.blob_store_config, + db_config=worker_settings.db_config, server_side_type=worker_settings.server_side_type, deployment_type=worker_settings.deployment_type, queue_config=queue_config, diff --git a/packages/syft/src/syft/service/queue/zmq_producer.py b/packages/syft/src/syft/service/queue/zmq_producer.py index 85dbb0edbf0..63d2aa6375c 100644 --- a/packages/syft/src/syft/service/queue/zmq_producer.py +++ b/packages/syft/src/syft/service/queue/zmq_producer.py @@ -160,7 +160,7 @@ def read_items(self) -> None: # Items to be queued items_to_queue = self.queue_stash.get_by_status( - self.queue_stash.partition.root_verify_key, + self.queue_stash.root_verify_key, status=Status.CREATED, ).unwrap() @@ -168,7 +168,7 @@ def read_items(self) -> None: # Queue Items that are in the processing state items_processing = self.queue_stash.get_by_status( - self.queue_stash.partition.root_verify_key, + self.queue_stash.root_verify_key, status=Status.PROCESSING, ).unwrap() @@ -284,14 +284,14 @@ def update_consumer_state_for_worker( try: try: self.worker_stash.get_by_uid( - credentials=self.worker_stash.partition.root_verify_key, + credentials=self.worker_stash.root_verify_key, uid=syft_worker_id, ).unwrap() except Exception: return None self.worker_stash.update_consumer_state( - credentials=self.worker_stash.partition.root_verify_key, + credentials=self.worker_stash.root_verify_key, worker_uid=syft_worker_id, consumer_state=consumer_state, ).unwrap() diff --git a/packages/syft/tests/syft/worker_test.py b/packages/syft/tests/syft/worker_test.py index 37ae7b7da24..eadc7630313 100644 --- a/packages/syft/tests/syft/worker_test.py +++ b/packages/syft/tests/syft/worker_test.py @@ -223,14 +223,6 @@ def post_add(context: Any, name: str, new_result: Any) -> Any: action_object.syft_post_hooks__["__add__"] = [] -def test_worker_serde(worker) -> None: - ser = sy.serialize(worker, to_bytes=True) - de = sy.deserialize(ser, from_bytes=True) - - assert de.signing_key == worker.signing_key - assert de.id == worker.id - - @pytest.fixture(params=[0]) def worker_with_proc(request): worker = Worker( From c6ae446307b39a09fca2090509a1bcf3924b999e Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 5 Sep 2024 13:39:43 +0200 Subject: [PATCH 087/197] fix jobservice --- packages/syft/src/syft/service/job/job_service.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index 925891877a5..f6e3389f3d8 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -221,7 +221,7 @@ def add_read_permission_job_for_code_owner( job.id, ActionPermission.READ, user_code.user_verify_key ) # TODO: make add_permission wrappable - return self.stash.add_permission(permission=permission) + return self.stash.add_permission(permission=permission).unwrap() @service_method( path="job.add_read_permission_log_for_code_owner", @@ -237,7 +237,7 @@ def add_read_permission_log_for_code_owner( ActionObjectPermission( log_id, ActionPermission.READ, user_code.user_verify_key ) - ) + ).unwrap() @service_method( path="job.create_job_for_user_code_id", From cd38c63311ddcea38832574a3f0a24a3a2c28e19 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 5 Sep 2024 15:12:16 +0200 Subject: [PATCH 088/197] make traceback clickable --- .../syft/assets/jinja/syft_exception.jinja2 | 60 +++++++++++-------- packages/syft/src/syft/types/errors.py | 1 + .../components/tabulator_template.py | 8 +++ 3 files changed, 45 insertions(+), 24 deletions(-) diff --git a/packages/syft/src/syft/assets/jinja/syft_exception.jinja2 b/packages/syft/src/syft/assets/jinja/syft_exception.jinja2 index bd4bc01635a..eab1977edb4 100644 --- a/packages/syft/src/syft/assets/jinja/syft_exception.jinja2 +++ b/packages/syft/src/syft/assets/jinja/syft_exception.jinja2 @@ -6,35 +6,47 @@
- {% if server_trace%} -
Server Trace:
-
{{server_trace | escape}}
-
-
+ {% if server_trace %} +
Server Trace:
+
+      {% if dev_mode %}
+        {{ server_trace | make_links | safe }}
+      {% else %}
+        {{ server_trace | escape }}
       {% endif %}
-      
Client Trace:
-
{{traceback_str | escape}}
+
+
+
+ {% endif %} +
Client Trace:
+
+      {% if dev_mode %}
+        {{ traceback_str | make_links | safe }}
+      {% else %}
+        {{ traceback_str | escape }}
+      {% endif %}
+    
+ .syft-exception-trace { + display: inline; + } + \ No newline at end of file diff --git a/packages/syft/src/syft/types/errors.py b/packages/syft/src/syft/types/errors.py index 72e7181f92c..d1255921a01 100644 --- a/packages/syft/src/syft/types/errors.py +++ b/packages/syft/src/syft/types/errors.py @@ -182,6 +182,7 @@ def _repr_html_(self) -> str: message=self._private_message or self.public, traceback_str=traceback_str, display=display, + dev_mode=is_dev_mode, ) return table_html diff --git a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py index 21b4c48761e..4e93e82c45a 100644 --- a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py +++ b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py @@ -1,6 +1,7 @@ # stdlib import json import logging +import re import secrets from typing import Any @@ -20,8 +21,15 @@ logger = logging.getLogger(__name__) + +def make_links(text: str) -> str: + file_pattern = re.compile(r"([\w/.-]+\.py)\", line (\d+)") + return file_pattern.sub(r'\1, line \2', text) + + DEFAULT_ID_WIDTH = 110 jinja_env = jinja2.Environment(loader=jinja2.PackageLoader("syft", "assets/jinja")) # nosec +jinja_env.filters["make_links"] = make_links def create_tabulator_columns( From 33d654f8539b68f6e1de2419d11214a8a795dbaf Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 5 Sep 2024 16:06:34 +0200 Subject: [PATCH 089/197] postgres bug fixes --- packages/syft/src/syft/store/db/sqlite_db.py | 1 + packages/syft/src/syft/store/db/stash.py | 51 ++++++++++++++------ 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index 22acff20049..5d56914ca8f 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -107,6 +107,7 @@ def init_tables(self) -> None: if self.config.reset: # drop all tables that we know about Base.metadata.drop_all(bind=self.engine) + self.config.reset = False Base.metadata.create_all(self.engine) def reset(self) -> None: diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index c8c9faa9847..d5f445eae70 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -103,7 +103,7 @@ def _create_table(self) -> Table: storage_permissions_type = ( JSON if self.db.engine.dialect.name == "sqlite" - else postgresql.ARRAY(UIDTypeDecorator) + else postgresql.ARRAY(sa.String) ) if table_name not in Base.metadata.tables: Table( @@ -208,7 +208,7 @@ def _get_field_filter( if self.db.engine.dialect.name == "sqlite": return table.c.fields[field_name] == func.json_quote(json_value) elif self.db.engine.dialect.name == "postgresql": - return sa.cast(table.c.fields[field_name], sa.String) == json_value + return table.c.fields[field_name].astext == json_value def _get_by_fields( self, @@ -327,8 +327,13 @@ def get_all_contains( ) -> list[StashT]: # TODO write filter logic, merge with get_all + if self._is_sqlite(): + field_value = func.json_quote(field_value) + else: + field_value = [field_value] # type: ignore + stmt = self.table.select().where( - self.table.c.fields[field_name].contains(func.json_quote(field_value)), + self.table.c.fields[field_name].contains(field_value), ) stmt = self._apply_permission_filter( stmt, credentials=credentials, has_permission=has_permission @@ -570,17 +575,21 @@ def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: @as_result(NotFoundException) def add_permission(self, permission: ActionObjectPermission) -> None: # TODO add error handling - stmt = ( - self.table.update() - .where(self.table.c.id == permission.uid) - .values( + stmt = self.table.update().where(self.table.c.id == permission.uid) + if self._is_sqlite(): + stmt = stmt.values( permissions=func.json_insert( self.table.c.permissions, "$[#]", permission.permission_string, ) ) - ) + else: + stmt = stmt.values( + permissions=func.array_append( + self.table.c.permissions, permission.permission_string + ) + ) result = self.session.execute(stmt) self.session.commit() @@ -651,11 +660,18 @@ def has_permission(self, permission: ActionObjectPermission) -> bool: def has_storage_permission(self, permission: StoragePermission) -> bool: return self.has_storage_permissions([permission]) + def _is_sqlite(self) -> bool: + return self.db.engine.dialect.name == "sqlite" + def has_storage_permissions(self, permissions: list[StoragePermission]) -> bool: permission_filters = [ sa.and_( self._get_field_filter("id", p.uid), - self.table.c.storage_permissions.contains(p.server_uid), + self.table.c.storage_permissions.contains( + p.server_uid.no_dash + if self._is_sqlite() + else [p.server_uid.no_dash] + ), ) for p in permissions ] @@ -690,17 +706,22 @@ def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: @as_result(NotFoundException) def add_storage_permission(self, permission: StoragePermission) -> None: - stmt = ( - self.table.update() - .where(self.table.c.id == permission.uid) - .values( + stmt = self.table.update().where(self.table.c.id == permission.uid) + if self._is_sqlite(): + stmt = stmt.values( storage_permissions=func.json_insert( self.table.c.storage_permissions, "$[#]", permission.permission_string, ) ) - ) + else: + stmt = stmt.values( + permissions=func.array_append( + self.table.c.storage_permissions, permission.server_uid.no_dash + ) + ) + result = self.session.execute(stmt) self.session.commit() if result.rowcount == 0: @@ -759,7 +780,7 @@ def set( storage_permissions = [] if add_storage_permission: storage_permissions.append( - self.server_uid, + self.server_uid.no_dash, ) fields = serialize_json(obj) From ddeab6d9c0c17aefe9432e50eef4c6d57b9c6f48 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 6 Sep 2024 11:49:17 +0200 Subject: [PATCH 090/197] wip postgres integration --- .github/workflows/pr-tests-stack.yml | 4 +- packages/grid/backend/backend.dockerfile | 9 +-- packages/grid/backend/grid/core/config.py | 9 +-- packages/grid/backend/grid/core/server.py | 18 +---- packages/grid/default.env | 11 ++- packages/grid/devspace.yaml | 14 ++-- packages/grid/helm/examples/dev/base.yaml | 7 ++ .../backend/backend-statefulset.yaml | 20 +++--- .../postgres/postgres-headless-service.yaml | 15 ++++ .../templates/postgres/postgres-secret.yaml | 17 +++++ .../templates/postgres/postgres-service.yaml | 17 +++++ .../postgres/postgres-statefuleset.yaml | 72 +++++++++++++++++++ packages/grid/helm/syft/values.yaml | 32 +++++++++ packages/syft/setup.cfg | 1 + packages/syft/src/syft/server/server.py | 3 +- packages/syft/src/syft/store/db/sqlite_db.py | 8 +-- tox.ini | 2 + 17 files changed, 210 insertions(+), 49 deletions(-) create mode 100644 packages/grid/helm/syft/templates/postgres/postgres-headless-service.yaml create mode 100644 packages/grid/helm/syft/templates/postgres/postgres-secret.yaml create mode 100644 packages/grid/helm/syft/templates/postgres/postgres-service.yaml create mode 100644 packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml diff --git a/.github/workflows/pr-tests-stack.yml b/.github/workflows/pr-tests-stack.yml index 2ad68b0730c..d9d1fc51649 100644 --- a/.github/workflows/pr-tests-stack.yml +++ b/.github/workflows/pr-tests-stack.yml @@ -60,9 +60,9 @@ jobs: if: steps.changes.outputs.stack == 'true' timeout-minutes: 60 run: | - echo "Skipping pr image test" + tox -e backend.test.basecpu # run: | - # tox -e backend.test.basecpu + # echo "Skipping pr image test" pr-tests-syft-integration: strategy: diff --git a/packages/grid/backend/backend.dockerfile b/packages/grid/backend/backend.dockerfile index c51ba31c8fd..caf5c69141b 100644 --- a/packages/grid/backend/backend.dockerfile +++ b/packages/grid/backend/backend.dockerfile @@ -84,9 +84,10 @@ ENV \ DEFAULT_ROOT_EMAIL="info@openmined.org" \ DEFAULT_ROOT_PASSWORD="changethis" \ STACK_API_KEY="changeme" \ - MONGO_HOST="localhost" \ - MONGO_PORT="27017" \ - MONGO_USERNAME="root" \ - MONGO_PASSWORD="example" + POSTGRESQL_DBNAME="syftdb_postgres" \ + POSTGRESQL_HOST="localhost" \ + POSTGRESQL_PORT="5432" \ + POSTGRESQL_USERNAME="syft_postgres" \ + POSTGRESQL_PASSWORD="changethis" CMD ["bash", "./grid/start.sh"] diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index 63bda939c29..55a81aee24a 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -126,10 +126,11 @@ def get_emails_enabled(self) -> Self: # NETWORK_CHECK_INTERVAL: int = int(os.getenv("NETWORK_CHECK_INTERVAL", 60)) # DATASITE_CHECK_INTERVAL: int = int(os.getenv("DATASITE_CHECK_INTERVAL", 60)) CONTAINER_HOST: str = str(os.getenv("CONTAINER_HOST", "docker")) - MONGO_HOST: str = str(os.getenv("MONGO_HOST", "")) - MONGO_PORT: int = int(os.getenv("MONGO_PORT", 27017)) - MONGO_USERNAME: str = str(os.getenv("MONGO_USERNAME", "")) - MONGO_PASSWORD: str = str(os.getenv("MONGO_PASSWORD", "")) + POSTGRESQL_DBNAME: str = str(os.getenv("POSTGRESQL_DBNAME", "")) + POSTGRESQL_HOST: str = str(os.getenv("POSTGRESQL_HOST", "")) + POSTGRESQL_PORT: int = int(os.getenv("POSTGRESQL_PORT", 5432)) + POSTGRESQL_USERNAME: str = str(os.getenv("POSTGRESQL_USERNAME", "")) + POSTGRESQL_PASSWORD: str = str(os.getenv("POSTGRESQL_PASSWORD", "")) DEV_MODE: bool = True if os.getenv("DEV_MODE", "false").lower() == "true" else False # ZMQ stuff QUEUE_PORT: int = int(os.getenv("QUEUE_PORT", 5556)) diff --git a/packages/grid/backend/grid/core/server.py b/packages/grid/backend/grid/core/server.py index 3f401d7e349..c8720ac3d18 100644 --- a/packages/grid/backend/grid/core/server.py +++ b/packages/grid/backend/grid/core/server.py @@ -14,8 +14,6 @@ from syft.service.queue.zmq_client import ZMQQueueConfig from syft.store.blob_storage.seaweedfs import SeaweedFSClientConfig from syft.store.blob_storage.seaweedfs import SeaweedFSConfig -from syft.store.mongo_client import MongoStoreClientConfig -from syft.store.mongo_document_store import MongoStoreConfig from syft.store.sqlite_document_store import SQLiteStoreClientConfig from syft.store.sqlite_document_store import SQLiteStoreConfig from syft.types.uid import UID @@ -36,17 +34,6 @@ def queue_config() -> ZMQQueueConfig: return queue_config -def mongo_store_config() -> MongoStoreConfig: - mongo_client_config = MongoStoreClientConfig( - hostname=settings.MONGO_HOST, - port=settings.MONGO_PORT, - username=settings.MONGO_USERNAME, - password=settings.MONGO_PASSWORD, - ) - - return MongoStoreConfig(client_config=mongo_client_config) - - def sql_store_config() -> SQLiteStoreConfig: client_config = SQLiteStoreClientConfig( filename=f"{UID.from_string(get_server_uid_env())}.sqlite", @@ -87,20 +74,17 @@ def seaweedfs_config() -> SeaweedFSConfig: worker_class = worker_classes[server_type] single_container_mode = settings.SINGLE_CONTAINER_MODE -store_config = sql_store_config() if single_container_mode else mongo_store_config() blob_storage_config = None if single_container_mode else seaweedfs_config() queue_config = queue_config() worker: Server = worker_class( name=server_name, server_side_type=server_side_type, - action_store_config=store_config, - document_store_config=store_config, enable_warnings=enable_warnings, blob_storage_config=blob_storage_config, local_db=single_container_mode, queue_config=queue_config, - migrate=True, + migrate=False, in_memory_workers=settings.INMEMORY_WORKERS, smtp_username=settings.SMTP_USERNAME, smtp_password=settings.SMTP_PASSWORD, diff --git a/packages/grid/default.env b/packages/grid/default.env index e1bc5c42557..46645305791 100644 --- a/packages/grid/default.env +++ b/packages/grid/default.env @@ -110,4 +110,13 @@ ENABLE_SIGNUP=False DOCKER_IMAGE_ENCLAVE_ATTESTATION=openmined/syft-enclave-attestation # Rathole Config -RATHOLE_PORT=2333 \ No newline at end of file +RATHOLE_PORT=2333 + +# PostgresSQL Config +# POSTGRESQL_IMAGE=postgres +# export POSTGRESQL_VERSION="15" +POSTGRESQL_DBNAME=syftdb_postgres +POSTGRESQL_HOST=postgres +POSTGRESQL_PORT=5432 +POSTGRESQL_USERNAME=syft_postgres +POSTGRESQL_PASSWORD=changethis \ No newline at end of file diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 43dd9d234cd..0d4250b9026 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -79,12 +79,12 @@ deployments: - ./helm/examples/dev/base.yaml dev: - mongo: + postgres: labelSelector: app.kubernetes.io/name: syft - app.kubernetes.io/component: mongo + app.kubernetes.io/component: postgres ports: - - port: "27017" + - port: "5432" seaweedfs: labelSelector: app.kubernetes.io/name: syft @@ -205,8 +205,8 @@ profiles: path: dev.seaweedfs # Port Re-Mapping - op: replace - path: dev.mongo.ports[0].port - value: 27018:27017 + path: dev.postgres.ports[0].port + value: 5433:5432 - op: replace path: dev.backend.ports[0].port value: 5679:5678 @@ -268,8 +268,8 @@ profiles: value: ./helm/examples/dev/enclave.yaml # Port Re-Mapping - op: replace - path: dev.mongo.ports[0].port - value: 27019:27017 + path: dev.postgres.ports[0].port + value: 5434:5432 - op: replace path: dev.backend.ports[0].port value: 5680:5678 diff --git a/packages/grid/helm/examples/dev/base.yaml b/packages/grid/helm/examples/dev/base.yaml index 4999ae40aed..4d290478841 100644 --- a/packages/grid/helm/examples/dev/base.yaml +++ b/packages/grid/helm/examples/dev/base.yaml @@ -28,6 +28,13 @@ mongo: secret: rootPassword: example +postgres: + resourcesPreset: null + resources: null + + secret: + rootPassword: example + seaweedfs: resourcesPreset: null resources: null diff --git a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml index 3293056ba2a..2d1a6880c33 100644 --- a/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml +++ b/packages/grid/helm/syft/templates/backend/backend-statefulset.yaml @@ -99,18 +99,20 @@ spec: - name: REVERSE_TUNNEL_ENABLED value: "true" {{- end }} - # MongoDB - - name: MONGO_PORT - value: {{ .Values.mongo.port | quote }} - - name: MONGO_HOST - value: "mongo" - - name: MONGO_USERNAME - value: {{ .Values.mongo.username | quote }} - - name: MONGO_PASSWORD + # Postgres + - name: POSTGRESQL_PORT + value: {{ .Values.postgres.port | quote }} + - name: POSTGRESQL_HOST + value: "postgres" + - name: POSTGRESQL_USERNAME + value: {{ .Values.postgres.username | quote }} + - name: POSTGRESQL_PASSWORD valueFrom: secretKeyRef: - name: {{ .Values.mongo.secretKeyName | required "mongo.secretKeyName is required" }} + name: {{ .Values.postgres.secretKeyName | required "postgres.secretKeyName is required" }} key: rootPassword + - name: POSTGRESQL_DBNAME + value: {{ .Values.postgres.dbname | quote }} # SMTP - name: SMTP_HOST value: {{ .Values.server.smtp.host | quote }} diff --git a/packages/grid/helm/syft/templates/postgres/postgres-headless-service.yaml b/packages/grid/helm/syft/templates/postgres/postgres-headless-service.yaml new file mode 100644 index 00000000000..4855a7868ff --- /dev/null +++ b/packages/grid/helm/syft/templates/postgres/postgres-headless-service.yaml @@ -0,0 +1,15 @@ +apiVersion: v1 +kind: Service +metadata: + name: postgres-headless + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: postgres +spec: + clusterIP: None + ports: + - name: postgres + port: 5432 + selector: + {{- include "common.selectorLabels" . | nindent 4 }} + app.kubernetes.io/component: postgres \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/postgres/postgres-secret.yaml b/packages/grid/helm/syft/templates/postgres/postgres-secret.yaml new file mode 100644 index 00000000000..63a990c0d9a --- /dev/null +++ b/packages/grid/helm/syft/templates/postgres/postgres-secret.yaml @@ -0,0 +1,17 @@ +{{- $secretName := "postgres-secret" }} +apiVersion: v1 +kind: Secret +metadata: + name: {{ $secretName }} + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: postgres +type: Opaque +data: + rootPassword: {{ include "common.secrets.set" (dict + "secret" $secretName + "key" "rootPassword" + "randomDefault" .Values.global.randomizedSecrets + "default" .Values.postgres.secret.rootPassword + "context" $) + }} \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/postgres/postgres-service.yaml b/packages/grid/helm/syft/templates/postgres/postgres-service.yaml new file mode 100644 index 00000000000..9cd8b156bdd --- /dev/null +++ b/packages/grid/helm/syft/templates/postgres/postgres-service.yaml @@ -0,0 +1,17 @@ +apiVersion: v1 +kind: Service +metadata: + name: postgres + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: postgres +spec: + type: ClusterIP + selector: + {{- include "common.selectorLabels" . | nindent 4 }} + app.kubernetes.io/component: postgres + ports: + - name: postgres + port: 5432 + protocol: TCP + targetPort: 5432 \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml b/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml new file mode 100644 index 00000000000..425f9e88770 --- /dev/null +++ b/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml @@ -0,0 +1,72 @@ +apiVersion: apps/v1 +kind: StatefulSet +metadata: + name: postgres + labels: + {{- include "common.labels" . | nindent 4 }} + app.kubernetes.io/component: postgres +spec: + replicas: 1 + updateStrategy: + type: RollingUpdate + selector: + matchLabels: + {{- include "common.selectorLabels" . | nindent 6 }} + app.kubernetes.io/component: postgres + serviceName: postgres-headless + podManagementPolicy: OrderedReady + template: + metadata: + labels: + {{- include "common.labels" . | nindent 8 }} + app.kubernetes.io/component: postgres + {{- if .Values.postgres.podLabels }} + {{- toYaml .Values.postgres.podLabels | nindent 8 }} + {{- end }} + {{- if .Values.postgres.podAnnotations }} + annotations: {{- toYaml .Values.postgres.podAnnotations | nindent 8 }} + {{- end }} + spec: + {{- if .Values.postgres.nodeSelector }} + nodeSelector: {{- .Values.postgres.nodeSelector | toYaml | nindent 8 }} + {{- end }} + containers: + - name: postgres-container + image: postgres:13 + imagePullPolicy: Always + resources: {{ include "common.resources.set" (dict "resources" .Values.postgres.resources "preset" .Values.postgres.resourcesPreset) | nindent 12 }} + env: + - name: POSTGRES_USER + value: {{ .Values.postgres.username | required "postgres.username is required" | quote }} + - name: POSTGRES_PASSWORD + valueFrom: + secretKeyRef: + name: {{ .Values.postgres.secretKeyName | required "postgres.secretKeyName is required" }} + key: rootPassword + - name: POSTGRES_DB + value: {{ .Values.postgres.dbname | required "postgres.dbname is required" | quote }} + {{- if .Values.postgres.env }} + {{- toYaml .Values.postgres.env | nindent 12 }} + {{- end }} + volumeMounts: + - mountPath: /data/db + name: postgres-data + readOnly: false + subPath: '' + ports: + - name: postgres-port + containerPort: 5432 + terminationGracePeriodSeconds: 5 + volumeClaimTemplates: + - metadata: + name: postgres-data + labels: + {{- include "common.volumeLabels" . | nindent 8 }} + app.kubernetes.io/component: postgres + spec: + accessModes: + - ReadWriteOnce + resources: + requests: + storage: {{ .Values.postgres.storageSize | quote }} + diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index deb7a6f1154..910533d79c3 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -43,6 +43,38 @@ mongo: # ================================================================================= +postgres: +# Postgres config + port: 5432 + username: syft_postgres + dbname: syftdb_postgres + host: postgres + + # Extra environment vars + env: null + + # Pod labels & annotations + podLabels: null + podAnnotations: null + + # Node selector for pods + nodeSelector: null + + # Pod Resource Limits + resourcesPreset: large + resources: null + + # PVC storage size + storageSize: 5Gi + + # Mongo secret name. Override this if you want to use a self-managed secret. + secretKeyName: postgres-secret + + # default/custom secret raw values + secret: + rootPassword: null +# ================================================================================= + frontend: # Extra environment vars env: null diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 010bc788433..629133ec0c4 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -69,6 +69,7 @@ syft = ipython<8.27.0 dynaconf==3.2.6 sqlalchemy==2.0.32 + psycopg2-binary==2.9.9 install_requires = %(syft)s diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 17137079fac..cc3f6c6431e 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -90,6 +90,7 @@ from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig from ..store.blob_storage.seaweedfs import SeaweedFSBlobDeposit from ..store.db.sqlite_db import DBConfig +from ..store.db.sqlite_db import PostgresDBConfig from ..store.db.sqlite_db import SQLiteDBConfig from ..store.db.sqlite_db import SQLiteDBManager from ..store.document_store import StoreConfig @@ -402,7 +403,7 @@ def __init__( filename=f"{self.id}_json.db", path=self.get_temp_dir("db"), ) - # json_db_config = PostgresDBConfig(reset=False) + db_config = PostgresDBConfig(reset=False) self.db_config = db_config diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index 5d56914ca8f..8cc260b5f88 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -44,11 +44,11 @@ def connection_string(self) -> str: @serializable(canonical_name="PostgresDBConfig", version=1) class PostgresDBConfig(DBConfig): - host: str = "localhost" + host: str = "postgres" port: int = 5432 - user: str = "postgres" - password: str = "postgres" - database: str = "postgres" + user: str = "syft_postgres" + password: str = "example" + database: str = "syftdb_postgres" @property def connection_string(self) -> str: diff --git a/tox.ini b/tox.ini index 571844a13af..5813eabbd07 100644 --- a/tox.ini +++ b/tox.ini @@ -1127,7 +1127,9 @@ setenv= CLUSTER_HTTP_PORT={env:CLUSTER_HTTP_PORT:9082} allowlist_externals = tox + bash commands = + bash -c "CLUSTER_NAME=${CLUSTER_NAME} tox -e dev.k8s.destroy" tox -e dev.k8s.start tox -e dev.k8s.{posargs:deploy} From 12a8bc49e2fc012992aae05a56519a5f6f0bd3a9 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 6 Sep 2024 12:00:09 +0200 Subject: [PATCH 091/197] comment out postgres --- packages/syft/src/syft/server/server.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index cc3f6c6431e..89691cd0bc4 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -90,7 +90,6 @@ from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig from ..store.blob_storage.seaweedfs import SeaweedFSBlobDeposit from ..store.db.sqlite_db import DBConfig -from ..store.db.sqlite_db import PostgresDBConfig from ..store.db.sqlite_db import SQLiteDBConfig from ..store.db.sqlite_db import SQLiteDBManager from ..store.document_store import StoreConfig @@ -403,7 +402,7 @@ def __init__( filename=f"{self.id}_json.db", path=self.get_temp_dir("db"), ) - db_config = PostgresDBConfig(reset=False) + # db_config = PostgresDBConfig(reset=False) self.db_config = db_config From 01c1599b54d3d6fcf25e7cb41a627f56187bca5f Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 9 Sep 2024 09:19:03 +0700 Subject: [PATCH 092/197] [k8s] stop creating the mongo pod in the cluster [chores] improve some debug printings --- packages/grid/backend/backend.dockerfile | 2 +- packages/grid/backend/grid/core/server.py | 5 +- packages/grid/default.env | 10 +-- .../mongo/mongo-headless-service.yaml | 15 ---- .../syft/templates/mongo/mongo-secret.yaml | 17 ----- .../syft/templates/mongo/mongo-service.yaml | 17 ----- .../templates/mongo/mongo-statefulset.yaml | 69 ------------------- packages/grid/helm/syft/values.yaml | 31 --------- packages/syft/src/syft/orchestra.py | 2 +- .../src/syft/protocol/protocol_version.json | 11 +++ .../syft/store/postgresql_document_store.py | 11 ++- 11 files changed, 23 insertions(+), 167 deletions(-) delete mode 100644 packages/grid/helm/syft/templates/mongo/mongo-headless-service.yaml delete mode 100644 packages/grid/helm/syft/templates/mongo/mongo-secret.yaml delete mode 100644 packages/grid/helm/syft/templates/mongo/mongo-service.yaml delete mode 100644 packages/grid/helm/syft/templates/mongo/mongo-statefulset.yaml diff --git a/packages/grid/backend/backend.dockerfile b/packages/grid/backend/backend.dockerfile index caf5c69141b..aafee9a43bb 100644 --- a/packages/grid/backend/backend.dockerfile +++ b/packages/grid/backend/backend.dockerfile @@ -88,6 +88,6 @@ ENV \ POSTGRESQL_HOST="localhost" \ POSTGRESQL_PORT="5432" \ POSTGRESQL_USERNAME="syft_postgres" \ - POSTGRESQL_PASSWORD="changethis" + POSTGRESQL_PASSWORD="example" CMD ["bash", "./grid/start.sh"] diff --git a/packages/grid/backend/grid/core/server.py b/packages/grid/backend/grid/core/server.py index 2a3e8081865..def69e71d72 100644 --- a/packages/grid/backend/grid/core/server.py +++ b/packages/grid/backend/grid/core/server.py @@ -1,3 +1,6 @@ +# stdlib +from pprint import pprint + # syft absolute from syft.abstract_server import ServerType from syft.server.datasite import Datasite @@ -93,7 +96,7 @@ def seaweedfs_config() -> SeaweedFSConfig: ) print("----------------------------Store Config----------------------------\n") -print(store_config.model_dump()) +pprint(store_config.model_dump()) print("\n----------------------------Store Config----------------------------") blob_storage_config = None if single_container_mode else seaweedfs_config() queue_config = queue_config() diff --git a/packages/grid/default.env b/packages/grid/default.env index 38da0dc5cc8..65019098bee 100644 --- a/packages/grid/default.env +++ b/packages/grid/default.env @@ -77,14 +77,6 @@ KANIKO_VERSION="v1.23.2" # Jax JAX_ENABLE_X64=True -# Mongo -MONGO_IMAGE=mongo -MONGO_VERSION="7.0.8" -MONGO_HOST=mongo -MONGO_PORT=27017 -MONGO_USERNAME=root -MONGO_PASSWORD=example - # Redis REDIS_PORT=6379 REDIS_STORE_DB_ID=0 @@ -119,4 +111,4 @@ POSTGRESQL_DBNAME=syftdb_postgres POSTGRESQL_HOST=localhost POSTGRESQL_PORT=5432 POSTGRESQL_USERNAME=syft_postgres -POSTGRESQL_PASSWORD=changethis \ No newline at end of file +POSTGRESQL_PASSWORD=example \ No newline at end of file diff --git a/packages/grid/helm/syft/templates/mongo/mongo-headless-service.yaml b/packages/grid/helm/syft/templates/mongo/mongo-headless-service.yaml deleted file mode 100644 index 7cb97ee3592..00000000000 --- a/packages/grid/helm/syft/templates/mongo/mongo-headless-service.yaml +++ /dev/null @@ -1,15 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - name: mongo-headless - labels: - {{- include "common.labels" . | nindent 4 }} - app.kubernetes.io/component: mongo -spec: - clusterIP: None - ports: - - name: mongo - port: 27017 - selector: - {{- include "common.selectorLabels" . | nindent 4 }} - app.kubernetes.io/component: mongo diff --git a/packages/grid/helm/syft/templates/mongo/mongo-secret.yaml b/packages/grid/helm/syft/templates/mongo/mongo-secret.yaml deleted file mode 100644 index 02c58d276ca..00000000000 --- a/packages/grid/helm/syft/templates/mongo/mongo-secret.yaml +++ /dev/null @@ -1,17 +0,0 @@ -{{- $secretName := "mongo-secret" }} -apiVersion: v1 -kind: Secret -metadata: - name: {{ $secretName }} - labels: - {{- include "common.labels" . | nindent 4 }} - app.kubernetes.io/component: mongo -type: Opaque -data: - rootPassword: {{ include "common.secrets.set" (dict - "secret" $secretName - "key" "rootPassword" - "randomDefault" .Values.global.randomizedSecrets - "default" .Values.mongo.secret.rootPassword - "context" $) - }} diff --git a/packages/grid/helm/syft/templates/mongo/mongo-service.yaml b/packages/grid/helm/syft/templates/mongo/mongo-service.yaml deleted file mode 100644 index a789f4e8f86..00000000000 --- a/packages/grid/helm/syft/templates/mongo/mongo-service.yaml +++ /dev/null @@ -1,17 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - name: mongo - labels: - {{- include "common.labels" . | nindent 4 }} - app.kubernetes.io/component: mongo -spec: - type: ClusterIP - selector: - {{- include "common.selectorLabels" . | nindent 4 }} - app.kubernetes.io/component: mongo - ports: - - name: mongo - port: 27017 - protocol: TCP - targetPort: 27017 diff --git a/packages/grid/helm/syft/templates/mongo/mongo-statefulset.yaml b/packages/grid/helm/syft/templates/mongo/mongo-statefulset.yaml deleted file mode 100644 index 91060b90a9b..00000000000 --- a/packages/grid/helm/syft/templates/mongo/mongo-statefulset.yaml +++ /dev/null @@ -1,69 +0,0 @@ -apiVersion: apps/v1 -kind: StatefulSet -metadata: - name: mongo - labels: - {{- include "common.labels" . | nindent 4 }} - app.kubernetes.io/component: mongo -spec: - replicas: 1 - updateStrategy: - type: RollingUpdate - selector: - matchLabels: - {{- include "common.selectorLabels" . | nindent 6 }} - app.kubernetes.io/component: mongo - serviceName: mongo-headless - podManagementPolicy: OrderedReady - template: - metadata: - labels: - {{- include "common.labels" . | nindent 8 }} - app.kubernetes.io/component: mongo - {{- if .Values.mongo.podLabels }} - {{- toYaml .Values.mongo.podLabels | nindent 8 }} - {{- end }} - {{- if .Values.mongo.podAnnotations }} - annotations: {{- toYaml .Values.mongo.podAnnotations | nindent 8 }} - {{- end }} - spec: - {{- if .Values.mongo.nodeSelector }} - nodeSelector: {{- .Values.mongo.nodeSelector | toYaml | nindent 8 }} - {{- end }} - containers: - - name: mongo-container - image: mongo:7 - imagePullPolicy: Always - resources: {{ include "common.resources.set" (dict "resources" .Values.mongo.resources "preset" .Values.mongo.resourcesPreset) | nindent 12 }} - env: - - name: MONGO_INITDB_ROOT_USERNAME - value: {{ .Values.mongo.username | required "mongo.username is required" | quote }} - - name: MONGO_INITDB_ROOT_PASSWORD - valueFrom: - secretKeyRef: - name: {{ .Values.mongo.secretKeyName | required "mongo.secretKeyName is required" }} - key: rootPassword - {{- if .Values.mongo.env }} - {{- toYaml .Values.mongo.env | nindent 12 }} - {{- end }} - volumeMounts: - - mountPath: /data/db - name: mongo-data - readOnly: false - subPath: '' - ports: - - name: mongo-port - containerPort: 27017 - terminationGracePeriodSeconds: 5 - volumeClaimTemplates: - - metadata: - name: mongo-data - labels: - {{- include "common.volumeLabels" . | nindent 8 }} - app.kubernetes.io/component: mongo - spec: - accessModes: - - ReadWriteOnce - resources: - requests: - storage: {{ .Values.mongo.storageSize | quote }} diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index 910533d79c3..a450d7a0e4d 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -12,37 +12,6 @@ global: # ================================================================================= -mongo: - # MongoDB config - port: 27017 - username: root - - # Extra environment vars - env: null - - # Pod labels & annotations - podLabels: null - podAnnotations: null - - # Node selector for pods - nodeSelector: null - - # Pod Resource Limits - resourcesPreset: large - resources: null - - # PVC storage size - storageSize: 5Gi - - # Mongo secret name. Override this if you want to use a self-managed secret. - secretKeyName: mongo-secret - - # default/custom secret raw values - secret: - rootPassword: null - -# ================================================================================= - postgres: # Postgres config port: 5432 diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 787c518f4ff..e3c3d5682c9 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -381,7 +381,7 @@ def launch( display( SyftInfo( message=f"You have launched a development server at http://{host}:{server_handle.port}." - + "It is intended only for local use." + + " It is intended only for local use." ) ) return server_handle diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 5f9f6a8fab1..ac6f743ffe7 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -1,5 +1,16 @@ { "1": { "release_name": "0.9.1.json" + }, + "dev": { + "object_versions": { + "PostgreSQLStorePartition": { + "1": { + "version": 1, + "hash": "1a807dcf54f969c53e6f46d62443d4dd83e5f6ff47fb4e9f6381c3374601c818", + "action": "add" + } + } + } } } diff --git a/packages/syft/src/syft/store/postgresql_document_store.py b/packages/syft/src/syft/store/postgresql_document_store.py index aca2976938b..b82f39acfad 100644 --- a/packages/syft/src/syft/store/postgresql_document_store.py +++ b/packages/syft/src/syft/store/postgresql_document_store.py @@ -100,9 +100,8 @@ def _connect(self) -> None: host=self.store_config.client_config.host, port=self.store_config.client_config.port, ) - print(f"Connected to {self.store_config.client_config.dbname}") - print("PostgreSQL database connection:", connection._check_connection_ok()) + print("PostgreSQL database connection:", connection) _CONNECTION_POOL_DB[cache_key(self.dbname)] = connection @@ -161,17 +160,17 @@ def _execute(self, sql: str, args: list[Any] | None) -> psycopg.Cursor: # Ensure self.cur is a psycopg cursor object cursor = self.cur # Assuming self.cur is already set as psycopg.Cursor cursor.execute(sql, args) # Execute the SQL with arguments - # cursor = self.cur.execute(sql, args) - except InFailedSqlTransaction: + self.db.commit() # Commit if everything went ok + except InFailedSqlTransaction as ie: self.db.rollback() # Rollback if something went wrong raise SyftException( - public_message=f"Transaction {sql} failed and was rolled back." + public_message=f"Transaction `{sql}` failed and was rolled back. \n" + f"Error: {ie}." ) except Exception as e: self.db.rollback() # Rollback on any other exception to maintain clean state public_message = special_exception_public_message(self.table_name, e) raise SyftException.from_exception(e, public_message=public_message) - self.db.commit() # Commit if everything went ok return cursor From 2fb065ee202eda05d888d1d55bdc3a4c166320c2 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 9 Sep 2024 10:44:01 +0700 Subject: [PATCH 093/197] [syft/postgre_store] fix store initialization by adding `CREATE TABLE IF NOT EXISTS` --- packages/syft/src/syft/store/postgresql_document_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/store/postgresql_document_store.py b/packages/syft/src/syft/store/postgresql_document_store.py index b82f39acfad..6fa7a3641a5 100644 --- a/packages/syft/src/syft/store/postgresql_document_store.py +++ b/packages/syft/src/syft/store/postgresql_document_store.py @@ -109,7 +109,7 @@ def create_table(self) -> None: try: with self.lock: self.cur.execute( - f"create table {self.table_name} (uid VARCHAR(32) NOT NULL PRIMARY KEY, " # nosec + f"CREATE TABLE IF NOT EXISTS {self.table_name} (uid VARCHAR(32) NOT NULL PRIMARY KEY, " # nosec + "repr TEXT NOT NULL, value BYTEA NOT NULL, " # nosec + "sqltime TIMESTAMP NOT NULL DEFAULT NOW())" # nosec ) From 9980c328718717878c39580a8ad770f587a56d7b Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 9 Sep 2024 10:49:47 +0700 Subject: [PATCH 094/197] [lint] fix some linting issues --- packages/syft/src/syft/store/sqlite_document_store.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/store/sqlite_document_store.py b/packages/syft/src/syft/store/sqlite_document_store.py index 1154952def7..e61748fac95 100644 --- a/packages/syft/src/syft/store/sqlite_document_store.py +++ b/packages/syft/src/syft/store/sqlite_document_store.py @@ -263,7 +263,7 @@ def _get_all(self) -> Any: cursor = self._execute(select_sql, []).unwrap() rows = cursor.fetchall() # type: ignore - if rows is None: + if not rows: return {} for row in rows: @@ -276,7 +276,7 @@ def _get_all_keys(self) -> Any: cursor = self._execute(select_sql, []).unwrap() rows = cursor.fetchall() # type: ignore - if rows is None: + if not rows: return [] keys = [UID(row[0]) for row in rows] From 6e128cfe436dc162874ddb96eb845a281fa2b932 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 9 Sep 2024 11:05:32 +0700 Subject: [PATCH 095/197] [tox] waiting for postgres instead of mongo service in k8s tests --- .../src/syft/store/sqlite_document_store.py | 18 +++++++++--------- tox.ini | 10 +++++----- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/packages/syft/src/syft/store/sqlite_document_store.py b/packages/syft/src/syft/store/sqlite_document_store.py index e61748fac95..2359d273b34 100644 --- a/packages/syft/src/syft/store/sqlite_document_store.py +++ b/packages/syft/src/syft/store/sqlite_document_store.py @@ -216,18 +216,18 @@ def _set(self, key: UID, value: Any) -> None: self._update(key, value) else: insert_sql = ( - f"insert into {self.table_name} (uid, repr, value) VALUES " - f"({self.subs_char}, {self.subs_char}, {self.subs_char})" - ) # nosec + f"insert into {self.table_name} (uid, repr, value) VALUES " # nosec + f"({self.subs_char}, {self.subs_char}, {self.subs_char})" # nosec + ) data = _serialize(value, to_bytes=True) self._execute(insert_sql, [str(key), _repr_debug_(value), data]).unwrap() def _update(self, key: UID, value: Any) -> None: insert_sql = ( - f"update {self.table_name} set uid = {self.subs_char}, " - f"repr = {self.subs_char}, value = {self.subs_char} " - f"where uid = {self.subs_char}" - ) # nosec + f"update {self.table_name} set uid = {self.subs_char}, " # nosec + f"repr = {self.subs_char}, value = {self.subs_char} " # nosec + f"where uid = {self.subs_char}" # nosec + ) data = _serialize(value, to_bytes=True) self._execute( insert_sql, [str(key), _repr_debug_(value), data, str(key)] @@ -235,9 +235,9 @@ def _update(self, key: UID, value: Any) -> None: def _get(self, key: UID) -> Any: select_sql = ( - f"select * from {self.table_name} where uid = {self.subs_char} " + f"select * from {self.table_name} where uid = {self.subs_char} " # nosec "order by sqltime" - ) # nosec + ) cursor = self._execute(select_sql, [str(key)]).unwrap( public_message=f"Query {select_sql} failed" ) diff --git a/tox.ini b/tox.ini index 5813eabbd07..a31f9c41a8b 100644 --- a/tox.ini +++ b/tox.ini @@ -460,7 +460,7 @@ commands = ; sleep 30 # wait for test-datasite-1 - bash packages/grid/scripts/wait_for.sh service mongo --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft + bash packages/grid/scripts/wait_for.sh service postgres --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service backend --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service proxy --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service seaweedfs --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft @@ -671,12 +671,12 @@ commands = sleep 30 # wait for test gateway 1 - bash packages/grid/scripts/wait_for.sh service mongo --context k3d-{env:GATEWAY_CLUSTER_NAME} --namespace syft + bash packages/grid/scripts/wait_for.sh service postgres --context k3d-{env:GATEWAY_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service backend --context k3d-{env:GATEWAY_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service proxy --context k3d-{env:GATEWAY_CLUSTER_NAME} --namespace syft # wait for test datasite 1 - bash packages/grid/scripts/wait_for.sh service mongo --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft + bash packages/grid/scripts/wait_for.sh service postgres --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service backend --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service proxy --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service seaweedfs --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft @@ -763,7 +763,7 @@ commands = sleep 30 # wait for test-datasite-1 - bash packages/grid/scripts/wait_for.sh service mongo --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft + bash packages/grid/scripts/wait_for.sh service postgres --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service backend --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service proxy --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service seaweedfs --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft @@ -1436,7 +1436,7 @@ commands = sleep 30 ; # wait for test-datasite-1 - bash packages/grid/scripts/wait_for.sh service mongo --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft + bash packages/grid/scripts/wait_for.sh service postgres --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service backend --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service proxy --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft bash packages/grid/scripts/wait_for.sh service seaweedfs --context k3d-{env:DATASITE_CLUSTER_NAME} --namespace syft From c5f0a487c9af25e7a9eca239be1bb22067ba4b1e Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 9 Sep 2024 11:22:46 +0700 Subject: [PATCH 096/197] [notebooks] revert api notebook `0-load-data` back to `dev`'s state --- notebooks/api/0.8/00-load-data.ipynb | 90 +--------------------------- 1 file changed, 2 insertions(+), 88 deletions(-) diff --git a/notebooks/api/0.8/00-load-data.ipynb b/notebooks/api/0.8/00-load-data.ipynb index 65aca80f4f5..8c3bb05b93b 100644 --- a/notebooks/api/0.8/00-load-data.ipynb +++ b/notebooks/api/0.8/00-load-data.ipynb @@ -1,72 +1,5 @@ { "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# !pip install psycopg[binary]==3.1.19" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# !docker run --name postgres-latest -e POSTGRES_USER=admin \\\n", - "# -e POSTGRES_PASSWORD=adminpassword -p 5432:5432 -d postgres:latest" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# import os\n", - "# os.environ[\"POSTGRESQL_DBNAME\"] = \"postgres\"\n", - "# os.environ[\"POSTGRESQL_HOST\"] = \"localhost\"\n", - "# os.environ[\"POSTGRESQL_PORT\"] = \"5432\"\n", - "# os.environ[\"POSTGRESQL_USERNAME\"] = \"admin\"\n", - "# os.environ[\"POSTGRESQL_PASSWORD\"] = \"adminpassword\"" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# import psycopg\n", - "# connection = psycopg.connect(\n", - "# dbname=\"postgres\",\n", - "# user=\"admin\",\n", - "# password=\"adminpassword\",\n", - "# host=\"localhost\",\n", - "# port=\"5432\",\n", - "# )\n", - "# cursor = connection.cursor()\n", - "# sql = \"select uid from User_unique_keys where uid = %s\"\n", - "# args = ['email']\n", - "# res = cursor.execute(sql, args)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - }, { "cell_type": "markdown", "metadata": {}, @@ -134,21 +67,7 @@ "outputs": [], "source": [ "# Launch a fresh datasite server named \"test-datasite-1\" in dev mode on the local machine\n", - "store_client_config = {\n", - " \"TYPE\": \"PostgreSQLStoreConfig\",\n", - " \"POSTGRESQL_DBNAME\": \"postgres\",\n", - " \"POSTGRESQL_HOST\": \"localhost\",\n", - " \"POSTGRESQL_PORT\": \"5432\",\n", - " \"POSTGRESQL_USERNAME\": \"admin\",\n", - " \"POSTGRESQL_PASSWORD\": \"adminpassword\",\n", - "}\n", - "\n", - "server = sy.orchestra.launch(\n", - " name=\"test-datasite-1\",\n", - " dev_mode=True,\n", - " reset=True,\n", - " store_client_config=store_client_config,\n", - ")" + "server = sy.orchestra.launch(name=\"test-datasite-1\", dev_mode=True, reset=True)" ] }, { @@ -792,11 +711,6 @@ } ], "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -807,7 +721,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.2" + "version": "3.12.5" }, "toc": { "base_numbering": 1, From e00edf778b3d7ef39eee0c97e3ae518327f72284 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 9 Sep 2024 11:32:10 +0700 Subject: [PATCH 097/197] [postgres] add error handling for getting postgres storage config from json --- packages/syft/src/syft/server/server.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 389668dee3e..87d8b699d40 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -302,7 +302,11 @@ def get_external_storage_config( ) -> PostgreSQLStoreConfig | None: if not store_client_config: store_client_config_json = os.environ.get("SYFT_STORE_CLIENT_CONFIG", "{}") - store_client_config = json.loads(store_client_config_json) + try: + store_client_config = json.loads(store_client_config_json) + except json.JSONDecodeError as e: + print(f"Error decoding JSON from 'SYFT_STORE_CLIENT_CONFIG': {e}") + store_client_config = {} if ( store_client_config From 55316a70767a83b488c439a36c5c0ce928a4933a Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 9 Sep 2024 10:31:52 +0200 Subject: [PATCH 098/197] fix consumer --- .../syft/src/syft/service/code/user_code.py | 19 ------------------- .../src/syft/service/queue/zmq_consumer.py | 2 +- 2 files changed, 1 insertion(+), 20 deletions(-) diff --git a/packages/syft/src/syft/service/code/user_code.py b/packages/syft/src/syft/service/code/user_code.py index b05b0b81332..161cff1af87 100644 --- a/packages/syft/src/syft/service/code/user_code.py +++ b/packages/syft/src/syft/service/code/user_code.py @@ -1698,25 +1698,6 @@ def job_increase_current_iter(current_iter: int) -> None: job.current_iter += current_iter job_service.update(context, job) - # def set_api_registry(): - # user_signing_key = [ - # x.signing_key - # for x in user_service.stash.partition.data.values() - # if x.verify_key == context.credentials - # ][0] - # data_protcol = get_data_protocol() - # user_api = server.get_api(context.credentials, data_protcol.latest_version) - # user_api.signing_key = user_signing_key - # # We hardcode a python connection here since we have access to the server - # # TODO: this is not secure - # user_api.connection = PythonConnection(server=server) - - # APIRegistry.set_api_for( - # server_uid=server.id, - # user_verify_key=context.credentials, - # api=user_api, - # ) - def launch_job(func: UserCode, **kwargs: Any) -> Job | None: # relative diff --git a/packages/syft/src/syft/service/queue/zmq_consumer.py b/packages/syft/src/syft/service/queue/zmq_consumer.py index 4de8da60494..b6861559838 100644 --- a/packages/syft/src/syft/service/queue/zmq_consumer.py +++ b/packages/syft/src/syft/service/queue/zmq_consumer.py @@ -237,7 +237,7 @@ def _set_worker_job(self, job_id: UID | None) -> None: ConsumerState.IDLE if job_id is None else ConsumerState.CONSUMING ) res = self.worker_stash.update_consumer_state( - credentials=self.worker_stash.partition.root_verify_key, + credentials=self.worker_stash.root_verify_key, worker_uid=self.syft_worker_id, consumer_state=consumer_state, ) From dc668f30ef9df93ff922a157ef9f4835515ad697 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 9 Sep 2024 16:51:08 +0700 Subject: [PATCH 099/197] [stores] refactor `_execute` to be a static method for both sqlite and postgre backing store - postgre's cursor is not cached and renewed whenever a command is executed Co-authored-by: Shubham Gupta --- packages/grid/backend/grid/core/server.py | 4 - .../syft/store/postgresql_document_store.py | 174 +++++++++++++++--- .../src/syft/store/sqlite_document_store.py | 79 +++++--- tox.ini | 1 - 4 files changed, 202 insertions(+), 56 deletions(-) diff --git a/packages/grid/backend/grid/core/server.py b/packages/grid/backend/grid/core/server.py index def69e71d72..76384076c08 100644 --- a/packages/grid/backend/grid/core/server.py +++ b/packages/grid/backend/grid/core/server.py @@ -1,5 +1,4 @@ # stdlib -from pprint import pprint # syft absolute from syft.abstract_server import ServerType @@ -95,9 +94,6 @@ def seaweedfs_config() -> SeaweedFSConfig: sql_store_config() if single_container_mode else postgresql_store_config() ) -print("----------------------------Store Config----------------------------\n") -pprint(store_config.model_dump()) -print("\n----------------------------Store Config----------------------------") blob_storage_config = None if single_container_mode else seaweedfs_config() queue_config = queue_config() diff --git a/packages/syft/src/syft/store/postgresql_document_store.py b/packages/syft/src/syft/store/postgresql_document_store.py index 6fa7a3641a5..ebaf4c32894 100644 --- a/packages/syft/src/syft/store/postgresql_document_store.py +++ b/packages/syft/src/syft/store/postgresql_document_store.py @@ -13,9 +13,12 @@ from typing_extensions import Self # relative +from ..serde.deserialize import _deserialize from ..serde.serializable import serializable +from ..serde.serialize import _serialize from ..types.errors import SyftException from ..types.result import as_result +from ..types.uid import UID from .document_store import DocumentStore from .document_store import PartitionSettings from .document_store import StoreClientConfig @@ -26,12 +29,12 @@ from .locks import SyftLock from .sqlite_document_store import SQLiteBackingStore from .sqlite_document_store import SQLiteStorePartition +from .sqlite_document_store import _repr_debug_ from .sqlite_document_store import cache_key from .sqlite_document_store import special_exception_public_message logger = logging.getLogger(__name__) _CONNECTION_POOL_DB: dict[str, Connection] = {} -_CONNECTION_POOL_CUR: dict[str, Cursor] = {} REF_COUNTS: dict[str, int] = defaultdict(int) @@ -130,10 +133,7 @@ def db(self) -> Connection: @property def cur(self) -> Cursor: - if cache_key(self.store_config_hash) not in _CONNECTION_POOL_CUR: - _CONNECTION_POOL_CUR[cache_key(self.dbname)] = self.db.cursor() - - return _CONNECTION_POOL_CUR[cache_key(self.dbname)] + return self.db.cursor() def _close(self) -> None: self._commit() @@ -143,35 +143,153 @@ def _close(self) -> None: # same connection self.db.close() - db_key = cache_key(self.store_config_hash) - if db_key in _CONNECTION_POOL_CUR: - # NOTE if we don't remove the cursor, the cursor cache_key can clash with a future thread id - del _CONNECTION_POOL_CUR[db_key] del _CONNECTION_POOL_DB[cache_key(self.store_config_hash)] else: # don't close yet because another SQLiteBackingStore is probably still open pass + @staticmethod @as_result(SyftException) - def _execute(self, sql: str, args: list[Any] | None) -> psycopg.Cursor: - with self.lock: - cursor: psycopg.Cursor | None = None - try: - # Ensure self.cur is a psycopg cursor object - cursor = self.cur # Assuming self.cur is already set as psycopg.Cursor - cursor.execute(sql, args) # Execute the SQL with arguments - self.db.commit() # Commit if everything went ok - except InFailedSqlTransaction as ie: - self.db.rollback() # Rollback if something went wrong - raise SyftException( - public_message=f"Transaction `{sql}` failed and was rolled back. \n" - f"Error: {ie}." - ) - except Exception as e: - self.db.rollback() # Rollback on any other exception to maintain clean state - public_message = special_exception_public_message(self.table_name, e) - raise SyftException.from_exception(e, public_message=public_message) - return cursor + def _execute( + lock: SyftLock, + cursor: Cursor, + db: Connection, + table_name: str, + sql: str, + args: list[Any] | None, + ) -> Cursor: + try: + cursor.execute(sql, args) # Execute the SQL with arguments + db.commit() # Commit if everything went ok + except InFailedSqlTransaction as ie: + db.rollback() # Rollback if something went wrong + raise SyftException( + public_message=f"Transaction `{sql}` failed and was rolled back. \n" + f"Error: {ie}." + ) + except Exception as e: + logger.debug(f"Rolling back SQL: {sql} with args: {args}") + db.rollback() # Rollback on any other exception to maintain clean state + public_message = special_exception_public_message(table_name, e) + logger.error(public_message) + raise SyftException.from_exception(e, public_message=public_message) + return cursor + + def _set(self, key: UID, value: Any) -> None: + if self._exists(key): + self._update(key, value) + else: + insert_sql = ( + f"insert into {self.table_name} (uid, repr, value) VALUES " # nosec + f"({self.subs_char}, {self.subs_char}, {self.subs_char})" # nosec + ) + data = _serialize(value, to_bytes=True) + with self.cur as cur: + self._execute( + self.lock, + cur, + self.db, + self.table_name, + insert_sql, + [str(key), _repr_debug_(value), data], + ).unwrap() + + def _update(self, key: UID, value: Any) -> None: + insert_sql = ( + f"update {self.table_name} set uid = {self.subs_char}, " # nosec + f"repr = {self.subs_char}, value = {self.subs_char} " # nosec + f"where uid = {self.subs_char}" # nosec + ) + data = _serialize(value, to_bytes=True) + with self.cur as cur: + self._execute( + self.lock, + cur, + self.db, + self.table_name, + insert_sql, + [str(key), _repr_debug_(value), data, str(key)], + ).unwrap() + + def _get(self, key: UID) -> Any: + select_sql = ( + f"select * from {self.table_name} where uid = {self.subs_char} " # nosec + "order by sqltime" + ) + with self.cur as cur: + cursor = self._execute( + self.lock, cur, self.db, self.table_name, select_sql, [str(key)] + ).unwrap(public_message=f"Query {select_sql} failed") + row = cursor.fetchone() + if row is None or len(row) == 0: + raise KeyError(f"{key} not in {type(self)}") + data = row[2] + return _deserialize(data, from_bytes=True) + + def _exists(self, key: UID) -> bool: + select_sql = f"select uid from {self.table_name} where uid = {self.subs_char}" # nosec + row = None + with self.cur as cur: + cursor = self._execute( + self.lock, cur, self.db, self.table_name, select_sql, [str(key)] + ).unwrap() + row = cursor.fetchone() # type: ignore + if row is None: + return False + return bool(row) + + def _get_all(self) -> Any: + select_sql = f"select * from {self.table_name} order by sqltime" # nosec + keys = [] + data = [] + with self.cur as cur: + cursor = self._execute( + self.lock, cur, self.db, self.table_name, select_sql, [] + ).unwrap() + rows = cursor.fetchall() # type: ignore + if not rows: + return {} + + for row in rows: + keys.append(UID(row[0])) + data.append(_deserialize(row[2], from_bytes=True)) + + return dict(zip(keys, data)) + + def _get_all_keys(self) -> Any: + select_sql = f"select uid from {self.table_name} order by sqltime" # nosec + with self.cur as cur: + cursor = self._execute( + self.lock, cur, self.db, self.table_name, select_sql, [] + ).unwrap() + rows = cursor.fetchall() # type: ignore + if not rows: + return [] + keys = [UID(row[0]) for row in rows] + return keys + + def _delete(self, key: UID) -> None: + select_sql = f"delete from {self.table_name} where uid = {self.subs_char}" # nosec + with self.cur as cur: + self._execute( + self.lock, cur, self.db, self.table_name, select_sql, [str(key)] + ).unwrap() + + def _delete_all(self) -> None: + select_sql = f"delete from {self.table_name}" # nosec + with self.cur as cur: + self._execute( + self.lock, cur, self.db, self.table_name, select_sql, [] + ).unwrap() + + def _len(self) -> int: + select_sql = f"select count(uid) from {self.table_name}" # nosec + with self.cur as cur: + cursor = self._execute( + self.lock, cur, self.db, self.table_name, select_sql, [] + ).unwrap() + cnt = cursor.fetchone()[0] + return cnt @serializable() diff --git a/packages/syft/src/syft/store/sqlite_document_store.py b/packages/syft/src/syft/store/sqlite_document_store.py index 2359d273b34..82d75d68e6b 100644 --- a/packages/syft/src/syft/store/sqlite_document_store.py +++ b/packages/syft/src/syft/store/sqlite_document_store.py @@ -7,6 +7,8 @@ import logging from pathlib import Path import sqlite3 +from sqlite3 import Connection +from sqlite3 import Cursor import tempfile from typing import Any @@ -40,8 +42,8 @@ # by its filename and optionally the thread that its running in # we keep track of each SQLiteBackingStore init in REF_COUNTS # when it hits 0 we can close the connection and release the file descriptor -SQLITE_CONNECTION_POOL_DB: dict[str, sqlite3.Connection] = {} -SQLITE_CONNECTION_POOL_CUR: dict[str, sqlite3.Cursor] = {} +SQLITE_CONNECTION_POOL_DB: dict[str, Connection] = {} +SQLITE_CONNECTION_POOL_CUR: dict[str, Cursor] = {} REF_COUNTS: dict[str, int] = defaultdict(int) @@ -160,13 +162,13 @@ def create_table(self) -> None: raise SyftException.from_exception(e, public_message=public_message) @property - def db(self) -> sqlite3.Connection: + def db(self) -> Connection: if cache_key(self.db_filename) not in SQLITE_CONNECTION_POOL_DB: self._connect() return SQLITE_CONNECTION_POOL_DB[cache_key(self.db_filename)] @property - def cur(self) -> sqlite3.Cursor: + def cur(self) -> Cursor: if cache_key(self.db_filename) not in SQLITE_CONNECTION_POOL_CUR: SQLITE_CONNECTION_POOL_CUR[cache_key(self.db_filename)] = self.db.cursor() @@ -192,15 +194,22 @@ def _close(self) -> None: def _commit(self) -> None: self.db.commit() + @staticmethod @as_result(SyftException) - def _execute(self, sql: str, args: list[Any] | None) -> sqlite3.Cursor: - with self.lock: - cursor: sqlite3.Cursor | None = None - # err = None + def _execute( + lock: SyftLock, + cursor: Cursor, + db: Connection, + table_name: str, + sql: str, + args: list[Any] | None, + ) -> Cursor: + with lock: + cur: Cursor | None = None try: - cursor = self.cur.execute(sql, args) + cur = cursor.execute(sql, args) except Exception as e: - public_message = special_exception_public_message(self.table_name, e) + public_message = special_exception_public_message(table_name, e) raise SyftException.from_exception(e, public_message=public_message) # TODO: Which exception is safe to rollback on @@ -208,8 +217,8 @@ def _execute(self, sql: str, args: list[Any] | None) -> sqlite3.Cursor: # rather than halting the program like disk I/O error etc # self.db.rollback() # Roll back all changes if an exception occurs. # err = Err(str(e)) - self.db.commit() # Commit if everything went ok - return cursor + db.commit() # Commit if everything went ok + return cur def _set(self, key: UID, value: Any) -> None: if self._exists(key): @@ -220,7 +229,14 @@ def _set(self, key: UID, value: Any) -> None: f"({self.subs_char}, {self.subs_char}, {self.subs_char})" # nosec ) data = _serialize(value, to_bytes=True) - self._execute(insert_sql, [str(key), _repr_debug_(value), data]).unwrap() + self._execute( + self.lock, + self.cur, + self.db, + self.table_name, + insert_sql, + [str(key), _repr_debug_(value), data], + ).unwrap() def _update(self, key: UID, value: Any) -> None: insert_sql = ( @@ -230,7 +246,12 @@ def _update(self, key: UID, value: Any) -> None: ) data = _serialize(value, to_bytes=True) self._execute( - insert_sql, [str(key), _repr_debug_(value), data, str(key)] + self.lock, + self.cur, + self.db, + self.table_name, + insert_sql, + [str(key), _repr_debug_(value), data, str(key)], ).unwrap() def _get(self, key: UID) -> Any: @@ -238,9 +259,9 @@ def _get(self, key: UID) -> Any: f"select * from {self.table_name} where uid = {self.subs_char} " # nosec "order by sqltime" ) - cursor = self._execute(select_sql, [str(key)]).unwrap( - public_message=f"Query {select_sql} failed" - ) + cursor = self._execute( + self.lock, self.cur, self.db, self.table_name, select_sql, [str(key)] + ).unwrap(public_message=f"Query {select_sql} failed") row = cursor.fetchone() if row is None or len(row) == 0: raise KeyError(f"{key} not in {type(self)}") @@ -249,7 +270,9 @@ def _get(self, key: UID) -> Any: def _exists(self, key: UID) -> bool: select_sql = f"select uid from {self.table_name} where uid = {self.subs_char}" # nosec - cursor = self._execute(select_sql, [str(key)]).unwrap() + cursor = self._execute( + self.lock, self.cur, self.db, self.table_name, select_sql, [str(key)] + ).unwrap() row = cursor.fetchone() # type: ignore if row is None: return False @@ -261,7 +284,9 @@ def _get_all(self) -> Any: keys = [] data = [] - cursor = self._execute(select_sql, []).unwrap() + cursor = self._execute( + self.lock, self.cur, self.db, self.table_name, select_sql, [] + ).unwrap() rows = cursor.fetchall() # type: ignore if not rows: return {} @@ -274,7 +299,9 @@ def _get_all(self) -> Any: def _get_all_keys(self) -> Any: select_sql = f"select uid from {self.table_name} order by sqltime" # nosec - cursor = self._execute(select_sql, []).unwrap() + cursor = self._execute( + self.lock, self.cur, self.db, self.table_name, select_sql, [] + ).unwrap() rows = cursor.fetchall() # type: ignore if not rows: return [] @@ -284,15 +311,21 @@ def _get_all_keys(self) -> Any: def _delete(self, key: UID) -> None: select_sql = f"delete from {self.table_name} where uid = {self.subs_char}" # nosec - self._execute(select_sql, [str(key)]).unwrap() + self._execute( + self.lock, self.cur, self.db, self.table_name, select_sql, [str(key)] + ).unwrap() def _delete_all(self) -> None: select_sql = f"delete from {self.table_name}" # nosec - self._execute(select_sql, []).unwrap() + self._execute( + self.lock, self.cur, self.db, self.table_name, select_sql, [] + ).unwrap() def _len(self) -> int: select_sql = f"select count(uid) from {self.table_name}" # nosec - cursor = self._execute(select_sql, []).unwrap() + cursor = self._execute( + self.lock, self.cur, self.db, self.table_name, select_sql, [] + ).unwrap() cnt = cursor.fetchone()[0] return cnt diff --git a/tox.ini b/tox.ini index a31f9c41a8b..4004cf81fbb 100644 --- a/tox.ini +++ b/tox.ini @@ -1129,7 +1129,6 @@ allowlist_externals = tox bash commands = - bash -c "CLUSTER_NAME=${CLUSTER_NAME} tox -e dev.k8s.destroy" tox -e dev.k8s.start tox -e dev.k8s.{posargs:deploy} From 48c774181192faf42b44bddcc2b06937764d800b Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 9 Sep 2024 16:34:38 +0200 Subject: [PATCH 100/197] add Query, split table creation --- packages/syft/src/syft/store/db/models.py | 28 --- packages/syft/src/syft/store/db/query.py | 244 +++++++++++++++++++ packages/syft/src/syft/store/db/schema.py | 83 +++++++ packages/syft/src/syft/store/db/sqlite_db.py | 2 +- packages/syft/src/syft/store/db/stash.py | 139 +++++------ 5 files changed, 388 insertions(+), 108 deletions(-) delete mode 100644 packages/syft/src/syft/store/db/models.py create mode 100644 packages/syft/src/syft/store/db/query.py create mode 100644 packages/syft/src/syft/store/db/schema.py diff --git a/packages/syft/src/syft/store/db/models.py b/packages/syft/src/syft/store/db/models.py deleted file mode 100644 index 92dd2c2a6fa..00000000000 --- a/packages/syft/src/syft/store/db/models.py +++ /dev/null @@ -1,28 +0,0 @@ -# stdlib - -# third party -import sqlalchemy as sa -from sqlalchemy import TypeDecorator -from sqlalchemy.orm import DeclarativeBase - -# relative -from ...types.uid import UID - - -class Base(DeclarativeBase): - pass - - -class UIDTypeDecorator(TypeDecorator): - """Converts between Syft UID and UUID.""" - - impl = sa.UUID - cache_ok = True - - def process_bind_param(self, value, dialect): # type: ignore - if value is not None: - return value.value - - def process_result_value(self, value, dialect): # type: ignore - if value is not None: - return UID(value) diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py new file mode 100644 index 00000000000..95eba89909f --- /dev/null +++ b/packages/syft/src/syft/store/db/query.py @@ -0,0 +1,244 @@ +# stdlib +from abc import ABC +from abc import abstractmethod +from typing import Any +from typing import Literal + +# third party +import sqlalchemy as sa +from sqlalchemy import Column +from sqlalchemy import Dialect +from sqlalchemy import Result +from sqlalchemy import Select +from sqlalchemy import Table +from sqlalchemy import dialects +from sqlalchemy import func +from sqlalchemy import select +from sqlalchemy.orm import Session +from typing_extensions import Self + +# relative +from ...serde.json_serde import serialize_json +from ...server.credentials import SyftVerifyKey +from ...service.action.action_permissions import ActionObjectPermission +from ...service.action.action_permissions import ActionPermission +from ...service.user.user_roles import ServiceRole +from ...types.syft_object import SyftObject +from ...types.uid import UID +from .sqlite_db import OBJECT_TYPE_TO_TABLE + + +class Query(ABC): + dialect: Dialect + + def __init__(self, object_type: type[SyftObject]) -> None: + self.object_type: type = object_type + self.table: Table = OBJECT_TYPE_TO_TABLE[object_type] + self.stmt: Select = select([self.table]) + + def compile(self) -> str: + """ + Compile the query to a string, for debugging purposes. + """ + return self.stmt.compile( + compile_kwargs={"literal_binds": True}, + dialect=self.dialect, + ) + + def execute(self, session: Session) -> Result: + """Execute the query using the given session.""" + return session.execute(self.stmt) + + def with_permissions( + self, + credentials: SyftVerifyKey, + role: ServiceRole, + permission: ActionPermission = ActionPermission.READ, + ) -> Self: + """Add a permission check to the query. + + If the user has a role below DATA_OWNER, the query will be filtered to only include objects + that the user has the specified permission on. + + Args: + credentials (SyftVerifyKey): user verify key + role (ServiceRole): role of the user + permission (ActionPermission, optional): Type of permission to check for. + Defaults to ActionPermission.READ. + + Returns: + Self: The query object with the permission check applied + """ + if role in (ServiceRole.ADMIN, ServiceRole.SUPER_ADMIN): + return self + + permission = ActionObjectPermission( + uid=UID(), # dummy uid, we just need the permission string + credentials=credentials, + permission=permission, + ) + + permission_clause = self._make_permissions_clause(permission) + self.stmt = self.stmt.where(permission_clause) + + return self + + def filter(self, field: str, operator: str, value: Any) -> Self: + """Add a filter to the query. + + example usage: + Query(User).filter("name", "==", "Alice") + Query(User).filter("friends", "contains", "Bob") + + Args: + field (str): Field to filter on + operator (str): Operator to use for the filter + value (Any): Value to filter on + + Raises: + ValueError: If the operator is not supported + + Returns: + Self: The query object with the filter applied + """ + if operator not in {"==", "!=", "contains"}: + raise ValueError(f"Operation {operator} not supported") + + if operator == "==": + filter = self._eq_filter(self.table, field, value) + self.stmt = self.stmt.where(filter) + elif operator == "contains": + filter = self._contains_filter(self.table, field, value) + self.stmt = self.stmt.where(filter) + + return self + + def order_by(self, field: str, order: Literal["asc", "desc"] = "asc") -> Self: + """Add an order by clause to the query. + + Args: + field (str): field to order by. + order (Literal["asc", "desc"], optional): Order to use. + Defaults to "asc". + + Raises: + ValueError: If the order is not "asc" or "desc" + + Returns: + Self: The query object with the order by clause applied + """ + column = self._get_column(field) + + if order.lower() == "asc": + self.stmt = self.stmt.order_by(column) + elif order.lower() == "desc": + self.stmt = self.stmt.order_by(column.desc()) + else: + raise ValueError(f"Invalid sort order {order}") # type: ignore + + return self + + def limit(self, limit: int) -> Self: + """Add a limit clause to the query.""" + self.stmt = self.stmt.limit(limit) + return self + + def offset(self, offset: int) -> Self: + """Add an offset clause to the query.""" + self.stmt = self.stmt.offset(offset) + return self + + @abstractmethod + def _make_permissions_clause( + self, + permission: ActionObjectPermission, + ) -> sa.sql.elements.BinaryExpression: + pass + + def default_order(self) -> Self: + if hasattr(self.object_type, "__order_by__"): + field, order = self.object_type.__order_by__ + else: + field, order = "_created_at", "desc" + + return self.order_by(field, order) + + def _eq_filter( + self, + table: Table, + field: str, + value: Any, + ) -> sa.sql.elements.BinaryExpression: + if field == "id": + return table.c.id == UID(value) + + json_value = serialize_json(value) + return table.c.fields[field] == func.json_quote(json_value) + + @abstractmethod + def _contains_filter( + self, + table: Table, + field: str, + value: Any, + ) -> sa.sql.elements.BinaryExpression: + pass + + def _get_column(self, column: str) -> Column: + if column == "id": + return self.table.c.id + if column == "created_date" or column == "_created_at": + return self.table.c._created_at + elif column == "updated_date" or column == "_updated_at": + return self.table.c._updated_at + elif column == "deleted_date" or column == "_deleted_at": + return self.table.c._deleted_at + + return self.table.c.fields[column] + + +class SQLiteQuery(Query): + dialect = dialects.sqlite.dialect + + def _make_permissions_clause( + self, + permission: ActionObjectPermission, + ) -> sa.sql.elements.BinaryExpression: + permission_string = permission.permission_string + compound_permission_string = permission.compound_permission_string + return sa.or_( + self.table.c.permissions.contains(permission_string), + self.table.c.permissions.contains(compound_permission_string), + ) + + def _contains_filter( + self, + table: Table, + field: str, + value: Any, + ) -> sa.sql.elements.BinaryExpression: + field_value = serialize_json(value) + return table.c.fields[field].contains(func.json_quote(field_value)) + + +class PostgresQuery(Query): + dialect = dialects.postgresql.dialect + + def _make_permissions_clause( + self, permission: ActionObjectPermission + ) -> sa.sql.elements.BinaryExpression: + permission_string = [permission.permission_string] + compound_permission_string = [permission.compound_permission_string] + return sa.or_( + self.table.c.permissions.contains(permission_string), + self.table.c.permissions.contains(compound_permission_string), + ) + + def _contains_filter( + self, + table: Table, + field: str, + value: Any, + ) -> sa.sql.elements.BinaryExpression: + field_value = [serialize_json(value)] + return table.c.fields[field].contains(field_value) diff --git a/packages/syft/src/syft/store/db/schema.py b/packages/syft/src/syft/store/db/schema.py new file mode 100644 index 00000000000..6f687800752 --- /dev/null +++ b/packages/syft/src/syft/store/db/schema.py @@ -0,0 +1,83 @@ +# stdlib + +# stdlib +import uuid + +# third party +import sqlalchemy as sa +from sqlalchemy import Column +from sqlalchemy import Dialect +from sqlalchemy import Table +from sqlalchemy import TypeDecorator +from sqlalchemy.dialects import postgresql +from sqlalchemy.orm import DeclarativeBase +from sqlalchemy.types import JSON + +# relative +from ...types.syft_object import SyftObject +from ...types.uid import UID + + +class Base(DeclarativeBase): + pass + + +class UIDTypeDecorator(TypeDecorator): + """Converts between Syft UID and UUID.""" + + impl = sa.UUID + cache_ok = True + + def process_bind_param(self, value, dialect): # type: ignore + if value is not None: + return value.value + + def process_result_value(self, value, dialect): # type: ignore + if value is not None: + return UID(value) + + +def create_table( + object_type: type[SyftObject], + dialect: Dialect, +) -> Table: + """Create a table for a given SYftObject type, and add it to the metadata. + + To create the table on the database, you must call `Base.metadata.create_all(engine)`. + + Args: + object_type (type[SyftObject]): The type of the object to create a table for. + dialect (Dialect): The dialect of the database. + + Returns: + Table: The created table. + """ + table_name = object_type.__canonical_name__ + dialect_name = dialect.name + + fields_type = JSON if dialect_name == "sqlite" else postgresql.JSONB + permissons_type = JSON if dialect_name == "sqlite" else postgresql.ARRAY(sa.String) + storage_permissions_type = ( + JSON if dialect_name == "sqlite" else postgresql.ARRAY(sa.String) + ) + + if table_name not in Base.metadata.tables: + Table( + object_type.__canonical_name__, + Base.metadata, + Column("id", UIDTypeDecorator, primary_key=True, default=uuid.uuid4), + Column("fields", fields_type, default={}), + Column("permissions", permissons_type, default=[]), + Column( + "storage_permissions", + storage_permissions_type, + default=[], + ), + Column( + "_created_at", sa.DateTime, server_default=sa.func.now(), index=True + ), + Column("_updated_at", sa.DateTime, server_onupdate=sa.func.now()), + Column("_deleted_at", sa.DateTime, index=True), + ) + + return Base.metadata.tables[table_name] diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index 8cc260b5f88..509650dc801 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -17,7 +17,7 @@ from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey from ...types.uid import UID -from .models import Base +from .schema import Base from .utils import dumps from .utils import loads diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index d5f445eae70..d50f27a9999 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -5,7 +5,6 @@ from typing import Generic from typing import cast from typing import get_args -import uuid # third party import sqlalchemy as sa @@ -14,9 +13,7 @@ from sqlalchemy import Table from sqlalchemy import func from sqlalchemy import select -from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session -from sqlalchemy.types import JSON from typing_extensions import TypeVar # relative @@ -37,8 +34,11 @@ from ...types.uid import UID from ..document_store_errors import NotFoundException from ..document_store_errors import StashException -from .models import Base -from .models import UIDTypeDecorator +from .query import PostgresQuery +from .query import Query +from .query import SQLiteQuery +from .schema import Base +from .schema import create_table from .sqlite_db import DBManager StashT = TypeVar("StashT", bound=SyftObject) @@ -53,10 +53,18 @@ class ObjectStash(Generic[StashT]): def __init__(self, store: DBManager) -> None: self.db = store self.object_type = self.get_object_type() - self.table = self._create_table() + self.table = create_table(self.object_type, self.dialect) + + @property + def dialect(self) -> sa.engine.interfaces.Dialect: + return self.db.engine.dialect @classmethod def get_object_type(cls) -> type[StashT]: + """ + Get the object type this stash is storing. This is the generic argument of the + ObjectStash class. + """ generic_args = get_args(cls.__orig_bases__[0]) if len(generic_args) != 1: raise TypeError("ObjectStash must have a single generic argument") @@ -78,6 +86,15 @@ def root_verify_key(self) -> SyftVerifyKey: def _data(self) -> list[StashT]: return self.get_all(self.root_verify_key, has_permission=True).unwrap() + def query(self) -> Query: + """Creates a query for this stash's object type.""" + if self.dialect.name == "sqlite": + return SQLiteQuery(self.object_type) + elif self.dialect.name == "postgresql": + return PostgresQuery(self.object_type) + else: + raise NotImplementedError(f"Query not implemented for {self.dialect.name}") + @as_result(StashException) def check_type(self, obj: T, type_: type) -> T: if not isinstance(obj, type_): @@ -88,44 +105,6 @@ def check_type(self, obj: T, type_: type) -> T: def session(self) -> Session: return self.db.session - def _create_table(self) -> Table: - # need to call Base.metadata.create_all(engine) to create the table - table_name = self.object_type.__canonical_name__ - - fields_type = ( - JSON if self.db.engine.dialect.name == "sqlite" else postgresql.JSONB - ) - permissons_type = ( - JSON - if self.db.engine.dialect.name == "sqlite" - else postgresql.ARRAY(sa.String) - ) - storage_permissions_type = ( - JSON - if self.db.engine.dialect.name == "sqlite" - else postgresql.ARRAY(sa.String) - ) - if table_name not in Base.metadata.tables: - Table( - self.object_type.__canonical_name__, - Base.metadata, - Column("id", UIDTypeDecorator, primary_key=True, default=uuid.uuid4), - Column("fields", fields_type, default={}), - Column("permissions", permissons_type, default=[]), - Column( - "storage_permissions", - storage_permissions_type, - default=[], - ), - # TODO rename and use on SyftObject fields - Column( - "_created_at", sa.DateTime, server_default=sa.func.now(), index=True - ), - Column("_updated_at", sa.DateTime, server_onupdate=sa.func.now()), - Column("_deleted_at", sa.DateTime, index=True), - ) - return Base.metadata.tables[table_name] - def _drop_table(self) -> None: table_name = self.object_type.__canonical_name__ if table_name in Base.metadata.tables: @@ -181,14 +160,13 @@ def exists(self, credentials: SyftVerifyKey, uid: UID) -> bool: def get_by_uid( self, credentials: SyftVerifyKey, uid: UID, has_permission: bool = False ) -> StashT: - stmt = self.table.select() - stmt = stmt.where(self._get_field_filter("id", uid)) - stmt = self._apply_permission_filter( - stmt, credentials=credentials, has_permission=has_permission - ) + query = self.query().filter("id", "==", uid) - result = self.session.execute(stmt).first() + if not has_permission: + role = self.get_role(credentials) + query = query.with_permissions(credentials, role) + result = query.execute(self.session).first() if result is None: raise NotFoundException(f"{self.object_type.__name__}: {uid} not found") return self.row_as_obj(result) @@ -221,22 +199,25 @@ def _get_by_fields( offset: int | None = None, has_permission: bool = False, ) -> sa.Result: - table = table if table is not None else self.table - filters = [] + query = self.query() for field_name, field_value in fields.items(): - filt = self._get_field_filter(field_name, field_value, table=table) - filters.append(filt) + query = query.filter(field_name, "==", field_value) - stmt = table.select() - stmt = stmt.where(sa.and_(*filters)) - stmt = self._apply_permission_filter( - stmt, credentials=credentials, has_permission=has_permission - ) - stmt = self._apply_order_by(stmt, order_by, sort_order) - stmt = self._apply_limit_offset(stmt, limit, offset) + if not has_permission: + role = self.get_role(credentials) + query = query.with_permissions(credentials, role) - result = self.session.execute(stmt) - return result + if order_by and sort_order: + query = query.order_by(order_by, sort_order) + else: + query = query.default_order() + + if limit: + query = query.limit(limit) + if offset: + query = query.offset(offset) + + return query.execute(self.session) @as_result(SyftException, StashException, NotFoundException) def get_one_by_field( @@ -325,24 +306,23 @@ def get_all_contains( offset: int | None = None, has_permission: bool = False, ) -> list[StashT]: - # TODO write filter logic, merge with get_all + query = self.query().filter(field_name, "contains", field_value) - if self._is_sqlite(): - field_value = func.json_quote(field_value) + if not has_permission: + role = self.get_role(credentials) + query = query.with_permissions(credentials, role) + + if order_by and sort_order: + query = query.order_by(order_by, sort_order) else: - field_value = [field_value] # type: ignore + query = query.default_order() - stmt = self.table.select().where( - self.table.c.fields[field_name].contains(field_value), - ) - stmt = self._apply_permission_filter( - stmt, credentials=credentials, has_permission=has_permission - ) - stmt = self._apply_order_by(stmt, order_by, sort_order) - stmt = self._apply_limit_offset(stmt, limit, offset) + if limit: + query = query.limit(limit) + if offset: + query = query.offset(offset) - result = self.session.execute(stmt).all() - return [self.row_as_obj(row) for row in result] + return query.execute(self.session).all() @as_result(SyftException, StashException, NotFoundException) def get_index( @@ -499,13 +479,14 @@ def update( ) -> StashT: """ NOTE: We cannot do partial updates on the database, - because we are using computed fields that are not known to the DB or ORM: + because we are using computed fields that are not known to the DB: - serialize_json will add computed fields to the JSON stored in the database - If we update a single field in the JSON, the computed fields can get out of sync. - To fix, we either need db-supported computed fields, or know in our ORM which fields should be re-computed. """ - self.check_type(obj, self.object_type).unwrap() + if not self.allow_any_type: + self.check_type(obj, self.object_type).unwrap() # TODO has_permission is not used if not self.is_unique(obj): From 2ec83715aed7eb70b64181de0fa6d65da28bc131 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 9 Sep 2024 18:02:24 +0200 Subject: [PATCH 101/197] fix queue stash tests --- packages/syft/src/syft/store/db/stash.py | 12 + packages/syft/tests/conftest.py | 10 - .../tests/syft/stores/queue_stash_test.py | 491 ++++++------------ .../tests/syft/stores/store_fixtures_test.py | 108 +--- 4 files changed, 195 insertions(+), 426 deletions(-) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index d5f445eae70..6c4ed42f02f 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -17,6 +17,7 @@ from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session from sqlalchemy.types import JSON +from typing_extensions import Self from typing_extensions import TypeVar # relative @@ -40,6 +41,7 @@ from .models import Base from .models import UIDTypeDecorator from .sqlite_db import DBManager +from .sqlite_db import SQLiteDBManager StashT = TypeVar("StashT", bound=SyftObject) T = TypeVar("T") @@ -66,6 +68,16 @@ def get_object_type(cls) -> type[StashT]: ) return generic_args[0] + def __len__(self) -> int: + return self.session.query(self.table).count() + + @classmethod + def random(cls, **kwargs: dict) -> Self: + db_manager = SQLiteDBManager.random(**kwargs) + stash = cls(store=db_manager) + stash.db.init_tables() + return stash + @property def server_uid(self) -> UID: return self.db.server_uid diff --git a/packages/syft/tests/conftest.py b/packages/syft/tests/conftest.py index 64d91d78151..1b539202716 100644 --- a/packages/syft/tests/conftest.py +++ b/packages/syft/tests/conftest.py @@ -27,14 +27,9 @@ # relative # our version of mongomock that has a fix for CodecOptions and custom TypeRegistry Support from .mongomock.mongo_client import MongoClient -from .syft.stores.store_fixtures_test import mongo_document_store -from .syft.stores.store_fixtures_test import mongo_queue_stash -from .syft.stores.store_fixtures_test import mongo_store_partition from .syft.stores.store_fixtures_test import sqlite_action_store from .syft.stores.store_fixtures_test import sqlite_document_store -from .syft.stores.store_fixtures_test import sqlite_queue_stash from .syft.stores.store_fixtures_test import sqlite_store_partition -from .syft.stores.store_fixtures_test import sqlite_workspace def patch_protocol_file(filepath: Path): @@ -302,13 +297,8 @@ def big_dataset() -> Dataset: __all__ = [ - "mongo_store_partition", - "mongo_document_store", - "mongo_queue_stash", "sqlite_store_partition", - "sqlite_workspace", "sqlite_document_store", - "sqlite_queue_stash", "sqlite_action_store", ] diff --git a/packages/syft/tests/syft/stores/queue_stash_test.py b/packages/syft/tests/syft/stores/queue_stash_test.py index 68ebd47e999..d4e2cd25747 100644 --- a/packages/syft/tests/syft/stores/queue_stash_test.py +++ b/packages/syft/tests/syft/stores/queue_stash_test.py @@ -1,26 +1,20 @@ # stdlib -import threading -from threading import Thread -import time -from typing import Any +from concurrent.futures import ThreadPoolExecutor # third party import pytest # syft absolute from syft.service.queue.queue_stash import QueueItem +from syft.service.queue.queue_stash import QueueStash from syft.service.worker.worker_pool import WorkerPool from syft.service.worker.worker_pool_service import SyftWorkerPoolService from syft.store.linked_obj import LinkedObject from syft.types.errors import SyftException from syft.types.uid import UID -# relative -from .store_fixtures_test import mongo_queue_stash_fn -from .store_fixtures_test import sqlite_queue_stash_fn - -def mock_queue_object(): +def mock_queue_object() -> QueueItem: worker_pool_obj = WorkerPool( name="mypool", image_id=UID(), @@ -47,389 +41,246 @@ def mock_queue_object(): @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), + pytest.lazy_fixture("queue_stash"), ], ) -def test_queue_stash_sanity(queue: Any) -> None: +def test_queue_stash_sanity(queue: QueueStash) -> None: assert len(queue) == 0 - assert hasattr(queue, "store") - assert hasattr(queue, "partition") @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), + pytest.lazy_fixture("queue_stash"), ], ) # -def test_queue_stash_set_get(root_verify_key, queue: Any) -> None: - objs = [] +def test_queue_stash_set_get(root_verify_key, queue: QueueStash) -> None: + objs: list[QueueItem] = [] repeats = 5 for idx in range(repeats): obj = mock_queue_object() objs.append(obj) - res = queue.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_ok() + queue.set(root_verify_key, obj, ignore_duplicates=False).unwrap() assert len(queue) == idx + 1 with pytest.raises(SyftException): - res = queue.set(root_verify_key, obj, ignore_duplicates=False) + queue.set(root_verify_key, obj, ignore_duplicates=False).unwrap() assert len(queue) == idx + 1 assert len(queue.get_all(root_verify_key).ok()) == idx + 1 - item = queue.find_one(root_verify_key, id=obj.id) - assert item.is_ok() - assert item.ok() == obj + item = queue.get_by_uid(root_verify_key, uid=obj.id).unwrap() + assert item == obj cnt = len(objs) for obj in objs: - res = queue.find_and_delete(root_verify_key, id=obj.id) - assert res.is_ok() - + queue.delete_by_uid(root_verify_key, uid=obj.id).unwrap() cnt -= 1 assert len(queue) == cnt - item = queue.find_one(root_verify_key, id=obj.id) + item = queue.get_by_uid(root_verify_key, uid=obj.id) assert item.is_err() @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), + pytest.lazy_fixture("queue_stash"), ], ) -def test_queue_stash_update(root_verify_key, queue: Any) -> None: +def test_queue_stash_update(queue: QueueStash) -> None: + root_verify_key = queue.db.root_verify_key obj = mock_queue_object() - res = queue.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_ok() + queue.set(root_verify_key, obj, ignore_duplicates=False).unwrap() repeats = 5 for idx in range(repeats): obj.args = [idx] - res = queue.update(root_verify_key, obj) - assert res.is_ok() + queue.update(root_verify_key, obj).unwrap() assert len(queue) == 1 - item = queue.find_one(root_verify_key, id=obj.id) - assert item.is_ok() - assert item.ok().args == [idx] + item = queue.get_by_uid(root_verify_key, uid=obj.id).unwrap() + assert item.args == [idx] - res = queue.find_and_delete(root_verify_key, id=obj.id) - assert res.is_ok() + queue.delete_by_uid(root_verify_key, uid=obj.id).unwrap() assert len(queue) == 0 @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), + pytest.lazy_fixture("queue_stash"), ], ) -def test_queue_set_existing_queue_threading(root_verify_key, queue: Any) -> None: - thread_cnt = 3 - repeats = 5 - - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - for _ in range(repeats): - obj = mock_queue_object() - - for _ in range(10): - res = queue.set(root_verify_key, obj, ignore_duplicates=False) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - assert len(queue) == thread_cnt * repeats +def test_queue_set_existing_queue_threading(root_verify_key, queue: QueueStash) -> None: + root_verify_key = queue.db.root_verify_key + items_to_create = 100 + with ThreadPoolExecutor(max_workers=3) as executor: + results = list( + executor.map( + lambda obj: queue.set( + root_verify_key, + mock_queue_object(), + ), + range(items_to_create), + ) + ) + assert all(res.is_ok() for res in results), "Error occurred during execution" + assert len(queue) == items_to_create @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), + pytest.lazy_fixture("queue_stash"), ], ) -def test_queue_update_existing_queue_threading(root_verify_key, queue: Any) -> None: - thread_cnt = 3 - repeats = 5 - +def test_queue_update_existing_queue_threading(queue: QueueStash) -> None: + root_verify_key = queue.db.root_verify_key obj = mock_queue_object() - queue.set(root_verify_key, obj, ignore_duplicates=False) - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - for repeat in range(repeats): - obj.args = [repeat] - for _ in range(10): - res = queue.update(root_verify_key, obj) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() + def update_queue(): + obj.args = [UID()] + res = queue.update(root_verify_key, obj) + return res - tids.append(thread) + queue.set(root_verify_key, obj, ignore_duplicates=False) - for thread in tids: - thread.join() + with ThreadPoolExecutor(max_workers=3) as executor: + # Run the update_queue function in multiple threads + results = list( + executor.map( + lambda _: update_queue(), + range(5), + ) + ) + assert all(res.is_ok() for res in results), "Error occurred during execution" - assert execution_err is None + assert len(queue) == 1 + item = queue.get_by_uid(root_verify_key, uid=obj.id).unwrap() + assert item.args != [] @pytest.mark.parametrize( "queue", [ - pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), + pytest.lazy_fixture("queue_stash"), ], ) def test_queue_set_delete_existing_queue_threading( - root_verify_key, - queue: Any, + queue: QueueStash, ) -> None: - thread_cnt = 3 - repeats = 5 - - execution_err = None - objs = [] - - for _ in range(repeats * thread_cnt): - obj = mock_queue_object() - res = queue.set(root_verify_key, obj, ignore_duplicates=False) - objs.append(obj) - - assert res.is_ok() - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - for idx in range(repeats): - item_idx = tid * repeats + idx - - for _ in range(10): - res = queue.find_and_delete(root_verify_key, id=objs[item_idx].id) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - assert len(queue) == 0 - - -def helper_queue_set_threading(root_verify_key, create_queue_cbk) -> None: - thread_cnt = 3 - repeats = 5 - - execution_err = None - lock = threading.Lock() - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - with lock: - queue = create_queue_cbk() - - for _ in range(repeats): - obj = mock_queue_object() - - for _ in range(10): - res = queue.set(root_verify_key, obj, ignore_duplicates=False) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - queue = create_queue_cbk() - - assert execution_err is None - assert len(queue) == thread_cnt * repeats - - -def test_queue_set_sqlite(root_verify_key, sqlite_workspace): - def create_queue_cbk(): - return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace) - - helper_queue_set_threading(root_verify_key, create_queue_cbk) - - -def test_queue_set_threading_mongo(root_verify_key, mongo_document_store): - def create_queue_cbk(): - return mongo_queue_stash_fn(mongo_document_store) - - helper_queue_set_threading(root_verify_key, create_queue_cbk) - - -def helper_queue_update_threading(root_verify_key, create_queue_cbk) -> None: - thread_cnt = 3 - repeats = 5 - - queue = create_queue_cbk() - time.sleep(1) - + root_verify_key = queue.db.root_verify_key + with ThreadPoolExecutor(max_workers=3) as executor: + results = list( + executor.map( + lambda obj: queue.set( + root_verify_key, + mock_queue_object(), + ), + range(15), + ) + ) + objs = [item.unwrap() for item in results] + + results = list( + executor.map( + lambda obj: queue.delete_by_uid(root_verify_key, uid=obj.id), + objs, + ) + ) + assert all(res.is_ok() for res in results), "Error occurred during execution" + + +def test_queue_set(queue_stash: QueueStash): + root_verify_key = queue_stash.db.root_verify_key + config = queue_stash.db.config + server_uid = queue_stash.db.server_uid + + def set_in_new_thread(_): + queue_stash = QueueStash.random( + root_verify_key=root_verify_key, + config=config, + server_uid=server_uid, + ) + return queue_stash.set(root_verify_key, mock_queue_object()) + + total_repeats = 50 + with ThreadPoolExecutor(max_workers=3) as executor: + results = list( + executor.map( + set_in_new_thread, + range(total_repeats), + ) + ) + + assert all(res.is_ok() for res in results), "Error occurred during execution" + assert len(queue_stash) == total_repeats + + +def test_queue_update_threading(queue_stash: QueueStash): + root_verify_key = queue_stash.db.root_verify_key + config = queue_stash.db.config + server_uid = queue_stash.db.server_uid obj = mock_queue_object() - queue.set(root_verify_key, obj, ignore_duplicates=False) - execution_err = None - lock = threading.Lock() - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - with lock: - queue_local = create_queue_cbk() - - for repeat in range(repeats): - obj.args = [repeat] - - for _ in range(10): - res = queue_local.update(root_verify_key, obj) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - -def test_queue_update_threading_sqlite(root_verify_key, sqlite_workspace): - def create_queue_cbk(): - return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace) - - helper_queue_update_threading(root_verify_key, create_queue_cbk) - - -def test_queue_update_threading_mongo(root_verify_key, mongo_document_store): - def create_queue_cbk(): - return mongo_queue_stash_fn(mongo_document_store) - - helper_queue_update_threading(root_verify_key, create_queue_cbk) - - -def helper_queue_set_delete_threading( - root_verify_key, - create_queue_cbk, -) -> None: - thread_cnt = 3 - repeats = 5 - - queue = create_queue_cbk() - execution_err = None - objs = [] - - for _ in range(repeats * thread_cnt): - obj = mock_queue_object() - res = queue.set(root_verify_key, obj, ignore_duplicates=False) - objs.append(obj) - - assert res.is_ok() - - lock = threading.Lock() - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - with lock: - queue = create_queue_cbk() - for idx in range(repeats): - item_idx = tid * repeats + idx - - for _ in range(10): - res = queue.find_and_delete(root_verify_key, id=objs[item_idx].id) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - assert len(queue) == 0 - - -def test_queue_delete_threading_sqlite(root_verify_key, sqlite_workspace): - def create_queue_cbk(): - return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace) - - helper_queue_set_delete_threading(root_verify_key, create_queue_cbk) - - -def test_queue_delete_threading_mongo(root_verify_key, mongo_document_store): - def create_queue_cbk(): - return mongo_queue_stash_fn(mongo_document_store) - - helper_queue_set_delete_threading(root_verify_key, create_queue_cbk) + queue_stash.set(root_verify_key, obj).unwrap() + + def update_in_new_thread(_): + queue_stash = QueueStash.random( + root_verify_key=root_verify_key, + config=config, + server_uid=server_uid, + ) + obj.args = [UID()] + return queue_stash.update(root_verify_key, obj) + + total_repeats = 50 + with ThreadPoolExecutor(max_workers=3) as executor: + results = list( + executor.map( + update_in_new_thread, + range(total_repeats), + ) + ) + + assert all(res.is_ok() for res in results), "Error occurred during execution" + assert len(queue_stash) == 1 + + +def test_queue_delete_threading(queue_stash: QueueStash): + root_verify_key = queue_stash.db.root_verify_key + root_verify_key = queue_stash.db.root_verify_key + config = queue_stash.db.config + server_uid = queue_stash.db.server_uid + + def delete_in_new_thread(obj: QueueItem): + queue_stash = QueueStash.random( + root_verify_key=root_verify_key, + config=config, + server_uid=server_uid, + ) + return queue_stash.delete_by_uid(root_verify_key, uid=obj.id) + + with ThreadPoolExecutor(max_workers=3) as executor: + results = list( + executor.map( + lambda obj: queue_stash.set( + root_verify_key, + mock_queue_object(), + ), + range(50), + ) + ) + objs = [item.unwrap() for item in results] + + results = list( + executor.map( + delete_in_new_thread, + objs, + ) + ) + assert all(res.is_ok() for res in results), "Error occurred during execution" + + assert len(queue_stash) == 0 diff --git a/packages/syft/tests/syft/stores/store_fixtures_test.py b/packages/syft/tests/syft/stores/store_fixtures_test.py index b049a0a58de..cbbe0022f49 100644 --- a/packages/syft/tests/syft/stores/store_fixtures_test.py +++ b/packages/syft/tests/syft/stores/store_fixtures_test.py @@ -26,10 +26,6 @@ from syft.store.locks import LockingConfig from syft.store.locks import NoLockingConfig from syft.store.locks import ThreadingLockingConfig -from syft.store.mongo_client import MongoStoreClientConfig -from syft.store.mongo_document_store import MongoDocumentStore -from syft.store.mongo_document_store import MongoStoreConfig -from syft.store.mongo_document_store import MongoStorePartition from syft.store.sqlite_document_store import SQLiteDocumentStore from syft.store.sqlite_document_store import SQLiteStoreClientConfig from syft.store.sqlite_document_store import SQLiteStoreConfig @@ -190,12 +186,17 @@ def sqlite_queue_stash_fn( return QueueStash(store=store) -@pytest.fixture(scope="function", params=locking_scenarios) -def sqlite_queue_stash(root_verify_key, sqlite_workspace: tuple[Path, str], request): - locking_config_name = request.param - yield sqlite_queue_stash_fn( - root_verify_key, sqlite_workspace, locking_config_name=locking_config_name - ) +@pytest.fixture( + scope="function", + params=[ + "tODOsqlite_address", + "TODOpostgres_address", + ], +) +def queue_stash(request): + _ = request.param + stash = QueueStash.random() + yield stash @pytest.fixture(scope="function", params=locking_scenarios) @@ -224,90 +225,5 @@ def sqlite_action_store(sqlite_workspace: tuple[Path, str], request): ) -def mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name: str = "mongo_db", - locking_config_name: str = "nop", -): - mongo_config = MongoStoreClientConfig(client=mongo_client) - - locking_config = str_to_locking_config(locking_config_name) - - store_config = MongoStoreConfig( - client_config=mongo_config, - db_name=mongo_db_name, - locking_config=locking_config, - ) - settings = PartitionSettings(name="test", object_type=MockObjectType) - - return MongoStorePartition( - UID(), root_verify_key, settings=settings, store_config=store_config - ) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def mongo_store_partition(root_verify_key, mongo_client, request): - mongo_db_name = token_hex(8) - locking_config_name = request.param - - partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - locking_config_name=locking_config_name, - ) - yield partition - - # cleanup db - try: - mongo_client.drop_database(mongo_db_name) - except BaseException as e: - print("failed to cleanup mongo fixture", e) - - -def mongo_document_store_fn( - mongo_client, - root_verify_key, - mongo_db_name: str = "mongo_db", - locking_config_name: str = "nop", -): - locking_config = str_to_locking_config(locking_config_name) - mongo_config = MongoStoreClientConfig(client=mongo_client) - store_config = MongoStoreConfig( - client_config=mongo_config, db_name=mongo_db_name, locking_config=locking_config - ) - - mongo_client.drop_database(mongo_db_name) - - return MongoDocumentStore(UID(), root_verify_key, store_config=store_config) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def mongo_document_store(root_verify_key, mongo_client, request): - locking_config_name = request.param - mongo_db_name = token_hex(8) - yield mongo_document_store_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - locking_config_name=locking_config_name, - ) - - -def mongo_queue_stash_fn(mongo_document_store): +def mongo_queue_stash(mongo_document_store): return QueueStash(store=mongo_document_store) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def mongo_queue_stash(root_verify_key, mongo_client, request): - mongo_db_name = token_hex(8) - locking_config_name = request.param - - store = mongo_document_store_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - locking_config_name=locking_config_name, - ) - yield mongo_queue_stash_fn(store) From 6f7d3e094a89a724d1ece794aedc12c6d9f85a8d Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 9 Sep 2024 18:24:26 +0200 Subject: [PATCH 102/197] remove old methods --- .../syft/src/syft/service/api/api_stash.py | 5 +- .../src/syft/service/code/user_code_stash.py | 10 +- .../code_history/code_history_stash.py | 14 +- .../data_subject_member_service.py | 8 +- .../data_subject/data_subject_service.py | 4 +- .../src/syft/service/dataset/dataset_stash.py | 35 +-- .../syft/src/syft/service/job/job_stash.py | 24 +- .../migration/object_migration_state.py | 4 +- .../syft/service/network/network_service.py | 4 +- .../notification/notification_stash.py | 27 ++- .../src/syft/service/output/output_service.py | 15 +- .../syft/service/policy/user_policy_stash.py | 4 +- .../src/syft/service/project/project_stash.py | 10 +- .../src/syft/service/queue/queue_stash.py | 8 +- .../src/syft/service/request/request_stash.py | 10 +- .../src/syft/service/user/user_service.py | 4 +- .../syft/src/syft/service/user/user_stash.py | 30 +-- .../service/worker/image_registry_stash.py | 5 +- .../syft/service/worker/worker_pool_stash.py | 4 +- packages/syft/src/syft/store/db/query.py | 52 +++-- packages/syft/src/syft/store/db/stash.py | 212 +++++------------- .../syft/tests/syft/stores/base_stash_test.py | 36 +-- 22 files changed, 218 insertions(+), 307 deletions(-) diff --git a/packages/syft/src/syft/service/api/api_stash.py b/packages/syft/src/syft/service/api/api_stash.py index 26210891360..d88b1c37504 100644 --- a/packages/syft/src/syft/service/api/api_stash.py +++ b/packages/syft/src/syft/service/api/api_stash.py @@ -15,10 +15,9 @@ class TwinAPIEndpointStash(ObjectStash[TwinAPIEndpoint]): @as_result(StashException, NotFoundException) def get_by_path(self, credentials: SyftVerifyKey, path: str) -> TwinAPIEndpoint: # TODO standardize by returning None if endpoint doesnt exist. - res = self.get_one_by_field( + res = self.get_one( credentials=credentials, - field_name="path", - field_value=path, + filters={"path": path}, ).unwrap() if res is None: diff --git a/packages/syft/src/syft/service/code/user_code_stash.py b/packages/syft/src/syft/service/code/user_code_stash.py index 052b41af7ed..ba86950d61f 100644 --- a/packages/syft/src/syft/service/code/user_code_stash.py +++ b/packages/syft/src/syft/service/code/user_code_stash.py @@ -19,18 +19,16 @@ class UserCodeStash(ObjectStash[UserCode]): @as_result(StashException, NotFoundException) def get_by_code_hash(self, credentials: SyftVerifyKey, code_hash: str) -> UserCode: - return self.get_one_by_field( + return self.get_one( credentials=credentials, - field_name="code_hash", - field_value=code_hash, + filters={"code_hash": code_hash}, ).unwrap() @as_result(StashException) def get_by_service_func_name( self, credentials: SyftVerifyKey, service_func_name: str ) -> list[UserCode]: - return self.get_all_by_field( + return self.get_all( credentials=credentials, - field_name="service_func_name", - field_value=service_func_name, + filters={"service_func_name": service_func_name}, ).unwrap() diff --git a/packages/syft/src/syft/service/code_history/code_history_stash.py b/packages/syft/src/syft/service/code_history/code_history_stash.py index fb12b738782..8921f668095 100644 --- a/packages/syft/src/syft/service/code_history/code_history_stash.py +++ b/packages/syft/src/syft/service/code_history/code_history_stash.py @@ -16,9 +16,9 @@ def get_by_service_func_name_and_verify_key( service_func_name: str, user_verify_key: SyftVerifyKey, ) -> CodeHistory: - return self.get_one_by_fields( + return self.get_one( credentials=credentials, - fields={ + filters={ "user_verify_key": str(user_verify_key), "service_func_name": service_func_name, }, @@ -28,18 +28,16 @@ def get_by_service_func_name_and_verify_key( def get_by_service_func_name( self, credentials: SyftVerifyKey, service_func_name: str ) -> list[CodeHistory]: - return self.get_all_by_field( + return self.get_all( credentials=credentials, - field_name="service_func_name", - field_value=service_func_name, + filters={"service_func_name": service_func_name}, ) @as_result(StashException) def get_by_verify_key( self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey ) -> list[CodeHistory]: - return self.get_all_by_field( + return self.get_all( credentials=credentials, - field_name="user_verify_key", - field_value=str(user_verify_key), + filters={"user_verify_key": user_verify_key}, ).unwrap() diff --git a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py index b9f8c851b78..a7aec8e6e44 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py @@ -21,18 +21,18 @@ class DataSubjectMemberStash(ObjectStash[DataSubjectMemberRelationship]): def get_all_for_parent( self, credentials: SyftVerifyKey, name: str ) -> list[DataSubjectMemberRelationship]: - return self.get_all_by_fields( + return self.get_all( credentials=credentials, - fields={"parent": name}, + filters={"parent": name}, ).unwrap() @as_result(StashException) def get_all_for_child( self, credentials: SyftVerifyKey, name: str ) -> list[DataSubjectMemberRelationship]: - return self.get_all_by_fields( + return self.get_all( credentials=credentials, - fields={"child": name}, + filters={"child": name}, ).unwrap() diff --git a/packages/syft/src/syft/service/data_subject/data_subject_service.py b/packages/syft/src/syft/service/data_subject/data_subject_service.py index b66de42e349..72dbd6cec64 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_service.py @@ -24,9 +24,9 @@ class DataSubjectStash(ObjectStash[DataSubject]): @as_result(StashException) def get_by_name(self, credentials: SyftVerifyKey, name: str) -> DataSubject: - return self.get_one_by_fields( + return self.get_one( credentials=credentials, - fields={"name": name}, + filters={"name": name}, ).unwrap() @as_result(StashException) diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index 3b50c91fb1f..28b6934f672 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -1,6 +1,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey +from ...store.db.query import Filter from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException @@ -15,16 +16,13 @@ class DatasetStash(ObjectStash[Dataset]): @as_result(StashException, NotFoundException) def get_by_name(self, credentials: SyftVerifyKey, name: str) -> Dataset: - return self.get_one_by_field( - credentials=credentials, field_name="name", field_value=name - ).unwrap() + return self.get_one(credentials=credentials, filters={"name": name}).unwrap() @as_result(StashException) def search_action_ids(self, credentials: SyftVerifyKey, uid: UID) -> list[Dataset]: - return self.get_all_contains( + return self.get_all( credentials=credentials, - field_name="action_ids", - field_value=uid.no_dash, + filters={"action_ids": Filter("action_ids", "contains", uid)}, ).unwrap() @as_result(StashException) @@ -33,18 +31,21 @@ def get_all( credentials: SyftVerifyKey, has_permission: bool = False, order_by: str | None = None, - sort_order: str = "asc", + sort_order: str | None = None, limit: int | None = None, offset: int | None = None, ) -> list[Dataset]: # TODO standardize soft delete and move to ObjectStash.get_all - return self.get_all_by_field( - credentials=credentials, - has_permission=has_permission, - field_name="to_be_deleted", - field_value=False, - order_by=order_by, - sort_order=sort_order, - limit=limit, - offset=offset, - ).unwrap() + return ( + super() + .get_all( + credentials=credentials, + filters={"to_be_deleted": False}, + has_permission=has_permission, + order_by=order_by, + sort_order=sort_order, + limit=limit, + offset=offset, + ) + .unwrap() + ) diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 12e2e6dea32..de1ff22392a 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -762,33 +762,35 @@ def set_result( ) def get_active(self, credentials: SyftVerifyKey) -> list[Job]: - return self.get_all_by_field( - credentials=credentials, field_name="status", field_value=JobStatus.CREATED + return self.get_all( + credentials=credentials, + filters={"status": JobStatus.CREATED}, ).unwrap() def get_by_worker(self, credentials: SyftVerifyKey, worker_id: str) -> list[Job]: - return self.get_all_by_field( - credentials=credentials, field_name="worker_id", field_value=str(worker_id) + return self.get_all( + credentials=credentials, + filters={"job_worker_id": worker_id}, ).unwrap() @as_result(StashException) def get_by_user_code_id( self, credentials: SyftVerifyKey, user_code_id: UID ) -> list[Job]: - return self.get_all_by_field( + return self.get_all( credentials=credentials, - field_name="user_code_id", - field_value=str(user_code_id), + filter={"user_code_id": user_code_id}, ).unwrap() @as_result(StashException) def get_by_parent_id(self, credentials: SyftVerifyKey, uid: UID) -> list[Job]: - return self.get_all_by_field( - credentials=credentials, field_name="parent_job_id", field_value=str(uid) + return self.get_all( + credentials=credentials, + filters={"parent_job_id": uid}, ).unwrap() @as_result(StashException) def get_by_result_id(self, credentials: SyftVerifyKey, uid: UID) -> Job: - return self.get_one_by_field( - credentials=credentials, field_name="result_id", field_value=str(uid) + return self.get_one( + credentials=credentials, filters={"result_id": uid} ).unwrap() diff --git a/packages/syft/src/syft/service/migration/object_migration_state.py b/packages/syft/src/syft/service/migration/object_migration_state.py index e6f744c3b0e..dbb9cd4df91 100644 --- a/packages/syft/src/syft/service/migration/object_migration_state.py +++ b/packages/syft/src/syft/service/migration/object_migration_state.py @@ -73,9 +73,9 @@ class SyftMigrationStateStash(ObjectStash[SyftObjectMigrationState]): def get_by_name( self, canonical_name: str, credentials: SyftVerifyKey ) -> SyftObjectMigrationState: - return self.get_one_by_fields( + return self.get_one( credentials=credentials, - fields={"canonical_name": canonical_name}, + filters={"canonical_name": canonical_name}, ).unwrap() diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index a2285901796..5ab258ffe95 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -83,9 +83,9 @@ class NetworkStash(ObjectStash[ServerPeer]): @as_result(StashException, NotFoundException) def get_by_name(self, credentials: SyftVerifyKey, name: str) -> ServerPeer: try: - return self.get_one_by_fields( + return self.get_one( credentials=credentials, - fields={"name": name}, + filters={"name": name}, ).unwrap() except NotFoundException as e: raise NotFoundException.from_exception( diff --git a/packages/syft/src/syft/service/notification/notification_stash.py b/packages/syft/src/syft/service/notification/notification_stash.py index e97dc253359..e343d31adfc 100644 --- a/packages/syft/src/syft/service/notification/notification_stash.py +++ b/packages/syft/src/syft/service/notification/notification_stash.py @@ -1,5 +1,4 @@ # relative -from ...serde.json_serde import serialize_json from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...store.db.stash import ObjectStash @@ -20,8 +19,9 @@ def get_all_inbox_for_verify_key( ) -> list[Notification]: if not isinstance(verify_key, SyftVerifyKey | str): raise AttributeError("verify_key must be of type SyftVerifyKey or str") - return self.get_all_by_field( - credentials, field_name="to_user_verify_key", field_value=str(verify_key) + return self.get_all( + credentials, + filters={"to_user_verify_key": verify_key}, ).unwrap() @as_result(StashException) @@ -30,10 +30,9 @@ def get_all_sent_for_verify_key( ) -> list[Notification]: if not isinstance(verify_key, SyftVerifyKey | str): raise AttributeError("verify_key must be of type SyftVerifyKey or str") - return self.get_all_by_field( + return self.get_all( credentials, - field_name="from_user_verify_key", - field_value=str(verify_key), + filters={"from_user_verify_key": verify_key}, ).unwrap() @as_result(StashException) @@ -42,8 +41,9 @@ def get_all_for_verify_key( ) -> list[Notification]: if not isinstance(verify_key, SyftVerifyKey | str): raise AttributeError("verify_key must be of type SyftVerifyKey or str") - return self.get_all_by_field( - credentials, field_name="from_user_verify_key", field_value=str(verify_key) + return self.get_all( + credentials, + filters={"fromuser_verify_key": verify_key}, ).unwrap() @as_result(StashException) @@ -55,9 +55,9 @@ def get_all_by_verify_key_for_status( ) -> list[Notification]: if not isinstance(verify_key, SyftVerifyKey | str): raise AttributeError("verify_key must be of type SyftVerifyKey or str") - return self.get_all_by_fields( + return self.get_all( credentials, - fields={ + filters={ "to_user_verify_key": str(verify_key), "status": status.name, }, @@ -70,8 +70,11 @@ def get_notification_for_linked_obj( linked_obj: LinkedObject, ) -> Notification: # TODO does this work? - return self.get_one_by_fields( - credentials, fields={"linked_obj": serialize_json(linked_obj)} + return self.get_one( + credentials, + filters={ + "linked_obj": linked_obj, + }, ).unwrap() @as_result(StashException, NotFoundException) diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index 26ae4cb9d05..8de4b878355 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -187,30 +187,27 @@ class OutputStash(ObjectStash[ExecutionOutput]): def get_by_user_code_id( self, credentials: SyftVerifyKey, user_code_id: UID ) -> list[ExecutionOutput]: - return self.get_all_by_field( + return self.get_all( credentials=credentials, - field_name="user_code_id", - field_value=str(user_code_id), + filters={"user_code_id": user_code_id}, ).unwrap() @as_result(StashException) def get_by_job_id( self, credentials: SyftVerifyKey, job_id: UID ) -> ExecutionOutput | None: - return self.get_one_by_field( + return self.get_one( credentials=credentials, - field_name="job_id", - field_value=str(job_id), + filters={"job_id": job_id}, ).unwrap() @as_result(StashException) def get_by_output_policy_id( self, credentials: SyftVerifyKey, output_policy_id: UID ) -> list[ExecutionOutput]: - return self.get_all_by_field( + return self.get_all( credentials=credentials, - field_name="output_policy_id", - field_value=str(output_policy_id), + filters={"output_policy_id": output_policy_id}, ).unwrap() diff --git a/packages/syft/src/syft/service/policy/user_policy_stash.py b/packages/syft/src/syft/service/policy/user_policy_stash.py index 643e7e4e84a..860472273b8 100644 --- a/packages/syft/src/syft/service/policy/user_policy_stash.py +++ b/packages/syft/src/syft/service/policy/user_policy_stash.py @@ -16,7 +16,7 @@ class UserPolicyStash(ObjectStash[UserPolicy]): def get_all_by_user_verify_key( self, credentials: SyftVerifyKey, user_verify_key: SyftVerifyKey ) -> list[UserPolicy]: - return self.get_all_by_fields( + return self.get_all( credentials=credentials, - fields={"user_verify_key": str(user_verify_key)}, + filters={"user_verify_key": str(user_verify_key)}, ).unwrap() diff --git a/packages/syft/src/syft/service/project/project_stash.py b/packages/syft/src/syft/service/project/project_stash.py index 99bf5fb2397..051c293362d 100644 --- a/packages/syft/src/syft/service/project/project_stash.py +++ b/packages/syft/src/syft/service/project/project_stash.py @@ -20,16 +20,14 @@ class ProjectStash(ObjectStash[Project]): def get_all_for_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> list[Project]: - return self.get_all_by_field( + return self.get_all( credentials=credentials, - field_name="user_verify_key", - field_value=str(verify_key), + filters={"user_verify_key": verify_key}, ).unwrap() @as_result(StashException, NotFoundException) def get_by_name(self, credentials: SyftVerifyKey, project_name: str) -> Project: - return self.get_one_by_field( + return self.get_one( credentials=credentials, - field_name="name", - field_value=project_name, + filters={"name": project_name}, ).unwrap() diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index 6ddc2594abe..463addbdfdf 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -148,9 +148,9 @@ def get_by_status( self, credentials: SyftVerifyKey, status: Status ) -> list[QueueItem]: # TODO do we need json serialization for Status? - return self.get_all_by_fields( + return self.get_all( credentials=credentials, - fields={"status": status}, + filters={"status": status}, ).unwrap() @as_result(StashException) @@ -159,7 +159,7 @@ def _get_by_worker_pool( ) -> list[QueueItem]: worker_pool_id = worker_pool.object_uid - return self.get_all_by_fields( + return self.get_all( credentials=credentials, - fields={"worker_pool_id": worker_pool_id}, + filters={"worker_pool_id": worker_pool_id}, ).unwrap() diff --git a/packages/syft/src/syft/service/request/request_stash.py b/packages/syft/src/syft/service/request/request_stash.py index 290c9203ba1..7a2e27c603e 100644 --- a/packages/syft/src/syft/service/request/request_stash.py +++ b/packages/syft/src/syft/service/request/request_stash.py @@ -18,18 +18,16 @@ def get_all_for_verify_key( credentials: SyftVerifyKey, verify_key: SyftVerifyKey, ) -> list[Request]: - return self.get_all_by_field( + return self.get_all( credentials=credentials, - field_name="requesting_user_verify_key", - field_value=str(verify_key), + filters={"requesting_user_verify_key": verify_key}, ).unwrap() @as_result(SyftException) def get_by_usercode_id( self, credentials: SyftVerifyKey, user_code_id: UID ) -> list[Request]: - return self.get_all_by_field( + return self.get_all( credentials=credentials, - field_name="code_id", - field_value=str(user_code_id), + filters={"code_id": user_code_id}, ).unwrap() diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index ec23aa6e36f..d6ecdfb6f5a 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -395,8 +395,8 @@ def search( if len(kwargs) == 0: raise SyftException(public_message="Invalid search parameters") - users = self.stash.get_all_by_fields( - credentials=context.credentials, fields=kwargs + users = self.stash.get_all( + credentials=context.credentials, filters=kwargs ).unwrap() users = [user.to(UserView) for user in users] if users is not None else [] diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index f12d3f12476..11a44d47832 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -44,15 +44,17 @@ def admin_user(self) -> User: @as_result(StashException, NotFoundException) def get_by_reset_token(self, credentials: SyftVerifyKey, token: str) -> User: - return self.get_one_by_field( - credentials=credentials, field_name="reset_token", field_value=token - ).unwrap() + return self.get_one( + credentials=credentials, + filters={"reset_token": token}, + ) @as_result(StashException, NotFoundException) def get_by_email(self, credentials: SyftVerifyKey, email: str) -> User: - return self.get_one_by_field( - credentials=credentials, field_name="email", field_value=email - ).unwrap() + return self.get_one( + credentials=credentials, + filters={"email": email}, + ) @as_result(StashException) def email_exists(self, email: str) -> bool: @@ -65,8 +67,9 @@ def email_exists(self, email: str) -> bool: @as_result(StashException, NotFoundException) def get_by_role(self, credentials: SyftVerifyKey, role: ServiceRole) -> User: try: - return self.get_one_by_field( - credentials=credentials, field_name="role", field_value=role.name + return self.get_one( + credentials=credentials, + filters={"role": role}, ).unwrap() except NotFoundException as exc: private_msg = f"User with role {role} not found" @@ -77,10 +80,9 @@ def get_by_signing_key( self, credentials: SyftVerifyKey, signing_key: SyftSigningKey | str ) -> User: try: - return self.get_one_by_field( + return self.get_one( credentials=credentials, - field_name="signing_key", - field_value=str(signing_key), + filters={"signing_key": signing_key}, ).unwrap() except NotFoundException as exc: private_msg = f"User with signing key {signing_key} not found" @@ -91,11 +93,11 @@ def get_by_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> User: try: - return self.get_one_by_field( + return self.get_one( credentials=credentials, - field_name="verify_key", - field_value=str(verify_key), + filters={"verify_key": verify_key}, ).unwrap() + except NotFoundException as exc: private_msg = f"User with verify key {verify_key} not found" raise NotFoundException.from_exception(exc, private_message=private_msg) diff --git a/packages/syft/src/syft/service/worker/image_registry_stash.py b/packages/syft/src/syft/service/worker/image_registry_stash.py index 145270339ab..3ca7fe1d03f 100644 --- a/packages/syft/src/syft/service/worker/image_registry_stash.py +++ b/packages/syft/src/syft/service/worker/image_registry_stash.py @@ -20,8 +20,9 @@ def get_by_url( credentials: SyftVerifyKey, url: str, ) -> SyftImageRegistry | None: - return self.get_one_by_fields( - credentials=credentials, fields={"url": url} + return self.get_one( + credentials=credentials, + filters={"url": url}, ).unwrap() @as_result(SyftException, StashException) diff --git a/packages/syft/src/syft/service/worker/worker_pool_stash.py b/packages/syft/src/syft/service/worker/worker_pool_stash.py index 36c0799296a..f3490ce10a7 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_stash.py +++ b/packages/syft/src/syft/service/worker/worker_pool_stash.py @@ -19,9 +19,9 @@ class SyftWorkerPoolStash(ObjectStash[WorkerPool]): @as_result(StashException, NotFoundException) def get_by_name(self, credentials: SyftVerifyKey, pool_name: str) -> WorkerPool: - result = self.get_one_by_fields( + result = self.get_one( credentials=credentials, - fields={"name": pool_name}, + filters={"name": pool_name}, ) return result.unwrap( diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py index 95eba89909f..326548edcef 100644 --- a/packages/syft/src/syft/store/db/query.py +++ b/packages/syft/src/syft/store/db/query.py @@ -1,6 +1,7 @@ # stdlib from abc import ABC from abc import abstractmethod +from dataclasses import dataclass from typing import Any from typing import Literal @@ -28,6 +29,13 @@ from .sqlite_db import OBJECT_TYPE_TO_TABLE +@dataclass +class Filter: + field: str + operator: str + value: Any + + class Query(ABC): dialect: Dialect @@ -113,34 +121,50 @@ def filter(self, field: str, operator: str, value: Any) -> Self: return self - def order_by(self, field: str, order: Literal["asc", "desc"] = "asc") -> Self: - """Add an order by clause to the query. + def order_by( + self, + field: str | None = None, + order: Literal["asc", "desc"] | None = None, + ) -> Self: + """Add an order by clause to the query, with sensible defaults if field or order is not provided. Args: - field (str): field to order by. - order (Literal["asc", "desc"], optional): Order to use. - Defaults to "asc". + field (Optional[str]): field to order by. If None, uses the default field. + order (Optional[Literal["asc", "desc"]]): Order to use ("asc" or "desc"). + Defaults to 'asc' if field is provided and order is not, or the default order otherwise. Raises: ValueError: If the order is not "asc" or "desc" Returns: - Self: The query object with the order by clause applied + Self: The query object with the order by clause applied. """ - column = self._get_column(field) + # Determine the field and order defaults if not provided + if field is None: + if hasattr(self.object_type, "__order_by__"): + default_field, default_order = self.object_type.__order_by__ + else: + default_field, default_order = "_created_at", "desc" + field = default_field + else: + # If field is provided but order is not, default to 'asc' + default_order = "asc" + order = order or default_order + column = self._get_column(field) if order.lower() == "asc": self.stmt = self.stmt.order_by(column) elif order.lower() == "desc": self.stmt = self.stmt.order_by(column.desc()) else: - raise ValueError(f"Invalid sort order {order}") # type: ignore + raise ValueError(f"Invalid sort order {order}") return self - def limit(self, limit: int) -> Self: + def limit(self, limit: int | None) -> Self: """Add a limit clause to the query.""" - self.stmt = self.stmt.limit(limit) + if limit is not None: + self.stmt = self.stmt.limit(limit) return self def offset(self, offset: int) -> Self: @@ -155,14 +179,6 @@ def _make_permissions_clause( ) -> sa.sql.elements.BinaryExpression: pass - def default_order(self) -> Self: - if hasattr(self.object_type, "__order_by__"): - field, order = self.object_type.__order_by__ - else: - field, order = "_created_at", "desc" - - return self.order_by(field, order) - def _eq_filter( self, table: Table, diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index d50f27a9999..f619fccd38d 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -34,6 +34,7 @@ from ...types.uid import UID from ..document_store_errors import NotFoundException from ..document_store_errors import StashException +from .query import Filter from .query import PostgresQuery from .query import Query from .query import SQLiteQuery @@ -188,142 +189,6 @@ def _get_field_filter( elif self.db.engine.dialect.name == "postgresql": return table.c.fields[field_name].astext == json_value - def _get_by_fields( - self, - credentials: SyftVerifyKey, - fields: dict[str, str], - table: Table | None = None, - order_by: str | None = None, - sort_order: str | None = None, - limit: int | None = None, - offset: int | None = None, - has_permission: bool = False, - ) -> sa.Result: - query = self.query() - for field_name, field_value in fields.items(): - query = query.filter(field_name, "==", field_value) - - if not has_permission: - role = self.get_role(credentials) - query = query.with_permissions(credentials, role) - - if order_by and sort_order: - query = query.order_by(order_by, sort_order) - else: - query = query.default_order() - - if limit: - query = query.limit(limit) - if offset: - query = query.offset(offset) - - return query.execute(self.session) - - @as_result(SyftException, StashException, NotFoundException) - def get_one_by_field( - self, - credentials: SyftVerifyKey, - field_name: str, - field_value: str, - has_permission: bool = False, - ) -> StashT: - return self.get_one_by_fields( - credentials=credentials, - fields={field_name: field_value}, - has_permission=has_permission, - ).unwrap() - - @as_result(SyftException, StashException, NotFoundException) - def get_one_by_fields( - self, - credentials: SyftVerifyKey, - fields: dict[str, str], - has_permission: bool = False, - ) -> StashT: - result = self._get_by_fields( - credentials=credentials, - fields=fields, - has_permission=has_permission, - ).first() - if result is None: - raise NotFoundException(f"{self.object_type.__name__}: not found") - return self.row_as_obj(result) - - @as_result(SyftException, StashException, NotFoundException) - def get_all_by_fields( - self, - credentials: SyftVerifyKey, - fields: dict[str, str], - order_by: str | None = None, - sort_order: str | None = None, - limit: int | None = None, - offset: int | None = None, - has_permission: bool = False, - ) -> list[StashT]: - result = self._get_by_fields( - credentials=credentials, - fields=fields, - order_by=order_by, - sort_order=sort_order, - limit=limit, - offset=offset, - has_permission=has_permission, - ).all() - - return [self.row_as_obj(row) for row in result] - - @as_result(SyftException, StashException, NotFoundException) - def get_all_by_field( - self, - credentials: SyftVerifyKey, - field_name: str, - field_value: str, - order_by: str | None = None, - sort_order: str | None = None, - limit: int | None = None, - offset: int | None = None, - has_permission: bool = False, - ) -> list[StashT]: - return self.get_all_by_fields( - credentials=credentials, - fields={field_name: field_value}, - order_by=order_by, - sort_order=sort_order, - limit=limit, - offset=offset, - has_permission=has_permission, - ).unwrap() - - @as_result(SyftException, StashException, NotFoundException) - def get_all_contains( - self, - credentials: SyftVerifyKey, - field_name: str, - field_value: str, - order_by: str | None = None, - sort_order: str | None = None, - limit: int | None = None, - offset: int | None = None, - has_permission: bool = False, - ) -> list[StashT]: - query = self.query().filter(field_name, "contains", field_value) - - if not has_permission: - role = self.get_role(credentials) - query = query.with_permissions(credentials, role) - - if order_by and sort_order: - query = query.order_by(order_by, sort_order) - else: - query = query.default_order() - - if limit: - query = query.limit(limit) - if offset: - query = query.offset(offset) - - return query.execute(self.session).all() - @as_result(SyftException, StashException, NotFoundException) def get_index( self, credentials: SyftVerifyKey, index: int, has_permission: bool = False @@ -446,30 +311,6 @@ def _apply_permission_filter( ) return stmt - @as_result(StashException) - def get_all( - self, - credentials: SyftVerifyKey, - has_permission: bool = False, - order_by: str | None = None, - sort_order: str | None = None, - limit: int | None = None, - offset: int | None = None, - ) -> list[StashT]: - stmt = self.table.select() - - stmt = self._apply_permission_filter( - stmt, - credentials=credentials, - has_permission=has_permission, - permission=ActionPermission.READ, - ) - stmt = self._apply_order_by(stmt, order_by, sort_order) - stmt = self._apply_limit_offset(stmt, limit, offset) - - result = self.session.execute(stmt).all() - return [self.row_as_obj(row) for row in result] - @as_result(StashException, NotFoundException) def update( self, @@ -784,3 +625,54 @@ def set( self.session.execute(stmt) self.session.commit() return self.get_by_uid(credentials, uid).unwrap() + + @as_result(StashException) + def get_one( + self, + credentials: SyftVerifyKey, + filters: dict[str, Any] | None = None, + has_permission: bool = False, + order_by: str | None = None, + sort_order: str | None = None, + offset: int = 0, + ) -> StashT: + result = self.get_all( + credentials=credentials, + filters=filters, + has_permission=has_permission, + order_by=order_by, + sort_order=sort_order, + limit=1, + offset=offset, + ).unwrap() + if len(result) == 0: + raise NotFoundException(f"{self.object_type.__name__}: not found") + return result[0] + + @as_result(StashException) + def get_all( + self, + credentials: SyftVerifyKey, + filters: dict[str, Any] | None = None, + has_permission: bool = False, + order_by: str | None = None, + sort_order: str | None = None, + limit: int | None = None, + offset: int = 0, + ) -> list[StashT]: + query = self.query() + + if not has_permission: + role = self.get_role(credentials) + query = query.with_permissions(credentials, role) + + if filters: + for field_name, field_value in filters.items(): + if isinstance(field_value, Filter): + query = query.filter(field_name, field_value.op, field_value.value) + query = query.filter(field_name, "==", field_value) + + query = query.order_by(order_by, sort_order).limit(limit).offset(offset) + result = query.execute(self.session).all() + + return [self.row_as_obj(row) for row in result] diff --git a/packages/syft/tests/syft/stores/base_stash_test.py b/packages/syft/tests/syft/stores/base_stash_test.py index 3cdc920bd72..451db6fbd74 100644 --- a/packages/syft/tests/syft/stores/base_stash_test.py +++ b/packages/syft/tests/syft/stores/base_stash_test.py @@ -261,8 +261,9 @@ def test_basestash_query_one( base_stash.set(root_verify_key, obj) obj = random.choice(mock_objects) - result = base_stash.get_one_by_fields( - root_verify_key, fields={"name": obj.name} + result = base_stash.get_one( + root_verify_key, + filters={"name": obj.name}, ).unwrap() assert result == obj @@ -271,17 +272,24 @@ def test_basestash_query_one( random_name = create_unique(faker.name, existing_names) with pytest.raises(NotFoundException): - result = base_stash.get_one_by_fields( - root_verify_key, fields={"name": random_name} + result = base_stash.get_one( + root_verify_key, + filters={"name": random_name}, ).unwrap() params = {"name": obj.name, "desc": obj.desc} - result = base_stash.get_one_by_fields(root_verify_key, fields=params).unwrap() + result = base_stash.get_one( + root_verify_key, + filters=params, + ).unwrap() assert result == obj params = {"name": random_name, "desc": random_sentence(faker)} with pytest.raises(NotFoundException): - result = base_stash.get_one_by_fields(root_verify_key, fields=params).unwrap() + result = base_stash.get_one( + root_verify_key, + filters=params, + ).unwrap() def test_basestash_query_all( @@ -296,9 +304,7 @@ def test_basestash_query_all( for obj in all_objects: base_stash.set(root_verify_key, obj) - objects = base_stash.get_all_by_fields( - root_verify_key, fields={"desc": desc} - ).unwrap() + objects = base_stash.get_all(root_verify_key, filters={"desc": desc}).unwrap() assert len(objects) == n_same assert all(obj.desc == desc for obj in objects) original_object_values = {get_object_values(obj) for obj in similar_objects} @@ -309,15 +315,15 @@ def test_basestash_query_all( random_sentence, [obj.desc for obj in all_objects], faker ) - objects = base_stash.get_all_by_fields( - root_verify_key, fields={"desc": random_desc} + objects = base_stash.get_all( + root_verify_key, filters={"desc": random_desc} ).unwrap() assert len(objects) == 0 obj = random.choice(similar_objects) params = {"name": obj.name, "desc": obj.desc} - objects = base_stash.get_all_by_fields(root_verify_key, fields=params).unwrap() + objects = base_stash.get_all(root_verify_key, filters=params).unwrap() assert len(objects) == sum( 1 for obj_ in all_objects if (obj_.name, obj_.desc) == (obj.name, obj.desc) ) @@ -340,7 +346,7 @@ def test_basestash_query_all_kwargs_multiple_params( base_stash.set(root_verify_key, obj) params = {"importance": importance, "desc": desc} - objects = base_stash.get_all_by_fields(root_verify_key, fields=params).unwrap() + objects = base_stash.get_all(root_verify_key, filters=params).unwrap() assert len(objects) == n_same assert all(obj.desc == desc for obj in objects) original_object_values = {get_object_values(obj) for obj in similar_objects} @@ -351,12 +357,12 @@ def test_basestash_query_all_kwargs_multiple_params( "name": create_unique(faker.name, [obj.name for obj in all_objects]), "desc": random_sentence(faker), } - objects = base_stash.get_all_by_fields(root_verify_key, fields=params).unwrap() + objects = base_stash.get_all(root_verify_key, filters=params).unwrap() assert len(objects) == 0 obj = random.choice(similar_objects) params = {"id": obj.id, "name": obj.name, "desc": obj.desc} - objects = base_stash.get_all_by_fields(root_verify_key, fields=params).unwrap() + objects = base_stash.get_all(root_verify_key, filters=params).unwrap() assert len(objects) == 1 assert objects[0] == obj From 4f884ff61671d0f9888c294d69800e347b481d0d Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 9 Sep 2024 20:24:42 +0200 Subject: [PATCH 103/197] fix action store tests --- .../src/syft/protocol/releases/0.9.1.json | 1178 ----------------- packages/syft/src/syft/store/db/stash.py | 5 + packages/syft/tests/conftest.py | 21 +- .../syft/service/action/action_object_test.py | 2 +- .../tests/syft/stores/action_store_test.py | 194 ++- .../syft/stores/mongo_document_store_test.py | 614 --------- .../syft/stores/sqlite_document_store_test.py | 520 -------- .../tests/syft/stores/store_fixtures_test.py | 175 --- packages/syft/tests/syft/worker_test.py | 10 +- 9 files changed, 117 insertions(+), 2602 deletions(-) delete mode 100644 packages/syft/src/syft/protocol/releases/0.9.1.json delete mode 100644 packages/syft/tests/syft/stores/sqlite_document_store_test.py diff --git a/packages/syft/src/syft/protocol/releases/0.9.1.json b/packages/syft/src/syft/protocol/releases/0.9.1.json deleted file mode 100644 index 9c33a5d3a88..00000000000 --- a/packages/syft/src/syft/protocol/releases/0.9.1.json +++ /dev/null @@ -1,1178 +0,0 @@ -{ - "1": { - "object_versions": { - "SyftObjectVersioned": { - "1": { - "version": 1, - "hash": "7c842dcdbb57e2528ffa690ea18c19fff3c8a591811d40cad2b19be3100e2ff4", - "action": "add" - } - }, - "BaseDateTime": { - "1": { - "version": 1, - "hash": "614db484b1950be729902b1861bd3a7b33899176507c61cef11dc0d44611cfd3", - "action": "add" - } - }, - "SyftObject": { - "1": { - "version": 1, - "hash": "bb70d874355988908d3a92a3941d6613a6995a4850be3b6a0147f4d387724406", - "action": "add" - } - }, - "PartialSyftObject": { - "1": { - "version": 1, - "hash": "19a995fcc2833f4fab24584fd99b71a80c2ef1f13c06f83af79e4482846b1656", - "action": "add" - } - }, - "ServerMetadata": { - "1": { - "version": 1, - "hash": "1691c7667eca86b20c4189e90ce4e643dd41fd3682cdb69c6308878f2a6f135c", - "action": "add" - } - }, - "StoreConfig": { - "1": { - "version": 1, - "hash": "a9997fce6a8a0ed2884c58b8eb9382f8554bdd18fff61f8bf0451945bcff12c7", - "action": "add" - } - }, - "MongoDict": { - "1": { - "version": 1, - "hash": "57e36f57eed75e62b29e2bac1295035a9bf2c0e3c56719dac24cb6cc685be00b", - "action": "add" - } - }, - "MongoStoreConfig": { - "1": { - "version": 1, - "hash": "53342b27d34165b7e2699f8e7ad70d13d125875e6a75e8fa18f5796428f41036", - "action": "add" - } - }, - "LinkedObject": { - "1": { - "version": 1, - "hash": "d80f5ac7f51a9383be1a3cb334d56ae50e49733ed3199f3b6b5d6febd9de410b", - "action": "add" - } - }, - "BaseConfig": { - "1": { - "version": 1, - "hash": "10bd7566041d0f0a3aa295367785fdcc2c5bbf0ded984ac9230754f37496a6a7", - "action": "add" - }, - "2": { - "version": 2, - "hash": "890d2879ac44611db9b88ba9334a721130d0ac3aa18a303fa9e4081f14b9b8c7", - "action": "add" - } - }, - "ServiceConfig": { - "1": { - "version": 1, - "hash": "28af8a296f5ff63de50438277eaa1f4380682e6aca9f2ca28320d7a444825e88", - "action": "add" - }, - "2": { - "version": 2, - "hash": "93dfab144e0b0884c602358b3a9ce889bb29ab96e3b4adcfe3cef47a31694a9a", - "action": "add" - } - }, - "LibConfig": { - "1": { - "version": 1, - "hash": "ee8f0e3f6aae81948d72e30226645e8eb5d312a6770411a1edca748168c467c0", - "action": "add" - }, - "2": { - "version": 2, - "hash": "a8a78a8d726ee9e79f95614f3d0fa5b85edc6fce7be7651715208669be93e0e3", - "action": "add" - } - }, - "APIEndpoint": { - "1": { - "version": 1, - "hash": "faa1cf9336a0d1233868c8c57745ff38c0be60399dc1acd0c0e8dd440e405dbd", - "action": "add" - } - }, - "LibEndpoint": { - "1": { - "version": 1, - "hash": "a585c83a33a019d363ae5a0c6d4197193654307c19a4829dfbf8a8cfd2c1842a", - "action": "add" - } - }, - "SignedSyftAPICall": { - "1": { - "version": 1, - "hash": "2f959455f7130f4e59360b8aa58f19785b76eaa0f8a5a9188a6cbf32b31311ca", - "action": "add" - } - }, - "SyftAPICall": { - "1": { - "version": 1, - "hash": "59e89e7b9ea30deaed64d1ffd9bc0769b999d3082b305428432c1f5be36c6343", - "action": "add" - } - }, - "SyftAPIData": { - "1": { - "version": 1, - "hash": "820b279c581cafd9bb5009702d4e3db22ec3a3156676426304b9038dad260a24", - "action": "add" - } - }, - "SyftAPI": { - "1": { - "version": 1, - "hash": "cc13ab058ee36748c14b0d4bd9b9e894c7566fff09cfa4170b3eece520169f15", - "action": "add" - } - }, - "User": { - "1": { - "version": 1, - "hash": "2df4b68182c558dba5485a8a6867acf2a5c341b249ad67373a504098aa8c4343", - "action": "add" - }, - "2": { - "version": 2, - "hash": "af6fb5b2e1606e97838f4a60f0536ad95db606d455e94acbd1977df866608a2c", - "action": "add" - } - }, - "UserUpdate": { - "1": { - "version": 1, - "hash": "1bf6707c69b809c804fb939c7c37d787c2f6889508a4bec37d24221af2eb777a", - "action": "add" - } - }, - "UserCreate": { - "1": { - "version": 1, - "hash": "49d6087e2309ba59987f3126e286e74b3a66492a08ad82fa507ea17d52ce78e3", - "action": "add" - } - }, - "UserSearch": { - "1": { - "version": 1, - "hash": "9ac946338cca68d00d1696a57943442f062628ec3daf53077d0bdd3f72cd9fa0", - "action": "add" - } - }, - "UserView": { - "1": { - "version": 1, - "hash": "0b52d758e31d5889c9cd88afb467aae4a74e34a5276924e07012243c34d300fe", - "action": "add" - } - }, - "UserViewPage": { - "1": { - "version": 1, - "hash": "1cd6528d02ec180f080d5c35f0da760d8a59af9da7baaa9c17c1c7cedcc858fa", - "action": "add" - } - }, - "UserPrivateKey": { - "1": { - "version": 1, - "hash": "4817d8147aba94373f320dcd90e65f097cf6e5a2ef353aa8520e23128d522b5d", - "action": "add" - } - }, - "DateTime": { - "1": { - "version": 1, - "hash": "394abb554114ead4d63c36e3fe83ac018dead4b21a8465174009577c46d54c58", - "action": "add" - } - }, - "ReplyNotification": { - "1": { - "version": 1, - "hash": "84102dfc59d711b03c2f3d3a6ecaca000b6835f1bbdd9af801057f7aacb5f1d0", - "action": "add" - } - }, - "Notification": { - "1": { - "version": 1, - "hash": "af4cb232bff390c431e399975f048b34da7e940ace8b23b940a3b398c91c5326", - "action": "add" - } - }, - "CreateNotification": { - "1": { - "version": 1, - "hash": "7e426c946b7d5db6f9427960ec16042f3018091d835ca5966f3568c324a2ab53", - "action": "add" - } - }, - "UserNotificationActivity": { - "1": { - "version": 1, - "hash": "422fd01c6d9af38688a9982abd34e80794a1f6ddd444cca225d77f49189847a9", - "action": "add" - } - }, - "NotificationPreferences": { - "1": { - "version": 1, - "hash": "a42f06b367e7c6cbabcbf3cfcc84d1ca0873e457d972ebd060e87c9d6185f62b", - "action": "add" - } - }, - "NotifierSettings": { - "1": { - "version": 1, - "hash": "65c8ab814d35fac32f68d3000756692592cc59940f30e3af3dcdfa2328755b9d", - "action": "add" - }, - "2": { - "version": 2, - "hash": "be8b52597fc628d1b7cd22b776ee81416e1adbb04a45188778eb0e32ed1416b4", - "action": "add" - } - }, - "SyftImageRegistry": { - "1": { - "version": 1, - "hash": "67e18903e41cba1afe136adf29d404b63ec04fea6e928abb2533ec4fa52b246b", - "action": "add" - } - }, - "SyftWorkerImage": { - "1": { - "version": 1, - "hash": "44da7badfbe573d5403d3ab78c077f17dbefc560b81fdf927b671815be047441", - "action": "add" - } - }, - "SyftWorker": { - "1": { - "version": 1, - "hash": "9d897f6039eabe48dfa8e8d5c5cdcb283b0375b4c64571b457777eaaf3fb1920", - "action": "add" - } - }, - "WorkerPool": { - "1": { - "version": 1, - "hash": "16efc5dd2596ae744fd611c8f46af9eaec1bd5729eb20e85e9fd2f31df402564", - "action": "add" - } - }, - "MarkdownDescription": { - "1": { - "version": 1, - "hash": "31a73f8824cad1636a55d14b6a1074cdb071d0d4e16e86baaa3d4f63a7e80134", - "action": "add" - } - }, - "HTMLObject": { - "1": { - "version": 1, - "hash": "97f2e93f5ceaa88015047186f66a17ff13df2a6b7925b41331f9e19d5a515a9f", - "action": "add" - } - }, - "PwdTokenResetConfig": { - "1": { - "version": 1, - "hash": "0415a272428f22add4896c64aa9f29c8c1d35619e2433da6564eb5f1faff39ac", - "action": "add" - } - }, - "ServerSettingsUpdate": { - "1": { - "version": 1, - "hash": "1e4260ad879ae80728c3ffae2cd1d48759abd51f9d0960d4b25855cdbb4c506b", - "action": "add" - }, - "2": { - "version": 2, - "hash": "23b2716e9dceca667e228408e2416c82f11821e322e5bccf1f83406f3d09abdc", - "action": "add" - }, - "3": { - "version": 3, - "hash": "335c7946f2e52d09c7b26f511120cd340717c74c5cca9107e84f839da993c55c", - "action": "add" - }, - "4": { - "version": 4, - "hash": "8d7a41992c39c287fcb46383bed429ce75d3c9524ced8c86b88c26dd0232e2fe", - "action": "add" - } - }, - "ServerSettings": { - "1": { - "version": 1, - "hash": "5a1e7470cbeaaae5b80ac9beecb743734f7e4e42d429a09ea8defa569a5ddff1", - "action": "add" - }, - "2": { - "version": 2, - "hash": "7727ea54e494dc9deaa0d1bd38ac8a6180bc192b74eec5659adbc338a19e21f5", - "action": "add" - }, - "3": { - "version": 3, - "hash": "997667e1cba22d151857aacc2caba6b1ca73c1648adbd03461dc74a0c0c372b3", - "action": "add" - }, - "4": { - "version": 4, - "hash": "b8067777967a0e06733433e179e549caaf501419d62f7e8474ee33b839e3890d", - "action": "add" - } - }, - "HTTPConnection": { - "1": { - "version": 1, - "hash": "bf10f81646c71069c76292b1237b4a3de1e507264392c5c591d067636ce6fb46", - "action": "add" - } - }, - "PythonConnection": { - "1": { - "version": 1, - "hash": "28010778b5e3463ff6960a0e2224818de00bc7b5e6f892192e02e399ccbe18b5", - "action": "add" - } - }, - "ActionDataEmpty": { - "1": { - "version": 1, - "hash": "e0e4a5cf18d05b6b747addc048515c6f2a5f35f0766ebaee96d898cb971e1c5b", - "action": "add" - } - }, - "ObjectNotReady": { - "1": { - "version": 1, - "hash": "8cf471e205cd0893d6aae5f0227d14db7df1c9698da08a3ab991f59132d17fe9", - "action": "add" - } - }, - "ActionDataLink": { - "1": { - "version": 1, - "hash": "3469478343439e411b761c270eec63eb3d533e459ad72d0965158c3a6cdf3b9a", - "action": "add" - } - }, - "Action": { - "1": { - "version": 1, - "hash": "021826d7c6f69bd0283d025d40661f3ffbeba8810ca94de01344f6afbdae62cd", - "action": "add" - } - }, - "ActionObject": { - "1": { - "version": 1, - "hash": "0a5f4bc343cb114a251f06686ecdbb59d74bfb3d29a098b176699deb35a1e683", - "action": "add" - } - }, - "AnyActionObject": { - "1": { - "version": 1, - "hash": "b3c44c7788c59c03fa1baeec656c2ca6e633f4cbd4b23ff7ece6ee94c38449f0", - "action": "add" - } - }, - "CustomEndpointActionObject": { - "1": { - "version": 1, - "hash": "c7addbaf2777707f3e91e5c1e092343476cd22efc4ec8617f39ccf76e61a5a14", - "action": "add" - }, - "2": { - "version": 2, - "hash": "846ba36e8737a1bec16853c9de54c4948450009278e0b76fe7e3355ef9e70089", - "action": "add" - } - }, - "DataSubject": { - "1": { - "version": 1, - "hash": "582cdf9e82b5d6915b7f09f7c0d5f08328b11a2ce9b0198e5083f1672c2e2bf5", - "action": "add" - } - }, - "DataSubjectCreate": { - "1": { - "version": 1, - "hash": "5a8423c2690d55f425bfeecc87cd4a797a75d88ebb5fbda754d4f269b62d2ceb", - "action": "add" - } - }, - "DataSubjectMemberRelationship": { - "1": { - "version": 1, - "hash": "0810483ea76ea10c8f286c6035dc0b2085291f345183be50c179f3a05a577110", - "action": "add" - } - }, - "Contributor": { - "1": { - "version": 1, - "hash": "30c32bd44098f00e0b15496be441763b6e50af8b12d3d2bef33aca6287193876", - "action": "add" - } - }, - "Asset": { - "1": { - "version": 1, - "hash": "000abc78719611c106295cf12b1690b7e5411dc1bb9db9d4afd22956da90d1f4", - "action": "add" - } - }, - "CreateAsset": { - "1": { - "version": 1, - "hash": "357d52576cb12b24fb3980342bb49a562b065c0e4419e87d34176340628c7309", - "action": "add" - } - }, - "Dataset": { - "1": { - "version": 1, - "hash": "0ca6b0b4a3aebb2c8f351668075b44951bb20d1e23a779b82109124f334ce3a4", - "action": "add" - } - }, - "DatasetPageView": { - "1": { - "version": 1, - "hash": "aa0dd69637281b80d5523b4409a2c7e89db114c9fe79c858063c6dadff8977d1", - "action": "add" - }, - "2": { - "version": 2, - "hash": "be1ca6dcd0b3aa0481ce5dce737e78432d06a78ad0c701aaf136be407c798352", - "action": "add" - } - }, - "CreateDataset": { - "1": { - "version": 1, - "hash": "7e02dfa89540c3dbebacbb13810d95cdc4e36db31d56cffed7ab54abe25716c9", - "action": "add" - } - }, - "SyftLog": { - "1": { - "version": 1, - "hash": "1bcd71e5bf3f0db3bba0996f33b6b2bde3489b9c71f11e6b30c3495c76a8f53f", - "action": "add" - } - }, - "JobItem": { - "1": { - "version": 1, - "hash": "0b32277b7d3b9bdc14a2a51cc9005f8254e7f7b6ec059ddcccbcd681a807afb6", - "action": "add" - }, - "2": { - "version": 2, - "hash": "b087d0c62b7d304c6ca80e4fb0e8a7f2a444be8f8cba57490dc09aeb98033105", - "action": "add" - } - }, - "ExecutionOutput": { - "1": { - "version": 1, - "hash": "e36c71685edf5276a3427cb6749550486d3a177c1dcf73dd337ab2a73c0ce6b5", - "action": "add" - } - }, - "TwinObject": { - "1": { - "version": 1, - "hash": "4f31243fb348dbb083579afd6f638d75af010cb53d19bfba59b74afff41ccbbb", - "action": "add" - } - }, - "PolicyRule": { - "1": { - "version": 1, - "hash": "44d1ca1db97be46f66558aa1a729ff31bf8e113c6a913b11aedf9d6b6ad5b7b5", - "action": "add" - } - }, - "CreatePolicyRule": { - "1": { - "version": 1, - "hash": "342bb723526d445151a0435f57d251f4c1219f8ae7cca3e8e9fce52e2ee1b8b1", - "action": "add" - } - }, - "CreatePolicyRuleConstant": { - "1": { - "version": 1, - "hash": "78b54832cb0468a87013bc36bc11d4759874ca1b5065a1b711f1e5ef5d94c2df", - "action": "add" - } - }, - "Matches": { - "1": { - "version": 1, - "hash": "dd6d91ddb2ec5eaf60be2b0899ecfdb9a15f7904aa39d2f4d9bb2d7b793040e6", - "action": "add" - } - }, - "PreFill": { - "1": { - "version": 1, - "hash": "c7aefb11dc4c4569dcd1e6988371047a32a8be1b32ad46d12adba419a19769ad", - "action": "add" - } - }, - "UserOwned": { - "1": { - "version": 1, - "hash": "c8738dc3d8c2a5ef461b85a0467c3dff53dab16b54a4d12b44b1477906aef51d", - "action": "add" - } - }, - "MixedInputPolicy": { - "1": { - "version": 1, - "hash": "37bb12d950518d9579c8ec7c4cc22ac731ea82caf8c1370dd0b0a82b46462dde", - "action": "add" - } - }, - "ExactMatch": { - "1": { - "version": 1, - "hash": "5eb37edbf5e451d942e599247f3eaed923c1fe9d91eefdba02bf06503f6cc08d", - "action": "add" - } - }, - "OutputHistory": { - "1": { - "version": 1, - "hash": "9366db79d131f8c65e5a4ff12c90e2aa0c11e302debe06e46eeb93b26e2aaf61", - "action": "add" - } - }, - "OutputPolicyExecuteCount": { - "1": { - "version": 1, - "hash": "2a77e5ed5c7b0391147562651ad4061e20b11745c191fbc34cb549da37ba72dd", - "action": "add" - } - }, - "OutputPolicyExecuteOnce": { - "1": { - "version": 1, - "hash": "5589c00d127d9eb1f5ccf3a16def8219737784d57bb3bf9be5cb6d83325ef436", - "action": "add" - } - }, - "EmptyInputPolicy": { - "1": { - "version": 1, - "hash": "7ef81cfd223be0064600e1503f8b04bafc16385e27730e9319466e68a077c68b", - "action": "add" - } - }, - "UserPolicy": { - "1": { - "version": 1, - "hash": "74373bb71a334f4dcf77623ae10ff5b1c7e5b3006f65f2051ffb1e01f422f982", - "action": "add" - } - }, - "SubmitUserPolicy": { - "1": { - "version": 1, - "hash": "ec4e808eb39613bcdbbbf9ffb3267612084a9d99880a2f3bee3ef32d46329c02", - "action": "add" - } - }, - "UserCodeStatusCollection": { - "1": { - "version": 1, - "hash": "735ecf2d4abb1e7d19b2e751d880f32b01ce267ba10e417ef1b440be3d94d8f1", - "action": "add" - } - }, - "UserCode": { - "1": { - "version": 1, - "hash": "3bcd14413b9c4fbde7c5612c2ed713518340280b5cff89cf2aaaf1c77c4037a8", - "action": "add" - } - }, - "SubmitUserCode": { - "1": { - "version": 1, - "hash": "d2bb8cfe12f070b4adafded78ce01900c5409bd83f055f94b1e285745ef65a76", - "action": "add" - } - }, - "UserCodeExecutionResult": { - "1": { - "version": 1, - "hash": "1f4cbc62caac4dd193f427306405dc7a099ae744bea5830cf57149ce71c1e589", - "action": "add" - } - }, - "UserCodeExecutionOutput": { - "1": { - "version": 1, - "hash": "c1d53300a39dbbb437d7d5a1257bd175a067b1065f4099a0938fac7540035258", - "action": "add" - }, - "2": { - "version": 2, - "hash": "3e104e39b4ab53c950e61e4f7e92ce935cf96a5100de301de9bf297eb7e5787e", - "action": "add" - } - }, - "CodeHistory": { - "1": { - "version": 1, - "hash": "e3ef5346f108257828f364d22b12d9311812c9cf843200afef5dc4d9302f9b21", - "action": "add" - } - }, - "CodeHistoryView": { - "1": { - "version": 1, - "hash": "8b8b97d334b51d1ce0a9efab722411ff25caa3f12be319105954497e0a306eb2", - "action": "add" - } - }, - "CodeHistoriesDict": { - "1": { - "version": 1, - "hash": "01d7dcd4b21525a06e4484d8699a4a34a5c84f1f6026ec55e32eb30412742601", - "action": "add" - } - }, - "UsersCodeHistoriesDict": { - "1": { - "version": 1, - "hash": "4ed8b83973258ea19a1f91feb2590ff73b801be86f4296cc3db48f6929ff784c", - "action": "add" - } - }, - "BlobFile": { - "1": { - "version": 1, - "hash": "d99239100f1cb0b73c69b2ad7cab01a06909cc3a4976ba2b3b67cf6fe5e2f516", - "action": "add" - } - }, - "BlobFileOBject": { - "1": { - "version": 1, - "hash": "6c40dab2c8d2220d4fff7cc653d76cc026a856db7e2b5713b6341e255adc7ea2", - "action": "add" - } - }, - "SecureFilePathLocation": { - "1": { - "version": 1, - "hash": "ea5978b98d7773d221665b450454c9130c103a5c850669a0acd620607cd614b7", - "action": "add" - } - }, - "SeaweedSecureFilePathLocation": { - "1": { - "version": 1, - "hash": "3fc9bfc8c1b1cf660c9747e8c1fe3eb2220e78d4e3b5d6b5c5f29a07a77ebf3e", - "action": "add" - } - }, - "AzureSecureFilePathLocation": { - "1": { - "version": 1, - "hash": "090a9e962eeb655586ee966c5651d8996363969818a38f9a486fd64d33047e05", - "action": "add" - } - }, - "BlobStorageEntry": { - "1": { - "version": 1, - "hash": "afdc6a1d8a24b1ee1ed9d3e79f5bac64b4f0d9d36800f07f10be0b896470345f", - "action": "add" - } - }, - "BlobStorageMetadata": { - "1": { - "version": 1, - "hash": "9d4b61ac4ea1910c2f7c767a50a6a52544a24663548f069e79bd906f11b538e4", - "action": "add" - } - }, - "CreateBlobStorageEntry": { - "1": { - "version": 1, - "hash": "ffc3cbfeade67d074dc5bf7d655a1eb8c83630076028a72b3cc4548f3b413e14", - "action": "add" - } - }, - "SyftObjectMigrationState": { - "1": { - "version": 1, - "hash": "ee83315828551f18904bab18e0cac48896493620561215b04cc448e6ce5834af", - "action": "add" - } - }, - "StoreMetadata": { - "1": { - "version": 1, - "hash": "8de9a22a2765ef976bc161cb0704347d30350c085da8c8ffa876065cfca3e5fd", - "action": "add" - } - }, - "MigrationData": { - "1": { - "version": 1, - "hash": "cb96b8c8413609e1224341d1b0dd1efb08387c0ff7b0ff65eba36c0b104c9ed1", - "action": "add" - }, - "2": { - "version": 2, - "hash": "1d1b14c196221ecf6d644d7dcaa32ac9e90361b2687fa83161ff399ebc6df1bd", - "action": "add" - } - }, - "BlobRetrieval": { - "1": { - "version": 1, - "hash": "c422c74b89a9349742acaa848566fe18bfef1a83333458b858c074baed37a859", - "action": "add" - } - }, - "SyftObjectRetrieval": { - "1": { - "version": 1, - "hash": "b2b62447445adc4cd0b77ab59d6fa56624dd316fb50281e570daad07556b6db2", - "action": "add" - } - }, - "BlobRetrievalByURL": { - "1": { - "version": 1, - "hash": "4db0e3b7a6334d3835356d8393866711e243e360af25a95f3cc4066f032404b5", - "action": "add" - } - }, - "BlobDeposit": { - "1": { - "version": 1, - "hash": "6eb5cc57dc763126bfc6ec5a2b79d02e77eadf9d9efb1888a5c366b7799c1c24", - "action": "add" - } - }, - "OnDiskBlobDeposit": { - "1": { - "version": 1, - "hash": "817bf1bee4a35bfa1cd25d6779a10d8d180b1b3f1e837952f81f48b9411d1970", - "action": "add" - } - }, - "RemoteConfig": { - "1": { - "version": 1, - "hash": "179d067099a178d748c6d9a0477e8de7c3b55577439669eca7150258f2409567", - "action": "add" - } - }, - "AzureRemoteConfig": { - "1": { - "version": 1, - "hash": "a143811fec0da5fd881e927643ef667c91c78a2c90519cf88da7da20738bd187", - "action": "add" - } - }, - "SeaweedFSBlobDeposit": { - "1": { - "version": 1, - "hash": "febeb2a2ce81aa2c512e4c6b611b582984042aafa0541403d4584662273a166c", - "action": "add" - } - }, - "DictStoreConfig": { - "1": { - "version": 1, - "hash": "2e1365c5535fa51c22eef79f67dd6444789bc829c27881367e3050e06e2ffbfe", - "action": "add" - } - }, - "NumpyArrayObject": { - "1": { - "version": 1, - "hash": "05dd2917b7692b3daf4e7ad083a46fa7ec7a2be8faac8d4a654809189c986443", - "action": "add" - } - }, - "NumpyScalarObject": { - "1": { - "version": 1, - "hash": "8753e5c78270a5cacbf0439447724772f4765351a4a8b58b0a5c416a6b2c8b6e", - "action": "add" - } - }, - "NumpyBoolObject": { - "1": { - "version": 1, - "hash": "331c44f8fa3d0a077f1aaad7313bae2c43b386d04def7b8bedae9fdf7690134d", - "action": "add" - } - }, - "PandasDataframeObject": { - "1": { - "version": 1, - "hash": "5e8018364cea31d5f185a901da4ab89846b02153ee7d041ee8a6d305ece31f90", - "action": "add" - } - }, - "PandasSeriesObject": { - "1": { - "version": 1, - "hash": "b8bd482bf16fc7177e9778292cd42f8835b6ced2ce8dc88908b4b8e6d7c7528f", - "action": "add" - } - }, - "Change": { - "1": { - "version": 1, - "hash": "75fb9a5cd4e76b189ebe130a421d3921a0c251947a48bbb92a2ef1c315dc3c16", - "action": "add" - } - }, - "ChangeStatus": { - "1": { - "version": 1, - "hash": "c914a6f7637b555a51b71e8e197e591f7a2e28121e29b5dd586f87e0383d179d", - "action": "add" - } - }, - "ActionStoreChange": { - "1": { - "version": 1, - "hash": "1a803bb08924b49f3114fd46e0e132f819d4d56be5e03a27e9fe90947ca26e85", - "action": "add" - } - }, - "CreateCustomImageChange": { - "1": { - "version": 1, - "hash": "c3dbea3f49979fdcc517c0d13cd02739ca2fe86b370c42496a224f142ae31562", - "action": "add" - } - }, - "CreateCustomWorkerPoolChange": { - "1": { - "version": 1, - "hash": "0355793dd58b364dcb84fff29714b6a26446bead3ba95c6d75e3200008e580f4", - "action": "add" - } - }, - "Request": { - "1": { - "version": 1, - "hash": "1d69f5f0074114f99aa29c5ee77cb20b9151e5b50e77b026f11c3632a12efadf", - "action": "add" - } - }, - "RequestInfo": { - "1": { - "version": 1, - "hash": "779562547744ebed64548f8021647292604fdf4256bf79685dfa14a1e56cc27b", - "action": "add" - } - }, - "RequestInfoFilter": { - "1": { - "version": 1, - "hash": "bb881a003032f4676321218d7cd09580f4d64fccaa1cf9e118fdcd5c73c3d3a8", - "action": "add" - } - }, - "SubmitRequest": { - "1": { - "version": 1, - "hash": "6c38b6ffd0a6f7442746e68b9ace7b21cb1dca7d2031929db5f9a302a280403f", - "action": "add" - } - }, - "ObjectMutation": { - "1": { - "version": 1, - "hash": "ce88096760ce9334599c8194ec97b0a1470651ad680d9d21b8826a0df0af2a36", - "action": "add" - } - }, - "EnumMutation": { - "1": { - "version": 1, - "hash": "5173fda73df17a344eb663b7692cca48bd46bf1773455439836b852cd165448c", - "action": "add" - } - }, - "UserCodeStatusChange": { - "1": { - "version": 1, - "hash": "89aaf7f1368c782e3a1b9e79988877f6eaa05ab84365f7d321b757fde7fe86e7", - "action": "add" - } - }, - "SyncedUserCodeStatusChange": { - "1": { - "version": 1, - "hash": "d9ad2d341eb645bd50d06330cd30fd4c266f93e37b9f5391d58b78365fc440e6", - "action": "add" - } - }, - "TwinAPIContextView": { - "1": { - "version": 1, - "hash": "e099eef32cb3a8a806cbdc54cc7fca96bed3d60344bd571163ec049db407938b", - "action": "add" - } - }, - "CustomAPIView": { - "1": { - "version": 1, - "hash": "769e96bebd05736ab860591670fb6da19406239b0104ddc71bd092a134335146", - "action": "add" - } - }, - "CustomApiEndpoint": { - "1": { - "version": 1, - "hash": "ec4a217585336d1b59c93c18570443a63f4fbb24d2c088fbacf80bcf389d23e8", - "action": "add" - } - }, - "PrivateAPIEndpoint": { - "1": { - "version": 1, - "hash": "6d7d143432c2811c520ab6dade005ba40173b590e5c676be04f5921b970ef938", - "action": "add" - } - }, - "PublicAPIEndpoint": { - "1": { - "version": 1, - "hash": "3bf51fc33aa8feb1abc9d0ef792e8889da31a57050430e0bd8e17f2065ff8734", - "action": "add" - } - }, - "UpdateTwinAPIEndpoint": { - "1": { - "version": 1, - "hash": "851e59412716e73c7f70a696619e0b375ce136b43f6fe2ea784747091caba5d8", - "action": "add" - } - }, - "CreateTwinAPIEndpoint": { - "1": { - "version": 1, - "hash": "3d0b84dae95ebcc6647b5aabe54e65b3c6bf957665fde57d8037806a4aac13be", - "action": "add" - } - }, - "TwinAPIEndpoint": { - "1": { - "version": 1, - "hash": "d1947b8f9c80d6c9b443e5a9f0758afa8849a5f12b9a511feefd7e4f82c374f4", - "action": "add" - } - }, - "SyncState": { - "1": { - "version": 1, - "hash": "9a3f0bb973858b55bc766c9770c4d9abcc817898f797d94a89938650c0c67868", - "action": "add" - } - }, - "WorkerSettings": { - "1": { - "version": 1, - "hash": "dca33003904a71688e5b07db65f8833eb4de8135aade7154076b8eafbb94d26b", - "action": "add" - } - }, - "HTTPServerRoute": { - "1": { - "version": 1, - "hash": "938245604a9c7e50001299afff5b669b2548364e356fed22a22780497831bf81", - "action": "add" - } - }, - "PythonServerRoute": { - "1": { - "version": 1, - "hash": "a068d8f942d55ecb6d45af88a27c6ebf208584275bf589cbc308df3f774ab9a9", - "action": "add" - } - }, - "VeilidServerRoute": { - "1": { - "version": 1, - "hash": "e676bc165601d2ede69707a4b6168ed4674f3f98887026d098a2dd4da4dfd097", - "action": "add" - } - }, - "ServerPeer": { - "1": { - "version": 1, - "hash": "0d5f252018e324ea0d2dcb5c2ad8bd15707220565fce4f14de7f63a8f9e4391b", - "action": "add" - } - }, - "ServerPeerUpdate": { - "1": { - "version": 1, - "hash": "0b854b57db7a18118c1fd8f31495b2ba4eeb9fbe4f24c631ff112418a94570d3", - "action": "add" - } - }, - "AssociationRequestChange": { - "1": { - "version": 1, - "hash": "0134ac0002879c85fc9ddb06bed6306a8905c8434b0a40d3a96ce24a7bd4da90", - "action": "add" - } - }, - "QueueItem": { - "1": { - "version": 1, - "hash": "1db212c46b6c56ccc5579cfe2141b693f0cd9286e2ede71210393e8455379bf1", - "action": "add" - } - }, - "ActionQueueItem": { - "1": { - "version": 1, - "hash": "396d579dfc2e2b36b9fbed2f204bffcca1bea7ee2db7175045dd3328ebf08718", - "action": "add" - } - }, - "APIEndpointQueueItem": { - "1": { - "version": 1, - "hash": "f04b3990a8d29c116d301e70df54d58f188895307a411dc13a666ff764ffd8dd", - "action": "add" - } - }, - "ZMQClientConfig": { - "1": { - "version": 1, - "hash": "36ee8f75067d5144f0ed062cdc79466caae16b7a128231d89b6b430174843bde", - "action": "add" - } - }, - "SQLiteStoreConfig": { - "1": { - "version": 1, - "hash": "ad062a5f863ae84683867d2a6a5e1d4420c010a64b88bc7b392106e33d71ac03", - "action": "add" - } - }, - "ProjectEvent": { - "1": { - "version": 1, - "hash": "dc0486c52daebd5e98c2b3b03ffd9a9a14bc3d86d8dc0c23e41ebf6c31fe2ffb", - "action": "add" - } - }, - "ProjectThreadMessage": { - "1": { - "version": 1, - "hash": "99256d7592577d1e37df94a06eabc0a287f2d79e144c51fd719315e278edb46d", - "action": "add" - } - }, - "ProjectMessage": { - "1": { - "version": 1, - "hash": "b5004b6354f71b19c81dd5f4b20bf446e0b959f5608a22707e96b944dd8175b0", - "action": "add" - } - }, - "ProjectRequestResponse": { - "1": { - "version": 1, - "hash": "52162a8a779a4a301d8755691bf4cf994c86b9f650f9e8c8a923b44e635b1bc0", - "action": "add" - } - }, - "ProjectRequest": { - "1": { - "version": 1, - "hash": "dc684135d5a5a48e5fc7988598c1e6e0de76cf1c5995f1c283fcf63d0eb4d24f", - "action": "add" - } - }, - "AnswerProjectPoll": { - "1": { - "version": 1, - "hash": "c83d83a5ba6cc034d5061df200b3f1d029aa770b1e13dbef959bb1790323dc6e", - "action": "add" - } - }, - "ProjectPoll": { - "1": { - "version": 1, - "hash": "ecf69b3b324e0bee9c82295796d44c4e8f796496cdc9db6d4302c2f160566466", - "action": "add" - } - }, - "Project": { - "1": { - "version": 1, - "hash": "de86a1163ddbcd1cc3cc2b1b5dfcb85a8ad9f9d4bbc759c2b1f92a0d0a2ff184", - "action": "add" - } - }, - "ProjectSubmit": { - "1": { - "version": 1, - "hash": "7555ba11ee5a814dcd9c45647300020f7359efc1081559940990cbd745936cac", - "action": "add" - } - }, - "Plan": { - "1": { - "version": 1, - "hash": "ed05cb87aec832098fc464ac36cd6bceaab705463d0d2fa1b2d8e1ccc510018c", - "action": "add" - } - }, - "EnclaveMetadata": { - "1": { - "version": 1, - "hash": "8d2dfafa01ec909c080a790cf15a8fc78e00382d3bfe6207098ceb25a60b9c53", - "action": "add" - } - } - } - } -} diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 6c4ed42f02f..bb69d40544e 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -667,6 +667,11 @@ def get_all_storage_permissions(self) -> dict[UID, set[UID]]: } def has_permission(self, permission: ActionObjectPermission) -> bool: + if self.get_role(permission.credentials) in ( + ServiceRole.ADMIN, + ServiceRole.DATA_OWNER, + ): + return True return self.has_permissions([permission]) def has_storage_permission(self, permission: StoragePermission) -> bool: diff --git a/packages/syft/tests/conftest.py b/packages/syft/tests/conftest.py index 1b539202716..a541383235d 100644 --- a/packages/syft/tests/conftest.py +++ b/packages/syft/tests/conftest.py @@ -22,14 +22,12 @@ from syft.protocol.data_protocol import protocol_release_dir from syft.protocol.data_protocol import stage_protocol_changes from syft.server.worker import Worker +from syft.service.queue.queue_stash import QueueStash from syft.service.user import user # relative # our version of mongomock that has a fix for CodecOptions and custom TypeRegistry Support from .mongomock.mongo_client import MongoClient -from .syft.stores.store_fixtures_test import sqlite_action_store -from .syft.stores.store_fixtures_test import sqlite_document_store -from .syft.stores.store_fixtures_test import sqlite_store_partition def patch_protocol_file(filepath: Path): @@ -296,11 +294,18 @@ def big_dataset() -> Dataset: yield dataset -__all__ = [ - "sqlite_store_partition", - "sqlite_document_store", - "sqlite_action_store", -] +@pytest.fixture( + scope="function", + params=[ + "tODOsqlite_address", + "TODOpostgres_address", + ], +) +def queue_stash(request): + _ = request.param + stash = QueueStash.random() + yield stash + pytest_plugins = [ "tests.syft.users.fixtures", diff --git a/packages/syft/tests/syft/service/action/action_object_test.py b/packages/syft/tests/syft/service/action/action_object_test.py index 5252ec1eb33..1caa0a4ad2e 100644 --- a/packages/syft/tests/syft/service/action/action_object_test.py +++ b/packages/syft/tests/syft/service/action/action_object_test.py @@ -1001,7 +1001,7 @@ def test_actionobject_syft_getattr_float_history(): @pytest.mark.skipif( - sys.platform != "linux", + sys.platform == "win32", reason="This is a hackish way to test attribute set/get, and it might fail on Windows or OSX", ) def test_actionobject_syft_getattr_np(worker): diff --git a/packages/syft/tests/syft/stores/action_store_test.py b/packages/syft/tests/syft/stores/action_store_test.py index 204235a421c..239ec8d33ba 100644 --- a/packages/syft/tests/syft/stores/action_store_test.py +++ b/packages/syft/tests/syft/stores/action_store_test.py @@ -1,25 +1,24 @@ # stdlib -import sys -from typing import Any # third party import pytest # syft absolute +from syft.server.credentials import SyftSigningKey from syft.server.credentials import SyftVerifyKey +from syft.service.action.action_object import ActionObject from syft.service.action.action_permissions import ActionObjectOWNER +from syft.service.action.action_permissions import ActionObjectPermission from syft.service.action.action_store import ActionObjectEXECUTE from syft.service.action.action_store import ActionObjectREAD +from syft.service.action.action_store import ActionObjectStash from syft.service.action.action_store import ActionObjectWRITE +from syft.service.user.user import User +from syft.service.user.user_roles import ServiceRole +from syft.service.user.user_stash import UserStash +from syft.store.db.sqlite_db import DBManager from syft.types.uid import UID -# relative -from .store_constants_test import TEST_VERIFY_KEY_NEW_ADMIN -from .store_constants_test import TEST_VERIFY_KEY_STRING_CLIENT -from .store_constants_test import TEST_VERIFY_KEY_STRING_HACKER -from .store_constants_test import TEST_VERIFY_KEY_STRING_ROOT -from .store_mocks_test import MockSyftObject - permissions = [ ActionObjectOWNER, ActionObjectREAD, @@ -28,134 +27,119 @@ ] -@pytest.mark.parametrize( - "store", - [ - pytest.lazy_fixture("dict_action_store"), - pytest.lazy_fixture("sqlite_action_store"), - pytest.lazy_fixture("mongo_action_store"), - ], -) -def test_action_store_sanity(store: Any): - assert hasattr(store, "store_config") - assert hasattr(store, "settings") - assert hasattr(store, "data") - assert hasattr(store, "permissions") - assert hasattr(store, "root_verify_key") - assert store.root_verify_key.verify == TEST_VERIFY_KEY_STRING_ROOT +def add_user(db_manager: DBManager, role: ServiceRole) -> SyftVerifyKey: + user_stash = UserStash(store=db_manager) + verify_key = SyftSigningKey.generate().verify_key + user_stash.set( + credentials=db_manager.root_verify_key, + obj=User(verify_key=verify_key, role=role, id=UID()), + ).unwrap() + return verify_key + + +def add_test_object( + stash: ActionObjectStash, verify_key: SyftVerifyKey +) -> ActionObject: + test_object = ActionObject.from_obj([1, 2, 3]) + uid = test_object.id + stash.set_or_update( + uid=uid, + credentials=verify_key, + syft_object=test_object, + has_result_read_permission=True, + ).unwrap() + return uid @pytest.mark.parametrize( - "store", + "stash", [ - pytest.lazy_fixture("dict_action_store"), - pytest.lazy_fixture("sqlite_action_store"), - pytest.lazy_fixture("mongo_action_store"), + pytest.lazy_fixture("action_object_stash"), ], ) @pytest.mark.parametrize("permission", permissions) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -@pytest.mark.skipif(sys.platform == "darwin", reason="skip on mac") -def test_action_store_test_permissions(store: Any, permission: Any): - client_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_CLIENT) - root_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT) - hacker_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - new_admin_key = TEST_VERIFY_KEY_NEW_ADMIN - - access = permission(uid=UID(), credentials=client_key) - access_root = permission(uid=UID(), credentials=root_key) - access_hacker = permission(uid=UID(), credentials=hacker_key) - access_new_admin = permission(uid=UID(), credentials=new_admin_key) - - # add permission - store.add_permission(access) - - assert store.has_permission(access) - assert store.has_permission(access_root) - assert store.has_permission(access_new_admin) - assert not store.has_permission(access_hacker) +def test_action_store_test_permissions( + stash: ActionObjectStash, permission: ActionObjectPermission +) -> None: + client_key = add_user(stash.db, ServiceRole.DATA_SCIENTIST) + root_key = add_user(stash.db, ServiceRole.ADMIN) + hacker_key = add_user(stash.db, ServiceRole.DATA_SCIENTIST) + new_admin_key = add_user(stash.db, ServiceRole.ADMIN) + + test_item_id = add_test_object(stash, client_key) + + access = permission(uid=test_item_id, credentials=client_key) + access_root = permission(uid=test_item_id, credentials=root_key) + access_hacker = permission(uid=test_item_id, credentials=hacker_key) + access_new_admin = permission(uid=test_item_id, credentials=new_admin_key) + + stash.add_permission(access) + assert stash.has_permission(access) + assert stash.has_permission(access_root) + assert stash.has_permission(access_new_admin) + assert not stash.has_permission(access_hacker) # remove permission - store.remove_permission(access) + stash.remove_permission(access) - assert not store.has_permission(access) - assert store.has_permission(access_root) - assert store.has_permission(access_new_admin) - assert not store.has_permission(access_hacker) + assert not stash.has_permission(access) + assert stash.has_permission(access_root) + assert stash.has_permission(access_new_admin) + assert not stash.has_permission(access_hacker) # take ownership with new UID - client_uid2 = UID() - access = permission(uid=client_uid2, credentials=client_key) + item2_id = add_test_object(stash, client_key) + access = permission(uid=item2_id, credentials=client_key) - store.take_ownership(client_uid2, client_key) - assert store.has_permission(access) - assert store.has_permission(access_root) - assert store.has_permission(access_new_admin) - assert not store.has_permission(access_hacker) + stash.add_permission(ActionObjectREAD(uid=item2_id, credentials=client_key)) + assert stash.has_permission(access) + assert stash.has_permission(access_root) + assert stash.has_permission(access_new_admin) + assert not stash.has_permission(access_hacker) # delete UID as hacker - access_hacker_ro = ActionObjectREAD(uid=UID(), credentials=hacker_key) - store.add_permission(access_hacker_ro) - res = store.delete(client_uid2, hacker_key) + res = stash.delete_by_uid(hacker_key, item2_id) assert res.is_err() - assert store.has_permission(access) - assert store.has_permission(access_new_admin) - assert store.has_permission(access_hacker_ro) + assert stash.has_permission(access) + assert stash.has_permission(access_root) + assert stash.has_permission(access_new_admin) + assert not stash.has_permission(access_hacker) # delete UID as owner - res = store.delete(client_uid2, client_key) + res = stash.delete_by_uid(client_key, item2_id) assert res.is_ok() - assert not store.has_permission(access) - assert store.has_permission(access_new_admin) - assert not store.has_permission(access_hacker) + assert not stash.has_permission(access) + assert stash.has_permission(access_new_admin) + assert not stash.has_permission(access_hacker) @pytest.mark.parametrize( - "store", + "stash", [ - pytest.lazy_fixture("dict_action_store"), - pytest.lazy_fixture("sqlite_action_store"), - pytest.lazy_fixture("mongo_action_store"), + pytest.lazy_fixture("action_object_stash"), ], ) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_action_store_test_dataset_get(store: Any): - client_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_CLIENT) - root_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT) - SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) +def test_action_store_test_dataset_get(stash: ActionObjectStash) -> None: + client_key = add_user(stash.db, ServiceRole.DATA_SCIENTIST) + root_key = add_user(stash.db, ServiceRole.ADMIN) - permission_only_uid = UID() - access = ActionObjectWRITE(uid=permission_only_uid, credentials=client_key) - access_root = ActionObjectWRITE(uid=permission_only_uid, credentials=root_key) - read_permission = ActionObjectREAD(uid=permission_only_uid, credentials=client_key) + data_uid = add_test_object(stash, client_key) + access = ActionObjectWRITE(uid=data_uid, credentials=client_key) + access_root = ActionObjectWRITE(uid=data_uid, credentials=root_key) + read_permission = ActionObjectREAD(uid=data_uid, credentials=client_key) # add permission - store.add_permission(access) + stash.add_permission(access) - assert store.has_permission(access) - assert store.has_permission(access_root) + assert stash.has_permission(access) + assert stash.has_permission(access_root) - store.add_permission(read_permission) - assert store.has_permission(read_permission) + stash.add_permission(read_permission) + assert stash.has_permission(read_permission) # check that trying to get action data that doesn't exist returns an error, even if have permissions - res = store.get(permission_only_uid, client_key) - assert res.is_err() - - # add data - data_uid = UID() - obj = MockSyftObject(data=1) - - res = store.set(data_uid, client_key, obj, has_result_read_permission=True) - assert res.is_ok() - res = store.get(data_uid, client_key) - assert res.is_ok() - assert res.ok() == obj - - assert store.exists(data_uid) - res = store.delete(data_uid, client_key) - assert res.is_ok() - res = store.delete(data_uid, client_key) + stash.delete_by_uid(client_key, data_uid) + res = stash.get(data_uid, client_key) assert res.is_err() diff --git a/packages/syft/tests/syft/stores/mongo_document_store_test.py b/packages/syft/tests/syft/stores/mongo_document_store_test.py index 5af4c52d626..bb583be2865 100644 --- a/packages/syft/tests/syft/stores/mongo_document_store_test.py +++ b/packages/syft/tests/syft/stores/mongo_document_store_test.py @@ -1,6 +1,4 @@ # stdlib -from secrets import token_hex -from threading import Thread # third party import pytest @@ -14,11 +12,7 @@ from syft.service.action.action_store import ActionObjectEXECUTE from syft.service.action.action_store import ActionObjectREAD from syft.service.action.action_store import ActionObjectWRITE -from syft.store.document_store import PartitionSettings from syft.store.document_store import QueryKey -from syft.store.document_store import QueryKeys -from syft.store.mongo_client import MongoStoreClientConfig -from syft.store.mongo_document_store import MongoStoreConfig from syft.store.mongo_document_store import MongoStorePartition from syft.types.errors import SyftException from syft.types.uid import UID @@ -26,8 +20,6 @@ # relative from ...mongomock.collection import Collection as MongoCollection from .store_constants_test import TEST_VERIFY_KEY_STRING_HACKER -from .store_fixtures_test import mongo_store_partition_fn -from .store_mocks_test import MockObjectType from .store_mocks_test import MockSyftObject PERMISSIONS = [ @@ -48,554 +40,6 @@ def test_mongo_store_partition_sanity( assert hasattr(mongo_store_partition, "_permissions") -@pytest.mark.skip(reason="Test gets stuck at store.init_store()") -def test_mongo_store_partition_init_failed(root_verify_key) -> None: - # won't connect - mongo_config = MongoStoreClientConfig( - connectTimeoutMS=1, - timeoutMS=1, - ) - - store_config = MongoStoreConfig(client_config=mongo_config) - settings = PartitionSettings(name="test", object_type=MockObjectType) - - store = MongoStorePartition( - UID(), root_verify_key, settings=settings, store_config=store_config - ) - - res = store.init_store() - assert res.is_err() - - -def test_mongo_store_partition_set( - root_verify_key, mongo_store_partition: MongoStorePartition -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - - obj = MockSyftObject(data=1) - - res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - - assert res.is_ok() - assert res.ok() == obj - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_err() - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=True) - assert res.is_ok() - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - obj2 = MockSyftObject(data=2) - res = mongo_store_partition.set(root_verify_key, obj2, ignore_duplicates=False) - assert res.is_ok() - assert res.ok() == obj2 - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 2 - ) - - repeats = 5 - for idx in range(repeats): - obj = MockSyftObject(data=idx) - res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_ok() - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 3 + idx - ) - - -def test_mongo_store_partition_delete( - root_verify_key, - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - repeats = 5 - - objs = [] - for v in range(repeats): - obj = MockSyftObject(data=v) - mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - objs.append(obj) - - assert len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) == len(objs) - - # random object - obj = MockSyftObject(data="bogus") - key = mongo_store_partition.settings.store_key.with_obj(obj) - res = mongo_store_partition.delete(root_verify_key, key) - assert res.is_err() - assert len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) == len(objs) - - # cleanup store - for idx, v in enumerate(objs): - key = mongo_store_partition.settings.store_key.with_obj(v) - res = mongo_store_partition.delete(root_verify_key, key) - assert res.is_ok() - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == len(objs) - idx - 1 - ) - - res = mongo_store_partition.delete(root_verify_key, key) - assert res.is_err() - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == len(objs) - idx - 1 - ) - - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 0 - ) - - -def test_mongo_store_partition_update( - root_verify_key, - mongo_store_partition: MongoStorePartition, -) -> None: - mongo_store_partition.init_store() - - # add item - obj = MockSyftObject(data=1) - mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - # fail to update missing keys - rand_obj = MockSyftObject(data="bogus") - key = mongo_store_partition.settings.store_key.with_obj(rand_obj) - res = mongo_store_partition.update(root_verify_key, key, obj) - assert res.is_err() - - # update the key multiple times - repeats = 5 - for v in range(repeats): - key = mongo_store_partition.settings.store_key.with_obj(obj) - obj_new = MockSyftObject(data=v) - - res = mongo_store_partition.update(root_verify_key, key, obj_new) - assert res.is_ok() - - # The ID should stay the same on update, only the values are updated. - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - assert ( - mongo_store_partition.all( - root_verify_key, - ) - .ok()[0] - .id - == obj.id - ) - assert ( - mongo_store_partition.all( - root_verify_key, - ) - .ok()[0] - .id - != obj_new.id - ) - assert ( - mongo_store_partition.all( - root_verify_key, - ) - .ok()[0] - .data - == v - ) - - stored = mongo_store_partition.get_all_from_store( - root_verify_key, QueryKeys(qks=[key]) - ) - assert stored.ok()[0].data == v - - -def test_mongo_store_partition_set_threading(root_verify_key, mongo_client) -> None: - thread_cnt = 3 - repeats = 5 - - execution_err = None - mongo_db_name = token_hex(8) - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - - mongo_store_partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - for idx in range(repeats): - obj = MockObjectType(data=idx) - - for _ in range(10): - res = mongo_store_partition.set( - root_verify_key, obj, ignore_duplicates=False - ) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok(), res - - return execution_err - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - mongo_store_partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - stored_cnt = len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - assert stored_cnt == thread_cnt * repeats - - -# @pytest.mark.skip( -# reason="PicklingError: Could not pickle the task to send it to the workers." -# ) -# def test_mongo_store_partition_set_joblib( -# root_verify_key, -# mongo_client, -# ) -> None: -# thread_cnt = 3 -# repeats = 5 -# mongo_db_name = token_hex(8) - -# def _kv_cbk(tid: int) -> None: -# for idx in range(repeats): -# mongo_store_partition = mongo_store_partition_fn( -# mongo_client, -# root_verify_key, -# mongo_db_name=mongo_db_name, -# ) -# obj = MockObjectType(data=idx) - -# for _ in range(10): -# res = mongo_store_partition.set( -# root_verify_key, obj, ignore_duplicates=False -# ) -# if res.is_ok(): -# break - -# if res.is_err(): -# return res - -# return None - -# errs = Parallel(n_jobs=thread_cnt)( -# delayed(_kv_cbk)(idx) for idx in range(thread_cnt) -# ) - -# for execution_err in errs: -# assert execution_err is None - -# mongo_store_partition = mongo_store_partition_fn( -# mongo_client, -# root_verify_key, -# mongo_db_name=mongo_db_name, -# ) -# stored_cnt = len( -# mongo_store_partition.all( -# root_verify_key, -# ).ok() -# ) -# assert stored_cnt == thread_cnt * repeats - - -def test_mongo_store_partition_update_threading( - root_verify_key, - mongo_client, -) -> None: - thread_cnt = 3 - repeats = 5 - - mongo_db_name = token_hex(8) - mongo_store_partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - - obj = MockSyftObject(data=0) - key = mongo_store_partition.settings.store_key.with_obj(obj) - mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - - mongo_store_partition_local = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - for repeat in range(repeats): - obj = MockSyftObject(data=repeat) - - for _ in range(10): - res = mongo_store_partition_local.update(root_verify_key, key, obj) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok(), res - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - -# @pytest.mark.skip( -# reason="PicklingError: Could not pickle the task to send it to the workers." -# ) -# def test_mongo_store_partition_update_joblib(root_verify_key, mongo_client) -> None: -# thread_cnt = 3 -# repeats = 5 - -# mongo_db_name = token_hex(8) - -# mongo_store_partition = mongo_store_partition_fn( -# mongo_client, -# root_verify_key, -# mongo_db_name=mongo_db_name, -# ) -# obj = MockSyftObject(data=0) -# key = mongo_store_partition.settings.store_key.with_obj(obj) -# mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - -# def _kv_cbk(tid: int) -> None: -# mongo_store_partition_local = mongo_store_partition_fn( -# mongo_client, -# root_verify_key, -# mongo_db_name=mongo_db_name, -# ) -# for repeat in range(repeats): -# obj = MockSyftObject(data=repeat) - -# for _ in range(10): -# res = mongo_store_partition_local.update(root_verify_key, key, obj) -# if res.is_ok(): -# break - -# if res.is_err(): -# return res -# return None - -# errs = Parallel(n_jobs=thread_cnt)( -# delayed(_kv_cbk)(idx) for idx in range(thread_cnt) -# ) - -# for execution_err in errs: -# assert execution_err is None - - -def test_mongo_store_partition_set_delete_threading( - root_verify_key, - mongo_client, -) -> None: - thread_cnt = 3 - repeats = 5 - execution_err = None - mongo_db_name = token_hex(8) - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - mongo_store_partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - - for idx in range(repeats): - obj = MockSyftObject(data=idx) - - for _ in range(10): - res = mongo_store_partition.set( - root_verify_key, obj, ignore_duplicates=False - ) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - key = mongo_store_partition.settings.store_key.with_obj(obj) - - res = mongo_store_partition.delete(root_verify_key, key) - if res.is_err(): - execution_err = res - assert res.is_ok(), res - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - mongo_store_partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - stored_cnt = len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - assert stored_cnt == 0 - - -# @pytest.mark.skip( -# reason="PicklingError: Could not pickle the task to send it to the workers." -# ) -# def test_mongo_store_partition_set_delete_joblib(root_verify_key, mongo_client) -> None: -# thread_cnt = 3 -# repeats = 5 -# mongo_db_name = token_hex(8) - -# def _kv_cbk(tid: int) -> None: -# mongo_store_partition = mongo_store_partition_fn( -# mongo_client, root_verify_key, mongo_db_name=mongo_db_name -# ) - -# for idx in range(repeats): -# obj = MockSyftObject(data=idx) - -# for _ in range(10): -# res = mongo_store_partition.set( -# root_verify_key, obj, ignore_duplicates=False -# ) -# if res.is_ok(): -# break - -# if res.is_err(): -# return res - -# key = mongo_store_partition.settings.store_key.with_obj(obj) - -# res = mongo_store_partition.delete(root_verify_key, key) -# if res.is_err(): -# return res -# return None - -# errs = Parallel(n_jobs=thread_cnt)( -# delayed(_kv_cbk)(idx) for idx in range(thread_cnt) -# ) -# for execution_err in errs: -# assert execution_err is None - -# mongo_store_partition = mongo_store_partition_fn( -# mongo_client, -# root_verify_key, -# mongo_db_name=mongo_db_name, -# ) -# stored_cnt = len( -# mongo_store_partition.all( -# root_verify_key, -# ).ok() -# ) -# assert stored_cnt == 0 - - def test_mongo_store_partition_permissions_collection( mongo_store_partition: MongoStorePartition, ) -> None: @@ -829,64 +273,6 @@ def test_mongo_store_partition_has_permission( assert not mongo_store_partition.has_permission(permisson_hacker_2) -@pytest.mark.parametrize("permission", PERMISSIONS) -def test_mongo_store_partition_take_ownership( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, - permission: ActionObjectPermission, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - - hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - obj = MockSyftObject(data=1) - - # the guest client takes ownership of obj - mongo_store_partition.take_ownership( - uid=obj.id, credentials=guest_verify_key - ).unwrap() - assert mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=guest_verify_key) - ) - # the root client will also has the permission - assert mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=root_verify_key) - ) - assert not mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=hacker_verify_key) - ) - - # hacker or root try to take ownership of the obj and will fail - res = mongo_store_partition.take_ownership( - uid=obj.id, credentials=hacker_verify_key - ) - res_2 = mongo_store_partition.take_ownership( - uid=obj.id, credentials=root_verify_key - ) - assert res.is_err() - assert res_2.is_err() - assert ( - res.value.public_message - == res_2.value.public_message - == f"UID: {obj.id} already owned." - ) - - # another object - obj_2 = MockSyftObject(data=2) - # root client takes ownership - mongo_store_partition.take_ownership(uid=obj_2.id, credentials=root_verify_key) - assert mongo_store_partition.has_permission( - permission(uid=obj_2.id, credentials=root_verify_key) - ) - assert not mongo_store_partition.has_permission( - permission(uid=obj_2.id, credentials=guest_verify_key) - ) - assert not mongo_store_partition.has_permission( - permission(uid=obj_2.id, credentials=hacker_verify_key) - ) - - def test_mongo_store_partition_permissions_set( root_verify_key: SyftVerifyKey, guest_verify_key: SyftVerifyKey, diff --git a/packages/syft/tests/syft/stores/sqlite_document_store_test.py b/packages/syft/tests/syft/stores/sqlite_document_store_test.py deleted file mode 100644 index 46ee540aa9c..00000000000 --- a/packages/syft/tests/syft/stores/sqlite_document_store_test.py +++ /dev/null @@ -1,520 +0,0 @@ -# stdlib -from threading import Thread - -# third party -import pytest - -# syft absolute -from syft.store.document_store import QueryKeys -from syft.store.sqlite_document_store import SQLiteStorePartition - -# relative -from .store_fixtures_test import sqlite_store_partition_fn -from .store_mocks_test import MockObjectType -from .store_mocks_test import MockSyftObject - - -def test_sqlite_store_partition_sanity( - sqlite_store_partition: SQLiteStorePartition, -) -> None: - assert hasattr(sqlite_store_partition, "data") - assert hasattr(sqlite_store_partition, "unique_keys") - assert hasattr(sqlite_store_partition, "searchable_keys") - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_sqlite_store_partition_set( - root_verify_key, - sqlite_store_partition: SQLiteStorePartition, -) -> None: - obj = MockSyftObject(data=1) - res = sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - - assert res.is_ok() - assert res.ok() == obj - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - res = sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_err() - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - res = sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=True) - assert res.is_ok() - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - obj2 = MockSyftObject(data=2) - res = sqlite_store_partition.set(root_verify_key, obj2, ignore_duplicates=False) - assert res.is_ok() - assert res.ok() == obj2 - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == 2 - ) - repeats = 5 - for idx in range(repeats): - obj = MockSyftObject(data=idx) - res = sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_ok() - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == 3 + idx - ) - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_sqlite_store_partition_delete( - root_verify_key, - sqlite_store_partition: SQLiteStorePartition, -) -> None: - objs = [] - repeats = 5 - for v in range(repeats): - obj = MockSyftObject(data=v) - sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - objs.append(obj) - - assert len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) == len(objs) - - # random object - obj = MockSyftObject(data="bogus") - key = sqlite_store_partition.settings.store_key.with_obj(obj) - res = sqlite_store_partition.delete(root_verify_key, key) - assert res.is_err() - assert len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) == len(objs) - - # cleanup store - for idx, v in enumerate(objs): - key = sqlite_store_partition.settings.store_key.with_obj(v) - res = sqlite_store_partition.delete(root_verify_key, key) - assert res.is_ok() - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == len(objs) - idx - 1 - ) - - res = sqlite_store_partition.delete(root_verify_key, key) - assert res.is_err() - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == len(objs) - idx - 1 - ) - - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == 0 - ) - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_sqlite_store_partition_update( - root_verify_key, - sqlite_store_partition: SQLiteStorePartition, -) -> None: - # add item - obj = MockSyftObject(data=1) - sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - # fail to update missing keys - rand_obj = MockSyftObject(data="bogus") - key = sqlite_store_partition.settings.store_key.with_obj(rand_obj) - res = sqlite_store_partition.update(root_verify_key, key, obj) - assert res.is_err() - - # update the key multiple times - repeats = 5 - for v in range(repeats): - key = sqlite_store_partition.settings.store_key.with_obj(obj) - obj_new = MockSyftObject(data=v) - - res = sqlite_store_partition.update(root_verify_key, key, obj_new) - assert res.is_ok() - - # The ID should stay the same on update, unly the values are updated. - assert ( - len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - assert ( - sqlite_store_partition.all( - root_verify_key, - ) - .ok()[0] - .id - == obj.id - ) - assert ( - sqlite_store_partition.all( - root_verify_key, - ) - .ok()[0] - .id - != obj_new.id - ) - assert ( - sqlite_store_partition.all( - root_verify_key, - ) - .ok()[0] - .data - == v - ) - - stored = sqlite_store_partition.get_all_from_store( - root_verify_key, QueryKeys(qks=[key]) - ) - assert stored.ok()[0].data == v - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_sqlite_store_partition_set_threading( - sqlite_workspace: tuple, - root_verify_key, -) -> None: - thread_cnt = 3 - repeats = 5 - - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - - sqlite_store_partition = sqlite_store_partition_fn( - root_verify_key, sqlite_workspace - ) - for idx in range(repeats): - for _ in range(10): - obj = MockObjectType(data=idx) - res = sqlite_store_partition.set( - root_verify_key, obj, ignore_duplicates=False - ) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok(), res - - return execution_err - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - sqlite_store_partition = sqlite_store_partition_fn( - root_verify_key, sqlite_workspace - ) - stored_cnt = len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - assert stored_cnt == thread_cnt * repeats - - -# @pytest.mark.skip(reason="Joblib is flaky") -# def test_sqlite_store_partition_set_joblib( -# root_verify_key, -# sqlite_workspace: Tuple, -# ) -> None: -# thread_cnt = 3 -# repeats = 5 - -# def _kv_cbk(tid: int) -> None: -# for idx in range(repeats): -# sqlite_store_partition = sqlite_store_partition_fn( -# root_verify_key, sqlite_workspace -# ) -# obj = MockObjectType(data=idx) - -# for _ in range(10): -# res = sqlite_store_partition.set( -# root_verify_key, obj, ignore_duplicates=False -# ) -# if res.is_ok(): -# break - -# if res.is_err(): -# return res - -# return None - -# errs = Parallel(n_jobs=thread_cnt)( -# delayed(_kv_cbk)(idx) for idx in range(thread_cnt) -# ) - -# for execution_err in errs: -# assert execution_err is None - -# sqlite_store_partition = sqlite_store_partition_fn( -# root_verify_key, sqlite_workspace -# ) -# stored_cnt = len( -# sqlite_store_partition.all( -# root_verify_key, -# ).ok() -# ) -# assert stored_cnt == thread_cnt * repeats - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_sqlite_store_partition_update_threading( - root_verify_key, - sqlite_workspace: tuple, -) -> None: - thread_cnt = 3 - repeats = 5 - - sqlite_store_partition = sqlite_store_partition_fn( - root_verify_key, sqlite_workspace - ) - obj = MockSyftObject(data=0) - key = sqlite_store_partition.settings.store_key.with_obj(obj) - sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - - sqlite_store_partition_local = sqlite_store_partition_fn( - root_verify_key, sqlite_workspace - ) - for repeat in range(repeats): - obj = MockSyftObject(data=repeat) - - for _ in range(10): - res = sqlite_store_partition_local.update(root_verify_key, key, obj) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok(), res - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - -# @pytest.mark.skip(reason="Joblib is flaky") -# def test_sqlite_store_partition_update_joblib( -# root_verify_key, -# sqlite_workspace: Tuple, -# ) -> None: -# thread_cnt = 3 -# repeats = 5 - -# sqlite_store_partition = sqlite_store_partition_fn( -# root_verify_key, sqlite_workspace -# ) -# obj = MockSyftObject(data=0) -# key = sqlite_store_partition.settings.store_key.with_obj(obj) -# sqlite_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - -# def _kv_cbk(tid: int) -> None: -# sqlite_store_partition_local = sqlite_store_partition_fn( -# root_verify_key, sqlite_workspace -# ) -# for repeat in range(repeats): -# obj = MockSyftObject(data=repeat) - -# for _ in range(10): -# res = sqlite_store_partition_local.update(root_verify_key, key, obj) -# if res.is_ok(): -# break - -# if res.is_err(): -# return res -# return None - -# errs = Parallel(n_jobs=thread_cnt)( -# delayed(_kv_cbk)(idx) for idx in range(thread_cnt) -# ) - -# for execution_err in errs: -# assert execution_err is None - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_sqlite_store_partition_set_delete_threading( - root_verify_key, - sqlite_workspace: tuple, -) -> None: - thread_cnt = 3 - repeats = 5 - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - sqlite_store_partition = sqlite_store_partition_fn( - root_verify_key, sqlite_workspace - ) - - for idx in range(repeats): - obj = MockSyftObject(data=idx) - - for _ in range(10): - res = sqlite_store_partition.set( - root_verify_key, obj, ignore_duplicates=False - ) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - key = sqlite_store_partition.settings.store_key.with_obj(obj) - - res = sqlite_store_partition.delete(root_verify_key, key) - if res.is_err(): - execution_err = res - assert res.is_ok(), res - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - sqlite_store_partition = sqlite_store_partition_fn( - root_verify_key, sqlite_workspace - ) - stored_cnt = len( - sqlite_store_partition.all( - root_verify_key, - ).ok() - ) - assert stored_cnt == 0 - - -# @pytest.mark.skip(reason="Joblib is flaky") -# def test_sqlite_store_partition_set_delete_joblib( -# root_verify_key, -# sqlite_workspace: Tuple, -# ) -> None: -# thread_cnt = 3 -# repeats = 5 - -# def _kv_cbk(tid: int) -> None: -# sqlite_store_partition = sqlite_store_partition_fn( -# root_verify_key, sqlite_workspace -# ) - -# for idx in range(repeats): -# obj = MockSyftObject(data=idx) - -# for _ in range(10): -# res = sqlite_store_partition.set( -# root_verify_key, obj, ignore_duplicates=False -# ) -# if res.is_ok(): -# break - -# if res.is_err(): -# return res - -# key = sqlite_store_partition.settings.store_key.with_obj(obj) - -# res = sqlite_store_partition.delete(root_verify_key, key) -# if res.is_err(): -# return res -# return None - -# errs = Parallel(n_jobs=thread_cnt)( -# delayed(_kv_cbk)(idx) for idx in range(thread_cnt) -# ) -# for execution_err in errs: -# assert execution_err is None - -# sqlite_store_partition = sqlite_store_partition_fn( -# root_verify_key, sqlite_workspace -# ) -# stored_cnt = len( -# sqlite_store_partition.all( -# root_verify_key, -# ).ok() -# ) -# assert stored_cnt == 0 diff --git a/packages/syft/tests/syft/stores/store_fixtures_test.py b/packages/syft/tests/syft/stores/store_fixtures_test.py index cbbe0022f49..7225be989b7 100644 --- a/packages/syft/tests/syft/stores/store_fixtures_test.py +++ b/packages/syft/tests/syft/stores/store_fixtures_test.py @@ -1,20 +1,10 @@ # stdlib -from collections.abc import Generator -import os -from pathlib import Path -from secrets import token_hex -import tempfile import uuid -# third party -import pytest - # syft absolute from syft.server.credentials import SyftVerifyKey from syft.service.action.action_permissions import ActionObjectPermission from syft.service.action.action_permissions import ActionPermission -from syft.service.action.action_store import ActionObjectStash -from syft.service.queue.queue_stash import QueueStash from syft.service.user.user import User from syft.service.user.user import UserCreate from syft.service.user.user_roles import ServiceRole @@ -22,37 +12,11 @@ from syft.store.db.sqlite_db import SQLiteDBConfig from syft.store.db.sqlite_db import SQLiteDBManager from syft.store.document_store import DocumentStore -from syft.store.document_store import PartitionSettings -from syft.store.locks import LockingConfig -from syft.store.locks import NoLockingConfig -from syft.store.locks import ThreadingLockingConfig -from syft.store.sqlite_document_store import SQLiteDocumentStore -from syft.store.sqlite_document_store import SQLiteStoreClientConfig -from syft.store.sqlite_document_store import SQLiteStoreConfig -from syft.store.sqlite_document_store import SQLiteStorePartition from syft.types.uid import UID # relative from .store_constants_test import TEST_SIGNING_KEY_NEW_ADMIN from .store_constants_test import TEST_VERIFY_KEY_NEW_ADMIN -from .store_constants_test import TEST_VERIFY_KEY_STRING_ROOT -from .store_mocks_test import MockObjectType - -MONGO_CLIENT_CACHE = None - -locking_scenarios = [ - "nop", - "threading", -] - - -def str_to_locking_config(conf: str) -> LockingConfig: - if conf == "nop": - return NoLockingConfig() - elif conf == "threading": - return ThreadingLockingConfig() - else: - raise NotImplementedError(f"unknown locking config {conf}") def document_store_with_admin( @@ -88,142 +52,3 @@ def document_store_with_admin( ) return document_store - - -@pytest.fixture(scope="function") -def sqlite_workspace() -> Generator: - sqlite_db_name = token_hex(8) + ".sqlite" - root = os.getenv("SYFT_TEMP_ROOT", "syft") - sqlite_workspace_folder = Path( - tempfile.gettempdir(), root, "fixture_sqlite_workspace" - ) - sqlite_workspace_folder.mkdir(parents=True, exist_ok=True) - - db_path = sqlite_workspace_folder / sqlite_db_name - - if db_path.exists(): - db_path.unlink() - - yield sqlite_workspace_folder, sqlite_db_name - - try: - db_path.exists() and db_path.unlink() - except BaseException as e: - print("failed to cleanup sqlite db", e) - - -def sqlite_store_partition_fn( - root_verify_key, - sqlite_workspace: tuple[Path, str], - locking_config_name: str = "nop", -): - workspace, db_name = sqlite_workspace - sqlite_config = SQLiteStoreClientConfig(filename=db_name, path=workspace) - - locking_config = str_to_locking_config(locking_config_name) - store_config = SQLiteStoreConfig( - client_config=sqlite_config, locking_config=locking_config - ) - - settings = PartitionSettings(name="test", object_type=MockObjectType) - - store = SQLiteStorePartition( - UID(), root_verify_key, settings=settings, store_config=store_config - ) - - store.init_store().unwrap() - - return store - - -@pytest.fixture(scope="function", params=locking_scenarios) -def sqlite_store_partition( - root_verify_key, sqlite_workspace: tuple[Path, str], request -): - locking_config_name = request.param - store = sqlite_store_partition_fn( - root_verify_key, sqlite_workspace, locking_config_name=locking_config_name - ) - - yield store - - -def sqlite_document_store_fn( - root_verify_key, - sqlite_workspace: tuple[Path, str], - locking_config_name: str = "nop", -): - workspace, db_name = sqlite_workspace - sqlite_config = SQLiteStoreClientConfig(filename=db_name, path=workspace) - - locking_config = str_to_locking_config(locking_config_name) - store_config = SQLiteStoreConfig( - client_config=sqlite_config, locking_config=locking_config - ) - - return SQLiteDocumentStore(UID(), root_verify_key, store_config=store_config) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def sqlite_document_store(root_verify_key, sqlite_workspace: tuple[Path, str], request): - locking_config_name = request.param - store = sqlite_document_store_fn( - root_verify_key, sqlite_workspace, locking_config_name=locking_config_name - ) - yield store - - -def sqlite_queue_stash_fn( - root_verify_key, - sqlite_workspace: tuple[Path, str], - locking_config_name: str = "threading", -): - store = sqlite_document_store_fn( - root_verify_key, - sqlite_workspace, - locking_config_name=locking_config_name, - ) - return QueueStash(store=store) - - -@pytest.fixture( - scope="function", - params=[ - "tODOsqlite_address", - "TODOpostgres_address", - ], -) -def queue_stash(request): - _ = request.param - stash = QueueStash.random() - yield stash - - -@pytest.fixture(scope="function", params=locking_scenarios) -def sqlite_action_store(sqlite_workspace: tuple[Path, str], request): - workspace, db_name = sqlite_workspace - locking_config_name = request.param - - sqlite_config = SQLiteStoreClientConfig(filename=db_name, path=workspace) - - locking_config = str_to_locking_config(locking_config_name) - store_config = SQLiteStoreConfig( - client_config=sqlite_config, - locking_config=locking_config, - ) - - ver_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT) - - server_uid = UID() - document_store = document_store_with_admin(server_uid, ver_key) - - yield ActionObjectStash( - server_uid=server_uid, - store_config=store_config, - root_verify_key=ver_key, - document_store=document_store, - ) - - -def mongo_queue_stash(mongo_document_store): - return QueueStash(store=mongo_document_store) diff --git a/packages/syft/tests/syft/worker_test.py b/packages/syft/tests/syft/worker_test.py index 4d9967344e2..d6cfcfaac5a 100644 --- a/packages/syft/tests/syft/worker_test.py +++ b/packages/syft/tests/syft/worker_test.py @@ -24,6 +24,7 @@ from syft.service.user.user import UserCreate from syft.service.user.user import UserView from syft.service.user.user_service import UserService +from syft.service.user.user_stash import UserStash from syft.store.db.sqlite_db import SQLiteDBManager from syft.types.errors import SyftException from syft.types.result import Ok @@ -76,11 +77,18 @@ def test_signing_key() -> None: assert test_verify_key == test_verify_key_2 -@pytest.fixture +@pytest.fixture( + scope="function", + params=[ + "tODOsqlite_address", + "TODOpostgres_address", + ], +) def action_object_stash() -> ActionObjectStash: root_verify_key = SyftVerifyKey.from_string(test_verify_key_string) db_manager = SQLiteDBManager.random(root_verify_key=root_verify_key) stash = ActionObjectStash(store=db_manager) + _ = UserStash(store=db_manager) stash.db.init_tables() yield stash From 08ba5b2b570c8ca2b431a724e7ad1246a9d2e92e Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 9 Sep 2024 20:33:20 +0200 Subject: [PATCH 104/197] fix worker pool tests --- packages/syft/src/syft/serde/json_serde.py | 2 +- .../syft/stores/mongo_document_store_test.py | 24 ------------------- 2 files changed, 1 insertion(+), 25 deletions(-) diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index 301788ff849..42d5ee1b2f3 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -184,7 +184,7 @@ def _serialize_pydantic_to_json(obj: pydantic.BaseModel) -> dict[str, Json]: except Exception as e: raise ValueError(f"Failed to serialize attribute {key}: {e}") - result = _serialize_searchable_attrs(obj, result) + result = _serialize_searchable_attrs(obj, result, raise_errors=False) return result diff --git a/packages/syft/tests/syft/stores/mongo_document_store_test.py b/packages/syft/tests/syft/stores/mongo_document_store_test.py index bb583be2865..a1fbee484da 100644 --- a/packages/syft/tests/syft/stores/mongo_document_store_test.py +++ b/packages/syft/tests/syft/stores/mongo_document_store_test.py @@ -30,28 +30,6 @@ ] -def test_mongo_store_partition_sanity( - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - - assert hasattr(mongo_store_partition, "_collection") - assert hasattr(mongo_store_partition, "_permissions") - - -def test_mongo_store_partition_permissions_collection( - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - - collection_permissions_status = mongo_store_partition.permissions - assert not collection_permissions_status.is_err() - collection_permissions = collection_permissions_status.ok() - assert isinstance(collection_permissions, MongoCollection) - - def test_mongo_store_partition_add_remove_permission( root_verify_key: SyftVerifyKey, mongo_store_partition: MongoStorePartition ) -> None: @@ -59,8 +37,6 @@ def test_mongo_store_partition_add_remove_permission( Test the add_permission and remove_permission functions of MongoStorePartition """ # setting up - res = mongo_store_partition.init_store() - assert res.is_ok() permissions_collection: MongoCollection = mongo_store_partition.permissions.ok() obj = MockSyftObject(data=1) From 726d8d10160f0aa0350efa3f5172d764b5b10b28 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 9 Sep 2024 21:06:13 +0200 Subject: [PATCH 105/197] fix some tests --- .../src/syft/service/user/user_service.py | 14 ++-- .../syft/settings/settings_service_test.py | 64 +++---------------- .../tests/syft/stores/action_store_test.py | 3 + 3 files changed, 17 insertions(+), 64 deletions(-) diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index ec23aa6e36f..bbae67a2e48 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -325,15 +325,11 @@ def get_all( page_size: int | None = 0, page_index: int | None = 0, ) -> list[UserView]: - if context.role in [ServiceRole.DATA_OWNER, ServiceRole.ADMIN]: - users = self.stash.get_all( - context.credentials, - has_permission=True, - order_by=order_by, - sort_order=sort_order, - ).unwrap() - else: - users = self.stash.get_all(context.credentials).unwrap() + users = self.stash.get_all( + context.credentials, + order_by=order_by, + sort_order=sort_order, + ).unwrap() users = [user.to(UserView) for user in users] return _paginate(users, page_size, page_index) diff --git a/packages/syft/tests/syft/settings/settings_service_test.py b/packages/syft/tests/syft/settings/settings_service_test.py index aaa7b0460fc..7d0500f8247 100644 --- a/packages/syft/tests/syft/settings/settings_service_test.py +++ b/packages/syft/tests/syft/settings/settings_service_test.py @@ -34,7 +34,6 @@ from syft.store.document_store_errors import NotFoundException from syft.store.document_store_errors import StashException from syft.types.errors import SyftException -from syft.types.result import Ok from syft.types.result import as_result @@ -100,37 +99,21 @@ def test_settingsservice_set_success( ) -> None: response = settings_service.set(authed_context, settings) assert isinstance(response, ServerSettings) + # PR NOTE do we write syft_client_verify_key and syft_server_location to the stash or not? + response.syft_client_verify_key = None + response.syft_server_location = None + response.pwd_token_config.syft_client_verify_key = None + response.pwd_token_config.syft_server_location = None assert response == settings -def test_settingsservice_set_fail( - monkeypatch: MonkeyPatch, - settings_service: SettingsService, - settings: ServerSettings, - authed_context: AuthedServiceContext, -) -> None: - mock_error_message = "database failure" - - @as_result(StashException) - def mock_stash_set_error(credentials, settings: ServerSettings) -> NoReturn: - raise StashException(public_message=mock_error_message) - - monkeypatch.setattr(settings_service.stash, "set", mock_stash_set_error) - - with pytest.raises(StashException) as exc: - settings_service.set(authed_context, settings) - - assert exc.type == StashException - assert exc.value.public_message == mock_error_message - - def add_mock_settings( root_verify_key: SyftVerifyKey, settings_stash: SettingsStash, settings: ServerSettings, ) -> ServerSettings: # create a mock settings in the stash so that we can update it - result = settings_stash.partition.set(root_verify_key, settings) + result = settings_stash.set(root_verify_key, settings) assert result.is_ok() created_settings = result.ok() @@ -150,9 +133,7 @@ def test_settingsservice_update_success( notifier_stash: NotifierStash, ) -> None: # add a mock settings to the stash - mock_settings = add_mock_settings( - authed_context.credentials, settings_stash, settings - ) + mock_settings = settings_stash.set(authed_context.credentials, settings).unwrap() # get a new settings according to update_settings new_settings = deepcopy(settings) @@ -164,14 +145,6 @@ def test_settingsservice_update_success( assert new_settings != mock_settings assert mock_settings == settings - mock_stash_get_all_output = [mock_settings, mock_settings] - - def mock_stash_get_all(root_verify_key) -> Ok: - return Ok(mock_stash_get_all_output) - - monkeypatch.setattr(settings_service.stash, "get_all", mock_stash_get_all) - - # Mock the get_service method to return a mocked notifier_service with the notifier_stash class MockNotifierService: def __init__(self, stash): self.stash = stash @@ -194,12 +167,7 @@ def mock_get_service(service_name: str): # update the settings in the settings stash using settings_service response = settings_service.update(context=authed_context, settings=update_settings) - # not_updated_settings = response.ok()[1] - assert isinstance(response, SyftSuccess) - # assert ( - # not_updated_settings.to_dict() == settings.to_dict() - # ) # the second settings is not updated def test_settingsservice_update_stash_get_all_fail( @@ -208,19 +176,7 @@ def test_settingsservice_update_stash_get_all_fail( update_settings: ServerSettingsUpdate, authed_context: AuthedServiceContext, ) -> None: - mock_error_message = "database failure" - - @as_result(StashException) - def mock_stash_get_all_error(credentials) -> NoReturn: - raise StashException(public_message=mock_error_message) - - monkeypatch.setattr(settings_service.stash, "get_all", mock_stash_get_all_error) - - with pytest.raises(StashException) as exc: - settings_service.update(context=authed_context, settings=update_settings) - - assert exc.type == StashException - assert exc.value.public_message == mock_error_message + settings_service.update(context=authed_context, settings=update_settings) def test_settingsservice_update_stash_empty( @@ -230,9 +186,7 @@ def test_settingsservice_update_stash_empty( ) -> None: with pytest.raises(NotFoundException) as exc: settings_service.update(context=authed_context, settings=update_settings) - - assert exc.type == NotFoundException - assert exc.value.public_message == "Server settings not found" + assert exc.value.public_message == "Server settings not found" def test_settingsservice_update_fail( diff --git a/packages/syft/tests/syft/stores/action_store_test.py b/packages/syft/tests/syft/stores/action_store_test.py index 239ec8d33ba..e0032877fcd 100644 --- a/packages/syft/tests/syft/stores/action_store_test.py +++ b/packages/syft/tests/syft/stores/action_store_test.py @@ -19,6 +19,9 @@ from syft.store.db.sqlite_db import DBManager from syft.types.uid import UID +# first party +from packages.syft.tests.syft.worker_test import action_object_stash # noqa: F401 + permissions = [ ActionObjectOWNER, ActionObjectREAD, From 77c8ac85e2d41a9fae94144b7827ddd401ed0dff Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Tue, 10 Sep 2024 09:53:38 +0530 Subject: [PATCH 106/197] scope the use of cursor and db connection --- packages/syft/setup.cfg | 1 + .../syft/store/postgresql_document_store.py | 101 +++++++++--------- 2 files changed, 49 insertions(+), 53 deletions(-) diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 5f6d7c4f8c9..4cc1ae70619 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -67,6 +67,7 @@ syft = tenacity==8.3.0 nh3==0.2.17 psycopg[binary]==3.1.19 + psycopg[pool]==3.1.19 ipython<8.27.0 dynaconf==3.2.6 diff --git a/packages/syft/src/syft/store/postgresql_document_store.py b/packages/syft/src/syft/store/postgresql_document_store.py index ebaf4c32894..8b47c960782 100644 --- a/packages/syft/src/syft/store/postgresql_document_store.py +++ b/packages/syft/src/syft/store/postgresql_document_store.py @@ -1,5 +1,4 @@ # stdlib -from collections import defaultdict import logging from typing import Any @@ -35,7 +34,6 @@ logger = logging.getLogger(__name__) _CONNECTION_POOL_DB: dict[str, Connection] = {} -REF_COUNTS: dict[str, int] = defaultdict(int) # https://www.psycopg.org/docs/module.html#psycopg2.connect @@ -91,7 +89,6 @@ def __init__( self.lock = SyftLock(NoLockingConfig()) self.create_table() - REF_COUNTS[cache_key(self.dbname)] += 1 self.subs_char = r"%s" # thanks postgresql def _connect(self) -> None: @@ -103,24 +100,28 @@ def _connect(self) -> None: host=self.store_config.client_config.host, port=self.store_config.client_config.port, ) - print(f"Connected to {self.store_config.client_config.dbname}") - print("PostgreSQL database connection:", connection) - _CONNECTION_POOL_DB[cache_key(self.dbname)] = connection + print(f"Connected to {self.store_config.client_config.dbname}") + print( + "PostgreSQL database connection:", + _CONNECTION_POOL_DB[cache_key(self.dbname)], + ) def create_table(self) -> None: + db = self.db try: with self.lock: - self.cur.execute( - f"CREATE TABLE IF NOT EXISTS {self.table_name} (uid VARCHAR(32) NOT NULL PRIMARY KEY, " # nosec - + "repr TEXT NOT NULL, value BYTEA NOT NULL, " # nosec - + "sqltime TIMESTAMP NOT NULL DEFAULT NOW())" # nosec - ) - self.db.commit() + with db.cursor() as cur: + cur.execute( + f"CREATE TABLE IF NOT EXISTS {self.table_name} (uid VARCHAR(32) NOT NULL PRIMARY KEY, " # nosec + + "repr TEXT NOT NULL, value BYTEA NOT NULL, " # nosec + + "sqltime TIMESTAMP NOT NULL DEFAULT NOW())" # nosec + ) + cur.connection.commit() except DuplicateTable: pass except InFailedSqlTransaction: - self.db.rollback() + db.rollback() except Exception as e: public_message = special_exception_public_message(self.table_name, e) raise SyftException.from_exception(e, public_message=public_message) @@ -135,44 +136,32 @@ def db(self) -> Connection: def cur(self) -> Cursor: return self.db.cursor() - def _close(self) -> None: - self._commit() - REF_COUNTS[cache_key(self.store_config_hash)] -= 1 - if REF_COUNTS[cache_key(self.store_config_hash)] <= 0: - # once you close it seems like other object references can't re-use the - # same connection - - self.db.close() - del _CONNECTION_POOL_DB[cache_key(self.store_config_hash)] - else: - # don't close yet because another SQLiteBackingStore is probably still open - pass - @staticmethod @as_result(SyftException) def _execute( lock: SyftLock, cursor: Cursor, - db: Connection, table_name: str, sql: str, args: list[Any] | None, ) -> Cursor: - try: - cursor.execute(sql, args) # Execute the SQL with arguments - db.commit() # Commit if everything went ok - except InFailedSqlTransaction as ie: - db.rollback() # Rollback if something went wrong - raise SyftException( - public_message=f"Transaction `{sql}` failed and was rolled back. \n" - f"Error: {ie}." - ) - except Exception as e: - logger.debug(f"Rolling back SQL: {sql} with args: {args}") - db.rollback() # Rollback on any other exception to maintain clean state - public_message = special_exception_public_message(table_name, e) - logger.error(public_message) - raise SyftException.from_exception(e, public_message=public_message) + with lock: + db = cursor.connection + try: + cursor.execute(sql, args) # Execute the SQL with arguments + db.commit() # Commit if everything went ok + except InFailedSqlTransaction as ie: + db.rollback() # Rollback if something went wrong + raise SyftException( + public_message=f"Transaction `{sql}` failed and was rolled back. \n" + f"Error: {ie}." + ) + except Exception as e: + logger.debug(f"Rolling back SQL: {sql} with args: {args}") + db.rollback() # Rollback on any other exception to maintain clean state + public_message = special_exception_public_message(table_name, e) + logger.error(public_message) + raise SyftException.from_exception(e, public_message=public_message) return cursor def _set(self, key: UID, value: Any) -> None: @@ -188,7 +177,6 @@ def _set(self, key: UID, value: Any) -> None: self._execute( self.lock, cur, - self.db, self.table_name, insert_sql, [str(key), _repr_debug_(value), data], @@ -205,7 +193,6 @@ def _update(self, key: UID, value: Any) -> None: self._execute( self.lock, cur, - self.db, self.table_name, insert_sql, [str(key), _repr_debug_(value), data, str(key)], @@ -218,7 +205,7 @@ def _get(self, key: UID) -> Any: ) with self.cur as cur: cursor = self._execute( - self.lock, cur, self.db, self.table_name, select_sql, [str(key)] + self.lock, cur, self.table_name, select_sql, [str(key)] ).unwrap(public_message=f"Query {select_sql} failed") row = cursor.fetchone() if row is None or len(row) == 0: @@ -231,7 +218,7 @@ def _exists(self, key: UID) -> bool: row = None with self.cur as cur: cursor = self._execute( - self.lock, cur, self.db, self.table_name, select_sql, [str(key)] + self.lock, cur, self.table_name, select_sql, [str(key)] ).unwrap() row = cursor.fetchone() # type: ignore if row is None: @@ -244,7 +231,7 @@ def _get_all(self) -> Any: data = [] with self.cur as cur: cursor = self._execute( - self.lock, cur, self.db, self.table_name, select_sql, [] + self.lock, cur, self.table_name, select_sql, [] ).unwrap() rows = cursor.fetchall() # type: ignore if not rows: @@ -260,7 +247,7 @@ def _get_all_keys(self) -> Any: select_sql = f"select uid from {self.table_name} order by sqltime" # nosec with self.cur as cur: cursor = self._execute( - self.lock, cur, self.db, self.table_name, select_sql, [] + self.lock, cur, self.table_name, select_sql, [] ).unwrap() rows = cursor.fetchall() # type: ignore if not rows: @@ -272,25 +259,33 @@ def _delete(self, key: UID) -> None: select_sql = f"delete from {self.table_name} where uid = {self.subs_char}" # nosec with self.cur as cur: self._execute( - self.lock, cur, self.db, self.table_name, select_sql, [str(key)] + self.lock, cur, self.table_name, select_sql, [str(key)] ).unwrap() def _delete_all(self) -> None: select_sql = f"delete from {self.table_name}" # nosec with self.cur as cur: - self._execute( - self.lock, cur, self.db, self.table_name, select_sql, [] - ).unwrap() + self._execute(self.lock, cur, self.table_name, select_sql, []).unwrap() def _len(self) -> int: select_sql = f"select count(uid) from {self.table_name}" # nosec with self.cur as cur: cursor = self._execute( - self.lock, cur, self.db, self.table_name, select_sql, [] + self.lock, cur, self.table_name, select_sql, [] ).unwrap() cnt = cursor.fetchone()[0] return cnt + def _close(self) -> None: + self._commit() + if cache_key(self.dbname) in _CONNECTION_POOL_DB: + conn = _CONNECTION_POOL_DB[cache_key(self.dbname)] + conn.close() + _CONNECTION_POOL_DB.pop(cache_key(self.dbname), None) + + def _commit(self) -> None: + self.db.commit() + @serializable() class PostgreSQLStoreConfig(StoreConfig): From 4fd6070b866e4c7f8098cf2d9288dc5f14c35c77 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Tue, 10 Sep 2024 10:22:54 +0530 Subject: [PATCH 107/197] fix lint --- .../syft/store/postgresql_document_store.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/packages/syft/src/syft/store/postgresql_document_store.py b/packages/syft/src/syft/store/postgresql_document_store.py index 8b47c960782..1d6896c8fc3 100644 --- a/packages/syft/src/syft/store/postgresql_document_store.py +++ b/packages/syft/src/syft/store/postgresql_document_store.py @@ -141,12 +141,12 @@ def cur(self) -> Cursor: def _execute( lock: SyftLock, cursor: Cursor, + db: Connection, table_name: str, sql: str, args: list[Any] | None, ) -> Cursor: with lock: - db = cursor.connection try: cursor.execute(sql, args) # Execute the SQL with arguments db.commit() # Commit if everything went ok @@ -177,6 +177,7 @@ def _set(self, key: UID, value: Any) -> None: self._execute( self.lock, cur, + cur.connection, self.table_name, insert_sql, [str(key), _repr_debug_(value), data], @@ -193,6 +194,7 @@ def _update(self, key: UID, value: Any) -> None: self._execute( self.lock, cur, + cur.connection, self.table_name, insert_sql, [str(key), _repr_debug_(value), data, str(key)], @@ -205,7 +207,7 @@ def _get(self, key: UID) -> Any: ) with self.cur as cur: cursor = self._execute( - self.lock, cur, self.table_name, select_sql, [str(key)] + self.lock, cur, cur.connection, self.table_name, select_sql, [str(key)] ).unwrap(public_message=f"Query {select_sql} failed") row = cursor.fetchone() if row is None or len(row) == 0: @@ -218,7 +220,7 @@ def _exists(self, key: UID) -> bool: row = None with self.cur as cur: cursor = self._execute( - self.lock, cur, self.table_name, select_sql, [str(key)] + self.lock, cur, cur.connection, self.table_name, select_sql, [str(key)] ).unwrap() row = cursor.fetchone() # type: ignore if row is None: @@ -231,7 +233,7 @@ def _get_all(self) -> Any: data = [] with self.cur as cur: cursor = self._execute( - self.lock, cur, self.table_name, select_sql, [] + self.lock, cur, cur.connection, self.table_name, select_sql, [] ).unwrap() rows = cursor.fetchall() # type: ignore if not rows: @@ -247,7 +249,7 @@ def _get_all_keys(self) -> Any: select_sql = f"select uid from {self.table_name} order by sqltime" # nosec with self.cur as cur: cursor = self._execute( - self.lock, cur, self.table_name, select_sql, [] + self.lock, cur, cur.connection, self.table_name, select_sql, [] ).unwrap() rows = cursor.fetchall() # type: ignore if not rows: @@ -265,13 +267,15 @@ def _delete(self, key: UID) -> None: def _delete_all(self) -> None: select_sql = f"delete from {self.table_name}" # nosec with self.cur as cur: - self._execute(self.lock, cur, self.table_name, select_sql, []).unwrap() + self._execute( + self.lock, cur, cur.connection, self.table_name, select_sql, [] + ).unwrap() def _len(self) -> int: select_sql = f"select count(uid) from {self.table_name}" # nosec with self.cur as cur: cursor = self._execute( - self.lock, cur, self.table_name, select_sql, [] + self.lock, cur, cur.connection, self.table_name, select_sql, [] ).unwrap() cnt = cursor.fetchone()[0] return cnt From da16fcff9651f886b8ab31c5c684fff1b0fdbcef Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Tue, 10 Sep 2024 10:46:38 +0530 Subject: [PATCH 108/197] fix _delete args in postgres sqlstore --- packages/syft/src/syft/store/postgresql_document_store.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/store/postgresql_document_store.py b/packages/syft/src/syft/store/postgresql_document_store.py index 1d6896c8fc3..2a8fc2863de 100644 --- a/packages/syft/src/syft/store/postgresql_document_store.py +++ b/packages/syft/src/syft/store/postgresql_document_store.py @@ -261,7 +261,7 @@ def _delete(self, key: UID) -> None: select_sql = f"delete from {self.table_name} where uid = {self.subs_char}" # nosec with self.cur as cur: self._execute( - self.lock, cur, self.table_name, select_sql, [str(key)] + self.lock, cur, cur.connection, self.table_name, select_sql, [str(key)] ).unwrap() def _delete_all(self) -> None: From d3e6e2b7a1b5e2090268896e0ff4987fb0962c5e Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 10 Sep 2024 09:07:03 +0200 Subject: [PATCH 109/197] fix iter and order --- packages/syft/src/syft/client/api.py | 5 +++++ packages/syft/src/syft/service/user/user.py | 1 + packages/syft/tests/syft/api_test.py | 2 +- .../tests/syft/stores/action_store_test.py | 4 ++-- .../tests/syft/users/user_service_test.py | 21 ++----------------- packages/syft/tests/syft/users/user_test.py | 1 + 6 files changed, 12 insertions(+), 22 deletions(-) diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index 21c8afe0557..85fe33545fa 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -743,6 +743,11 @@ def __getitem__(self, key: str | int) -> Any: return self.get_all()[key] raise NotImplementedError + def __iter__(self) -> Any: + if hasattr(self, "get_all"): + return iter(self.get_all()) + raise NotImplementedError + def _repr_html_(self) -> Any: if self.path == "settings": return self.get()._repr_html_() diff --git a/packages/syft/src/syft/service/user/user.py b/packages/syft/src/syft/service/user/user.py index 6c24554478f..f2421173af4 100644 --- a/packages/syft/src/syft/service/user/user.py +++ b/packages/syft/src/syft/service/user/user.py @@ -105,6 +105,7 @@ class User(SyftObject): __attr_searchable__ = ["name", "email", "verify_key", "role", "reset_token"] __attr_unique__ = ["email", "signing_key", "verify_key"] __repr_attrs__ = ["name", "email"] + __order_by__ = ("created_at", "asc") @migrate(UserV1, User) diff --git a/packages/syft/tests/syft/api_test.py b/packages/syft/tests/syft/api_test.py index 67febf838b0..610c1acd154 100644 --- a/packages/syft/tests/syft/api_test.py +++ b/packages/syft/tests/syft/api_test.py @@ -45,7 +45,7 @@ def test_api_cache_invalidation_login(root_verify_key, worker): name="q", email="a@b.org", password="aaa", password_verify="aaa" ) guest_client = guest_client.login(email="a@b.org", password="aaa") - user_id = worker.document_store.partitions["User"].all(root_verify_key).value[-1].id + user_id = worker.root_client.users[-1].id def get_role(verify_key): users = worker.get_service("UserService").stash.get_all(root_verify_key).ok() diff --git a/packages/syft/tests/syft/stores/action_store_test.py b/packages/syft/tests/syft/stores/action_store_test.py index e0032877fcd..d40fa15e8ca 100644 --- a/packages/syft/tests/syft/stores/action_store_test.py +++ b/packages/syft/tests/syft/stores/action_store_test.py @@ -19,8 +19,8 @@ from syft.store.db.sqlite_db import DBManager from syft.types.uid import UID -# first party -from packages.syft.tests.syft.worker_test import action_object_stash # noqa: F401 +# relative +from ..worker_test import action_object_stash # noqa: F401 permissions = [ ActionObjectOWNER, diff --git a/packages/syft/tests/syft/users/user_service_test.py b/packages/syft/tests/syft/users/user_service_test.py index 45e31da18fe..d97c6ab601e 100644 --- a/packages/syft/tests/syft/users/user_service_test.py +++ b/packages/syft/tests/syft/users/user_service_test.py @@ -222,23 +222,6 @@ def mock_get_all(credentials: SyftVerifyKey) -> list[User]: ) -def test_userservice_get_all_error( - monkeypatch: MonkeyPatch, - user_service: UserService, - authed_context: AuthedServiceContext, -) -> None: - @as_result(StashException) - def mock_get_all(credentials: SyftVerifyKey) -> NoReturn: - raise StashException - - monkeypatch.setattr(user_service.stash, "get_all", mock_get_all) - - with pytest.raises(StashException) as exc: - user_service.get_all(authed_context) - - assert exc.type == StashException - - def test_userservice_search( monkeypatch: MonkeyPatch, user_service: UserService, @@ -246,13 +229,13 @@ def test_userservice_search( guest_user: User, ) -> None: @as_result(SyftException) - def mock_find_all(credentials: SyftVerifyKey, **kwargs) -> list[User]: + def get_all(credentials: SyftVerifyKey, **kwargs) -> list[User]: for key in kwargs.keys(): if hasattr(guest_user, key): return [guest_user] return [] - monkeypatch.setattr(user_service.stash, "find_all", mock_find_all) + monkeypatch.setattr(user_service.stash, "get_all", get_all) expected_output = [guest_user.to(UserView)] diff --git a/packages/syft/tests/syft/users/user_test.py b/packages/syft/tests/syft/users/user_test.py index 99d525fe857..6bbf94dd42e 100644 --- a/packages/syft/tests/syft/users/user_test.py +++ b/packages/syft/tests/syft/users/user_test.py @@ -424,6 +424,7 @@ def test_user_view_set_role(worker: Worker, guest_client: DatasiteClient) -> Non with pytest.raises(SyftException): ds_client.account.update(role="guest") + with pytest.raises(SyftException): ds_client.account.update(role="data_scientist") # now we set sheldon's role to admin. Only now he can change his role From eb3501b8aef58ae838c18590f4cf7780e9fbc3d1 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 10 Sep 2024 10:29:49 +0200 Subject: [PATCH 110/197] WIP migrationservice --- packages/syft/src/syft/server/server.py | 10 +++ .../syft/src/syft/server/service_registry.py | 11 +++ .../service/migration/migration_service.py | 82 ++++++------------- 3 files changed, 47 insertions(+), 56 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 89691cd0bc4..577eb80eb67 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -17,6 +17,7 @@ from time import sleep import traceback from typing import Any +from typing import TypeVar from typing import cast # third party @@ -92,6 +93,7 @@ from ..store.db.sqlite_db import DBConfig from ..store.db.sqlite_db import SQLiteDBConfig from ..store.db.sqlite_db import SQLiteDBManager +from ..store.db.stash import ObjectStash from ..store.document_store import StoreConfig from ..store.document_store_errors import NotFoundException from ..store.document_store_errors import StashException @@ -126,6 +128,8 @@ logger = logging.getLogger(__name__) +SyftT = TypeVar("SyftT", bound=SyftObject) + # if user code needs to be serded and its not available we can call this to refresh # the code for a specific server UID and thread CODE_RELOADER: dict[int, Callable] = {} @@ -917,6 +921,12 @@ def get_service_method(self, path_or_func: str | Callable) -> Callable: def get_service(self, path_or_func: str | Callable) -> AbstractService: return self.services.get_service(path_or_func) + @as_result(ValueError) + def get_stash(self, object_type: SyftT) -> ObjectStash[SyftT]: + if object_type not in self.services.stashes: + raise ValueError(f"Stash for {object_type} not found.") + return self.services.stashes[object_type] + def _get_service_method_from_path(self, path: str) -> Callable: path_list = path.split(".") method_name = path_list.pop() diff --git a/packages/syft/src/syft/server/service_registry.py b/packages/syft/src/syft/server/service_registry.py index 8bd877f0f3b..efccc52d145 100644 --- a/packages/syft/src/syft/server/service_registry.py +++ b/packages/syft/src/syft/server/service_registry.py @@ -4,6 +4,7 @@ from dataclasses import field import typing from typing import TYPE_CHECKING +from typing import TypeVar # relative from ..serde.serializable import serializable @@ -38,12 +39,17 @@ from ..service.worker.worker_image_service import SyftWorkerImageService from ..service.worker.worker_pool_service import SyftWorkerPoolService from ..service.worker.worker_service import WorkerService +from ..store.db.stash import ObjectStash +from ..types.syft_object import SyftObject if TYPE_CHECKING: # relative from .server import Server +StashT = TypeVar("StashT", bound=SyftObject) + + @serializable(canonical_name="ServiceRegistry", version=1) @dataclass class ServiceRegistry: @@ -82,6 +88,7 @@ class ServiceRegistry: service_path_map: dict[str, AbstractService] = field( default_factory=dict, init=False ) + stashes: dict[StashT, ObjectStash[StashT]] = {} @classmethod def for_server(cls, server: "Server") -> "ServiceRegistry": @@ -93,6 +100,10 @@ def __post_init__(self) -> None: self.services.append(service) self.service_path_map[service_cls.__name__.lower()] = service + if hasattr(service, "stash"): + stash: ObjectStash = service.stash + self.stashes[stash.object_type] = stash + @classmethod def get_service_classes( cls, diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index da2fc107d32..4f7ce88a149 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -7,6 +7,7 @@ # relative from ...serde.serializable import serializable +from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore from ...store.document_store import StorePartition from ...store.document_store_errors import NotFoundException @@ -15,6 +16,7 @@ from ...types.result import as_result from ...types.syft_object import SyftObject from ...types.syft_object_registry import SyftObjectRegistry +from ...types.twin_object import TwinObject from ..action.action_object import Action from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectPermission @@ -165,6 +167,7 @@ def _get_all_store_metadata( document_store_object_types: list[type[SyftObject]] | None = None, include_action_store: bool = True, ) -> dict[type[SyftObject], StoreMetadata]: + # metadata = permissions + storage permissions if document_store_object_types is None: document_store_object_types = self.store.get_partition_object_types() @@ -272,24 +275,27 @@ def _get_migration_objects( return dict(result) @as_result(SyftException) - def _search_partition_for_object( + def _search_stash_for_object( self, context: AuthedServiceContext, obj: SyftObject - ) -> StorePartition: + ) -> ObjectStash: + stashes: dict[str, ObjectStash] = { + t.__canonical_name__: stash + for t, stash in context.server.services.stashes.items() + } + klass = type(obj) mro = klass.__mro__ class_index = 0 - object_partition = None + object_stash = None while len(mro) > class_index: canonical_name = mro[class_index].__canonical_name__ - object_partition = self.store.partitions.get(canonical_name) - if object_partition is not None: + object_stash = stashes.get(canonical_name) + if object_stash is not None: break class_index += 1 - if object_partition is None: - raise SyftException( - public_message=f"Object partition not found for {klass}" - ) - return object_partition + if object_stash is None: + raise SyftException(public_message=f"Object stash not found for {klass}") + return object_stash @service_method( path="migration.create_migrated_objects", @@ -314,12 +320,9 @@ def _create_migrated_objects( ignore_existing: bool = True, ) -> SyftSuccess: for migrated_object in migrated_objects: - object_partition = self._search_partition_for_object( - context, migrated_object - ).unwrap() + stash = self._search_stash_for_object(context, migrated_object).unwrap() - # upsert the object - result = object_partition.set( + result = stash.set( context.credentials, obj=migrated_object, ) @@ -338,34 +341,18 @@ def _create_migrated_objects( result.unwrap() # this will raise the exception inside the wrapper return SyftSuccess(message="Created migrate objects!") - @service_method( - path="migration.update_migrated_objects", - name="update_migrated_objects", - roles=ADMIN_ROLE_LEVEL, - ) - def update_migrated_objects( - self, context: AuthedServiceContext, migrated_objects: list[SyftObject] - ) -> None: - self._update_migrated_objects(context, migrated_objects).unwrap() - @as_result(SyftException) def _update_migrated_objects( self, context: AuthedServiceContext, migrated_objects: list[SyftObject] ) -> SyftSuccess: for migrated_object in migrated_objects: - object_partition = self._search_partition_for_object( - context, migrated_object - ).unwrap() + stash = self._search_stash_for_object(context, migrated_object).unwrap() - qk = object_partition.settings.store_key.with_obj(migrated_object.id) - object_partition._update( + stash.update( context.credentials, - qk=qk, obj=migrated_object, - has_permission=True, - overwrite=True, - allow_missing_keys=True, ).unwrap() + return SyftSuccess(message="Updated migration objects!") @as_result(SyftException) @@ -403,33 +390,18 @@ def migrate_data( context: AuthedServiceContext, document_store_object_types: list[type[SyftObject]] | None = None, ) -> SyftSuccess: - # Track all object type that need migration for document store - - # get all objects, keyed by type (because we might want to have different rules for different types) - # Q: will this be tricky with the protocol???? - # A: For now we will assume that the client will have the same version - - # Then, locally we write stuff that says - # for klass, objects in migration_dict.items(): - # for object in objects: - # if isinstance(object, X): - # do something custom - # else: - # migrated_value = object.migrate_to(klass.__version__, context) - # - # migrated_values = [SyftObject] - # client.migration.write_migrated_values(migrated_values) - migration_objects = self._get_migration_objects( context, document_store_object_types ).unwrap() migrated_objects = self._migrate_objects(context, migration_objects).unwrap() self._update_migrated_objects(context, migrated_objects).unwrap() + migration_actionobjects = self._get_migration_actionobjects(context).unwrap() migrated_actionobjects = self._migrate_objects( context, migration_actionobjects ).unwrap() self._update_migrated_actionobjects(context, migrated_actionobjects).unwrap() + return SyftSuccess(message="Data upgraded to the latest version") @service_method( @@ -447,7 +419,7 @@ def _get_migration_actionobjects( self, context: AuthedServiceContext, get_all: bool = False ) -> dict[type[SyftObject], list[SyftObject]]: # Track all object types from action store - action_object_types = [Action, ActionObject] + action_object_types = [Action, ActionObject, TwinObject] action_object_types.extend(ActionObject.__subclasses__()) klass_by_canonical_name: dict[str, type[SyftObject]] = { klass.__canonical_name__: klass for klass in action_object_types @@ -457,10 +429,8 @@ def _get_migration_actionobjects( context=context, object_types=action_object_types ).unwrap() result_dict: dict[type[SyftObject], list[SyftObject]] = defaultdict(list) - action_store = context.server.action_store - action_store_objects = action_store.get_all( - context.credentials, has_permission=True - ).unwrap() + action_stash = context.server.services.action.stash + action_store_objects = action_stash.get_all(context.credentials).unwrap() for obj in action_store_objects: if get_all or type(obj) in action_object_pending_migration: From 8af186f47101c26ba09bf738a767187328a9483e Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 10 Sep 2024 11:25:08 +0200 Subject: [PATCH 111/197] fix migrations --- .../syft/src/syft/server/service_registry.py | 5 +- .../src/syft/service/action/action_service.py | 4 +- .../service/migration/migration_service.py | 142 +++++------------- packages/syft/src/syft/store/db/stash.py | 2 +- 4 files changed, 41 insertions(+), 112 deletions(-) diff --git a/packages/syft/src/syft/server/service_registry.py b/packages/syft/src/syft/server/service_registry.py index efccc52d145..0689185f4cc 100644 --- a/packages/syft/src/syft/server/service_registry.py +++ b/packages/syft/src/syft/server/service_registry.py @@ -88,7 +88,7 @@ class ServiceRegistry: service_path_map: dict[str, AbstractService] = field( default_factory=dict, init=False ) - stashes: dict[StashT, ObjectStash[StashT]] = {} + stashes: dict[StashT, ObjectStash[StashT]] = field(default_factory=dict, init=False) @classmethod def for_server(cls, server: "Server") -> "ServiceRegistry": @@ -100,7 +100,8 @@ def __post_init__(self) -> None: self.services.append(service) self.service_path_map[service_cls.__name__.lower()] = service - if hasattr(service, "stash"): + # TODO ActionService now has same stash, but interface is still different. Fix this. + if hasattr(service, "stash") and not issubclass(service_cls, ActionService): stash: ObjectStash = service.stash self.stashes[stash.object_type] = stash diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index fa37f8dd47a..ddca7e03ebd 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -56,10 +56,10 @@ @serializable(canonical_name="ActionService", version=1) class ActionService(AbstractService): - stash: ActionObjectStash - def __init__(self, store: DocumentStore) -> None: + # TODO remove self.store, use self.stash instead self.store = ActionObjectStash(store) + self.stash = self.store @service_method(path="action.np_array", name="np_array") def np_array(self, context: AuthedServiceContext, data: Any) -> Any: diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 4f7ce88a149..b9050ba4b89 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -1,6 +1,5 @@ # stdlib from collections import defaultdict -from typing import cast # syft absolute import syft @@ -9,7 +8,6 @@ from ...serde.serializable import serializable from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore -from ...store.document_store import StorePartition from ...store.document_store_errors import NotFoundException from ...types.blob_storage import BlobStorageEntry from ...types.errors import SyftException @@ -123,81 +121,35 @@ def get_all_store_metadata( include_action_store=include_action_store, ).unwrap() - @as_result(SyftException) - def _get_partition_from_type( - self, - context: AuthedServiceContext, - object_type: type[SyftObject], - ) -> StorePartition: - object_partition: ActionObjectStash | StorePartition | None = None - if issubclass(object_type, ActionObject): - object_partition = cast(ActionObjectStash, context.server.action_store) - else: - canonical_name = object_type.__canonical_name__ # type: ignore[unreachable] - object_partition = self.store.partitions.get(canonical_name) - - if object_partition is None: - raise SyftException( - public_message=f"Object partition not found for {object_type}" - ) # type: ignore - - return object_partition - - @as_result(SyftException) - def _get_store_metadata( - self, - context: AuthedServiceContext, - object_type: type[SyftObject], - ) -> StoreMetadata: - object_partition = self._get_partition_from_type(context, object_type).unwrap() - permissions = dict(object_partition.get_all_permissions().unwrap()) - storage_permissions = dict( - object_partition.get_all_storage_permissions().unwrap() - ) - return StoreMetadata( - object_type=object_type, - permissions=permissions, - storage_permissions=storage_permissions, - ) - @as_result(SyftException) def _get_all_store_metadata( self, context: AuthedServiceContext, document_store_object_types: list[type[SyftObject]] | None = None, - include_action_store: bool = True, ) -> dict[type[SyftObject], StoreMetadata]: # metadata = permissions + storage permissions - if document_store_object_types is None: - document_store_object_types = self.store.get_partition_object_types() - + stashes = context.server.services.stashes store_metadata = {} - for klass in document_store_object_types: - store_metadata[klass] = self._get_store_metadata(context, klass).unwrap() - if include_action_store: - store_metadata[ActionObject] = self._get_store_metadata( - context, ActionObject - ).unwrap() - return store_metadata + for klass, stash in stashes.items(): + if ( + document_store_object_types is not None + and klass not in document_store_object_types + ): + continue + store_metadata[klass] = StoreMetadata( + object_type=klass, + permissions=stash.get_all_permissions().unwrap(), + storage_permissions=stash.get_all_storage_permissions().unwrap(), + ) - @service_method( - path="migration.update_store_metadata", - name="update_store_metadata", - roles=ADMIN_ROLE_LEVEL, - ) - def update_store_metadata( - self, context: AuthedServiceContext, store_metadata: dict[type, StoreMetadata] - ) -> None: # type: ignore - self._update_store_metadata(context, store_metadata).unwrap() + return store_metadata @as_result(SyftException) def _update_store_metadata_for_klass( self, context: AuthedServiceContext, metadata: StoreMetadata ) -> None: - object_partition = self._get_partition_from_type( - context, metadata.object_type - ).unwrap() + stash = self._search_stash_for_klass(context, metadata.object_type).unwrap() permissions = [ ActionObjectPermission.from_permission_string(uid, perm_str) for uid, perm_strs in metadata.permissions.items() @@ -210,8 +162,8 @@ def _update_store_metadata_for_klass( for server_uid in server_uids ] - object_partition.add_permissions(permissions) - object_partition.add_storage_permissions(storage_permissions) + stash.add_permissions(permissions) + stash.add_storage_permissions(storage_permissions) @as_result(SyftException) def _update_store_metadata( @@ -221,21 +173,6 @@ def _update_store_metadata( for metadata in store_metadata.values(): self._update_store_metadata_for_klass(context, metadata).unwrap() - @service_method( - path="migration.get_migration_objects", - name="get_migration_objects", - roles=ADMIN_ROLE_LEVEL, - ) - def get_migration_objects( - self, - context: AuthedServiceContext, - document_store_object_types: list[type[SyftObject]] | None = None, - get_all: bool = False, - ) -> dict: - return self._get_migration_objects( - context, document_store_object_types, get_all - ).unwrap() - @as_result(SyftException) def _get_migration_objects( self, @@ -244,7 +181,7 @@ def _get_migration_objects( get_all: bool = False, ) -> dict[type[SyftObject], list[SyftObject]]: if document_store_object_types is None: - document_store_object_types = self.store.get_partition_object_types() + document_store_object_types = list(context.server.services.stashes.keys()) if get_all: klasses_to_migrate = document_store_object_types @@ -256,14 +193,12 @@ def _get_migration_objects( result = defaultdict(list) for klass in klasses_to_migrate: - canonical_name = klass.__canonical_name__ - object_partition = self.store.partitions.get(canonical_name) - if object_partition is None: + stash_or_err = self._search_stash_for_klass(context, klass) + if stash_or_err.is_err(): continue - objects = object_partition.all( - context.credentials, has_permission=True - ).unwrap() - for object in objects: + stash = stash_or_err.unwrap() + + for object in stash._data: actual_klass = type(object) use_klass = ( klass @@ -275,15 +210,14 @@ def _get_migration_objects( return dict(result) @as_result(SyftException) - def _search_stash_for_object( - self, context: AuthedServiceContext, obj: SyftObject + def _search_stash_for_klass( + self, context: AuthedServiceContext, klass: type[SyftObject] ) -> ObjectStash: stashes: dict[str, ObjectStash] = { t.__canonical_name__: stash for t, stash in context.server.services.stashes.items() } - klass = type(obj) mro = klass.__mro__ class_index = 0 object_stash = None @@ -320,7 +254,9 @@ def _create_migrated_objects( ignore_existing: bool = True, ) -> SyftSuccess: for migrated_object in migrated_objects: - stash = self._search_stash_for_object(context, migrated_object).unwrap() + stash = self._search_stash_for_klass( + context, type(migrated_object) + ).unwrap() result = stash.set( context.credentials, @@ -346,7 +282,9 @@ def _update_migrated_objects( self, context: AuthedServiceContext, migrated_objects: list[SyftObject] ) -> SyftSuccess: for migrated_object in migrated_objects: - stash = self._search_stash_for_object(context, migrated_object).unwrap() + stash = self._search_stash_for_klass( + context, type(migrated_object) + ).unwrap() stash.update( context.credentials, @@ -438,26 +376,16 @@ def _get_migration_actionobjects( result_dict[klass].append(obj) # type: ignore return dict(result_dict) - @service_method( - path="migration.update_migrated_actionobjects", - name="update_migrated_actionobjects", - roles=ADMIN_ROLE_LEVEL, - ) - def update_migrated_actionobjects( - self, context: AuthedServiceContext, objects: list[SyftObject] - ) -> SyftSuccess: - self._update_migrated_actionobjects(context, objects).unwrap() - return SyftSuccess(message="succesfully migrated actionobjects") - @as_result(SyftException) def _update_migrated_actionobjects( self, context: AuthedServiceContext, objects: list[SyftObject] ) -> str: - # Track all object types from action store - action_store: ActionObjectStash = context.server.action_store + action_store: ActionObjectStash = context.server.services.action.stash for obj in objects: - action_store.set( - uid=obj.id, credentials=context.credentials, syft_object=obj + action_store.set_or_update( + uid=obj.id, + credentials=context.credentials, + syft_object=obj, ).unwrap() return "success" diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index bb69d40544e..b8bdda30ae5 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -662,7 +662,7 @@ def get_all_storage_permissions(self) -> dict[UID, set[UID]]: results = self.session.execute(stmt).all() return { - UID(row.id): {(UID(uid) for uid in row.storage_permissions)} + UID(row.id): {UID(uid) for uid in row.storage_permissions} for row in results } From 151682d899e52743e3df705db76e654d749dd281 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 10 Sep 2024 11:31:38 +0200 Subject: [PATCH 112/197] fix test --- packages/syft/tests/syft/users/user_service_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/tests/syft/users/user_service_test.py b/packages/syft/tests/syft/users/user_service_test.py index d97c6ab601e..efa99ed9a50 100644 --- a/packages/syft/tests/syft/users/user_service_test.py +++ b/packages/syft/tests/syft/users/user_service_test.py @@ -208,7 +208,7 @@ def test_userservice_get_all_success( expected_output = [x.to(UserView) for x in mock_get_all_output] @as_result(StashException) - def mock_get_all(credentials: SyftVerifyKey) -> list[User]: + def mock_get_all(credentials: SyftVerifyKey, **kwargs) -> list[User]: return mock_get_all_output monkeypatch.setattr(user_service.stash, "get_all", mock_get_all) From c4c13443f5837a914977d30af48ce8a4269c10d5 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 10 Sep 2024 11:32:15 +0200 Subject: [PATCH 113/197] fix test --- packages/syft/tests/syft/stores/action_store_test.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/packages/syft/tests/syft/stores/action_store_test.py b/packages/syft/tests/syft/stores/action_store_test.py index e0032877fcd..51b118f12fa 100644 --- a/packages/syft/tests/syft/stores/action_store_test.py +++ b/packages/syft/tests/syft/stores/action_store_test.py @@ -19,8 +19,7 @@ from syft.store.db.sqlite_db import DBManager from syft.types.uid import UID -# first party -from packages.syft.tests.syft.worker_test import action_object_stash # noqa: F401 +# from packages.syft.tests.syft.worker_test import action_object_stash # noqa: F401 permissions = [ ActionObjectOWNER, From dcbc08e6384e277811f357003301d9df9efabfc9 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 10 Sep 2024 11:34:06 +0200 Subject: [PATCH 114/197] merge --- packages/syft/src/syft/client/api.py | 5 ++++ packages/syft/src/syft/service/user/user.py | 1 + packages/syft/tests/syft/api_test.py | 2 +- .../tests/syft/stores/action_store_test.py | 3 ++- .../tests/syft/users/user_service_test.py | 23 +++---------------- packages/syft/tests/syft/users/user_test.py | 1 + 6 files changed, 13 insertions(+), 22 deletions(-) diff --git a/packages/syft/src/syft/client/api.py b/packages/syft/src/syft/client/api.py index 21c8afe0557..85fe33545fa 100644 --- a/packages/syft/src/syft/client/api.py +++ b/packages/syft/src/syft/client/api.py @@ -743,6 +743,11 @@ def __getitem__(self, key: str | int) -> Any: return self.get_all()[key] raise NotImplementedError + def __iter__(self) -> Any: + if hasattr(self, "get_all"): + return iter(self.get_all()) + raise NotImplementedError + def _repr_html_(self) -> Any: if self.path == "settings": return self.get()._repr_html_() diff --git a/packages/syft/src/syft/service/user/user.py b/packages/syft/src/syft/service/user/user.py index 6c24554478f..f2421173af4 100644 --- a/packages/syft/src/syft/service/user/user.py +++ b/packages/syft/src/syft/service/user/user.py @@ -105,6 +105,7 @@ class User(SyftObject): __attr_searchable__ = ["name", "email", "verify_key", "role", "reset_token"] __attr_unique__ = ["email", "signing_key", "verify_key"] __repr_attrs__ = ["name", "email"] + __order_by__ = ("created_at", "asc") @migrate(UserV1, User) diff --git a/packages/syft/tests/syft/api_test.py b/packages/syft/tests/syft/api_test.py index 6c511b45d48..ca2f61ac147 100644 --- a/packages/syft/tests/syft/api_test.py +++ b/packages/syft/tests/syft/api_test.py @@ -45,7 +45,7 @@ def test_api_cache_invalidation_login(root_verify_key, worker): name="q", email="a@b.org", password="aaa", password_verify="aaa" ) guest_client = guest_client.login(email="a@b.org", password="aaa") - user_id = worker.document_store.partitions["User"].all(root_verify_key).value[-1].id + user_id = worker.root_client.users[-1].id def get_role(verify_key): users = worker.services.user.stash.get_all(root_verify_key).ok() diff --git a/packages/syft/tests/syft/stores/action_store_test.py b/packages/syft/tests/syft/stores/action_store_test.py index 51b118f12fa..d40fa15e8ca 100644 --- a/packages/syft/tests/syft/stores/action_store_test.py +++ b/packages/syft/tests/syft/stores/action_store_test.py @@ -19,7 +19,8 @@ from syft.store.db.sqlite_db import DBManager from syft.types.uid import UID -# from packages.syft.tests.syft.worker_test import action_object_stash # noqa: F401 +# relative +from ..worker_test import action_object_stash # noqa: F401 permissions = [ ActionObjectOWNER, diff --git a/packages/syft/tests/syft/users/user_service_test.py b/packages/syft/tests/syft/users/user_service_test.py index 45e31da18fe..efa99ed9a50 100644 --- a/packages/syft/tests/syft/users/user_service_test.py +++ b/packages/syft/tests/syft/users/user_service_test.py @@ -208,7 +208,7 @@ def test_userservice_get_all_success( expected_output = [x.to(UserView) for x in mock_get_all_output] @as_result(StashException) - def mock_get_all(credentials: SyftVerifyKey) -> list[User]: + def mock_get_all(credentials: SyftVerifyKey, **kwargs) -> list[User]: return mock_get_all_output monkeypatch.setattr(user_service.stash, "get_all", mock_get_all) @@ -222,23 +222,6 @@ def mock_get_all(credentials: SyftVerifyKey) -> list[User]: ) -def test_userservice_get_all_error( - monkeypatch: MonkeyPatch, - user_service: UserService, - authed_context: AuthedServiceContext, -) -> None: - @as_result(StashException) - def mock_get_all(credentials: SyftVerifyKey) -> NoReturn: - raise StashException - - monkeypatch.setattr(user_service.stash, "get_all", mock_get_all) - - with pytest.raises(StashException) as exc: - user_service.get_all(authed_context) - - assert exc.type == StashException - - def test_userservice_search( monkeypatch: MonkeyPatch, user_service: UserService, @@ -246,13 +229,13 @@ def test_userservice_search( guest_user: User, ) -> None: @as_result(SyftException) - def mock_find_all(credentials: SyftVerifyKey, **kwargs) -> list[User]: + def get_all(credentials: SyftVerifyKey, **kwargs) -> list[User]: for key in kwargs.keys(): if hasattr(guest_user, key): return [guest_user] return [] - monkeypatch.setattr(user_service.stash, "find_all", mock_find_all) + monkeypatch.setattr(user_service.stash, "get_all", get_all) expected_output = [guest_user.to(UserView)] diff --git a/packages/syft/tests/syft/users/user_test.py b/packages/syft/tests/syft/users/user_test.py index 1a9830ed48c..cdc348c9f4f 100644 --- a/packages/syft/tests/syft/users/user_test.py +++ b/packages/syft/tests/syft/users/user_test.py @@ -424,6 +424,7 @@ def test_user_view_set_role(worker: Worker, guest_client: DatasiteClient) -> Non with pytest.raises(SyftException): ds_client.account.update(role="guest") + with pytest.raises(SyftException): ds_client.account.update(role="data_scientist") # now we set sheldon's role to admin. Only now he can change his role From 72d36d6826e426649436cdf63a5a109352cddda4 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Tue, 10 Sep 2024 15:39:11 +0530 Subject: [PATCH 115/197] add a few log statements --- packages/syft/src/syft/server/server.py | 3 +++ packages/syft/src/syft/server/uvicorn.py | 9 +++++++++ .../syft/src/syft/store/postgresql_document_store.py | 7 ++----- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 912023749c8..c31247360eb 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -652,6 +652,7 @@ def init_queue_manager(self, queue_config: QueueConfig) -> None: worker_stash=self.worker_stash, ) producer.run() + address = producer.address else: port = queue_config.client_config.queue_port @@ -917,6 +918,8 @@ def init_stores( document_store_config: StoreConfig, action_store_config: StoreConfig, ) -> None: + logger.info(f"Document store config: {document_store_config}") + logger.info(f"Action store config: {action_store_config}") self.document_store_config = document_store_config self.document_store = document_store_config.store_type( server_uid=self.id, diff --git a/packages/syft/src/syft/server/uvicorn.py b/packages/syft/src/syft/server/uvicorn.py index ee9abd5a43d..6f2617cc7b7 100644 --- a/packages/syft/src/syft/server/uvicorn.py +++ b/packages/syft/src/syft/server/uvicorn.py @@ -2,6 +2,7 @@ from collections.abc import Callable from contextlib import asynccontextmanager import json +import logging import multiprocessing import multiprocessing.synchronize import os @@ -47,6 +48,9 @@ WAIT_TIME_SECONDS = 20 +logger = logging.getLogger("uvicorn") + + class AppSettings(BaseSettings): name: str server_type: ServerType = ServerType.DATASITE @@ -62,6 +66,7 @@ class AppSettings(BaseSettings): n_consumers: int = 0 association_request_auto_approval: bool = False background_tasks: bool = False + store_client_config: dict | None = None model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None") @@ -92,6 +97,10 @@ def app_factory() -> FastAPI: worker_class = worker_classes[settings.server_type] kwargs = settings.model_dump() + + logger.info( + f"Starting server with settings: {kwargs} and worker class: {worker_class}" + ) if settings.dev_mode: print( f"WARN: private key is based on server name: {settings.name} in dev_mode. " diff --git a/packages/syft/src/syft/store/postgresql_document_store.py b/packages/syft/src/syft/store/postgresql_document_store.py index 2a8fc2863de..241b311301b 100644 --- a/packages/syft/src/syft/store/postgresql_document_store.py +++ b/packages/syft/src/syft/store/postgresql_document_store.py @@ -101,11 +101,8 @@ def _connect(self) -> None: port=self.store_config.client_config.port, ) _CONNECTION_POOL_DB[cache_key(self.dbname)] = connection - print(f"Connected to {self.store_config.client_config.dbname}") - print( - "PostgreSQL database connection:", - _CONNECTION_POOL_DB[cache_key(self.dbname)], - ) + logger.info(f"Connected to {self.store_config.client_config.dbname}") + logger.info(f"PostgreSQL database connection: {connection.info.dsn}") def create_table(self) -> None: db = self.db From c51f544fad5030e29466555b1fd3f9c03d145e2c Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Tue, 10 Sep 2024 16:55:56 +0530 Subject: [PATCH 116/197] set prepare_threshold to None --- packages/syft/src/syft/store/postgresql_document_store.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/packages/syft/src/syft/store/postgresql_document_store.py b/packages/syft/src/syft/store/postgresql_document_store.py index 241b311301b..75dee5776ca 100644 --- a/packages/syft/src/syft/store/postgresql_document_store.py +++ b/packages/syft/src/syft/store/postgresql_document_store.py @@ -99,6 +99,9 @@ def _connect(self) -> None: password=self.store_config.client_config.password, host=self.store_config.client_config.host, port=self.store_config.client_config.port, + # This should default to None, + # https://www.psycopg.org/psycopg3/docs/advanced/prepare.html#using-prepared-statements-with-pgbouncer + prepare_threshold=None, ) _CONNECTION_POOL_DB[cache_key(self.dbname)] = connection logger.info(f"Connected to {self.store_config.client_config.dbname}") From ff90871d10f7a318b94906b1ae8ef7c6d82d8e68 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 10 Sep 2024 13:37:33 +0200 Subject: [PATCH 117/197] fix new query --- .../src/syft/service/dataset/dataset_stash.py | 3 +- .../syft/src/syft/service/user/user_stash.py | 2 +- packages/syft/src/syft/store/db/query.py | 58 ++++++++----------- packages/syft/src/syft/store/db/stash.py | 49 +++++++++++++--- 4 files changed, 67 insertions(+), 45 deletions(-) diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index 28b6934f672..3c5a1f203fd 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -1,7 +1,6 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.query import Filter from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException @@ -22,7 +21,7 @@ def get_by_name(self, credentials: SyftVerifyKey, name: str) -> Dataset: def search_action_ids(self, credentials: SyftVerifyKey, uid: UID) -> list[Dataset]: return self.get_all( credentials=credentials, - filters={"action_ids": Filter("action_ids", "contains", uid)}, + filters={"action_ids__contains": uid}, ).unwrap() @as_result(StashException) diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index 11a44d47832..41f561ec9af 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -54,7 +54,7 @@ def get_by_email(self, credentials: SyftVerifyKey, email: str) -> User: return self.get_one( credentials=credentials, filters={"email": email}, - ) + ).unwrap() @as_result(StashException) def email_exists(self, email: str) -> bool: diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py index 326548edcef..d43eb3ca102 100644 --- a/packages/syft/src/syft/store/db/query.py +++ b/packages/syft/src/syft/store/db/query.py @@ -1,20 +1,17 @@ # stdlib from abc import ABC from abc import abstractmethod -from dataclasses import dataclass +import enum from typing import Any from typing import Literal # third party import sqlalchemy as sa from sqlalchemy import Column -from sqlalchemy import Dialect from sqlalchemy import Result from sqlalchemy import Select from sqlalchemy import Table -from sqlalchemy import dialects from sqlalchemy import func -from sqlalchemy import select from sqlalchemy.orm import Session from typing_extensions import Self @@ -26,32 +23,25 @@ from ...service.user.user_roles import ServiceRole from ...types.syft_object import SyftObject from ...types.uid import UID -from .sqlite_db import OBJECT_TYPE_TO_TABLE +from .schema import Base -@dataclass -class Filter: - field: str - operator: str - value: Any +class FilterOperator(enum.Enum): + EQ = "eq" + CONTAINS = "contains" class Query(ABC): - dialect: Dialect - def __init__(self, object_type: type[SyftObject]) -> None: self.object_type: type = object_type - self.table: Table = OBJECT_TYPE_TO_TABLE[object_type] - self.stmt: Select = select([self.table]) + self.table: Table = self._get_table(object_type) + self.stmt: Select = self.table.select() - def compile(self) -> str: - """ - Compile the query to a string, for debugging purposes. - """ - return self.stmt.compile( - compile_kwargs={"literal_binds": True}, - dialect=self.dialect, - ) + def _get_table(self, object_type: type[SyftObject]) -> Table: + cname = object_type.__canonical_name__ + if cname not in Base.metadata.tables: + raise ValueError(f"Table for {cname} not found") + return Base.metadata.tables[cname] def execute(self, session: Session) -> Result: """Execute the query using the given session.""" @@ -77,7 +67,7 @@ def with_permissions( Returns: Self: The query object with the permission check applied """ - if role in (ServiceRole.ADMIN, ServiceRole.SUPER_ADMIN): + if role in (ServiceRole.ADMIN, ServiceRole.DATA_OWNER): return self permission = ActionObjectPermission( @@ -91,11 +81,11 @@ def with_permissions( return self - def filter(self, field: str, operator: str, value: Any) -> Self: + def filter(self, field: str, operator: str | FilterOperator, value: Any) -> Self: """Add a filter to the query. example usage: - Query(User).filter("name", "==", "Alice") + Query(User).filter("name", "eq", "Alice") Query(User).filter("friends", "contains", "Bob") Args: @@ -109,16 +99,18 @@ def filter(self, field: str, operator: str, value: Any) -> Self: Returns: Self: The query object with the filter applied """ - if operator not in {"==", "!=", "contains"}: - raise ValueError(f"Operation {operator} not supported") + if isinstance(operator, str): + try: + operator = FilterOperator(operator.lower()) + except ValueError: + raise ValueError(f"Filter operator {operator} not supported") - if operator == "==": + if operator == FilterOperator.EQ: filter = self._eq_filter(self.table, field, value) - self.stmt = self.stmt.where(filter) - elif operator == "contains": + elif operator == FilterOperator.CONTAINS: filter = self._contains_filter(self.table, field, value) - self.stmt = self.stmt.where(filter) + self.stmt = self.stmt.where(filter) return self def order_by( @@ -214,8 +206,6 @@ def _get_column(self, column: str) -> Column: class SQLiteQuery(Query): - dialect = dialects.sqlite.dialect - def _make_permissions_clause( self, permission: ActionObjectPermission, @@ -238,8 +228,6 @@ def _contains_filter( class PostgresQuery(Query): - dialect = dialects.postgresql.dialect - def _make_permissions_clause( self, permission: ActionObjectPermission ) -> sa.sql.elements.BinaryExpression: diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index d2f00e954f5..abe5ba54d71 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -33,7 +33,6 @@ from ...types.uid import UID from ..document_store_errors import NotFoundException from ..document_store_errors import StashException -from .query import Filter from .query import PostgresQuery from .query import Query from .query import SQLiteQuery @@ -46,6 +45,22 @@ T = TypeVar("T") +def parse_filters(filter_dict: dict[str, Any] | None) -> list[tuple[str, str, Any]]: + # NOTE using django style filters, e.g. {"age__gt": 18} + if filter_dict is None: + return [] + filters = [] + for key, value in filter_dict.items(): + key_split = key.split("__") + # Operator is eq if not specified + if len(key_split) == 1: + field, operator = key, "eq" + elif len(key_split) == 2: + field, operator = key_split + filters.append((field, operator, value)) + return filters + + class ObjectStash(Generic[StashT]): table: Table object_type: type[SyftObject] @@ -171,7 +186,7 @@ def exists(self, credentials: SyftVerifyKey, uid: UID) -> bool: def get_by_uid( self, credentials: SyftVerifyKey, uid: UID, has_permission: bool = False ) -> StashT: - query = self.query().filter("id", "==", uid) + query = self.query().filter("id", "eq", uid) if not has_permission: role = self.get_role(credentials) @@ -675,17 +690,37 @@ def get_all( limit: int | None = None, offset: int = 0, ) -> list[StashT]: + """ + Get all objects from the stash, optionally filtered. + + Args: + credentials (SyftVerifyKey): credentials of the user + filters (dict[str, Any] | None, optional): dictionary of filters, + where the key is the field name and the value is the filter value. + Operators other than equals can be used in the key, + e.g. {"name": "Bob", "friends__contains": "Alice"}. Defaults to None. + has_permission (bool, optional): If True, overrides the permission check. + Defaults to False. + order_by (str | None, optional): If provided, the results will be ordered by this field. + If not provided, the default order and field defined on the SyftObject.__order_by__ are used. + Defaults to None. + sort_order (str | None, optional): "asc" or "desc" If not defined, + the default order defined on the SyftObject.__order_by__ is used. + Defaults to None. + limit (int | None, optional): limit the number of results. Defaults to None. + offset (int, optional): offset the results. Defaults to 0. + + Returns: + list[StashT]: list of objects. + """ query = self.query() if not has_permission: role = self.get_role(credentials) query = query.with_permissions(credentials, role) - if filters: - for field_name, field_value in filters.items(): - if isinstance(field_value, Filter): - query = query.filter(field_name, field_value.op, field_value.value) - query = query.filter(field_name, "==", field_value) + for field_name, operator, field_value in parse_filters(filters): + query = query.filter(field_name, operator, field_value) query = query.order_by(order_by, sort_order).limit(limit).offset(offset) result = query.execute(self.session).all() From 2609c0d59a9a3e772ca7724814642c0e691995a6 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 10 Sep 2024 13:37:42 +0200 Subject: [PATCH 118/197] fix more tests --- packages/syft/src/syft/server/server.py | 4 + packages/syft/src/syft/service/user/user.py | 2 - packages/syft/src/syft/types/syft_object.py | 2 +- .../syft/settings/settings_service_test.py | 15 +- .../syft/stores/mongo_document_store_test.py | 407 ------------------ 5 files changed, 9 insertions(+), 421 deletions(-) delete mode 100644 packages/syft/tests/syft/stores/mongo_document_store_test.py diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 852d84a3d0b..e952c0f4633 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -351,6 +351,7 @@ def __init__( self.server_side_type = ServerSideType(server_side_type) self.client_cache: dict = {} self.peer_client_cache: dict = {} + self._settings = None if isinstance(server_type, str): server_type = ServerType(server_type) @@ -968,6 +969,8 @@ def update_self(self, settings: ServerSettings) -> None: # it should be removed once the settings are refactored and the inconsistencies between # settings and services are resolved. def get_settings(self) -> ServerSettings | None: + if self._settings: + return self._settings if self.signing_key is None: raise ValueError(f"{self} has no signing key") @@ -979,6 +982,7 @@ def get_settings(self) -> ServerSettings | None: if len(settings) > 0: setting = settings[0] self.update_self(setting) + self._settings = setting return setting else: return None diff --git a/packages/syft/src/syft/service/user/user.py b/packages/syft/src/syft/service/user/user.py index f2421173af4..0fb8a87e7fd 100644 --- a/packages/syft/src/syft/service/user/user.py +++ b/packages/syft/src/syft/service/user/user.py @@ -76,7 +76,6 @@ class User(SyftObject): # version __canonical_name__ = "User" __version__ = SYFT_OBJECT_VERSION_2 - __order_by__ = ("email", "asc") id: UID | None = None # type: ignore[assignment] @@ -105,7 +104,6 @@ class User(SyftObject): __attr_searchable__ = ["name", "email", "verify_key", "role", "reset_token"] __attr_unique__ = ["email", "signing_key", "verify_key"] __repr_attrs__ = ["name", "email"] - __order_by__ = ("created_at", "asc") @migrate(UserV1, User) diff --git a/packages/syft/src/syft/types/syft_object.py b/packages/syft/src/syft/types/syft_object.py index a6fe14f6faa..f6a4d3233cb 100644 --- a/packages/syft/src/syft/types/syft_object.py +++ b/packages/syft/src/syft/types/syft_object.py @@ -410,7 +410,7 @@ def make_id(cls, values: Any) -> Any: values["id"] = id_field.annotation() return values - __order_by__: ClassVar[tuple[str, str]] = ("created_date", "desc") + __order_by__: ClassVar[tuple[str, str]] = ("_created_at", "asc") __attr_searchable__: ClassVar[ list[str] ] = [] # keys which can be searched in the ORM diff --git a/packages/syft/tests/syft/settings/settings_service_test.py b/packages/syft/tests/syft/settings/settings_service_test.py index 7d0500f8247..0c9306aaff0 100644 --- a/packages/syft/tests/syft/settings/settings_service_test.py +++ b/packages/syft/tests/syft/settings/settings_service_test.py @@ -104,6 +104,8 @@ def test_settingsservice_set_success( response.syft_server_location = None response.pwd_token_config.syft_client_verify_key = None response.pwd_token_config.syft_server_location = None + response.welcome_markdown.syft_client_verify_key = None + response.welcome_markdown.syft_server_location = None assert response == settings @@ -170,15 +172,6 @@ def mock_get_service(service_name: str): assert isinstance(response, SyftSuccess) -def test_settingsservice_update_stash_get_all_fail( - monkeypatch: MonkeyPatch, - settings_service: SettingsService, - update_settings: ServerSettingsUpdate, - authed_context: AuthedServiceContext, -) -> None: - settings_service.update(context=authed_context, settings=update_settings) - - def test_settingsservice_update_stash_empty( settings_service: SettingsService, update_settings: ServerSettingsUpdate, @@ -202,7 +195,7 @@ def test_settingsservice_update_fail( mock_stash_get_all_output = [settings, settings] @as_result(StashException) - def mock_stash_get_all(credentials) -> list[ServerSettings]: + def mock_stash_get_all(credentials, **kwargs) -> list[ServerSettings]: return mock_stash_get_all_output monkeypatch.setattr(settings_service.stash, "get_all", mock_stash_get_all) @@ -210,7 +203,7 @@ def mock_stash_get_all(credentials) -> list[ServerSettings]: mock_update_error_message = "Failed to update obj ServerMetadata" @as_result(StashException) - def mock_stash_update_error(credentials, settings: ServerSettings) -> NoReturn: + def mock_stash_update_error(credentials, obj: ServerSettings) -> NoReturn: raise StashException(public_message=mock_update_error_message) monkeypatch.setattr(settings_service.stash, "update", mock_stash_update_error) diff --git a/packages/syft/tests/syft/stores/mongo_document_store_test.py b/packages/syft/tests/syft/stores/mongo_document_store_test.py deleted file mode 100644 index a1fbee484da..00000000000 --- a/packages/syft/tests/syft/stores/mongo_document_store_test.py +++ /dev/null @@ -1,407 +0,0 @@ -# stdlib - -# third party -import pytest - -# syft absolute -from syft.server.credentials import SyftVerifyKey -from syft.service.action.action_permissions import ActionObjectOWNER -from syft.service.action.action_permissions import ActionObjectPermission -from syft.service.action.action_permissions import ActionPermission -from syft.service.action.action_permissions import StoragePermission -from syft.service.action.action_store import ActionObjectEXECUTE -from syft.service.action.action_store import ActionObjectREAD -from syft.service.action.action_store import ActionObjectWRITE -from syft.store.document_store import QueryKey -from syft.store.mongo_document_store import MongoStorePartition -from syft.types.errors import SyftException -from syft.types.uid import UID - -# relative -from ...mongomock.collection import Collection as MongoCollection -from .store_constants_test import TEST_VERIFY_KEY_STRING_HACKER -from .store_mocks_test import MockSyftObject - -PERMISSIONS = [ - ActionObjectOWNER, - ActionObjectREAD, - ActionObjectWRITE, - ActionObjectEXECUTE, -] - - -def test_mongo_store_partition_add_remove_permission( - root_verify_key: SyftVerifyKey, mongo_store_partition: MongoStorePartition -) -> None: - """ - Test the add_permission and remove_permission functions of MongoStorePartition - """ - # setting up - permissions_collection: MongoCollection = mongo_store_partition.permissions.ok() - obj = MockSyftObject(data=1) - - # add the first permission - obj_read_permission = ActionObjectPermission( - uid=obj.id, permission=ActionPermission.READ, credentials=root_verify_key - ) - mongo_store_partition.add_permission(obj_read_permission) - find_res_1 = permissions_collection.find_one({"_id": obj_read_permission.uid}) - assert find_res_1 is not None - assert len(find_res_1["permissions"]) == 1 - assert find_res_1["permissions"] == { - obj_read_permission.permission_string, - } - - # add the second permission - obj_write_permission = ActionObjectPermission( - uid=obj.id, permission=ActionPermission.WRITE, credentials=root_verify_key - ) - mongo_store_partition.add_permission(obj_write_permission) - - find_res_2 = permissions_collection.find_one({"_id": obj.id}) - assert find_res_2 is not None - assert len(find_res_2["permissions"]) == 2 - assert find_res_2["permissions"] == { - obj_read_permission.permission_string, - obj_write_permission.permission_string, - } - - # add duplicated permission - mongo_store_partition.add_permission(obj_write_permission) - find_res_3 = permissions_collection.find_one({"_id": obj.id}) - assert len(find_res_3["permissions"]) == 2 - assert find_res_3["permissions"] == find_res_2["permissions"] - - # remove the write permission - mongo_store_partition.remove_permission(obj_write_permission) - find_res_4 = permissions_collection.find_one({"_id": obj.id}) - assert len(find_res_4["permissions"]) == 1 - assert find_res_1["permissions"] == { - obj_read_permission.permission_string, - } - - # remove a non-existent permission - with pytest.raises(SyftException): - mongo_store_partition.remove_permission( - ActionObjectPermission( - uid=obj.id, - permission=ActionPermission.OWNER, - credentials=root_verify_key, - ) - ) - find_res_5 = permissions_collection.find_one({"_id": obj.id}) - assert len(find_res_5["permissions"]) == 1 - assert find_res_1["permissions"] == { - obj_read_permission.permission_string, - } - - # there is only one permission object - assert permissions_collection.count_documents({}) == 1 - - # add permissions in a loop - new_permissions = [] - repeats = 5 - for idx in range(1, repeats + 1): - new_obj = MockSyftObject(data=idx) - new_obj_read_permission = ActionObjectPermission( - uid=new_obj.id, - permission=ActionPermission.READ, - credentials=root_verify_key, - ) - new_permissions.append(new_obj_read_permission) - mongo_store_partition.add_permission(new_obj_read_permission) - assert permissions_collection.count_documents({}) == 1 + idx - - # remove all the permissions added in the loop - for permission in new_permissions: - mongo_store_partition.remove_permission(permission) - - assert permissions_collection.count_documents({}) == 1 - - -def test_mongo_store_partition_add_remove_storage_permission( - root_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - """ - Test the add_storage_permission and remove_storage_permission functions of MongoStorePartition - """ - - obj = MockSyftObject(data=1) - - storage_permission = StoragePermission( - uid=obj.id, - server_uid=UID(), - ) - assert not mongo_store_partition.has_storage_permission(storage_permission) - mongo_store_partition.add_storage_permission(storage_permission) - assert mongo_store_partition.has_storage_permission(storage_permission) - mongo_store_partition.remove_storage_permission(storage_permission) - assert not mongo_store_partition.has_storage_permission(storage_permission) - - obj2 = MockSyftObject(data=1) - mongo_store_partition.set(root_verify_key, obj2, add_storage_permission=False) - storage_permission3 = StoragePermission( - uid=obj2.id, server_uid=mongo_store_partition.server_uid - ) - assert not mongo_store_partition.has_storage_permission(storage_permission3) - - obj3 = MockSyftObject(data=1) - mongo_store_partition.set(root_verify_key, obj3, add_storage_permission=True) - storage_permission4 = StoragePermission( - uid=obj3.id, server_uid=mongo_store_partition.server_uid - ) - assert mongo_store_partition.has_storage_permission(storage_permission4) - - -def test_mongo_store_partition_add_permissions( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - permissions_collection: MongoCollection = mongo_store_partition.permissions.ok() - obj = MockSyftObject(data=1) - - # add multiple permissions for the first object - permission_1 = ActionObjectPermission( - uid=obj.id, permission=ActionPermission.WRITE, credentials=root_verify_key - ) - permission_2 = ActionObjectPermission( - uid=obj.id, permission=ActionPermission.OWNER, credentials=root_verify_key - ) - permission_3 = ActionObjectPermission( - uid=obj.id, permission=ActionPermission.READ, credentials=guest_verify_key - ) - permissions: list[ActionObjectPermission] = [ - permission_1, - permission_2, - permission_3, - ] - mongo_store_partition.add_permissions(permissions) - - # check if the permissions have been added properly - assert permissions_collection.count_documents({}) == 1 - find_res = permissions_collection.find_one({"_id": obj.id}) - assert find_res is not None - assert len(find_res["permissions"]) == 3 - - # add permissions for the second object - obj_2 = MockSyftObject(data=2) - permission_4 = ActionObjectPermission( - uid=obj_2.id, permission=ActionPermission.READ, credentials=root_verify_key - ) - permission_5 = ActionObjectPermission( - uid=obj_2.id, permission=ActionPermission.WRITE, credentials=root_verify_key - ) - mongo_store_partition.add_permissions([permission_4, permission_5]) - - assert permissions_collection.count_documents({}) == 2 - find_res_2 = permissions_collection.find_one({"_id": obj_2.id}) - assert find_res_2 is not None - assert len(find_res_2["permissions"]) == 2 - - -@pytest.mark.parametrize("permission", PERMISSIONS) -def test_mongo_store_partition_has_permission( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, - permission: ActionObjectPermission, -) -> None: - hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - - res = mongo_store_partition.init_store() - assert res.is_ok() - - # root permission - obj = MockSyftObject(data=1) - permission_root = permission(uid=obj.id, credentials=root_verify_key) - permission_client = permission(uid=obj.id, credentials=guest_verify_key) - permission_hacker = permission(uid=obj.id, credentials=hacker_verify_key) - mongo_store_partition.add_permission(permission_root) - # only the root user has access to this permission - assert mongo_store_partition.has_permission(permission_root) - assert not mongo_store_partition.has_permission(permission_client) - assert not mongo_store_partition.has_permission(permission_hacker) - - # client permission for another object - obj_2 = MockSyftObject(data=2) - permission_client_2 = permission(uid=obj_2.id, credentials=guest_verify_key) - permission_root_2 = permission(uid=obj_2.id, credentials=root_verify_key) - permisson_hacker_2 = permission(uid=obj_2.id, credentials=hacker_verify_key) - mongo_store_partition.add_permission(permission_client_2) - # the root (admin) and guest client should have this permission - assert mongo_store_partition.has_permission(permission_root_2) - assert mongo_store_partition.has_permission(permission_client_2) - assert not mongo_store_partition.has_permission(permisson_hacker_2) - - # remove permissions - mongo_store_partition.remove_permission(permission_root) - assert not mongo_store_partition.has_permission(permission_root) - assert not mongo_store_partition.has_permission(permission_client) - assert not mongo_store_partition.has_permission(permission_hacker) - - mongo_store_partition.remove_permission(permission_client_2) - assert not mongo_store_partition.has_permission(permission_root_2) - assert not mongo_store_partition.has_permission(permission_client_2) - assert not mongo_store_partition.has_permission(permisson_hacker_2) - - -def test_mongo_store_partition_permissions_set( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - """ - Test the permissions functionalities when using MongoStorePartition._set function - """ - hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - res = mongo_store_partition.init_store() - assert res.is_ok() - - # set the object to mongo_store_partition.collection - obj = MockSyftObject(data=1) - res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_ok() - assert res.ok() == obj - - # check if the corresponding permissions has been added to the permissions - # collection after the root client claim it - pemissions_collection = mongo_store_partition.permissions.ok() - assert isinstance(pemissions_collection, MongoCollection) - permissions = pemissions_collection.find_one({"_id": obj.id}) - assert permissions is not None - assert isinstance(permissions["permissions"], set) - assert len(permissions["permissions"]) == 4 - for permission in PERMISSIONS: - assert mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=root_verify_key) - ) - - # the hacker tries to set duplicated object but should not be able to claim it - res_2 = mongo_store_partition.set(guest_verify_key, obj, ignore_duplicates=True) - assert res_2.is_ok() - for permission in PERMISSIONS: - assert not mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=hacker_verify_key) - ) - assert mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=root_verify_key) - ) - - -def test_mongo_store_partition_permissions_get_all( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - # set several objects for the root and guest client - num_root_objects: int = 5 - num_guest_objects: int = 3 - for i in range(num_root_objects): - obj = MockSyftObject(data=i) - mongo_store_partition.set( - credentials=root_verify_key, obj=obj, ignore_duplicates=False - ) - for i in range(num_guest_objects): - obj = MockSyftObject(data=i) - mongo_store_partition.set( - credentials=guest_verify_key, obj=obj, ignore_duplicates=False - ) - - assert ( - len(mongo_store_partition.all(root_verify_key).ok()) - == num_root_objects + num_guest_objects - ) - assert len(mongo_store_partition.all(guest_verify_key).ok()) == num_guest_objects - assert len(mongo_store_partition.all(hacker_verify_key).ok()) == 0 - - -def test_mongo_store_partition_permissions_delete( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - collection: MongoCollection = mongo_store_partition.collection.ok() - pemissions_collection: MongoCollection = mongo_store_partition.permissions.ok() - hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - - # the root client set an object - obj = MockSyftObject(data=1) - mongo_store_partition.set( - credentials=root_verify_key, obj=obj, ignore_duplicates=False - ) - qk: QueryKey = mongo_store_partition.settings.store_key.with_obj(obj) - # guest or hacker can't delete it - assert not mongo_store_partition.delete(guest_verify_key, qk).is_ok() - assert not mongo_store_partition.delete(hacker_verify_key, qk).is_ok() - # only the root client can delete it - assert mongo_store_partition.delete(root_verify_key, qk).is_ok() - # check if the object and its permission have been deleted - assert collection.count_documents({}) == 0 - assert pemissions_collection.count_documents({}) == 0 - - # the guest client set an object - obj_2 = MockSyftObject(data=2) - mongo_store_partition.set( - credentials=guest_verify_key, obj=obj_2, ignore_duplicates=False - ) - qk_2: QueryKey = mongo_store_partition.settings.store_key.with_obj(obj_2) - # the hacker can't delete it - assert not mongo_store_partition.delete(hacker_verify_key, qk_2).is_ok() - # the guest client can delete it - assert mongo_store_partition.delete(guest_verify_key, qk_2).is_ok() - assert collection.count_documents({}) == 0 - assert pemissions_collection.count_documents({}) == 0 - - # the guest client set another object - obj_3 = MockSyftObject(data=3) - mongo_store_partition.set( - credentials=guest_verify_key, obj=obj_3, ignore_duplicates=False - ) - qk_3: QueryKey = mongo_store_partition.settings.store_key.with_obj(obj_3) - # the root client also has the permission to delete it - assert mongo_store_partition.delete(root_verify_key, qk_3).is_ok() - assert collection.count_documents({}) == 0 - assert pemissions_collection.count_documents({}) == 0 - - -def test_mongo_store_partition_permissions_update( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - # the root client set an object - obj = MockSyftObject(data=1) - mongo_store_partition.set( - credentials=root_verify_key, obj=obj, ignore_duplicates=False - ) - assert len(mongo_store_partition.all(credentials=root_verify_key).ok()) == 1 - - qk: QueryKey = mongo_store_partition.settings.store_key.with_obj(obj) - permsissions: MongoCollection = mongo_store_partition.permissions.ok() - repeats = 5 - - for v in range(repeats): - # the guest client should not have permission to update obj - obj_new = MockSyftObject(data=v) - res = mongo_store_partition.update( - credentials=guest_verify_key, qk=qk, obj=obj_new - ) - assert res.is_err() - # the root client has the permission to update obj - res = mongo_store_partition.update( - credentials=root_verify_key, qk=qk, obj=obj_new - ) - assert res.is_ok() - # the id of the object in the permission collection should not be changed - assert permsissions.find_one(qk.as_dict_mongo)["_id"] == obj.id From 064bb2d96b55ae072cec2d62c1d18d38691f62f0 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 10 Sep 2024 13:48:05 +0200 Subject: [PATCH 119/197] fix all unit tests --- packages/syft/src/syft/service/dataset/dataset_stash.py | 6 +++++- packages/syft/src/syft/service/job/job_stash.py | 2 +- .../src/syft/service/notification/notification_stash.py | 2 +- packages/syft/src/syft/store/db/stash.py | 7 +++++++ packages/syft/tests/syft/dataset/dataset_stash_test.py | 2 +- 5 files changed, 15 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index 3c5a1f203fd..b91ea5a0c98 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -33,13 +33,17 @@ def get_all( sort_order: str | None = None, limit: int | None = None, offset: int | None = None, + filters: dict | None = None, ) -> list[Dataset]: # TODO standardize soft delete and move to ObjectStash.get_all + default_filters = {"to_be_deleted": False} + filters = filters or {} + filters.update(default_filters) return ( super() .get_all( credentials=credentials, - filters={"to_be_deleted": False}, + filters=filters, has_permission=has_permission, order_by=order_by, sort_order=sort_order, diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 5fe04661b33..fb85cd810d7 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -778,7 +778,7 @@ def get_by_user_code_id( ) -> list[Job]: return self.get_all( credentials=credentials, - filter={"user_code_id": user_code_id}, + filters={"user_code_id": user_code_id}, ).unwrap() @as_result(StashException) diff --git a/packages/syft/src/syft/service/notification/notification_stash.py b/packages/syft/src/syft/service/notification/notification_stash.py index e343d31adfc..eba855856be 100644 --- a/packages/syft/src/syft/service/notification/notification_stash.py +++ b/packages/syft/src/syft/service/notification/notification_stash.py @@ -43,7 +43,7 @@ def get_all_for_verify_key( raise AttributeError("verify_key must be of type SyftVerifyKey or str") return self.get_all( credentials, - filters={"fromuser_verify_key": verify_key}, + filters={"from_user_verify_key": verify_key}, ).unwrap() @as_result(StashException) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index abe5ba54d71..4900b4400ac 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -218,11 +218,18 @@ def _get_field_filter( def get_index( self, credentials: SyftVerifyKey, index: int, has_permission: bool = False ) -> StashT: + order_by, sort_order = self.object_type.__order_by__ + if index < 0: + index = -1 - index + sort_order = "desc" if sort_order == "asc" else "asc" + items = self.get_all( credentials, has_permission=has_permission, limit=1, offset=index, + order_by=order_by, + sort_order=sort_order, ).unwrap() if len(items) == 0: diff --git a/packages/syft/tests/syft/dataset/dataset_stash_test.py b/packages/syft/tests/syft/dataset/dataset_stash_test.py index 394b03e05c3..bfec6e00895 100644 --- a/packages/syft/tests/syft/dataset/dataset_stash_test.py +++ b/packages/syft/tests/syft/dataset/dataset_stash_test.py @@ -43,5 +43,5 @@ def test_dataset_search_action_ids( # passing random object random_obj = object() - with pytest.raises(AttributeError): + with pytest.raises(ValueError): result = mock_dataset_stash.search_action_ids(root_verify_key, uid=random_obj) From d92a9f0c77e8cbbd306cd54d6632441779ff3984 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 10 Sep 2024 13:56:04 +0200 Subject: [PATCH 120/197] refactor get_one, add validation to limit offset --- packages/syft/src/syft/store/db/query.py | 12 +++++- packages/syft/src/syft/store/db/stash.py | 48 ++++++++++++++++++------ 2 files changed, 46 insertions(+), 14 deletions(-) diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py index d43eb3ca102..348daeda0f5 100644 --- a/packages/syft/src/syft/store/db/query.py +++ b/packages/syft/src/syft/store/db/query.py @@ -155,12 +155,20 @@ def order_by( def limit(self, limit: int | None) -> Self: """Add a limit clause to the query.""" - if limit is not None: - self.stmt = self.stmt.limit(limit) + if limit is None: + return self + + if limit < 0: + raise ValueError("Limit must be a positive integer") + self.stmt = self.stmt.limit(limit) + return self def offset(self, offset: int) -> Self: """Add an offset clause to the query.""" + if offset < 0: + raise ValueError("Offset must be a positive integer") + self.stmt = self.stmt.offset(offset) return self diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index abe5ba54d71..39e5366e29b 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -666,18 +666,43 @@ def get_one( sort_order: str | None = None, offset: int = 0, ) -> StashT: - result = self.get_all( - credentials=credentials, - filters=filters, - has_permission=has_permission, - order_by=order_by, - sort_order=sort_order, - limit=1, - offset=offset, - ).unwrap() - if len(result) == 0: + """ + Get first objects from the stash, optionally filtered. + + Args: + credentials (SyftVerifyKey): credentials of the user + filters (dict[str, Any] | None, optional): dictionary of filters, + where the key is the field name and the value is the filter value. + Operators other than equals can be used in the key, + e.g. {"name": "Bob", "friends__contains": "Alice"}. Defaults to None. + has_permission (bool, optional): If True, overrides the permission check. + Defaults to False. + order_by (str | None, optional): If provided, the results will be ordered by this field. + If not provided, the default order and field defined on the SyftObject.__order_by__ are used. + Defaults to None. + sort_order (str | None, optional): "asc" or "desc" If not defined, + the default order defined on the SyftObject.__order_by__ is used. + Defaults to None. + offset (int, optional): offset the results. Defaults to 0. + + Returns: + list[StashT]: list of objects. + """ + query = self.query() + + if not has_permission: + role = self.get_role(credentials) + query = query.with_permissions(credentials, role) + + for field_name, operator, field_value in parse_filters(filters): + query = query.filter(field_name, operator, field_value) + + query = query.order_by(order_by, sort_order).offset(offset) + result = query.execute(self.session).first() + if result is None: raise NotFoundException(f"{self.object_type.__name__}: not found") - return result[0] + + return self.row_as_obj(result) @as_result(StashException) def get_all( @@ -724,5 +749,4 @@ def get_all( query = query.order_by(order_by, sort_order).limit(limit).offset(offset) result = query.execute(self.session).all() - return [self.row_as_obj(row) for row in result] From 00d69acf7dac7d3b170251aa290e2289efda04fe Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 10 Sep 2024 14:05:36 +0200 Subject: [PATCH 121/197] fix dataset get_all --- .../src/syft/service/dataset/dataset_stash.py | 4 ++ packages/syft/src/syft/store/db/stash.py | 45 ------------------- 2 files changed, 4 insertions(+), 45 deletions(-) diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index b91ea5a0c98..b56f4743604 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -39,6 +39,10 @@ def get_all( default_filters = {"to_be_deleted": False} filters = filters or {} filters.update(default_filters) + + if offset is None: + offset = 0 + return ( super() .get_all( diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 018d2b4dadf..ddfd37edd55 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -6,7 +6,6 @@ # third party import sqlalchemy as sa -from sqlalchemy import Column from sqlalchemy import Row from sqlalchemy import Table from sqlalchemy import func @@ -270,50 +269,6 @@ def _get_permission_filter_from_permisson( self.table.c.permissions.contains(compound_permission_string), ) - def _apply_limit_offset( - self, - stmt: T, - limit: int | None = None, - offset: int | None = None, - ) -> T: - if offset is not None: - stmt = stmt.offset(offset) - if limit is not None: - stmt = stmt.limit(limit) - return stmt - - def _get_order_by_col(self, order_by: str, sort_order: str | None = None) -> Column: - # TODO connect+rename created_date to created_at - if sort_order is None: - sort_order = "asc" - - if order_by == "id": - col = self.table.c.id - if order_by == "created_date" or order_by == "_created_at": - col = self.table.c._created_at - else: - col = self.table.c.fields[order_by] - - return col.desc() if sort_order.lower() == "desc" else col.asc() - - def _apply_order_by( - self, - stmt: T, - order_by: str | None = None, - sort_order: str | None = None, - ) -> T: - if order_by is None: - order_by, default_sort_order = self.object_type.__order_by__ - sort_order = sort_order or default_sort_order - - order_by_col = self._get_order_by_col(order_by, sort_order) - - if order_by == "id": - return stmt.order_by(order_by_col) - else: - secondary_order_by = self._get_order_by_col("id", sort_order) - return stmt.order_by(order_by_col, secondary_order_by) - def _apply_permission_filter( self, stmt: T, From 7ee43e91519c142d2c1b02c8dedcaddc6500a955 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 10 Sep 2024 14:09:58 +0200 Subject: [PATCH 122/197] implement reset password flow for python and fix it --- packages/syft/src/syft/client/client.py | 43 ++++++++++++++++++ .../syft/src/syft/service/user/user_stash.py | 2 +- packages/syft/src/syft/store/db/stash.py | 45 ------------------- .../tests/syft/users/user_service_test.py | 20 +++++++++ 4 files changed, 64 insertions(+), 46 deletions(-) diff --git a/packages/syft/src/syft/client/client.py b/packages/syft/src/syft/client/client.py index 692bd017565..d6fd164f44f 100644 --- a/packages/syft/src/syft/client/client.py +++ b/packages/syft/src/syft/client/client.py @@ -605,6 +605,49 @@ def login( result = post_process_result(result, unwrap_on_success=True) return result + def forgot_password( + self, + email: str, + ) -> SyftSigningKey | None: + credentials = {"email": email} + if self.proxy_target_uid: + obj = forward_message_to_proxy( + self.make_call, + proxy_target_uid=self.proxy_target_uid, + path="forgot_password", + kwargs=credentials, + ) + else: + response = self.server.services.user.forgot_password( + context=ServerServiceContext(server=self.server), email=email + ) + obj = post_process_result(response, unwrap_on_success=True) + + return obj + + def reset_password( + self, + token: str, + new_password: str, + ) -> SyftSigningKey | None: + payload = {"token": token, "new_password": new_password} + if self.proxy_target_uid: + obj = forward_message_to_proxy( + self.make_call, + proxy_target_uid=self.proxy_target_uid, + path="reset_password", + kwargs=payload, + ) + else: + response = self.server.services.user.reset_password( + context=ServerServiceContext(server=self.server), + token=token, + new_password=new_password, + ) + obj = post_process_result(response, unwrap_on_success=True) + + return obj + def register(self, new_user: UserCreate) -> SyftSigningKey | None: if self.proxy_target_uid: response = forward_message_to_proxy( diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index 41f561ec9af..1a4a97fa3a1 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -47,7 +47,7 @@ def get_by_reset_token(self, credentials: SyftVerifyKey, token: str) -> User: return self.get_one( credentials=credentials, filters={"reset_token": token}, - ) + ).unwrap() @as_result(StashException, NotFoundException) def get_by_email(self, credentials: SyftVerifyKey, email: str) -> User: diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 4900b4400ac..64c97139741 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -6,7 +6,6 @@ # third party import sqlalchemy as sa -from sqlalchemy import Column from sqlalchemy import Row from sqlalchemy import Table from sqlalchemy import func @@ -270,50 +269,6 @@ def _get_permission_filter_from_permisson( self.table.c.permissions.contains(compound_permission_string), ) - def _apply_limit_offset( - self, - stmt: T, - limit: int | None = None, - offset: int | None = None, - ) -> T: - if offset is not None: - stmt = stmt.offset(offset) - if limit is not None: - stmt = stmt.limit(limit) - return stmt - - def _get_order_by_col(self, order_by: str, sort_order: str | None = None) -> Column: - # TODO connect+rename created_date to created_at - if sort_order is None: - sort_order = "asc" - - if order_by == "id": - col = self.table.c.id - if order_by == "created_date" or order_by == "_created_at": - col = self.table.c._created_at - else: - col = self.table.c.fields[order_by] - - return col.desc() if sort_order.lower() == "desc" else col.asc() - - def _apply_order_by( - self, - stmt: T, - order_by: str | None = None, - sort_order: str | None = None, - ) -> T: - if order_by is None: - order_by, default_sort_order = self.object_type.__order_by__ - sort_order = sort_order or default_sort_order - - order_by_col = self._get_order_by_col(order_by, sort_order) - - if order_by == "id": - return stmt.order_by(order_by_col) - else: - secondary_order_by = self._get_order_by_col("id", sort_order) - return stmt.order_by(order_by_col, secondary_order_by) - def _apply_permission_filter( self, stmt: T, diff --git a/packages/syft/tests/syft/users/user_service_test.py b/packages/syft/tests/syft/users/user_service_test.py index efa99ed9a50..95b05bb6be7 100644 --- a/packages/syft/tests/syft/users/user_service_test.py +++ b/packages/syft/tests/syft/users/user_service_test.py @@ -10,6 +10,7 @@ # syft absolute from syft import orchestra +from syft.client.client import SyftClient from syft.server.credentials import SyftVerifyKey from syft.server.worker import Worker from syft.service.context import AuthedServiceContext @@ -773,3 +774,22 @@ def test_userservice_update_via_client_with_mixed_args(): email="new_user@openmined.org", password="newpassword" ) assert user_client.account.name == "User name" + + +def test_reset_password(): + server = orchestra.launch(name="datasite-test", reset=True) + + datasite_client = server.login(email="info@openmined.org", password="changethis") + datasite_client.register( + email="new_syft_user@openmined.org", + password="verysecurepassword", + password_verify="verysecurepassword", + name="New User", + ) + guest_client: SyftClient = server.login_as_guest() + guest_client.forgot_password(email="new_syft_user@openmined.org") + temp_token = datasite_client.users.request_password_reset( + datasite_client.notifications[-1].linked_obj.resolve.id + ) + guest_client.reset_password(token=temp_token, new_password="Password123") + server.login(email="new_syft_user@openmined.org", password="Password123") From b6c2381e90f52b23f5285eb0fec928dd996a7c46 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 10 Sep 2024 14:21:44 +0200 Subject: [PATCH 123/197] fix some linting errors --- .../src/syft/service/code_history/code_history_service.py | 2 +- .../syft/src/syft/service/code_history/code_history_stash.py | 2 +- packages/syft/src/syft/service/dataset/dataset_service.py | 4 +--- packages/syft/src/syft/service/dataset/dataset_stash.py | 4 ++-- packages/syft/src/syft/service/migration/migration_service.py | 1 - packages/syft/src/syft/service/network/network_service.py | 1 - packages/syft/src/syft/service/notification/notifications.py | 2 +- packages/syft/src/syft/service/project/project.py | 2 +- packages/syft/src/syft/service/settings/settings_service.py | 2 +- packages/syft/src/syft/service/worker/image_registry_stash.py | 2 +- packages/syft/src/syft/store/db/query.py | 4 ++-- packages/syft/src/syft/store/db/stash.py | 2 +- 12 files changed, 12 insertions(+), 16 deletions(-) diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index a3d06bdb4fa..0943978759f 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -46,7 +46,7 @@ def submit_version( if isinstance(code, SubmitUserCode): code = context.server.services.user_code._submit( context=context, submit_code=code - ) + ).unwrap() try: code_history = self.stash.get_by_service_func_name_and_verify_key( diff --git a/packages/syft/src/syft/service/code_history/code_history_stash.py b/packages/syft/src/syft/service/code_history/code_history_stash.py index 8921f668095..14ca4e89720 100644 --- a/packages/syft/src/syft/service/code_history/code_history_stash.py +++ b/packages/syft/src/syft/service/code_history/code_history_stash.py @@ -31,7 +31,7 @@ def get_by_service_func_name( return self.get_all( credentials=credentials, filters={"service_func_name": service_func_name}, - ) + ).unwrap() @as_result(StashException) def get_by_verify_key( diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py index 8e4cd59a193..de6fdf9cb90 100644 --- a/packages/syft/src/syft/service/dataset/dataset_service.py +++ b/packages/syft/src/syft/service/dataset/dataset_service.py @@ -116,13 +116,11 @@ def get_all( page_index: int | None = 0, ) -> DatasetPageView | DictTuple[str, Dataset]: """Get a Dataset""" - datasets = self.stash.get_all(context.credentials).unwrap() + datasets = self.stash.get_all_active(context.credentials).unwrap() for dataset in datasets: if context.server is not None: dataset.server_uid = context.server.id - if dataset.to_be_deleted: - datasets.remove(dataset) return _paginate_dataset_collection( datasets=datasets, page_size=page_size, page_index=page_index diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index b56f4743604..bc7af00afdc 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -19,13 +19,13 @@ def get_by_name(self, credentials: SyftVerifyKey, name: str) -> Dataset: @as_result(StashException) def search_action_ids(self, credentials: SyftVerifyKey, uid: UID) -> list[Dataset]: - return self.get_all( + return self.get_all_active( credentials=credentials, filters={"action_ids__contains": uid}, ).unwrap() @as_result(StashException) - def get_all( + def get_all_active( self, credentials: SyftVerifyKey, has_permission: bool = False, diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index b9050ba4b89..f150f5e72ed 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -118,7 +118,6 @@ def get_all_store_metadata( return self._get_all_store_metadata( context, document_store_object_types=document_store_object_types, - include_action_store=include_action_store, ).unwrap() @as_result(SyftException) diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 51e729e3d6e..28f28db75a8 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -486,7 +486,6 @@ def set_reverse_tunnel_config( def delete_peer_by_id(self, context: AuthedServiceContext, uid: UID) -> SyftSuccess: """Delete Server Peer""" peer_to_delete = self.stash.get_by_uid(context.credentials, uid).unwrap() - peer_to_delete = cast(ServerPeer, peer_to_delete) server_side_type = cast(ServerType, context.server.server_type) if server_side_type.value == ServerType.GATEWAY.value: diff --git a/packages/syft/src/syft/service/notification/notifications.py b/packages/syft/src/syft/service/notification/notifications.py index 2b176a65af3..3a552df424b 100644 --- a/packages/syft/src/syft/service/notification/notifications.py +++ b/packages/syft/src/syft/service/notification/notifications.py @@ -61,7 +61,7 @@ class Notification(SyftObject): linked_obj: LinkedObject | None = None notifier_types: list[NOTIFIERS] = [] email_template: type[EmailTemplate] | None = None - replies: list[ReplyNotification] | None = [] + replies: list[ReplyNotification] = [] __attr_searchable__ = [ "from_user_verify_key", diff --git a/packages/syft/src/syft/service/project/project.py b/packages/syft/src/syft/service/project/project.py index 9876a597198..5ed21f5007d 100644 --- a/packages/syft/src/syft/service/project/project.py +++ b/packages/syft/src/syft/service/project/project.py @@ -1252,7 +1252,7 @@ def send(self, return_all_projects: bool = False) -> Project | list[Project]: if return_all_projects: return list(projects_map.values()) - return projects_map[leader.id] + return projects_map[leader.id] # type: ignore def _pre_submit_checks(self, clients: list[SyftClient]) -> bool: try: diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py index 868237afc25..b067644342f 100644 --- a/packages/syft/src/syft/service/settings/settings_service.py +++ b/packages/syft/src/syft/service/settings/settings_service.py @@ -180,7 +180,7 @@ def set_server_side_type_dangerous( ).unwrap() if len(current_settings) > 0: new_settings = current_settings[0] - new_settings.server_side_type = server_side_type + new_settings.server_side_type = ServerSideType(server_side_type) updated_settings = self.stash.update( context.credentials, new_settings ).unwrap() diff --git a/packages/syft/src/syft/service/worker/image_registry_stash.py b/packages/syft/src/syft/service/worker/image_registry_stash.py index 3ca7fe1d03f..cfb71b9848b 100644 --- a/packages/syft/src/syft/service/worker/image_registry_stash.py +++ b/packages/syft/src/syft/service/worker/image_registry_stash.py @@ -19,7 +19,7 @@ def get_by_url( self, credentials: SyftVerifyKey, url: str, - ) -> SyftImageRegistry | None: + ) -> SyftImageRegistry: return self.get_one( credentials=credentials, filters={"url": url}, diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py index 348daeda0f5..bec3cf5b1ad 100644 --- a/packages/syft/src/syft/store/db/query.py +++ b/packages/syft/src/syft/store/db/query.py @@ -70,13 +70,13 @@ def with_permissions( if role in (ServiceRole.ADMIN, ServiceRole.DATA_OWNER): return self - permission = ActionObjectPermission( + ao_permission = ActionObjectPermission( uid=UID(), # dummy uid, we just need the permission string credentials=credentials, permission=permission, ) - permission_clause = self._make_permissions_clause(permission) + permission_clause = self._make_permissions_clause(ao_permission) self.stmt = self.stmt.where(permission_clause) return self diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index ddfd37edd55..de5849b2723 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -554,7 +554,7 @@ def _get_permissions_for_uid(self, uid: UID) -> set[str]: stmt = select(self.table.c.permissions).where(self.table.c.id == uid) result = self.session.execute(stmt).scalar_one_or_none() if result is None: - return NotFoundException(f"No permissions found for uid: {uid}") + raise NotFoundException(f"No permissions found for uid: {uid}") return set(result) @as_result(StashException) From 5998d58e19da82802227c8dd75073b1b8aafae91 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 10 Sep 2024 16:22:41 +0200 Subject: [PATCH 124/197] fix is_unique, get_role --- .../syft/src/syft/custom_worker/config.py | 5 + packages/syft/src/syft/serde/json_serde.py | 21 +- .../src/syft/service/action/action_object.py | 1 + .../src/syft/service/worker/worker_image.py | 14 +- packages/syft/src/syft/store/db/query.py | 92 ++++- packages/syft/src/syft/store/db/stash.py | 362 +++++++++--------- 6 files changed, 297 insertions(+), 198 deletions(-) diff --git a/packages/syft/src/syft/custom_worker/config.py b/packages/syft/src/syft/custom_worker/config.py index 1cbdb44c488..6410c990eac 100644 --- a/packages/syft/src/syft/custom_worker/config.py +++ b/packages/syft/src/syft/custom_worker/config.py @@ -14,6 +14,7 @@ # relative from ..serde.serializable import serializable +from ..serde.serialize import _serialize from ..service.response import SyftSuccess from ..types.base import SyftBaseModel from ..types.errors import SyftException @@ -83,6 +84,10 @@ def merged_custom_cmds(self, sep: str = ";") -> str: class WorkerConfig(SyftBaseModel): pass + def hash(self) -> str: + _bytes = _serialize(self, to_bytes=True, for_hashing=True) + return sha256(_bytes).digest().hex() + @serializable(canonical_name="CustomWorkerConfig", version=1) class CustomWorkerConfig(WorkerConfig): diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index 42d5ee1b2f3..000ff04085a 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -35,8 +35,8 @@ JSON_VERSION_FIELD = "__version__" JSON_DATA_FIELD = "data" - -Json = str | int | float | bool | None | list["Json"] | dict[str, "Json"] +JsonPrimitive = str | int | float | bool | None +Json = JsonPrimitive | list["Json"] | dict[str, "Json"] class JSONSerdeError(SyftException): @@ -179,10 +179,7 @@ def _serialize_pydantic_to_json(obj: pydantic.BaseModel) -> dict[str, Json]: for key, type_ in obj.model_fields.items(): if key in all_exclude_attrs: continue - try: - result[key] = serialize_json(getattr(obj, key), type_.annotation) - except Exception as e: - raise ValueError(f"Failed to serialize attribute {key}: {e}") + result[key] = serialize_json(getattr(obj, key), type_.annotation) result = _serialize_searchable_attrs(obj, result, raise_errors=False) @@ -206,7 +203,7 @@ def _serialize_searchable_attrs( obj: pydantic.BaseModel, obj_dict: dict[str, Json], raise_errors: bool = True ) -> dict[str, Json]: """ - Add searchable attrs to the serialized object dict, if they are not already present. + Add searchable attrs and unique attrs to the serialized object dict, if they are not already present. Needed for adding non-field attributes (like @property) Args: @@ -222,7 +219,10 @@ def _serialize_searchable_attrs( dict[str, Json]: Serialized object dict including searchable attributes. """ searchable_attrs: list[str] = getattr(obj, "__attr_searchable__", []) - for attr in searchable_attrs: + unique_attrs: list[str] = getattr(obj, "__attr_unique__", []) + + attrs_to_add = set(searchable_attrs) | set(unique_attrs) + for attr in attrs_to_add: if attr not in obj_dict: try: value = getattr(obj, attr) @@ -451,3 +451,8 @@ def deserialize_json(value: Json, annotation: Any = None) -> Any: return _deserialize_from_json_bytes(value) else: raise ValueError(f"Cannot deserialize {value} to {annotation}") + + +def is_json_primitive(value: Any) -> bool: + serialized = serialize_json(value, validate=False) + return isinstance(serialized, JsonPrimitive) # type: ignore diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index d7e566b733a..ec4cbc93f0f 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -256,6 +256,7 @@ class ActionObjectPointer: "deleted_date", # syft "to_mongo", # syft 🟡 TODO 23: Add composeable / inheritable object passthrough attrs "__attr_searchable__", # syft + "__attr_unique__", # syft "__canonical_name__", # syft "__version__", # syft "__args__", # pydantic diff --git a/packages/syft/src/syft/service/worker/worker_image.py b/packages/syft/src/syft/service/worker/worker_image.py index e5c110a6e0e..85eaf38d00f 100644 --- a/packages/syft/src/syft/service/worker/worker_image.py +++ b/packages/syft/src/syft/service/worker/worker_image.py @@ -17,8 +17,14 @@ class SyftWorkerImage(SyftObject): __canonical_name__ = "SyftWorkerImage" __version__ = SYFT_OBJECT_VERSION_1 - __attr_unique__ = ["config"] - __attr_searchable__ = ["config", "image_hash", "created_by"] + __attr_unique__ = ["config_hash"] + __attr_searchable__ = [ + "config", + "image_hash", + "created_by", + "config_hash", + ] + __repr_attrs__ = [ "image_identifier", "image_hash", @@ -35,6 +41,10 @@ class SyftWorkerImage(SyftObject): image_hash: str | None = None built_at: DateTime | None = None + @property + def config_hash(self) -> str: + return self.config.hash() + @property def is_built(self) -> bool: """Returns True if the image has been built or is prebuilt.""" diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py index 348daeda0f5..29955165cb1 100644 --- a/packages/syft/src/syft/store/db/query.py +++ b/packages/syft/src/syft/store/db/query.py @@ -8,6 +8,7 @@ # third party import sqlalchemy as sa from sqlalchemy import Column +from sqlalchemy import Dialect from sqlalchemy import Result from sqlalchemy import Select from sqlalchemy import Table @@ -43,6 +44,23 @@ def _get_table(self, object_type: type[SyftObject]) -> Table: raise ValueError(f"Table for {cname} not found") return Base.metadata.tables[cname] + @staticmethod + def get_query_class(dialect: str | Dialect) -> "type[Query]": + if isinstance(dialect, Dialect): + dialect = dialect.name + + if dialect == "sqlite": + return SQLiteQuery + elif dialect == "postgresql": + return PostgresQuery + else: + raise ValueError(f"Unsupported dialect {dialect}") + + @classmethod + def create(cls, object_type: type[SyftObject], dialect: str | Dialect) -> "Query": + query_class = cls.get_query_class(dialect) + return query_class(object_type) + def execute(self, session: Session) -> Result: """Execute the query using the given session.""" return session.execute(self.stmt) @@ -99,6 +117,73 @@ def filter(self, field: str, operator: str | FilterOperator, value: Any) -> Self Returns: Self: The query object with the filter applied """ + filter = self._create_filter_clause(self.table, field, operator, value) + self.stmt = self.stmt.where(filter) + return self + + def filter_and(self, *filters: tuple[str, str | FilterOperator, Any]) -> Self: + """Add filters to the query using an AND clause. + + example usage: + Query(User).filter_and( + ("name", "eq", "Alice"), + ("age", "eq", 30), + ) + + Args: + field (str): Field to filter on + operator (str): Operator to use for the filter + value (Any): Value to filter on + + Raises: + ValueError: If the operator is not supported + + Returns: + Self: The query object with the filter applied + """ + filter_clauses = [ + self._create_filter_clause(self.table, field, operator, value) + for field, operator, value in filters + ] + + self.stmt = self.stmt.where(sa.and_(*filter_clauses)) + return self + + def filter_or(self, *filters: tuple[str, str | FilterOperator, Any]) -> Self: + """Add filters to the query using an OR clause. + + example usage: + Query(User).filter_or( + ("name", "eq", "Alice"), + ("age", "eq", 30), + ) + + Args: + field (str): Field to filter on + operator (str): Operator to use for the filter + value (Any): Value to filter on + + Raises: + ValueError: If the operator is not supported + + Returns: + Self: The query object with the filter applied + """ + filter_clauses = [ + self._create_filter_clause(self.table, field, operator, value) + for field, operator, value in filters + ] + + self.stmt = self.stmt.where(sa.or_(*filter_clauses)) + return self + + def _create_filter_clause( + self, + table: Table, + field: str, + operator: str | FilterOperator, + value: Any, + ) -> sa.sql.elements.BinaryExpression: if isinstance(operator, str): try: operator = FilterOperator(operator.lower()) @@ -106,12 +191,9 @@ def filter(self, field: str, operator: str | FilterOperator, value: Any) -> Self raise ValueError(f"Filter operator {operator} not supported") if operator == FilterOperator.EQ: - filter = self._eq_filter(self.table, field, value) + return self._eq_filter(table, field, value) elif operator == FilterOperator.CONTAINS: - filter = self._contains_filter(self.table, field, value) - - self.stmt = self.stmt.where(filter) - return self + return self._contains_filter(table, field, value) def order_by( self, diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index ddfd37edd55..37f9fe6414b 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -16,6 +16,7 @@ # relative from ...serde.json_serde import deserialize_json +from ...serde.json_serde import is_json_primitive from ...serde.json_serde import serialize_json from ...server.credentials import SyftVerifyKey from ...service.action.action_permissions import ActionObjectEXECUTE @@ -32,9 +33,7 @@ from ...types.uid import UID from ..document_store_errors import NotFoundException from ..document_store_errors import StashException -from .query import PostgresQuery from .query import Query -from .query import SQLiteQuery from .schema import Base from .schema import create_table from .sqlite_db import DBManager @@ -99,6 +98,9 @@ def random(cls, **kwargs: dict) -> Self: stash.db.init_tables() return stash + def _is_sqlite(self) -> bool: + return self.db.engine.dialect.name == "sqlite" + @property def server_uid(self) -> UID: return self.db.server_uid @@ -111,14 +113,10 @@ def root_verify_key(self) -> SyftVerifyKey: def _data(self) -> list[StashT]: return self.get_all(self.root_verify_key, has_permission=True).unwrap() - def query(self) -> Query: - """Creates a query for this stash's object type.""" - if self.dialect.name == "sqlite": - return SQLiteQuery(self.object_type) - elif self.dialect.name == "postgresql": - return PostgresQuery(self.object_type) - else: - raise NotImplementedError(f"Query not implemented for {self.dialect.name}") + def query(self, object_type: type[SyftObject] | None = None) -> Query: + """Creates a query for this stash's object type and SQL dialect.""" + object_type = object_type or self.object_type + return Query.create(object_type, self.dialect) @as_result(StashException) def check_type(self, obj: T, type_: type) -> T: @@ -153,20 +151,25 @@ def is_unique(self, obj: StashT) -> bool: unique_fields = self.unique_fields if not unique_fields: return True + filters = [] - for filter_name in unique_fields: - field_value = getattr(obj, filter_name, None) + for field_name in unique_fields: + field_value = getattr(obj, field_name, None) + if not is_json_primitive(field_value): + raise StashException( + f"Cannot check uniqueness of non-primitive field {field_name}" + ) if field_value is None: continue - filt = self._get_field_filter( - field_name=filter_name, - # is the str cast correct? how to handle SyftVerifyKey? - field_value=str(field_value), - ) - filters.append(filt) + filters.append((field_name, "eq", field_value)) + + query = self.query() + query = query.filter_or( + *filters, + ) + + results = query.execute(self.session).all() - stmt = self.table.select().where(sa.or_(*filters)) - results = self.session.execute(stmt).all() if len(results) > 1: return False elif len(results) == 1: @@ -175,26 +178,20 @@ def is_unique(self, obj: StashT) -> bool: return True def exists(self, credentials: SyftVerifyKey, uid: UID) -> bool: + # TODO should be @as_result # TODO needs credentials check? # TODO use COUNT(*) instead of SELECT - stmt = self.table.select().where(self._get_field_filter("id", uid)) - result = self.session.execute(stmt).first() + query = self.query().filter("id", "eq", uid) + result = query.execute(self.session).first() return result is not None @as_result(SyftException, StashException, NotFoundException) def get_by_uid( self, credentials: SyftVerifyKey, uid: UID, has_permission: bool = False ) -> StashT: - query = self.query().filter("id", "eq", uid) - - if not has_permission: - role = self.get_role(credentials) - query = query.with_permissions(credentials, role) - - result = query.execute(self.session).first() - if result is None: - raise NotFoundException(f"{self.object_type.__name__}: {uid} not found") - return self.row_as_obj(result) + return self.get_one( + credentials=credentials, filters={"id": uid}, has_permission=has_permission + ).unwrap() def _get_field_filter( self, @@ -240,19 +237,21 @@ def row_as_obj(self, row: Row) -> StashT: return deserialize_json(row.fields) def get_role(self, credentials: SyftVerifyKey) -> ServiceRole: + # relative + from ...service.user.user import User + # TODO error handling if Base.metadata.tables.get("User") is None: # if User table does not exist, we assume the user is a guest # this happens when we create stashes in tests return ServiceRole.GUEST - user_table = Table("User", Base.metadata) - stmt = select(user_table.c.fields["role"]).where( - self._get_field_filter("verify_key", str(credentials), table=user_table), - ) - role = self.session.scalar(stmt) - if role is None: + + query = self.query(User).filter("verify_key", "eq", credentials) + user = query.execute(self.session).first() + if user is None: return ServiceRole.GUEST - return ServiceRole[role] + + return self.row_as_obj(user).role def _get_permission_filter_from_permisson( self, @@ -298,52 +297,6 @@ def _apply_permission_filter( ) return stmt - @as_result(StashException, NotFoundException) - def update( - self, - credentials: SyftVerifyKey, - obj: StashT, - has_permission: bool = False, - ) -> StashT: - """ - NOTE: We cannot do partial updates on the database, - because we are using computed fields that are not known to the DB: - - serialize_json will add computed fields to the JSON stored in the database - - If we update a single field in the JSON, the computed fields can get out of sync. - - To fix, we either need db-supported computed fields, or know in our ORM which fields should be re-computed. - """ - - if not self.allow_any_type: - self.check_type(obj, self.object_type).unwrap() - - # TODO has_permission is not used - if not self.is_unique(obj): - raise StashException(f"Some fields are not unique for {type(obj).__name__}") - - stmt = self.table.update().where(self._get_field_filter("id", obj.id)) - stmt = self._apply_permission_filter( - stmt, - credentials=credentials, - permission=ActionPermission.WRITE, - has_permission=has_permission, - ) - fields = serialize_json(obj) - try: - deserialize_json(fields) - except Exception as e: - raise StashException( - f"Error serializing object: {e}. Some fields are invalid." - ) - stmt = stmt.values(fields=fields) - - result = self.session.execute(stmt) - self.session.commit() - if result.rowcount == 0: - raise NotFoundException( - f"{self.object_type.__name__}: {obj.id} not found or no permission to update." - ) - return self.get_by_uid(credentials, obj.id).unwrap() - def get_ownership_permissions( self, uid: UID, credentials: SyftVerifyKey ) -> list[str]: @@ -354,25 +307,6 @@ def get_ownership_permissions( ActionObjectEXECUTE(uid=uid, credentials=credentials).permission_string, ] - @as_result(StashException, NotFoundException) - def delete_by_uid( - self, credentials: SyftVerifyKey, uid: UID, has_permission: bool = False - ) -> UID: - stmt = self.table.delete().where(self._get_field_filter("id", uid)) - stmt = self._apply_permission_filter( - stmt, - credentials=credentials, - permission=ActionPermission.WRITE, - has_permission=has_permission, - ) - result = self.session.execute(stmt) - self.session.commit() - if result.rowcount == 0: - raise NotFoundException( - f"{self.object_type.__name__}: {uid} not found or no permission to delete." - ) - return uid - @as_result(NotFoundException) def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: # TODO: should do this in a single transaction @@ -425,44 +359,6 @@ def remove_permission(self, permission: ActionObjectPermission) -> None: self.session.commit() return None - def remove_storage_permission(self, permission: StoragePermission) -> None: - # TODO not threadsafe - try: - permissions = self._get_storage_permissions_for_uid(permission.uid).unwrap() - permissions.remove(permission.server_uid) - except (NotFoundException, KeyError): - # TODO add error handling to permissions - return None - - stmt = ( - self.table.update() - .where(self.table.c.id == permission.uid) - .values(storage_permissions=[str(uid) for uid in permissions]) - ) - self.session.execute(stmt) - self.session.commit() - return None - - @as_result(StashException) - def _get_storage_permissions_for_uid(self, uid: UID) -> set[UID]: - stmt = select(self.table.c.id, self.table.c.storage_permissions).where( - self.table.c.id == uid - ) - result = self.session.execute(stmt).first() - if result is None: - raise NotFoundException(f"No storage permissions found for uid: {uid}") - return {UID(uid) for uid in result.storage_permissions} - - @as_result(StashException) - def get_all_storage_permissions(self) -> dict[UID, set[UID]]: - stmt = select(self.table.c.id, self.table.c.storage_permissions) - results = self.session.execute(stmt).all() - - return { - UID(row.id): {UID(uid) for uid in row.storage_permissions} - for row in results - } - def has_permission(self, permission: ActionObjectPermission) -> bool: if self.get_role(permission.credentials) in ( ServiceRole.ADMIN, @@ -471,33 +367,6 @@ def has_permission(self, permission: ActionObjectPermission) -> bool: return True return self.has_permissions([permission]) - def has_storage_permission(self, permission: StoragePermission) -> bool: - return self.has_storage_permissions([permission]) - - def _is_sqlite(self) -> bool: - return self.db.engine.dialect.name == "sqlite" - - def has_storage_permissions(self, permissions: list[StoragePermission]) -> bool: - permission_filters = [ - sa.and_( - self._get_field_filter("id", p.uid), - self.table.c.storage_permissions.contains( - p.server_uid.no_dash - if self._is_sqlite() - else [p.server_uid.no_dash] - ), - ) - for p in permissions - ] - - stmt = self.table.select().where( - sa.and_( - *permission_filters, - ) - ) - result = self.session.execute(stmt).first() - return result is not None - def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: # TODO: we should use a permissions table to check all permissions at once # TODO: should check for compound permissions @@ -518,6 +387,33 @@ def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: result = self.session.execute(stmt).first() return result is not None + @as_result(StashException) + def _get_permissions_for_uid(self, uid: UID) -> set[str]: + stmt = select(self.table.c.permissions).where(self.table.c.id == uid) + result = self.session.execute(stmt).scalar_one_or_none() + if result is None: + return NotFoundException(f"No permissions found for uid: {uid}") + return set(result) + + @as_result(StashException) + def get_all_permissions(self) -> dict[UID, set[str]]: + stmt = select(self.table.c.id, self.table.c.permissions) + results = self.session.execute(stmt).all() + return {UID(row.id): set(row.permissions) for row in results} + + def has_storage_permission(self, permission: StoragePermission) -> bool: + return self.has_storage_permissions([permission]) + + @as_result(StashException) + def get_all_storage_permissions(self) -> dict[UID, set[UID]]: + stmt = select(self.table.c.id, self.table.c.storage_permissions) + results = self.session.execute(stmt).all() + + return { + UID(row.id): {UID(uid) for uid in row.storage_permissions} + for row in results + } + @as_result(NotFoundException) def add_storage_permission(self, permission: StoragePermission) -> None: stmt = self.table.update().where(self.table.c.id == permission.uid) @@ -544,24 +440,59 @@ def add_storage_permission(self, permission: StoragePermission) -> None: ) return None - @as_result(NotFoundException) - def add_storage_permissions(self, permissions: list[StoragePermission]) -> None: - for permission in permissions: - self.add_storage_permission(permission) + def has_storage_permissions(self, permissions: list[StoragePermission]) -> bool: + permission_filters = [ + sa.and_( + self._get_field_filter("id", p.uid), + self.table.c.storage_permissions.contains( + p.server_uid.no_dash + if self._is_sqlite() + else [p.server_uid.no_dash] + ), + ) + for p in permissions + ] + + stmt = self.table.select().where( + sa.and_( + *permission_filters, + ) + ) + result = self.session.execute(stmt).first() + return result is not None + + def remove_storage_permission(self, permission: StoragePermission) -> None: + # TODO not threadsafe + try: + permissions = self._get_storage_permissions_for_uid(permission.uid).unwrap() + permissions.remove(permission.server_uid) + except (NotFoundException, KeyError): + # TODO add error handling to permissions + return None + + stmt = ( + self.table.update() + .where(self.table.c.id == permission.uid) + .values(storage_permissions=[str(uid) for uid in permissions]) + ) + self.session.execute(stmt) + self.session.commit() + return None @as_result(StashException) - def _get_permissions_for_uid(self, uid: UID) -> set[str]: - stmt = select(self.table.c.permissions).where(self.table.c.id == uid) - result = self.session.execute(stmt).scalar_one_or_none() + def _get_storage_permissions_for_uid(self, uid: UID) -> set[UID]: + stmt = select(self.table.c.id, self.table.c.storage_permissions).where( + self.table.c.id == uid + ) + result = self.session.execute(stmt).first() if result is None: - return NotFoundException(f"No permissions found for uid: {uid}") - return set(result) + raise NotFoundException(f"No storage permissions found for uid: {uid}") + return {UID(uid) for uid in result.storage_permissions} - @as_result(StashException) - def get_all_permissions(self) -> dict[UID, set[str]]: - stmt = select(self.table.c.id, self.table.c.permissions) - results = self.session.execute(stmt).all() - return {UID(row.id): set(row.permissions) for row in results} + @as_result(NotFoundException) + def add_storage_permissions(self, permissions: list[StoragePermission]) -> None: + for permission in permissions: + self.add_storage_permission(permission).unwrap() @as_result(SyftException, StashException) def set( @@ -618,6 +549,71 @@ def set( self.session.commit() return self.get_by_uid(credentials, uid).unwrap() + @as_result(StashException, NotFoundException) + def update( + self, + credentials: SyftVerifyKey, + obj: StashT, + has_permission: bool = False, + ) -> StashT: + """ + NOTE: We cannot do partial updates on the database, + because we are using computed fields that are not known to the DB: + - serialize_json will add computed fields to the JSON stored in the database + - If we update a single field in the JSON, the computed fields can get out of sync. + - To fix, we either need db-supported computed fields, or know in our ORM which fields should be re-computed. + """ + + if not self.allow_any_type: + self.check_type(obj, self.object_type).unwrap() + + # TODO has_permission is not used + if not self.is_unique(obj): + raise StashException(f"Some fields are not unique for {type(obj).__name__}") + + stmt = self.table.update().where(self._get_field_filter("id", obj.id)) + stmt = self._apply_permission_filter( + stmt, + credentials=credentials, + permission=ActionPermission.WRITE, + has_permission=has_permission, + ) + fields = serialize_json(obj) + try: + deserialize_json(fields) + except Exception as e: + raise StashException( + f"Error serializing object: {e}. Some fields are invalid." + ) + stmt = stmt.values(fields=fields) + + result = self.session.execute(stmt) + self.session.commit() + if result.rowcount == 0: + raise NotFoundException( + f"{self.object_type.__name__}: {obj.id} not found or no permission to update." + ) + return self.get_by_uid(credentials, obj.id).unwrap() + + @as_result(StashException, NotFoundException) + def delete_by_uid( + self, credentials: SyftVerifyKey, uid: UID, has_permission: bool = False + ) -> UID: + stmt = self.table.delete().where(self._get_field_filter("id", uid)) + stmt = self._apply_permission_filter( + stmt, + credentials=credentials, + permission=ActionPermission.WRITE, + has_permission=has_permission, + ) + result = self.session.execute(stmt) + self.session.commit() + if result.rowcount == 0: + raise NotFoundException( + f"{self.object_type.__name__}: {uid} not found or no permission to delete." + ) + return uid + @as_result(StashException) def get_one( self, From a094d64fc0e0a2c3725d778a478232074f40aa8e Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 10 Sep 2024 16:48:17 +0200 Subject: [PATCH 125/197] comments --- packages/syft/src/syft/store/db/query.py | 1 + packages/syft/src/syft/store/db/stash.py | 396 ++++++++++++----------- 2 files changed, 200 insertions(+), 197 deletions(-) diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py index 966f8753920..8dc13f1ccb5 100644 --- a/packages/syft/src/syft/store/db/query.py +++ b/packages/syft/src/syft/store/db/query.py @@ -58,6 +58,7 @@ def get_query_class(dialect: str | Dialect) -> "type[Query]": @classmethod def create(cls, object_type: type[SyftObject], dialect: str | Dialect) -> "Query": + """Create a query object for the given object type and dialect.""" query_class = cls.get_query_class(dialect) return query_class(object_type) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index b2cf5aa8896..2d1a9399a9f 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -297,203 +297,6 @@ def _apply_permission_filter( ) return stmt - def get_ownership_permissions( - self, uid: UID, credentials: SyftVerifyKey - ) -> list[str]: - return [ - ActionObjectOWNER(uid=uid, credentials=credentials).permission_string, - ActionObjectWRITE(uid=uid, credentials=credentials).permission_string, - ActionObjectREAD(uid=uid, credentials=credentials).permission_string, - ActionObjectEXECUTE(uid=uid, credentials=credentials).permission_string, - ] - - @as_result(NotFoundException) - def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: - # TODO: should do this in a single transaction - # TODO add error handling - for permission in permissions: - self.add_permission(permission).unwrap() - return None - - @as_result(NotFoundException) - def add_permission(self, permission: ActionObjectPermission) -> None: - # TODO add error handling - stmt = self.table.update().where(self.table.c.id == permission.uid) - if self._is_sqlite(): - stmt = stmt.values( - permissions=func.json_insert( - self.table.c.permissions, - "$[#]", - permission.permission_string, - ) - ) - else: - stmt = stmt.values( - permissions=func.array_append( - self.table.c.permissions, permission.permission_string - ) - ) - - result = self.session.execute(stmt) - self.session.commit() - if result.rowcount == 0: - raise NotFoundException( - f"{self.object_type.__name__}: {permission.uid} not found or no permission to change." - ) - - def remove_permission(self, permission: ActionObjectPermission) -> None: - # TODO not threadsafe - try: - permissions = self._get_permissions_for_uid(permission.uid).unwrap() - permissions.remove(permission.permission_string) - except (NotFoundException, KeyError): - # TODO add error handling to permissions - return None - - stmt = ( - self.table.update() - .where(self.table.c.id == permission.uid) - .values(permissions=list(permissions)) - ) - self.session.execute(stmt) - self.session.commit() - return None - - def has_permission(self, permission: ActionObjectPermission) -> bool: - if self.get_role(permission.credentials) in ( - ServiceRole.ADMIN, - ServiceRole.DATA_OWNER, - ): - return True - return self.has_permissions([permission]) - - def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: - # TODO: we should use a permissions table to check all permissions at once - # TODO: should check for compound permissions - - permission_filters = [ - sa.and_( - self._get_field_filter("id", p.uid), - self._get_permission_filter_from_permisson(permission=p), - ) - for p in permissions - ] - - stmt = self.table.select().where( - sa.and_( - *permission_filters, - ), - ) - result = self.session.execute(stmt).first() - return result is not None - - @as_result(StashException) - def _get_permissions_for_uid(self, uid: UID) -> set[str]: - stmt = select(self.table.c.permissions).where(self.table.c.id == uid) - result = self.session.execute(stmt).scalar_one_or_none() - if result is None: - raise NotFoundException(f"No permissions found for uid: {uid}") - return set(result) - - @as_result(StashException) - def get_all_permissions(self) -> dict[UID, set[str]]: - stmt = select(self.table.c.id, self.table.c.permissions) - results = self.session.execute(stmt).all() - return {UID(row.id): set(row.permissions) for row in results} - - def has_storage_permission(self, permission: StoragePermission) -> bool: - return self.has_storage_permissions([permission]) - - @as_result(StashException) - def get_all_storage_permissions(self) -> dict[UID, set[UID]]: - stmt = select(self.table.c.id, self.table.c.storage_permissions) - results = self.session.execute(stmt).all() - - return { - UID(row.id): {UID(uid) for uid in row.storage_permissions} - for row in results - } - - @as_result(NotFoundException) - def add_storage_permission(self, permission: StoragePermission) -> None: - stmt = self.table.update().where(self.table.c.id == permission.uid) - if self._is_sqlite(): - stmt = stmt.values( - storage_permissions=func.json_insert( - self.table.c.storage_permissions, - "$[#]", - permission.permission_string, - ) - ) - else: - stmt = stmt.values( - permissions=func.array_append( - self.table.c.storage_permissions, permission.server_uid.no_dash - ) - ) - - result = self.session.execute(stmt) - self.session.commit() - if result.rowcount == 0: - raise NotFoundException( - f"{self.object_type.__name__}: {permission.uid} not found or no permission to change." - ) - return None - - def has_storage_permissions(self, permissions: list[StoragePermission]) -> bool: - permission_filters = [ - sa.and_( - self._get_field_filter("id", p.uid), - self.table.c.storage_permissions.contains( - p.server_uid.no_dash - if self._is_sqlite() - else [p.server_uid.no_dash] - ), - ) - for p in permissions - ] - - stmt = self.table.select().where( - sa.and_( - *permission_filters, - ) - ) - result = self.session.execute(stmt).first() - return result is not None - - def remove_storage_permission(self, permission: StoragePermission) -> None: - # TODO not threadsafe - try: - permissions = self._get_storage_permissions_for_uid(permission.uid).unwrap() - permissions.remove(permission.server_uid) - except (NotFoundException, KeyError): - # TODO add error handling to permissions - return None - - stmt = ( - self.table.update() - .where(self.table.c.id == permission.uid) - .values(storage_permissions=[str(uid) for uid in permissions]) - ) - self.session.execute(stmt) - self.session.commit() - return None - - @as_result(StashException) - def _get_storage_permissions_for_uid(self, uid: UID) -> set[UID]: - stmt = select(self.table.c.id, self.table.c.storage_permissions).where( - self.table.c.id == uid - ) - result = self.session.execute(stmt).first() - if result is None: - raise NotFoundException(f"No storage permissions found for uid: {uid}") - return {UID(uid) for uid in result.storage_permissions} - - @as_result(NotFoundException) - def add_storage_permissions(self, permissions: list[StoragePermission]) -> None: - for permission in permissions: - self.add_storage_permission(permission).unwrap() - @as_result(SyftException, StashException) def set( self, @@ -708,3 +511,202 @@ def get_all( query = query.order_by(order_by, sort_order).limit(limit).offset(offset) result = query.execute(self.session).all() return [self.row_as_obj(row) for row in result] + + # PERMISSIONS + def get_ownership_permissions( + self, uid: UID, credentials: SyftVerifyKey + ) -> list[str]: + return [ + ActionObjectOWNER(uid=uid, credentials=credentials).permission_string, + ActionObjectWRITE(uid=uid, credentials=credentials).permission_string, + ActionObjectREAD(uid=uid, credentials=credentials).permission_string, + ActionObjectEXECUTE(uid=uid, credentials=credentials).permission_string, + ] + + @as_result(NotFoundException) + def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: + # TODO: should do this in a single transaction + # TODO add error handling + for permission in permissions: + self.add_permission(permission).unwrap() + return None + + @as_result(NotFoundException) + def add_permission(self, permission: ActionObjectPermission) -> None: + # TODO add error handling + stmt = self.table.update().where(self.table.c.id == permission.uid) + if self._is_sqlite(): + stmt = stmt.values( + permissions=func.json_insert( + self.table.c.permissions, + "$[#]", + permission.permission_string, + ) + ) + else: + stmt = stmt.values( + permissions=func.array_append( + self.table.c.permissions, permission.permission_string + ) + ) + + result = self.session.execute(stmt) + self.session.commit() + if result.rowcount == 0: + raise NotFoundException( + f"{self.object_type.__name__}: {permission.uid} not found or no permission to change." + ) + + def remove_permission(self, permission: ActionObjectPermission) -> None: + # TODO not threadsafe + try: + permissions = self._get_permissions_for_uid(permission.uid).unwrap() + permissions.remove(permission.permission_string) + except (NotFoundException, KeyError): + # TODO add error handling to permissions + return None + + stmt = ( + self.table.update() + .where(self.table.c.id == permission.uid) + .values(permissions=list(permissions)) + ) + self.session.execute(stmt) + self.session.commit() + return None + + def has_permission(self, permission: ActionObjectPermission) -> bool: + if self.get_role(permission.credentials) in ( + ServiceRole.ADMIN, + ServiceRole.DATA_OWNER, + ): + return True + return self.has_permissions([permission]) + + def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: + # TODO: we should use a permissions table to check all permissions at once + # TODO: should check for compound permissions + + permission_filters = [ + sa.and_( + self._get_field_filter("id", p.uid), + self._get_permission_filter_from_permisson(permission=p), + ) + for p in permissions + ] + + stmt = self.table.select().where( + sa.and_( + *permission_filters, + ), + ) + result = self.session.execute(stmt).first() + return result is not None + + @as_result(StashException) + def _get_permissions_for_uid(self, uid: UID) -> set[str]: + stmt = select(self.table.c.permissions).where(self.table.c.id == uid) + result = self.session.execute(stmt).scalar_one_or_none() + if result is None: + raise NotFoundException(f"No permissions found for uid: {uid}") + return set(result) + + @as_result(StashException) + def get_all_permissions(self) -> dict[UID, set[str]]: + stmt = select(self.table.c.id, self.table.c.permissions) + results = self.session.execute(stmt).all() + return {UID(row.id): set(row.permissions) for row in results} + + # STORAGE PERMISSIONS + def has_storage_permission(self, permission: StoragePermission) -> bool: + return self.has_storage_permissions([permission]) + + @as_result(StashException) + def get_all_storage_permissions(self) -> dict[UID, set[UID]]: + stmt = select(self.table.c.id, self.table.c.storage_permissions) + results = self.session.execute(stmt).all() + + return { + UID(row.id): {UID(uid) for uid in row.storage_permissions} + for row in results + } + + @as_result(NotFoundException) + def add_storage_permission(self, permission: StoragePermission) -> None: + stmt = self.table.update().where(self.table.c.id == permission.uid) + if self._is_sqlite(): + stmt = stmt.values( + storage_permissions=func.json_insert( + self.table.c.storage_permissions, + "$[#]", + permission.permission_string, + ) + ) + else: + stmt = stmt.values( + permissions=func.array_append( + self.table.c.storage_permissions, permission.server_uid.no_dash + ) + ) + + result = self.session.execute(stmt) + self.session.commit() + if result.rowcount == 0: + raise NotFoundException( + f"{self.object_type.__name__}: {permission.uid} not found or no permission to change." + ) + return None + + def has_storage_permissions(self, permissions: list[StoragePermission]) -> bool: + permission_filters = [ + sa.and_( + self._get_field_filter("id", p.uid), + self.table.c.storage_permissions.contains( + p.server_uid.no_dash + if self._is_sqlite() + else [p.server_uid.no_dash] + ), + ) + for p in permissions + ] + + stmt = self.table.select().where( + sa.and_( + *permission_filters, + ) + ) + result = self.session.execute(stmt).first() + return result is not None + + def remove_storage_permission(self, permission: StoragePermission) -> None: + # TODO not threadsafe + try: + permissions = self._get_storage_permissions_for_uid(permission.uid).unwrap() + permissions.remove(permission.server_uid) + except (NotFoundException, KeyError): + # TODO add error handling to permissions + return None + + stmt = ( + self.table.update() + .where(self.table.c.id == permission.uid) + .values(storage_permissions=[str(uid) for uid in permissions]) + ) + self.session.execute(stmt) + self.session.commit() + return None + + @as_result(StashException) + def _get_storage_permissions_for_uid(self, uid: UID) -> set[UID]: + stmt = select(self.table.c.id, self.table.c.storage_permissions).where( + self.table.c.id == uid + ) + result = self.session.execute(stmt).first() + if result is None: + raise NotFoundException(f"No storage permissions found for uid: {uid}") + return {UID(uid) for uid in result.storage_permissions} + + @as_result(NotFoundException) + def add_storage_permissions(self, permissions: list[StoragePermission]) -> None: + for permission in permissions: + self.add_storage_permission(permission).unwrap() From d593a94e116b6cd8b993f1857a8909fd33f8fd5a Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 10 Sep 2024 16:52:14 +0200 Subject: [PATCH 126/197] fix set type annotation --- packages/syft/src/syft/store/db/stash.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 2d1a9399a9f..614455c6d37 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -1,6 +1,7 @@ # stdlib from typing import Any from typing import Generic +from typing import Set # noqa: UP035 from typing import cast from typing import get_args @@ -604,7 +605,7 @@ def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: return result is not None @as_result(StashException) - def _get_permissions_for_uid(self, uid: UID) -> set[str]: + def _get_permissions_for_uid(self, uid: UID) -> Set[str]: # noqa: UP006 stmt = select(self.table.c.permissions).where(self.table.c.id == uid) result = self.session.execute(stmt).scalar_one_or_none() if result is None: @@ -612,7 +613,7 @@ def _get_permissions_for_uid(self, uid: UID) -> set[str]: return set(result) @as_result(StashException) - def get_all_permissions(self) -> dict[UID, set[str]]: + def get_all_permissions(self) -> dict[UID, Set[str]]: # noqa: UP006 stmt = select(self.table.c.id, self.table.c.permissions) results = self.session.execute(stmt).all() return {UID(row.id): set(row.permissions) for row in results} @@ -622,7 +623,7 @@ def has_storage_permission(self, permission: StoragePermission) -> bool: return self.has_storage_permissions([permission]) @as_result(StashException) - def get_all_storage_permissions(self) -> dict[UID, set[UID]]: + def get_all_storage_permissions(self) -> dict[UID, Set[UID]]: # noqa: UP006 stmt = select(self.table.c.id, self.table.c.storage_permissions) results = self.session.execute(stmt).all() @@ -697,7 +698,7 @@ def remove_storage_permission(self, permission: StoragePermission) -> None: return None @as_result(StashException) - def _get_storage_permissions_for_uid(self, uid: UID) -> set[UID]: + def _get_storage_permissions_for_uid(self, uid: UID) -> Set[UID]: # noqa: UP006 stmt = select(self.table.c.id, self.table.c.storage_permissions).where( self.table.c.id == uid ) From 963a8573e88fea133636fa97120df55063dfad93 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 10 Sep 2024 20:20:01 +0200 Subject: [PATCH 127/197] local session management --- packages/syft/src/syft/store/db/sqlite_db.py | 52 ++--- packages/syft/src/syft/store/db/stash.py | 216 ++++++++++++++----- 2 files changed, 173 insertions(+), 95 deletions(-) diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index 509650dc801..674fc58c224 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -1,7 +1,6 @@ # stdlib from pathlib import Path import tempfile -import threading import uuid # third party @@ -9,7 +8,6 @@ from pydantic import Field import sqlalchemy as sa from sqlalchemy import create_engine -from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker # relative @@ -18,8 +16,6 @@ from ...server.credentials import SyftVerifyKey from ...types.uid import UID from .schema import Base -from .utils import dumps -from .utils import loads @serializable(canonical_name="DBConfig", version=1) @@ -58,13 +54,25 @@ def connection_string(self) -> str: class DBManager: def __init__( self, - config: DBConfig, + config: SQLiteDBConfig, server_uid: UID, root_verify_key: SyftVerifyKey, ) -> None: self.config = config - self.server_uid = server_uid self.root_verify_key = root_verify_key + self.server_uid = server_uid + self.engine = create_engine( + config.connection_string, + # json_serializer=dumps, + # json_deserializer=loads, + ) + print(f"Connecting to {config.connection_string}") + self.sessionmaker = sessionmaker(bind=self.engine) + + self.update_settings() + + def update_settings(self) -> None: + pass def init_tables(self) -> None: pass @@ -74,27 +82,8 @@ def reset(self) -> None: class SQLiteDBManager(DBManager): - def __init__( - self, - config: SQLiteDBConfig, - server_uid: UID, - root_verify_key: SyftVerifyKey, - ) -> None: - self.config = config - self.root_verify_key = root_verify_key - self.server_uid = server_uid - self.engine = create_engine( - config.connection_string, json_serializer=dumps, json_deserializer=loads - ) - print(f"Connecting to {config.connection_string}") - self.Session = sessionmaker(bind=self.engine) - - # TODO use AuthedServiceContext for session management instead of threading.local - self.thread_local = threading.local() - - self.update_settings() - def update_settings(self) -> None: + # TODO split SQLite / PostgresDBManager connection = self.engine.connect() if self.engine.dialect.name == "sqlite": @@ -114,17 +103,6 @@ def reset(self) -> None: Base.metadata.drop_all(bind=self.engine) Base.metadata.create_all(self.engine) - # TODO remove - def get_session_threading_local(self) -> Session: - if not hasattr(self.thread_local, "session"): - self.thread_local.session = self.Session() - return self.thread_local.session - - # TODO remove - @property - def session(self) -> Session: - return self.get_session_threading_local() - @classmethod def random( cls, diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 614455c6d37..7f5510d92c2 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -1,6 +1,10 @@ # stdlib +from collections.abc import Callable +from functools import wraps +import inspect from typing import Any from typing import Generic +from typing import ParamSpec from typing import Set # noqa: UP035 from typing import cast from typing import get_args @@ -42,6 +46,7 @@ StashT = TypeVar("StashT", bound=SyftObject) T = TypeVar("T") +P = ParamSpec("P") def parse_filters(filter_dict: dict[str, Any] | None) -> list[tuple[str, str, Any]]: @@ -60,6 +65,29 @@ def parse_filters(filter_dict: dict[str, Any] | None) -> list[tuple[str, str, An return filters +def with_session(func: Callable[P, T]) -> Callable[P, T]: # type: ignore + """ + Decorator to inject a session into the function kwargs if it is not provided. + + TODO: This decorator is a temporary fix, we want to move to a DI approach instead: + move db connection and session to context, and pass context to all stash methods. + """ + + # inspect if the function has a session kwarg + sig = inspect.signature(func) + inject_session: bool = "session" in sig.parameters + + @wraps(func) + def wrapper(self: "ObjectStash[StashT]", *args: Any, **kwargs: Any) -> Any: + if inject_session and kwargs.get("session") is None: + with self.sessionmaker() as session: + kwargs["session"] = session + return func(self, *args, **kwargs) + return func(self, *args, **kwargs) + + return wrapper # type: ignore + + class ObjectStash(Generic[StashT]): table: Table object_type: type[SyftObject] @@ -69,6 +97,7 @@ def __init__(self, store: DBManager) -> None: self.db = store self.object_type = self.get_object_type() self.table = create_table(self.object_type, self.dialect) + self.sessionmaker = self.db.sessionmaker @property def dialect(self) -> sa.engine.interfaces.Dialect: @@ -89,8 +118,9 @@ def get_object_type(cls) -> type[StashT]: ) return generic_args[0] - def __len__(self) -> int: - return self.session.query(self.table).count() + @with_session + def __len__(self, session: Session = None) -> int: + return session.query(self.table).count() @classmethod def random(cls, **kwargs: dict) -> Self: @@ -140,7 +170,7 @@ def _print_query(self, stmt: sa.sql.select) -> None: print( stmt.compile( compile_kwargs={"literal_binds": True}, - dialect=self.session.bind.dialect, + dialect=self.db.engine.dialect, ) ) @@ -148,7 +178,8 @@ def _print_query(self, stmt: sa.sql.select) -> None: def unique_fields(self) -> list[str]: return getattr(self.object_type, "__attr_unique__", []) - def is_unique(self, obj: StashT) -> bool: + @with_session + def is_unique(self, obj: StashT, session: Session = None) -> bool: unique_fields = self.unique_fields if not unique_fields: return True @@ -169,7 +200,7 @@ def is_unique(self, obj: StashT) -> bool: *filters, ) - results = query.execute(self.session).all() + results = query.execute(session).all() if len(results) > 1: return False @@ -178,20 +209,31 @@ def is_unique(self, obj: StashT) -> bool: return result.id == obj.id return True - def exists(self, credentials: SyftVerifyKey, uid: UID) -> bool: + @with_session + def exists( + self, credentials: SyftVerifyKey, uid: UID, session: Session = None + ) -> bool: # TODO should be @as_result # TODO needs credentials check? # TODO use COUNT(*) instead of SELECT query = self.query().filter("id", "eq", uid) - result = query.execute(self.session).first() + result = query.execute(session).first() return result is not None @as_result(SyftException, StashException, NotFoundException) + @with_session def get_by_uid( - self, credentials: SyftVerifyKey, uid: UID, has_permission: bool = False + self, + credentials: SyftVerifyKey, + uid: UID, + has_permission: bool = False, + session: Session = None, ) -> StashT: return self.get_one( - credentials=credentials, filters={"id": uid}, has_permission=has_permission + credentials=credentials, + filters={"id": uid}, + has_permission=has_permission, + session=session, ).unwrap() def _get_field_filter( @@ -237,7 +279,10 @@ def row_as_obj(self, row: Row) -> StashT: # TODO make unwrappable serde return deserialize_json(row.fields) - def get_role(self, credentials: SyftVerifyKey) -> ServiceRole: + @with_session + def get_role( + self, credentials: SyftVerifyKey, session: Session = None + ) -> ServiceRole: # relative from ...service.user.user import User @@ -248,7 +293,7 @@ def get_role(self, credentials: SyftVerifyKey) -> ServiceRole: return ServiceRole.GUEST query = self.query(User).filter("verify_key", "eq", credentials) - user = query.execute(self.session).first() + user = query.execute(session).first() if user is None: return ServiceRole.GUEST @@ -261,7 +306,7 @@ def _get_permission_filter_from_permisson( permission_string = permission.permission_string compound_permission_string = permission.compound_permission_string - if self.session.bind.dialect.name == "postgresql": + if self.db.engine.dialect.name == "postgresql": permission_string = [permission_string] # type: ignore compound_permission_string = [compound_permission_string] # type: ignore return sa.or_( @@ -269,6 +314,7 @@ def _get_permission_filter_from_permisson( self.table.c.permissions.contains(compound_permission_string), ) + @with_session def _apply_permission_filter( self, stmt: T, @@ -276,11 +322,12 @@ def _apply_permission_filter( credentials: SyftVerifyKey, permission: ActionPermission = ActionPermission.READ, has_permission: bool = False, + session: Session = None, ) -> T: if has_permission: # ignoring permissions return stmt - role = self.get_role(credentials) + role = self.get_role(credentials, session=session) if role in (ServiceRole.ADMIN, ServiceRole.DATA_OWNER): # admins and data owners have all permissions return stmt @@ -299,6 +346,7 @@ def _apply_permission_filter( return stmt @as_result(SyftException, StashException) + @with_session def set( self, credentials: SyftVerifyKey, @@ -306,6 +354,7 @@ def set( add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, # TODO: check the default value ignore_duplicates: bool = False, + session: Session = None, ) -> StashT: if not self.allow_any_type: self.check_type(obj, self.object_type).unwrap() @@ -349,16 +398,18 @@ def set( permissions=permissions, storage_permissions=storage_permissions, ) - self.session.execute(stmt) - self.session.commit() - return self.get_by_uid(credentials, uid).unwrap() + session.execute(stmt) + session.commit() + return self.get_by_uid(credentials, uid, session=session).unwrap() @as_result(StashException, NotFoundException) + @with_session def update( self, credentials: SyftVerifyKey, obj: StashT, has_permission: bool = False, + session: Session = None, ) -> StashT: """ NOTE: We cannot do partial updates on the database, @@ -381,6 +432,7 @@ def update( credentials=credentials, permission=ActionPermission.WRITE, has_permission=has_permission, + session=session, ) fields = serialize_json(obj) try: @@ -391,8 +443,8 @@ def update( ) stmt = stmt.values(fields=fields) - result = self.session.execute(stmt) - self.session.commit() + result = session.execute(stmt) + session.commit() if result.rowcount == 0: raise NotFoundException( f"{self.object_type.__name__}: {obj.id} not found or no permission to update." @@ -400,8 +452,13 @@ def update( return self.get_by_uid(credentials, obj.id).unwrap() @as_result(StashException, NotFoundException) + @with_session def delete_by_uid( - self, credentials: SyftVerifyKey, uid: UID, has_permission: bool = False + self, + credentials: SyftVerifyKey, + uid: UID, + has_permission: bool = False, + session: Session = None, ) -> UID: stmt = self.table.delete().where(self._get_field_filter("id", uid)) stmt = self._apply_permission_filter( @@ -409,9 +466,10 @@ def delete_by_uid( credentials=credentials, permission=ActionPermission.WRITE, has_permission=has_permission, + session=session, ) - result = self.session.execute(stmt) - self.session.commit() + result = session.execute(stmt) + session.commit() if result.rowcount == 0: raise NotFoundException( f"{self.object_type.__name__}: {uid} not found or no permission to delete." @@ -419,6 +477,7 @@ def delete_by_uid( return uid @as_result(StashException) + @with_session def get_one( self, credentials: SyftVerifyKey, @@ -427,6 +486,7 @@ def get_one( order_by: str | None = None, sort_order: str | None = None, offset: int = 0, + session: Session = None, ) -> StashT: """ Get first objects from the stash, optionally filtered. @@ -453,20 +513,21 @@ def get_one( query = self.query() if not has_permission: - role = self.get_role(credentials) + role = self.get_role(credentials, session=session) query = query.with_permissions(credentials, role) for field_name, operator, field_value in parse_filters(filters): query = query.filter(field_name, operator, field_value) query = query.order_by(order_by, sort_order).offset(offset) - result = query.execute(self.session).first() + result = query.execute(session).first() if result is None: raise NotFoundException(f"{self.object_type.__name__}: not found") return self.row_as_obj(result) @as_result(StashException) + @with_session def get_all( self, credentials: SyftVerifyKey, @@ -476,6 +537,7 @@ def get_all( sort_order: str | None = None, limit: int | None = None, offset: int = 0, + session: Session = None, ) -> list[StashT]: """ Get all objects from the stash, optionally filtered. @@ -503,14 +565,14 @@ def get_all( query = self.query() if not has_permission: - role = self.get_role(credentials) + role = self.get_role(credentials, session=session) query = query.with_permissions(credentials, role) for field_name, operator, field_value in parse_filters(filters): query = query.filter(field_name, operator, field_value) query = query.order_by(order_by, sort_order).limit(limit).offset(offset) - result = query.execute(self.session).all() + result = query.execute(session).all() return [self.row_as_obj(row) for row in result] # PERMISSIONS @@ -525,15 +587,21 @@ def get_ownership_permissions( ] @as_result(NotFoundException) - def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: + @with_session + def add_permissions( + self, permissions: list[ActionObjectPermission], session: Session = None + ) -> None: # TODO: should do this in a single transaction # TODO add error handling for permission in permissions: - self.add_permission(permission).unwrap() + self.add_permission(permission, session=session).unwrap() return None @as_result(NotFoundException) - def add_permission(self, permission: ActionObjectPermission) -> None: + @with_session + def add_permission( + self, permission: ActionObjectPermission, session: Session = None + ) -> None: # TODO add error handling stmt = self.table.update().where(self.table.c.id == permission.uid) if self._is_sqlite(): @@ -551,14 +619,17 @@ def add_permission(self, permission: ActionObjectPermission) -> None: ) ) - result = self.session.execute(stmt) - self.session.commit() + result = session.execute(stmt) + session.commit() if result.rowcount == 0: raise NotFoundException( f"{self.object_type.__name__}: {permission.uid} not found or no permission to change." ) - def remove_permission(self, permission: ActionObjectPermission) -> None: + @with_session + def remove_permission( + self, permission: ActionObjectPermission, session: Session = None + ) -> None: # TODO not threadsafe try: permissions = self._get_permissions_for_uid(permission.uid).unwrap() @@ -572,19 +643,25 @@ def remove_permission(self, permission: ActionObjectPermission) -> None: .where(self.table.c.id == permission.uid) .values(permissions=list(permissions)) ) - self.session.execute(stmt) - self.session.commit() + session.execute(stmt) + session.commit() return None - def has_permission(self, permission: ActionObjectPermission) -> bool: - if self.get_role(permission.credentials) in ( + @with_session + def has_permission( + self, permission: ActionObjectPermission, session: Session = None + ) -> bool: + if self.get_role(permission.credentials, session=session) in ( ServiceRole.ADMIN, ServiceRole.DATA_OWNER, ): return True - return self.has_permissions([permission]) + return self.has_permissions([permission], session=session) - def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: + @with_session + def has_permissions( + self, permissions: list[ActionObjectPermission], session: Session = None + ) -> bool: # TODO: we should use a permissions table to check all permissions at once # TODO: should check for compound permissions @@ -601,31 +678,39 @@ def has_permissions(self, permissions: list[ActionObjectPermission]) -> bool: *permission_filters, ), ) - result = self.session.execute(stmt).first() + result = session.execute(stmt).first() return result is not None @as_result(StashException) - def _get_permissions_for_uid(self, uid: UID) -> Set[str]: # noqa: UP006 + @with_session + def _get_permissions_for_uid(self, uid: UID, session: Session = None) -> Set[str]: # noqa: UP006 stmt = select(self.table.c.permissions).where(self.table.c.id == uid) - result = self.session.execute(stmt).scalar_one_or_none() + result = session.execute(stmt).scalar_one_or_none() if result is None: raise NotFoundException(f"No permissions found for uid: {uid}") return set(result) @as_result(StashException) - def get_all_permissions(self) -> dict[UID, Set[str]]: # noqa: UP006 + @with_session + def get_all_permissions(self, session: Session = None) -> dict[UID, Set[str]]: # noqa: UP006 stmt = select(self.table.c.id, self.table.c.permissions) - results = self.session.execute(stmt).all() + results = session.execute(stmt).all() return {UID(row.id): set(row.permissions) for row in results} # STORAGE PERMISSIONS - def has_storage_permission(self, permission: StoragePermission) -> bool: - return self.has_storage_permissions([permission]) + @with_session + def has_storage_permission( + self, permission: StoragePermission, session: Session = None + ) -> bool: + return self.has_storage_permissions([permission], session=session) @as_result(StashException) - def get_all_storage_permissions(self) -> dict[UID, Set[UID]]: # noqa: UP006 + @with_session + def get_all_storage_permissions( + self, session: Session = None + ) -> dict[UID, Set[UID]]: # noqa: UP006 stmt = select(self.table.c.id, self.table.c.storage_permissions) - results = self.session.execute(stmt).all() + results = session.execute(stmt).all() return { UID(row.id): {UID(uid) for uid in row.storage_permissions} @@ -633,7 +718,10 @@ def get_all_storage_permissions(self) -> dict[UID, Set[UID]]: # noqa: UP006 } @as_result(NotFoundException) - def add_storage_permission(self, permission: StoragePermission) -> None: + @with_session + def add_storage_permission( + self, permission: StoragePermission, session: Session = None + ) -> None: stmt = self.table.update().where(self.table.c.id == permission.uid) if self._is_sqlite(): stmt = stmt.values( @@ -650,15 +738,18 @@ def add_storage_permission(self, permission: StoragePermission) -> None: ) ) - result = self.session.execute(stmt) - self.session.commit() + result = session.execute(stmt) + session.commit() if result.rowcount == 0: raise NotFoundException( f"{self.object_type.__name__}: {permission.uid} not found or no permission to change." ) return None - def has_storage_permissions(self, permissions: list[StoragePermission]) -> bool: + @with_session + def has_storage_permissions( + self, permissions: list[StoragePermission], session: Session = None + ) -> bool: permission_filters = [ sa.and_( self._get_field_filter("id", p.uid), @@ -676,10 +767,13 @@ def has_storage_permissions(self, permissions: list[StoragePermission]) -> bool: *permission_filters, ) ) - result = self.session.execute(stmt).first() + result = session.execute(stmt).first() return result is not None - def remove_storage_permission(self, permission: StoragePermission) -> None: + @with_session + def remove_storage_permission( + self, permission: StoragePermission, session: Session = None + ) -> None: # TODO not threadsafe try: permissions = self._get_storage_permissions_for_uid(permission.uid).unwrap() @@ -693,21 +787,27 @@ def remove_storage_permission(self, permission: StoragePermission) -> None: .where(self.table.c.id == permission.uid) .values(storage_permissions=[str(uid) for uid in permissions]) ) - self.session.execute(stmt) - self.session.commit() + session.execute(stmt) + session.commit() return None @as_result(StashException) - def _get_storage_permissions_for_uid(self, uid: UID) -> Set[UID]: # noqa: UP006 + @with_session + def _get_storage_permissions_for_uid( + self, uid: UID, session: Session = None + ) -> Set[UID]: # noqa: UP006 stmt = select(self.table.c.id, self.table.c.storage_permissions).where( self.table.c.id == uid ) - result = self.session.execute(stmt).first() + result = session.execute(stmt).first() if result is None: raise NotFoundException(f"No storage permissions found for uid: {uid}") return {UID(uid) for uid in result.storage_permissions} @as_result(NotFoundException) - def add_storage_permissions(self, permissions: list[StoragePermission]) -> None: + @with_session + def add_storage_permissions( + self, permissions: list[StoragePermission], session: Session = None + ) -> None: for permission in permissions: - self.add_storage_permission(permission).unwrap() + self.add_storage_permission(permission, session=session).unwrap() From c81764b48cf4224cdc10bb5649303fde31b0a9d8 Mon Sep 17 00:00:00 2001 From: dk Date: Wed, 11 Sep 2024 11:03:31 +0700 Subject: [PATCH 128/197] [clean] remove mongo related tests and fixtures --- packages/syft/tests/conftest.py | 35 - .../tests/syft/stores/action_store_test.py | 3 - .../syft/stores/mongo_document_store_test.py | 1045 ----------------- .../tests/syft/stores/queue_stash_test.py | 31 - .../tests/syft/stores/store_fixtures_test.py | 117 -- 5 files changed, 1231 deletions(-) delete mode 100644 packages/syft/tests/syft/stores/mongo_document_store_test.py diff --git a/packages/syft/tests/conftest.py b/packages/syft/tests/conftest.py index eca68d13b12..ad7a16c8166 100644 --- a/packages/syft/tests/conftest.py +++ b/packages/syft/tests/conftest.py @@ -25,16 +25,10 @@ from syft.service.user import user # relative -# our version of mongomock that has a fix for CodecOptions and custom TypeRegistry Support -from .mongomock.mongo_client import MongoClient from .syft.stores.store_fixtures_test import dict_action_store from .syft.stores.store_fixtures_test import dict_document_store from .syft.stores.store_fixtures_test import dict_queue_stash from .syft.stores.store_fixtures_test import dict_store_partition -from .syft.stores.store_fixtures_test import mongo_action_store -from .syft.stores.store_fixtures_test import mongo_document_store -from .syft.stores.store_fixtures_test import mongo_queue_stash -from .syft.stores.store_fixtures_test import mongo_store_partition from .syft.stores.store_fixtures_test import sqlite_action_store from .syft.stores.store_fixtures_test import sqlite_document_store from .syft.stores.store_fixtures_test import sqlite_queue_stash @@ -221,31 +215,6 @@ def action_store(worker): yield worker.action_store -@pytest.fixture(scope="session") -def mongo_client(testrun_uid): - """ - A race-free fixture that starts a MongoDB server for an entire pytest session. - Cleans up the server when the session ends, or when the last client disconnects. - """ - db_name = f"pytest_mongo_{testrun_uid}" - - # rand conn str - conn_str = f"mongodb://localhost:27017/{db_name}" - - # create a client, and test the connection - client = MongoClient(conn_str) - assert client.server_info().get("ok") == 1.0 - - yield client - - # stop_mongo_server(db_name) - - -@pytest.fixture(autouse=True) -def patched_mongo_client(monkeypatch): - monkeypatch.setattr("pymongo.mongo_client.MongoClient", MongoClient) - - @pytest.fixture(autouse=True) def patched_session_cache(monkeypatch): # patching compute heavy hashing to speed up tests @@ -308,10 +277,6 @@ def big_dataset() -> Dataset: __all__ = [ - "mongo_store_partition", - "mongo_document_store", - "mongo_queue_stash", - "mongo_action_store", "sqlite_store_partition", "sqlite_workspace", "sqlite_document_store", diff --git a/packages/syft/tests/syft/stores/action_store_test.py b/packages/syft/tests/syft/stores/action_store_test.py index 375204908c1..4bf93a2ed65 100644 --- a/packages/syft/tests/syft/stores/action_store_test.py +++ b/packages/syft/tests/syft/stores/action_store_test.py @@ -33,7 +33,6 @@ [ pytest.lazy_fixture("dict_action_store"), pytest.lazy_fixture("sqlite_action_store"), - pytest.lazy_fixture("mongo_action_store"), ], ) def test_action_store_sanity(store: Any): @@ -50,7 +49,6 @@ def test_action_store_sanity(store: Any): [ pytest.lazy_fixture("dict_action_store"), pytest.lazy_fixture("sqlite_action_store"), - pytest.lazy_fixture("mongo_action_store"), ], ) @pytest.mark.parametrize("permission", permissions) @@ -117,7 +115,6 @@ def test_action_store_test_permissions(store: Any, permission: Any): [ pytest.lazy_fixture("dict_action_store"), pytest.lazy_fixture("sqlite_action_store"), - pytest.lazy_fixture("mongo_action_store"), ], ) @pytest.mark.flaky(reruns=3, reruns_delay=3) diff --git a/packages/syft/tests/syft/stores/mongo_document_store_test.py b/packages/syft/tests/syft/stores/mongo_document_store_test.py deleted file mode 100644 index 95df806c189..00000000000 --- a/packages/syft/tests/syft/stores/mongo_document_store_test.py +++ /dev/null @@ -1,1045 +0,0 @@ -# stdlib -from secrets import token_hex -from threading import Thread - -# third party -import pytest - -# syft absolute -from syft.server.credentials import SyftVerifyKey -from syft.service.action.action_permissions import ActionObjectPermission -from syft.service.action.action_permissions import ActionPermission -from syft.service.action.action_permissions import StoragePermission -from syft.service.action.action_store import ActionObjectEXECUTE -from syft.service.action.action_store import ActionObjectOWNER -from syft.service.action.action_store import ActionObjectREAD -from syft.service.action.action_store import ActionObjectWRITE -from syft.store.document_store import PartitionSettings -from syft.store.document_store import QueryKey -from syft.store.document_store import QueryKeys -from syft.store.mongo_client import MongoStoreClientConfig -from syft.store.mongo_document_store import MongoStoreConfig -from syft.store.mongo_document_store import MongoStorePartition -from syft.types.errors import SyftException -from syft.types.uid import UID - -# relative -from ...mongomock.collection import Collection as MongoCollection -from .store_constants_test import TEST_VERIFY_KEY_STRING_HACKER -from .store_fixtures_test import mongo_store_partition_fn -from .store_mocks_test import MockObjectType -from .store_mocks_test import MockSyftObject - -PERMISSIONS = [ - ActionObjectOWNER, - ActionObjectREAD, - ActionObjectWRITE, - ActionObjectEXECUTE, -] - - -def test_mongo_store_partition_sanity( - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - - assert hasattr(mongo_store_partition, "_collection") - assert hasattr(mongo_store_partition, "_permissions") - - -@pytest.mark.skip(reason="Test gets stuck at store.init_store()") -def test_mongo_store_partition_init_failed(root_verify_key) -> None: - # won't connect - mongo_config = MongoStoreClientConfig( - connectTimeoutMS=1, - timeoutMS=1, - ) - - store_config = MongoStoreConfig(client_config=mongo_config) - settings = PartitionSettings(name="test", object_type=MockObjectType) - - store = MongoStorePartition( - UID(), root_verify_key, settings=settings, store_config=store_config - ) - - res = store.init_store() - assert res.is_err() - - -def test_mongo_store_partition_set( - root_verify_key, mongo_store_partition: MongoStorePartition -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - - obj = MockSyftObject(data=1) - - res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - - assert res.is_ok() - assert res.ok() == obj - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_err() - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=True) - assert res.is_ok() - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - obj2 = MockSyftObject(data=2) - res = mongo_store_partition.set(root_verify_key, obj2, ignore_duplicates=False) - assert res.is_ok() - assert res.ok() == obj2 - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 2 - ) - - repeats = 5 - for idx in range(repeats): - obj = MockSyftObject(data=idx) - res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_ok() - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 3 + idx - ) - - -def test_mongo_store_partition_delete( - root_verify_key, - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - repeats = 5 - - objs = [] - for v in range(repeats): - obj = MockSyftObject(data=v) - mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - objs.append(obj) - - assert len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) == len(objs) - - # random object - obj = MockSyftObject(data="bogus") - key = mongo_store_partition.settings.store_key.with_obj(obj) - res = mongo_store_partition.delete(root_verify_key, key) - assert res.is_err() - assert len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) == len(objs) - - # cleanup store - for idx, v in enumerate(objs): - key = mongo_store_partition.settings.store_key.with_obj(v) - res = mongo_store_partition.delete(root_verify_key, key) - assert res.is_ok() - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == len(objs) - idx - 1 - ) - - res = mongo_store_partition.delete(root_verify_key, key) - assert res.is_err() - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == len(objs) - idx - 1 - ) - - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 0 - ) - - -def test_mongo_store_partition_update( - root_verify_key, - mongo_store_partition: MongoStorePartition, -) -> None: - mongo_store_partition.init_store() - - # add item - obj = MockSyftObject(data=1) - mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - - # fail to update missing keys - rand_obj = MockSyftObject(data="bogus") - key = mongo_store_partition.settings.store_key.with_obj(rand_obj) - res = mongo_store_partition.update(root_verify_key, key, obj) - assert res.is_err() - - # update the key multiple times - repeats = 5 - for v in range(repeats): - key = mongo_store_partition.settings.store_key.with_obj(obj) - obj_new = MockSyftObject(data=v) - - res = mongo_store_partition.update(root_verify_key, key, obj_new) - assert res.is_ok() - - # The ID should stay the same on update, only the values are updated. - assert ( - len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - == 1 - ) - assert ( - mongo_store_partition.all( - root_verify_key, - ) - .ok()[0] - .id - == obj.id - ) - assert ( - mongo_store_partition.all( - root_verify_key, - ) - .ok()[0] - .id - != obj_new.id - ) - assert ( - mongo_store_partition.all( - root_verify_key, - ) - .ok()[0] - .data - == v - ) - - stored = mongo_store_partition.get_all_from_store( - root_verify_key, QueryKeys(qks=[key]) - ) - assert stored.ok()[0].data == v - - -def test_mongo_store_partition_set_threading(root_verify_key, mongo_client) -> None: - thread_cnt = 3 - repeats = 5 - - execution_err = None - mongo_db_name = token_hex(8) - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - - mongo_store_partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - for idx in range(repeats): - obj = MockObjectType(data=idx) - - for _ in range(10): - res = mongo_store_partition.set( - root_verify_key, obj, ignore_duplicates=False - ) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok(), res - - return execution_err - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - mongo_store_partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - stored_cnt = len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - assert stored_cnt == thread_cnt * repeats - - -# @pytest.mark.skip( -# reason="PicklingError: Could not pickle the task to send it to the workers." -# ) -# def test_mongo_store_partition_set_joblib( -# root_verify_key, -# mongo_client, -# ) -> None: -# thread_cnt = 3 -# repeats = 5 -# mongo_db_name = token_hex(8) - -# def _kv_cbk(tid: int) -> None: -# for idx in range(repeats): -# mongo_store_partition = mongo_store_partition_fn( -# mongo_client, -# root_verify_key, -# mongo_db_name=mongo_db_name, -# ) -# obj = MockObjectType(data=idx) - -# for _ in range(10): -# res = mongo_store_partition.set( -# root_verify_key, obj, ignore_duplicates=False -# ) -# if res.is_ok(): -# break - -# if res.is_err(): -# return res - -# return None - -# errs = Parallel(n_jobs=thread_cnt)( -# delayed(_kv_cbk)(idx) for idx in range(thread_cnt) -# ) - -# for execution_err in errs: -# assert execution_err is None - -# mongo_store_partition = mongo_store_partition_fn( -# mongo_client, -# root_verify_key, -# mongo_db_name=mongo_db_name, -# ) -# stored_cnt = len( -# mongo_store_partition.all( -# root_verify_key, -# ).ok() -# ) -# assert stored_cnt == thread_cnt * repeats - - -def test_mongo_store_partition_update_threading( - root_verify_key, - mongo_client, -) -> None: - thread_cnt = 3 - repeats = 5 - - mongo_db_name = token_hex(8) - mongo_store_partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - - obj = MockSyftObject(data=0) - key = mongo_store_partition.settings.store_key.with_obj(obj) - mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - execution_err = None - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - - mongo_store_partition_local = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - for repeat in range(repeats): - obj = MockSyftObject(data=repeat) - - for _ in range(10): - res = mongo_store_partition_local.update(root_verify_key, key, obj) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok(), res - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - -# @pytest.mark.skip( -# reason="PicklingError: Could not pickle the task to send it to the workers." -# ) -# def test_mongo_store_partition_update_joblib(root_verify_key, mongo_client) -> None: -# thread_cnt = 3 -# repeats = 5 - -# mongo_db_name = token_hex(8) - -# mongo_store_partition = mongo_store_partition_fn( -# mongo_client, -# root_verify_key, -# mongo_db_name=mongo_db_name, -# ) -# obj = MockSyftObject(data=0) -# key = mongo_store_partition.settings.store_key.with_obj(obj) -# mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - -# def _kv_cbk(tid: int) -> None: -# mongo_store_partition_local = mongo_store_partition_fn( -# mongo_client, -# root_verify_key, -# mongo_db_name=mongo_db_name, -# ) -# for repeat in range(repeats): -# obj = MockSyftObject(data=repeat) - -# for _ in range(10): -# res = mongo_store_partition_local.update(root_verify_key, key, obj) -# if res.is_ok(): -# break - -# if res.is_err(): -# return res -# return None - -# errs = Parallel(n_jobs=thread_cnt)( -# delayed(_kv_cbk)(idx) for idx in range(thread_cnt) -# ) - -# for execution_err in errs: -# assert execution_err is None - - -def test_mongo_store_partition_set_delete_threading( - root_verify_key, - mongo_client, -) -> None: - thread_cnt = 3 - repeats = 5 - execution_err = None - mongo_db_name = token_hex(8) - - def _kv_cbk(tid: int) -> None: - nonlocal execution_err - mongo_store_partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - - for idx in range(repeats): - obj = MockSyftObject(data=idx) - - for _ in range(10): - res = mongo_store_partition.set( - root_verify_key, obj, ignore_duplicates=False - ) - if res.is_ok(): - break - - if res.is_err(): - execution_err = res - assert res.is_ok() - - key = mongo_store_partition.settings.store_key.with_obj(obj) - - res = mongo_store_partition.delete(root_verify_key, key) - if res.is_err(): - execution_err = res - assert res.is_ok(), res - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - assert execution_err is None - - mongo_store_partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - ) - stored_cnt = len( - mongo_store_partition.all( - root_verify_key, - ).ok() - ) - assert stored_cnt == 0 - - -# @pytest.mark.skip( -# reason="PicklingError: Could not pickle the task to send it to the workers." -# ) -# def test_mongo_store_partition_set_delete_joblib(root_verify_key, mongo_client) -> None: -# thread_cnt = 3 -# repeats = 5 -# mongo_db_name = token_hex(8) - -# def _kv_cbk(tid: int) -> None: -# mongo_store_partition = mongo_store_partition_fn( -# mongo_client, root_verify_key, mongo_db_name=mongo_db_name -# ) - -# for idx in range(repeats): -# obj = MockSyftObject(data=idx) - -# for _ in range(10): -# res = mongo_store_partition.set( -# root_verify_key, obj, ignore_duplicates=False -# ) -# if res.is_ok(): -# break - -# if res.is_err(): -# return res - -# key = mongo_store_partition.settings.store_key.with_obj(obj) - -# res = mongo_store_partition.delete(root_verify_key, key) -# if res.is_err(): -# return res -# return None - -# errs = Parallel(n_jobs=thread_cnt)( -# delayed(_kv_cbk)(idx) for idx in range(thread_cnt) -# ) -# for execution_err in errs: -# assert execution_err is None - -# mongo_store_partition = mongo_store_partition_fn( -# mongo_client, -# root_verify_key, -# mongo_db_name=mongo_db_name, -# ) -# stored_cnt = len( -# mongo_store_partition.all( -# root_verify_key, -# ).ok() -# ) -# assert stored_cnt == 0 - - -def test_mongo_store_partition_permissions_collection( - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - - collection_permissions_status = mongo_store_partition.permissions - assert not collection_permissions_status.is_err() - collection_permissions = collection_permissions_status.ok() - assert isinstance(collection_permissions, MongoCollection) - - -def test_mongo_store_partition_add_remove_permission( - root_verify_key: SyftVerifyKey, mongo_store_partition: MongoStorePartition -) -> None: - """ - Test the add_permission and remove_permission functions of MongoStorePartition - """ - # setting up - res = mongo_store_partition.init_store() - assert res.is_ok() - permissions_collection: MongoCollection = mongo_store_partition.permissions.ok() - obj = MockSyftObject(data=1) - - # add the first permission - obj_read_permission = ActionObjectPermission( - uid=obj.id, permission=ActionPermission.READ, credentials=root_verify_key - ) - mongo_store_partition.add_permission(obj_read_permission) - find_res_1 = permissions_collection.find_one({"_id": obj_read_permission.uid}) - assert find_res_1 is not None - assert len(find_res_1["permissions"]) == 1 - assert find_res_1["permissions"] == { - obj_read_permission.permission_string, - } - - # add the second permission - obj_write_permission = ActionObjectPermission( - uid=obj.id, permission=ActionPermission.WRITE, credentials=root_verify_key - ) - mongo_store_partition.add_permission(obj_write_permission) - - find_res_2 = permissions_collection.find_one({"_id": obj.id}) - assert find_res_2 is not None - assert len(find_res_2["permissions"]) == 2 - assert find_res_2["permissions"] == { - obj_read_permission.permission_string, - obj_write_permission.permission_string, - } - - # add duplicated permission - mongo_store_partition.add_permission(obj_write_permission) - find_res_3 = permissions_collection.find_one({"_id": obj.id}) - assert len(find_res_3["permissions"]) == 2 - assert find_res_3["permissions"] == find_res_2["permissions"] - - # remove the write permission - mongo_store_partition.remove_permission(obj_write_permission) - find_res_4 = permissions_collection.find_one({"_id": obj.id}) - assert len(find_res_4["permissions"]) == 1 - assert find_res_1["permissions"] == { - obj_read_permission.permission_string, - } - - # remove a non-existent permission - with pytest.raises(SyftException): - mongo_store_partition.remove_permission( - ActionObjectPermission( - uid=obj.id, - permission=ActionPermission.OWNER, - credentials=root_verify_key, - ) - ) - find_res_5 = permissions_collection.find_one({"_id": obj.id}) - assert len(find_res_5["permissions"]) == 1 - assert find_res_1["permissions"] == { - obj_read_permission.permission_string, - } - - # there is only one permission object - assert permissions_collection.count_documents({}) == 1 - - # add permissions in a loop - new_permissions = [] - repeats = 5 - for idx in range(1, repeats + 1): - new_obj = MockSyftObject(data=idx) - new_obj_read_permission = ActionObjectPermission( - uid=new_obj.id, - permission=ActionPermission.READ, - credentials=root_verify_key, - ) - new_permissions.append(new_obj_read_permission) - mongo_store_partition.add_permission(new_obj_read_permission) - assert permissions_collection.count_documents({}) == 1 + idx - - # remove all the permissions added in the loop - for permission in new_permissions: - mongo_store_partition.remove_permission(permission) - - assert permissions_collection.count_documents({}) == 1 - - -def test_mongo_store_partition_add_remove_storage_permission( - root_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - """ - Test the add_storage_permission and remove_storage_permission functions of MongoStorePartition - """ - - obj = MockSyftObject(data=1) - - storage_permission = StoragePermission( - uid=obj.id, - server_uid=UID(), - ) - assert not mongo_store_partition.has_storage_permission(storage_permission) - mongo_store_partition.add_storage_permission(storage_permission) - assert mongo_store_partition.has_storage_permission(storage_permission) - mongo_store_partition.remove_storage_permission(storage_permission) - assert not mongo_store_partition.has_storage_permission(storage_permission) - - obj2 = MockSyftObject(data=1) - mongo_store_partition.set(root_verify_key, obj2, add_storage_permission=False) - storage_permission3 = StoragePermission( - uid=obj2.id, server_uid=mongo_store_partition.server_uid - ) - assert not mongo_store_partition.has_storage_permission(storage_permission3) - - obj3 = MockSyftObject(data=1) - mongo_store_partition.set(root_verify_key, obj3, add_storage_permission=True) - storage_permission4 = StoragePermission( - uid=obj3.id, server_uid=mongo_store_partition.server_uid - ) - assert mongo_store_partition.has_storage_permission(storage_permission4) - - -def test_mongo_store_partition_add_permissions( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - permissions_collection: MongoCollection = mongo_store_partition.permissions.ok() - obj = MockSyftObject(data=1) - - # add multiple permissions for the first object - permission_1 = ActionObjectPermission( - uid=obj.id, permission=ActionPermission.WRITE, credentials=root_verify_key - ) - permission_2 = ActionObjectPermission( - uid=obj.id, permission=ActionPermission.OWNER, credentials=root_verify_key - ) - permission_3 = ActionObjectPermission( - uid=obj.id, permission=ActionPermission.READ, credentials=guest_verify_key - ) - permissions: list[ActionObjectPermission] = [ - permission_1, - permission_2, - permission_3, - ] - mongo_store_partition.add_permissions(permissions) - - # check if the permissions have been added properly - assert permissions_collection.count_documents({}) == 1 - find_res = permissions_collection.find_one({"_id": obj.id}) - assert find_res is not None - assert len(find_res["permissions"]) == 3 - - # add permissions for the second object - obj_2 = MockSyftObject(data=2) - permission_4 = ActionObjectPermission( - uid=obj_2.id, permission=ActionPermission.READ, credentials=root_verify_key - ) - permission_5 = ActionObjectPermission( - uid=obj_2.id, permission=ActionPermission.WRITE, credentials=root_verify_key - ) - mongo_store_partition.add_permissions([permission_4, permission_5]) - - assert permissions_collection.count_documents({}) == 2 - find_res_2 = permissions_collection.find_one({"_id": obj_2.id}) - assert find_res_2 is not None - assert len(find_res_2["permissions"]) == 2 - - -@pytest.mark.parametrize("permission", PERMISSIONS) -def test_mongo_store_partition_has_permission( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, - permission: ActionObjectPermission, -) -> None: - hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - - res = mongo_store_partition.init_store() - assert res.is_ok() - - # root permission - obj = MockSyftObject(data=1) - permission_root = permission(uid=obj.id, credentials=root_verify_key) - permission_client = permission(uid=obj.id, credentials=guest_verify_key) - permission_hacker = permission(uid=obj.id, credentials=hacker_verify_key) - mongo_store_partition.add_permission(permission_root) - # only the root user has access to this permission - assert mongo_store_partition.has_permission(permission_root) - assert not mongo_store_partition.has_permission(permission_client) - assert not mongo_store_partition.has_permission(permission_hacker) - - # client permission for another object - obj_2 = MockSyftObject(data=2) - permission_client_2 = permission(uid=obj_2.id, credentials=guest_verify_key) - permission_root_2 = permission(uid=obj_2.id, credentials=root_verify_key) - permisson_hacker_2 = permission(uid=obj_2.id, credentials=hacker_verify_key) - mongo_store_partition.add_permission(permission_client_2) - # the root (admin) and guest client should have this permission - assert mongo_store_partition.has_permission(permission_root_2) - assert mongo_store_partition.has_permission(permission_client_2) - assert not mongo_store_partition.has_permission(permisson_hacker_2) - - # remove permissions - mongo_store_partition.remove_permission(permission_root) - assert not mongo_store_partition.has_permission(permission_root) - assert not mongo_store_partition.has_permission(permission_client) - assert not mongo_store_partition.has_permission(permission_hacker) - - mongo_store_partition.remove_permission(permission_client_2) - assert not mongo_store_partition.has_permission(permission_root_2) - assert not mongo_store_partition.has_permission(permission_client_2) - assert not mongo_store_partition.has_permission(permisson_hacker_2) - - -@pytest.mark.parametrize("permission", PERMISSIONS) -def test_mongo_store_partition_take_ownership( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, - permission: ActionObjectPermission, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - - hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - obj = MockSyftObject(data=1) - - # the guest client takes ownership of obj - mongo_store_partition.take_ownership( - uid=obj.id, credentials=guest_verify_key - ).unwrap() - assert mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=guest_verify_key) - ) - # the root client will also has the permission - assert mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=root_verify_key) - ) - assert not mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=hacker_verify_key) - ) - - # hacker or root try to take ownership of the obj and will fail - res = mongo_store_partition.take_ownership( - uid=obj.id, credentials=hacker_verify_key - ) - res_2 = mongo_store_partition.take_ownership( - uid=obj.id, credentials=root_verify_key - ) - assert res.is_err() - assert res_2.is_err() - assert ( - res.value.public_message - == res_2.value.public_message - == f"UID: {obj.id} already owned." - ) - - # another object - obj_2 = MockSyftObject(data=2) - # root client takes ownership - mongo_store_partition.take_ownership(uid=obj_2.id, credentials=root_verify_key) - assert mongo_store_partition.has_permission( - permission(uid=obj_2.id, credentials=root_verify_key) - ) - assert not mongo_store_partition.has_permission( - permission(uid=obj_2.id, credentials=guest_verify_key) - ) - assert not mongo_store_partition.has_permission( - permission(uid=obj_2.id, credentials=hacker_verify_key) - ) - - -def test_mongo_store_partition_permissions_set( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - """ - Test the permissions functionalities when using MongoStorePartition._set function - """ - hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - res = mongo_store_partition.init_store() - assert res.is_ok() - - # set the object to mongo_store_partition.collection - obj = MockSyftObject(data=1) - res = mongo_store_partition.set(root_verify_key, obj, ignore_duplicates=False) - assert res.is_ok() - assert res.ok() == obj - - # check if the corresponding permissions has been added to the permissions - # collection after the root client claim it - pemissions_collection = mongo_store_partition.permissions.ok() - assert isinstance(pemissions_collection, MongoCollection) - permissions = pemissions_collection.find_one({"_id": obj.id}) - assert permissions is not None - assert isinstance(permissions["permissions"], set) - assert len(permissions["permissions"]) == 4 - for permission in PERMISSIONS: - assert mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=root_verify_key) - ) - - # the hacker tries to set duplicated object but should not be able to claim it - res_2 = mongo_store_partition.set(guest_verify_key, obj, ignore_duplicates=True) - assert res_2.is_ok() - for permission in PERMISSIONS: - assert not mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=hacker_verify_key) - ) - assert mongo_store_partition.has_permission( - permission(uid=obj.id, credentials=root_verify_key) - ) - - -def test_mongo_store_partition_permissions_get_all( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - # set several objects for the root and guest client - num_root_objects: int = 5 - num_guest_objects: int = 3 - for i in range(num_root_objects): - obj = MockSyftObject(data=i) - mongo_store_partition.set( - credentials=root_verify_key, obj=obj, ignore_duplicates=False - ) - for i in range(num_guest_objects): - obj = MockSyftObject(data=i) - mongo_store_partition.set( - credentials=guest_verify_key, obj=obj, ignore_duplicates=False - ) - - assert ( - len(mongo_store_partition.all(root_verify_key).ok()) - == num_root_objects + num_guest_objects - ) - assert len(mongo_store_partition.all(guest_verify_key).ok()) == num_guest_objects - assert len(mongo_store_partition.all(hacker_verify_key).ok()) == 0 - - -def test_mongo_store_partition_permissions_delete( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - collection: MongoCollection = mongo_store_partition.collection.ok() - pemissions_collection: MongoCollection = mongo_store_partition.permissions.ok() - hacker_verify_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_HACKER) - - # the root client set an object - obj = MockSyftObject(data=1) - mongo_store_partition.set( - credentials=root_verify_key, obj=obj, ignore_duplicates=False - ) - qk: QueryKey = mongo_store_partition.settings.store_key.with_obj(obj) - # guest or hacker can't delete it - assert not mongo_store_partition.delete(guest_verify_key, qk).is_ok() - assert not mongo_store_partition.delete(hacker_verify_key, qk).is_ok() - # only the root client can delete it - assert mongo_store_partition.delete(root_verify_key, qk).is_ok() - # check if the object and its permission have been deleted - assert collection.count_documents({}) == 0 - assert pemissions_collection.count_documents({}) == 0 - - # the guest client set an object - obj_2 = MockSyftObject(data=2) - mongo_store_partition.set( - credentials=guest_verify_key, obj=obj_2, ignore_duplicates=False - ) - qk_2: QueryKey = mongo_store_partition.settings.store_key.with_obj(obj_2) - # the hacker can't delete it - assert not mongo_store_partition.delete(hacker_verify_key, qk_2).is_ok() - # the guest client can delete it - assert mongo_store_partition.delete(guest_verify_key, qk_2).is_ok() - assert collection.count_documents({}) == 0 - assert pemissions_collection.count_documents({}) == 0 - - # the guest client set another object - obj_3 = MockSyftObject(data=3) - mongo_store_partition.set( - credentials=guest_verify_key, obj=obj_3, ignore_duplicates=False - ) - qk_3: QueryKey = mongo_store_partition.settings.store_key.with_obj(obj_3) - # the root client also has the permission to delete it - assert mongo_store_partition.delete(root_verify_key, qk_3).is_ok() - assert collection.count_documents({}) == 0 - assert pemissions_collection.count_documents({}) == 0 - - -def test_mongo_store_partition_permissions_update( - root_verify_key: SyftVerifyKey, - guest_verify_key: SyftVerifyKey, - mongo_store_partition: MongoStorePartition, -) -> None: - res = mongo_store_partition.init_store() - assert res.is_ok() - # the root client set an object - obj = MockSyftObject(data=1) - mongo_store_partition.set( - credentials=root_verify_key, obj=obj, ignore_duplicates=False - ) - assert len(mongo_store_partition.all(credentials=root_verify_key).ok()) == 1 - - qk: QueryKey = mongo_store_partition.settings.store_key.with_obj(obj) - permsissions: MongoCollection = mongo_store_partition.permissions.ok() - repeats = 5 - - for v in range(repeats): - # the guest client should not have permission to update obj - obj_new = MockSyftObject(data=v) - res = mongo_store_partition.update( - credentials=guest_verify_key, qk=qk, obj=obj_new - ) - assert res.is_err() - # the root client has the permission to update obj - res = mongo_store_partition.update( - credentials=root_verify_key, qk=qk, obj=obj_new - ) - assert res.is_ok() - # the id of the object in the permission collection should not be changed - assert permsissions.find_one(qk.as_dict_mongo)["_id"] == obj.id diff --git a/packages/syft/tests/syft/stores/queue_stash_test.py b/packages/syft/tests/syft/stores/queue_stash_test.py index 312766e7c4e..cd004ce1107 100644 --- a/packages/syft/tests/syft/stores/queue_stash_test.py +++ b/packages/syft/tests/syft/stores/queue_stash_test.py @@ -16,7 +16,6 @@ from syft.types.uid import UID # relative -from .store_fixtures_test import mongo_queue_stash_fn from .store_fixtures_test import sqlite_queue_stash_fn @@ -49,7 +48,6 @@ def mock_queue_object(): [ pytest.lazy_fixture("dict_queue_stash"), pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), ], ) def test_queue_stash_sanity(queue: Any) -> None: @@ -63,7 +61,6 @@ def test_queue_stash_sanity(queue: Any) -> None: [ pytest.lazy_fixture("dict_queue_stash"), pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), ], ) # @pytest.mark.flaky(reruns=3, reruns_delay=3) @@ -104,7 +101,6 @@ def test_queue_stash_set_get(root_verify_key, queue: Any) -> None: [ pytest.lazy_fixture("dict_queue_stash"), pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), ], ) @pytest.mark.flaky(reruns=3, reruns_delay=3) @@ -135,7 +131,6 @@ def test_queue_stash_update(root_verify_key, queue: Any) -> None: [ pytest.lazy_fixture("dict_queue_stash"), pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), ], ) @pytest.mark.flaky(reruns=3, reruns_delay=3) @@ -178,7 +173,6 @@ def _kv_cbk(tid: int) -> None: [ pytest.lazy_fixture("dict_queue_stash"), pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), ], ) @pytest.mark.flaky(reruns=3, reruns_delay=3) @@ -222,7 +216,6 @@ def _kv_cbk(tid: int) -> None: [ pytest.lazy_fixture("dict_queue_stash"), pytest.lazy_fixture("sqlite_queue_stash"), - pytest.lazy_fixture("mongo_queue_stash"), ], ) @pytest.mark.flaky(reruns=3, reruns_delay=3) @@ -319,14 +312,6 @@ def create_queue_cbk(): helper_queue_set_threading(root_verify_key, create_queue_cbk) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_queue_set_threading_mongo(root_verify_key, mongo_document_store): - def create_queue_cbk(): - return mongo_queue_stash_fn(mongo_document_store) - - helper_queue_set_threading(root_verify_key, create_queue_cbk) - - def helper_queue_update_threading(root_verify_key, create_queue_cbk) -> None: thread_cnt = 3 repeats = 5 @@ -377,14 +362,6 @@ def create_queue_cbk(): helper_queue_update_threading(root_verify_key, create_queue_cbk) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_queue_update_threading_mongo(root_verify_key, mongo_document_store): - def create_queue_cbk(): - return mongo_queue_stash_fn(mongo_document_store) - - helper_queue_update_threading(root_verify_key, create_queue_cbk) - - def helper_queue_set_delete_threading( root_verify_key, create_queue_cbk, @@ -441,11 +418,3 @@ def create_queue_cbk(): return sqlite_queue_stash_fn(root_verify_key, sqlite_workspace) helper_queue_set_delete_threading(root_verify_key, create_queue_cbk) - - -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_queue_delete_threading_mongo(root_verify_key, mongo_document_store): - def create_queue_cbk(): - return mongo_queue_stash_fn(mongo_document_store) - - helper_queue_set_delete_threading(root_verify_key, create_queue_cbk) diff --git a/packages/syft/tests/syft/stores/store_fixtures_test.py b/packages/syft/tests/syft/stores/store_fixtures_test.py index b64d14be8e3..09e684eee7c 100644 --- a/packages/syft/tests/syft/stores/store_fixtures_test.py +++ b/packages/syft/tests/syft/stores/store_fixtures_test.py @@ -14,7 +14,6 @@ from syft.service.action.action_permissions import ActionObjectPermission from syft.service.action.action_permissions import ActionPermission from syft.service.action.action_store import DictActionStore -from syft.service.action.action_store import MongoActionStore from syft.service.action.action_store import SQLiteActionStore from syft.service.queue.queue_stash import QueueStash from syft.service.user.user import User @@ -29,10 +28,6 @@ from syft.store.locks import LockingConfig from syft.store.locks import NoLockingConfig from syft.store.locks import ThreadingLockingConfig -from syft.store.mongo_client import MongoStoreClientConfig -from syft.store.mongo_document_store import MongoDocumentStore -from syft.store.mongo_document_store import MongoStoreConfig -from syft.store.mongo_document_store import MongoStorePartition from syft.store.sqlite_document_store import SQLiteDocumentStore from syft.store.sqlite_document_store import SQLiteStoreClientConfig from syft.store.sqlite_document_store import SQLiteStoreConfig @@ -226,118 +221,6 @@ def sqlite_action_store(sqlite_workspace: tuple[Path, str], request): ) -def mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name: str = "mongo_db", - locking_config_name: str = "nop", -): - mongo_config = MongoStoreClientConfig(client=mongo_client) - - locking_config = str_to_locking_config(locking_config_name) - - store_config = MongoStoreConfig( - client_config=mongo_config, - db_name=mongo_db_name, - locking_config=locking_config, - ) - settings = PartitionSettings(name="test", object_type=MockObjectType) - - return MongoStorePartition( - UID(), root_verify_key, settings=settings, store_config=store_config - ) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def mongo_store_partition(root_verify_key, mongo_client, request): - mongo_db_name = token_hex(8) - locking_config_name = request.param - - partition = mongo_store_partition_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - locking_config_name=locking_config_name, - ) - yield partition - - # cleanup db - try: - mongo_client.drop_database(mongo_db_name) - except BaseException as e: - print("failed to cleanup mongo fixture", e) - - -def mongo_document_store_fn( - mongo_client, - root_verify_key, - mongo_db_name: str = "mongo_db", - locking_config_name: str = "nop", -): - locking_config = str_to_locking_config(locking_config_name) - mongo_config = MongoStoreClientConfig(client=mongo_client) - store_config = MongoStoreConfig( - client_config=mongo_config, db_name=mongo_db_name, locking_config=locking_config - ) - - mongo_client.drop_database(mongo_db_name) - - return MongoDocumentStore(UID(), root_verify_key, store_config=store_config) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def mongo_document_store(root_verify_key, mongo_client, request): - locking_config_name = request.param - mongo_db_name = token_hex(8) - yield mongo_document_store_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - locking_config_name=locking_config_name, - ) - - -def mongo_queue_stash_fn(mongo_document_store): - return QueueStash(store=mongo_document_store) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def mongo_queue_stash(root_verify_key, mongo_client, request): - mongo_db_name = token_hex(8) - locking_config_name = request.param - - store = mongo_document_store_fn( - mongo_client, - root_verify_key, - mongo_db_name=mongo_db_name, - locking_config_name=locking_config_name, - ) - yield mongo_queue_stash_fn(store) - - -@pytest.fixture(scope="function", params=locking_scenarios) -def mongo_action_store(mongo_client, request): - mongo_db_name = token_hex(8) - locking_config_name = request.param - locking_config = str_to_locking_config(locking_config_name) - - mongo_config = MongoStoreClientConfig(client=mongo_client) - store_config = MongoStoreConfig( - client_config=mongo_config, db_name=mongo_db_name, locking_config=locking_config - ) - ver_key = SyftVerifyKey.from_string(TEST_VERIFY_KEY_STRING_ROOT) - server_uid = UID() - document_store = document_store_with_admin(server_uid, ver_key) - mongo_action_store = MongoActionStore( - server_uid=server_uid, - store_config=store_config, - root_verify_key=ver_key, - document_store=document_store, - ) - - yield mongo_action_store - - def dict_store_partition_fn( root_verify_key, locking_config_name: str = "nop", From 175a50b2eb507778a83a3d7d056e77874577e3cf Mon Sep 17 00:00:00 2001 From: dk Date: Wed, 11 Sep 2024 11:19:29 +0700 Subject: [PATCH 129/197] [clean] remove mongo from tox.ini and helm's `base.yaml` - scanning for for postgres container security instead of mongo --- .github/workflows/container-scan.yml | 16 ++++++++-------- packages/grid/helm/examples/dev/base.yaml | 7 ------- tox.ini | 12 ++++++------ 3 files changed, 14 insertions(+), 21 deletions(-) diff --git a/.github/workflows/container-scan.yml b/.github/workflows/container-scan.yml index 303eb11bc40..0e12d35d357 100644 --- a/.github/workflows/container-scan.yml +++ b/.github/workflows/container-scan.yml @@ -224,7 +224,7 @@ jobs: name: syft.sbom.json path: syft.sbom.json - scan-mongo-latest-trivy: + scan-postgres-latest-trivy: permissions: contents: read # for actions/checkout to fetch code security-events: write # for github/codeql-action/upload-sarif to upload SARIF results @@ -238,24 +238,24 @@ jobs: continue-on-error: true uses: aquasecurity/trivy-action@master with: - image-ref: "mongo:7.0.0" + image-ref: "postgres:13" format: "cyclonedx" - output: "mongo-trivy-results.sbom.json" + output: "postgres-trivy-results.sbom.json" timeout: "10m0s" #Upload SBOM to GitHub Artifacts - name: Upload SBOM to GitHub Artifacts uses: actions/upload-artifact@v4 with: - name: mongo-trivy-results.sbom.json - path: mongo-trivy-results.sbom.json + name: postgres-trivy-results.sbom.json + path: postgres-trivy-results.sbom.json #Generate sarif file - name: Run Trivy vulnerability scanner continue-on-error: true uses: aquasecurity/trivy-action@master with: - image-ref: "mongo:7.0.0" + image-ref: "postgres:13" format: "sarif" output: "trivy-results.sarif" timeout: "10m0s" @@ -266,7 +266,7 @@ jobs: with: sarif_file: "trivy-results.sarif" - scan-mongo-latest-snyk: + scan-postgres-latest-snyk: permissions: contents: read # for actions/checkout to fetch code security-events: write # for github/codeql-action/upload-sarif to upload SARIF results @@ -281,7 +281,7 @@ jobs: # This is where you will need to introduce the Snyk API token created with your Snyk account SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }} with: - image: mongo:7.0.0 + image: postgres:13 args: --sarif-file-output=snyk-code.sarif # Replace any "undefined" security severity values with 0. The undefined value is used in the case diff --git a/packages/grid/helm/examples/dev/base.yaml b/packages/grid/helm/examples/dev/base.yaml index 4d290478841..3fc1ad5c4da 100644 --- a/packages/grid/helm/examples/dev/base.yaml +++ b/packages/grid/helm/examples/dev/base.yaml @@ -21,13 +21,6 @@ server: secret: defaultRootPassword: changethis -mongo: - resourcesPreset: null - resources: null - - secret: - rootPassword: example - postgres: resourcesPreset: null resources: null diff --git a/tox.ini b/tox.ini index 4004cf81fbb..f722449463a 100644 --- a/tox.ini +++ b/tox.ini @@ -837,7 +837,7 @@ allowlist_externals = setenv = CLUSTER_NAME = {env:CLUSTER_NAME:syft} CLUSTER_HTTP_PORT = {env:SERVER_PORT:8080} -; Usage for posargs: names of the relevant services among {frontend backend proxy mongo seaweedfs registry} +; Usage for posargs: names of the relevant services among {frontend backend proxy postgres seaweedfs registry} commands = bash -c "env; date; k3d version" @@ -856,7 +856,7 @@ commands = fi" # Mongo - bash -c "if echo '{posargs}' | grep -q 'mongo'; then echo 'Checking readiness of Mongo'; ./scripts/wait_for.sh service mongo --context k3d-$CLUSTER_NAME --namespace syft; fi" + bash -c "if echo '{posargs}' | grep -q 'postgres'; then echo 'Checking readiness of Postgres'; ./scripts/wait_for.sh service postgres --context k3d-$CLUSTER_NAME --namespace syft; fi" # Proxy bash -c "if echo '{posargs}' | grep -q 'proxy'; then echo 'Checking readiness of proxy'; ./scripts/wait_for.sh service proxy --context k3d-$CLUSTER_NAME --namespace syft; fi" @@ -900,7 +900,7 @@ commands = echo "Installing local helm charts"; \ if [[ "{posargs}" == "override" ]]; then \ echo "Overriding resourcesPreset"; \ - helm install ${CLUSTER_NAME} ./helm/syft -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set mongo.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null; \ + helm install ${CLUSTER_NAME} ./helm/syft -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set postgres.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null; \ else \ helm install ${CLUSTER_NAME} ./helm/syft -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace; \ fi \ @@ -910,14 +910,14 @@ commands = helm repo update openmined; \ if [[ "{posargs}" == "override" ]]; then \ echo "Overriding resourcesPreset"; \ - helm install ${CLUSTER_NAME} openmined/syft --version=${SYFT_VERSION} -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set mongo.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null; \ + helm install ${CLUSTER_NAME} openmined/syft --version=${SYFT_VERSION} -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set postgres.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null; \ else \ helm install ${CLUSTER_NAME} openmined/syft --version=${SYFT_VERSION} -f ./helm/examples/dev/base.yaml --kube-context k3d-${CLUSTER_NAME} --namespace syft --create-namespace; \ fi \ fi' ; wait for everything else to be loaded - tox -e dev.k8s.ready -- frontend backend mongo proxy seaweedfs registry + tox -e dev.k8s.ready -- frontend backend postgres proxy seaweedfs registry # Run Notebook tests tox -e e2e.test.notebook @@ -1354,7 +1354,7 @@ commands = ' ; wait for everything else to be loaded - tox -e dev.k8s.ready -- frontend backend mongo proxy seaweedfs registry + tox -e dev.k8s.ready -- frontend backend postgres proxy seaweedfs registry bash -c 'python -c "import syft as sy; print(\"Migrating from syft version:\", sy.__version__)"' From 81bf1a3d8089907bed5fb3240b910c708b1be5b2 Mon Sep 17 00:00:00 2001 From: dk Date: Wed, 11 Sep 2024 12:02:23 +0700 Subject: [PATCH 130/197] [clean] delete mongo code from document store - remove mongo from gcp helm charts --- .../deployments/03-deploy-k8s-k3d.ipynb | 4 +- .../grid/helm/examples/azure/azure.high.yaml | 2 +- packages/grid/helm/examples/gcp/gcp.high.yaml | 2 +- packages/grid/helm/examples/gcp/gcp.low.yaml | 2 +- .../grid/helm/examples/gcp/gcp.nosync.yaml | 2 +- packages/syft/src/syft/serde/third_party.py | 6 - .../src/syft/service/action/action_object.py | 1 - packages/syft/src/syft/store/__init__.py | 3 - .../syft/src/syft/store/document_store.py | 25 - packages/syft/src/syft/store/mongo_client.py | 275 ----- packages/syft/src/syft/store/mongo_codecs.py | 31 - .../src/syft/store/mongo_document_store.py | 962 ------------------ scripts/dev_tools.sh | 4 +- 13 files changed, 7 insertions(+), 1312 deletions(-) delete mode 100644 packages/syft/src/syft/store/mongo_client.py delete mode 100644 packages/syft/src/syft/store/mongo_codecs.py delete mode 100644 packages/syft/src/syft/store/mongo_document_store.py diff --git a/notebooks/tutorials/deployments/03-deploy-k8s-k3d.ipynb b/notebooks/tutorials/deployments/03-deploy-k8s-k3d.ipynb index c64e6c40f4a..a92f7987e68 100644 --- a/notebooks/tutorials/deployments/03-deploy-k8s-k3d.ipynb +++ b/notebooks/tutorials/deployments/03-deploy-k8s-k3d.ipynb @@ -78,7 +78,7 @@ "If you want to deploy your Kubernetes cluster in a resource-constrained environment, use the following flags to override the default configurations. Please note that you will need at least 1 CPU and 2 GB of RAM on Docker, and some tests may not work in such low-resource environments:\n", "\n", "```sh\n", - "helm install my-syft openmined/syft --version $SYFT_VERSION --namespace syft --create-namespace --set ingress.className=\"traefik\" --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set mongo.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null\n", + "helm install my-syft openmined/syft --version $SYFT_VERSION --namespace syft --create-namespace --set ingress.className=\"traefik\" --set server.resourcesPreset=null --set seaweedfs.resourcesPreset=null --set postgres.resourcesPreset=null --set registry.resourcesPreset=null --set proxy.resourcesPreset=null --set frontend.resourcesPreset=null\n", "```\n", "\n", "\n", @@ -89,7 +89,7 @@ "If you would like to set your own default password even for the production style deployment, use the following command:\n", "\n", "```sh\n", - "helm install my-syft openmined/syft --version $SYFT_VERSION --namespace syft --create-namespace --set ingress.className=\"traefik\" --set global.randomizedSecrets=false --set server.secret.defaultRootPassword=\"changethis\" --set seaweedfs.secret.s3RootPassword=\"admin\" --set mongo.secret.rootPassword=\"example\"\n", + "helm install my-syft openmined/syft --version $SYFT_VERSION --namespace syft --create-namespace --set ingress.className=\"traefik\" --set global.randomizedSecrets=false --set server.secret.defaultRootPassword=\"changethis\" --set seaweedfs.secret.s3RootPassword=\"admin\" --set postgres.secret.rootPassword=\"example\"\n", "```\n", "\n" ] diff --git a/packages/grid/helm/examples/azure/azure.high.yaml b/packages/grid/helm/examples/azure/azure.high.yaml index 3234a62b757..4733fed4cc7 100644 --- a/packages/grid/helm/examples/azure/azure.high.yaml +++ b/packages/grid/helm/examples/azure/azure.high.yaml @@ -38,5 +38,5 @@ registry: frontend: resourcesPreset: medium -mongo: +postgres: resourcesPreset: large diff --git a/packages/grid/helm/examples/gcp/gcp.high.yaml b/packages/grid/helm/examples/gcp/gcp.high.yaml index efdbbe72e68..2a430807fac 100644 --- a/packages/grid/helm/examples/gcp/gcp.high.yaml +++ b/packages/grid/helm/examples/gcp/gcp.high.yaml @@ -97,7 +97,7 @@ frontend: # ================================================================================= -mongo: +postgres: resourcesPreset: large # ================================================================================= diff --git a/packages/grid/helm/examples/gcp/gcp.low.yaml b/packages/grid/helm/examples/gcp/gcp.low.yaml index 94cfc324b0b..8e9e3e7ba35 100644 --- a/packages/grid/helm/examples/gcp/gcp.low.yaml +++ b/packages/grid/helm/examples/gcp/gcp.low.yaml @@ -97,7 +97,7 @@ frontend: # ================================================================================= -mongo: +postgres: resourcesPreset: large # ================================================================================= diff --git a/packages/grid/helm/examples/gcp/gcp.nosync.yaml b/packages/grid/helm/examples/gcp/gcp.nosync.yaml index 02935edfd8f..8e622be5254 100644 --- a/packages/grid/helm/examples/gcp/gcp.nosync.yaml +++ b/packages/grid/helm/examples/gcp/gcp.nosync.yaml @@ -67,7 +67,7 @@ frontend: # ================================================================================= -mongo: +postgres: resourcesPreset: large # ================================================================================= diff --git a/packages/syft/src/syft/serde/third_party.py b/packages/syft/src/syft/serde/third_party.py index 89cc5ffaab5..6fadf6261f9 100644 --- a/packages/syft/src/syft/serde/third_party.py +++ b/packages/syft/src/syft/serde/third_party.py @@ -18,7 +18,6 @@ import pyarrow.parquet as pq import pydantic from pydantic._internal._model_construction import ModelMetaclass -from pymongo.collection import Collection # relative from ..types.dicttuple import DictTuple @@ -58,11 +57,6 @@ # exceptions recursive_serde_register(cls=TypeError, canonical_name="TypeError", version=1) -# mongo collection -recursive_serde_register_type( - Collection, canonical_name="pymongo_collection", version=1 -) - def serialize_dataframe(df: DataFrame) -> bytes: table = pa.Table.from_pandas(df) diff --git a/packages/syft/src/syft/service/action/action_object.py b/packages/syft/src/syft/service/action/action_object.py index d7e566b733a..e1513f12bd3 100644 --- a/packages/syft/src/syft/service/action/action_object.py +++ b/packages/syft/src/syft/service/action/action_object.py @@ -254,7 +254,6 @@ class ActionObjectPointer: "created_date", # syft "updated_date", # syft "deleted_date", # syft - "to_mongo", # syft 🟡 TODO 23: Add composeable / inheritable object passthrough attrs "__attr_searchable__", # syft "__canonical_name__", # syft "__version__", # syft diff --git a/packages/syft/src/syft/store/__init__.py b/packages/syft/src/syft/store/__init__.py index 9260d13f956..e69de29bb2d 100644 --- a/packages/syft/src/syft/store/__init__.py +++ b/packages/syft/src/syft/store/__init__.py @@ -1,3 +0,0 @@ -# relative -from .mongo_document_store import MongoDict -from .mongo_document_store import MongoStoreConfig diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index cc97802a08b..5ff2542001c 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -188,16 +188,6 @@ def from_obj(partition_key: PartitionKey, obj: Any) -> QueryKey: def as_dict(self) -> dict[str, Any]: return {self.key: self.value} - @property - def as_dict_mongo(self) -> dict[str, Any]: - key = self.key - if key == "id": - key = "_id" - if self.type_list: - # We want to search inside the list of values - return {key: {"$in": self.value}} - return {key: self.value} - @serializable(canonical_name="PartitionKeysWithUID", version=1) class PartitionKeysWithUID(PartitionKeys): @@ -273,21 +263,6 @@ def as_dict(self) -> dict: qk_dict[qk_key] = qk_value return qk_dict - @property - def as_dict_mongo(self) -> dict: - qk_dict = {} - for qk in self.all: - qk_key = qk.key - qk_value = qk.value - if qk_key == "id": - qk_key = "_id" - if qk.type_list: - # We want to search inside the list of values - qk_dict[qk_key] = {"$in": qk_value} - else: - qk_dict[qk_key] = qk_value - return qk_dict - UIDPartitionKey = PartitionKey(key="id", type_=UID) diff --git a/packages/syft/src/syft/store/mongo_client.py b/packages/syft/src/syft/store/mongo_client.py deleted file mode 100644 index 9767059a1cb..00000000000 --- a/packages/syft/src/syft/store/mongo_client.py +++ /dev/null @@ -1,275 +0,0 @@ -# stdlib -import logging -from threading import Lock -from typing import Any - -# third party -from pymongo.collection import Collection as MongoCollection -from pymongo.database import Database as MongoDatabase -from pymongo.errors import ConnectionFailure -from pymongo.mongo_client import MongoClient as PyMongoClient - -# relative -from ..serde.serializable import serializable -from ..types.errors import SyftException -from ..types.result import as_result -from ..util.telemetry import TRACING_ENABLED -from .document_store import PartitionSettings -from .document_store import StoreClientConfig -from .document_store import StoreConfig -from .mongo_codecs import SYFT_CODEC_OPTIONS - -if TRACING_ENABLED: - try: - # third party - from opentelemetry.instrumentation.pymongo import PymongoInstrumentor - - PymongoInstrumentor().instrument() - message = "> Added OTEL PymongoInstrumentor" - print(message) - logger = logging.getLogger(__name__) - logger.info(message) - except Exception: # nosec - pass - - -@serializable(canonical_name="MongoStoreClientConfig", version=1) -class MongoStoreClientConfig(StoreClientConfig): - """ - Paramaters: - `hostname`: optional string - hostname or IP address or Unix domain socket path of a single mongod or mongos - instance to connect to, or a mongodb URI, or a list of hostnames (but no more - than one mongodb URI). If `host` is an IPv6 literal it must be enclosed in '[' - and ']' characters following the RFC2732 URL syntax (e.g. '[::1]' for localhost). - Multihomed and round robin DNS addresses are **not** supported. - `port` : optional int - port number on which to connect - `directConnection`: bool - if ``True``, forces this client to connect directly to the specified MongoDB host - as a standalone. If ``false``, the client connects to the entire replica set of which - the given MongoDB host(s) is a part. If this is ``True`` and a mongodb+srv:// URI - or a URI containing multiple seeds is provided, an exception will be raised. - `maxPoolSize`: int. Default 100 - The maximum allowable number of concurrent connections to each connected server. - Requests to a server will block if there are `maxPoolSize` outstanding connections - to the requested server. Defaults to 100. Can be either 0 or None, in which case - there is no limit on the number of concurrent connections. - `minPoolSize` : int. Default 0 - The minimum required number of concurrent connections that the pool will maintain - to each connected server. Default is 0. - `maxIdleTimeMS`: int - The maximum number of milliseconds that a connection can remain idle in the pool - before being removed and replaced. Defaults to `None` (no limit). - `appname`: string - The name of the application that created this MongoClient instance. The server will - log this value upon establishing each connection. It is also recorded in the slow - query log and profile collections. - `maxConnecting`: optional int - The maximum number of connections that each pool can establish concurrently. - Defaults to `2`. - `timeoutMS`: (integer or None) - Controls how long (in milliseconds) the driver will wait when executing an operation - (including retry attempts) before raising a timeout error. ``0`` or ``None`` means - no timeout. - `socketTimeoutMS`: (integer or None) - Controls how long (in milliseconds) the driver will wait for a response after sending - an ordinary (non-monitoring) database operation before concluding that a network error - has occurred. ``0`` or ``None`` means no timeout. Defaults to ``None`` (no timeout). - `connectTimeoutMS`: (integer or None) - Controls how long (in milliseconds) the driver will wait during server monitoring when - connecting a new socket to a server before concluding the server is unavailable. - ``0`` or ``None`` means no timeout. Defaults to ``20000`` (20 seconds). - `serverSelectionTimeoutMS`: (integer) - Controls how long (in milliseconds) the driver will wait to find an available, appropriate - server to carry out a database operation; while it is waiting, multiple server monitoring - operations may be carried out, each controlled by `connectTimeoutMS`. - Defaults to ``120000`` (120 seconds). - `waitQueueTimeoutMS`: (integer or None) - How long (in milliseconds) a thread will wait for a socket from the pool if the pool - has no free sockets. Defaults to ``None`` (no timeout). - `heartbeatFrequencyMS`: (optional) - The number of milliseconds between periodic server checks, or None to accept the default - frequency of 10 seconds. - # Auth - username: str - Database username - password: str - Database pass - authSource: str - The database to authenticate on. - Defaults to the database specified in the URI, if provided, or to “admin”. - tls: bool - If True, create the connection to the server using transport layer security. - Defaults to False. - # Testing and connection reuse - client: Optional[PyMongoClient] - If provided, this client is reused. Default = None - - """ - - # Connection - hostname: str | None = "127.0.0.1" - port: int | None = None - directConnection: bool = False - maxPoolSize: int = 200 - minPoolSize: int = 0 - maxIdleTimeMS: int | None = None - maxConnecting: int = 3 - timeoutMS: int = 0 - socketTimeoutMS: int = 0 - connectTimeoutMS: int = 20000 - serverSelectionTimeoutMS: int = 120000 - waitQueueTimeoutMS: int | None = None - heartbeatFrequencyMS: int = 10000 - appname: str = "pysyft" - # Auth - username: str | None = None - password: str | None = None - authSource: str = "admin" - tls: bool | None = False - # Testing and connection reuse - client: Any = None - - # this allows us to have one connection per `Server` object - # in the MongoClientCache - server_obj_python_id: int | None = None - - -class MongoClientCache: - __client_cache__: dict[int, type["MongoClient"] | None] = {} - _lock: Lock = Lock() - - @classmethod - def from_cache(cls, config: MongoStoreClientConfig) -> PyMongoClient | None: - return cls.__client_cache__.get(hash(str(config)), None) - - @classmethod - def set_cache(cls, config: MongoStoreClientConfig, client: PyMongoClient) -> None: - with cls._lock: - cls.__client_cache__[hash(str(config))] = client - - -class MongoClient: - client: PyMongoClient = None - - def __init__(self, config: MongoStoreClientConfig, cache: bool = True) -> None: - self.config = config - if config.client is not None: - self.client = config.client - elif cache: - self.client = MongoClientCache.from_cache(config=config) - - if not cache or self.client is None: - self.connect(config=config).unwrap() - - @as_result(SyftException) - def connect(self, config: MongoStoreClientConfig) -> bool: - self.client = PyMongoClient( - # Connection - host=config.hostname, - port=config.port, - directConnection=config.directConnection, - maxPoolSize=config.maxPoolSize, - minPoolSize=config.minPoolSize, - maxIdleTimeMS=config.maxIdleTimeMS, - maxConnecting=config.maxConnecting, - timeoutMS=config.timeoutMS, - socketTimeoutMS=config.socketTimeoutMS, - connectTimeoutMS=config.connectTimeoutMS, - serverSelectionTimeoutMS=config.serverSelectionTimeoutMS, - waitQueueTimeoutMS=config.waitQueueTimeoutMS, - heartbeatFrequencyMS=config.heartbeatFrequencyMS, - appname=config.appname, - # Auth - username=config.username, - password=config.password, - authSource=config.authSource, - tls=config.tls, - uuidRepresentation="standard", - ) - MongoClientCache.set_cache(config=config, client=self.client) - try: - # Check if mongo connection is still up - self.client.admin.command("ping") - except ConnectionFailure as e: - self.client = None - raise SyftException.from_exception(e) - - return True - - @as_result(SyftException) - def with_db(self, db_name: str) -> MongoDatabase: - try: - return self.client[db_name] - except BaseException as e: - raise SyftException.from_exception(e) - - @as_result(SyftException) - def with_collection( - self, - collection_settings: PartitionSettings, - store_config: StoreConfig, - collection_name: str | None = None, - ) -> MongoCollection: - db = self.with_db(db_name=store_config.db_name).unwrap() - - try: - collection_name = ( - collection_name - if collection_name is not None - else collection_settings.name - ) - collection = db.get_collection( - name=collection_name, codec_options=SYFT_CODEC_OPTIONS - ) - except BaseException as e: - raise SyftException.from_exception(e) - - return collection - - @as_result(SyftException) - def with_collection_permissions( - self, collection_settings: PartitionSettings, store_config: StoreConfig - ) -> MongoCollection: - """ - For each collection, create a corresponding collection - that store the permissions to the data in that collection - """ - db = self.with_db(db_name=store_config.db_name).unwrap() - - try: - collection_permissions_name: str = collection_settings.name + "_permissions" - collection_permissions = db.get_collection( - name=collection_permissions_name, codec_options=SYFT_CODEC_OPTIONS - ) - except BaseException as e: - raise SyftException.from_exception(e) - return collection_permissions - - @as_result(SyftException) - def with_collection_storage_permissions( - self, collection_settings: PartitionSettings, store_config: StoreConfig - ) -> MongoCollection: - """ - For each collection, create a corresponding collection - that store the permissions to the data in that collection - """ - db = self.with_db(db_name=store_config.db_name).unwrap() - - try: - collection_storage_permissions_name: str = ( - collection_settings.name + "_storage_permissions" - ) - storage_permissons_collection = db.get_collection( - name=collection_storage_permissions_name, - codec_options=SYFT_CODEC_OPTIONS, - ) - except BaseException as e: - raise SyftException.from_exception(e) - - return storage_permissons_collection - - def close(self) -> None: - self.client.close() - MongoClientCache.__client_cache__.pop(hash(str(self.config)), None) diff --git a/packages/syft/src/syft/store/mongo_codecs.py b/packages/syft/src/syft/store/mongo_codecs.py deleted file mode 100644 index 08b7fa63562..00000000000 --- a/packages/syft/src/syft/store/mongo_codecs.py +++ /dev/null @@ -1,31 +0,0 @@ -# stdlib -from typing import Any - -# third party -from bson import CodecOptions -from bson.binary import Binary -from bson.binary import USER_DEFINED_SUBTYPE -from bson.codec_options import TypeDecoder -from bson.codec_options import TypeRegistry - -# relative -from ..serde.deserialize import _deserialize -from ..serde.serialize import _serialize - - -def fallback_syft_encoder(value: object) -> Binary: - return Binary(_serialize(value, to_bytes=True), USER_DEFINED_SUBTYPE) - - -class SyftMongoBinaryDecoder(TypeDecoder): - bson_type = Binary - - def transform_bson(self, value: Any) -> Any: - if value.subtype == USER_DEFINED_SUBTYPE: - return _deserialize(value, from_bytes=True) - return value - - -syft_codecs = [SyftMongoBinaryDecoder()] -syft_type_registry = TypeRegistry(syft_codecs, fallback_encoder=fallback_syft_encoder) -SYFT_CODEC_OPTIONS = CodecOptions(type_registry=syft_type_registry) diff --git a/packages/syft/src/syft/store/mongo_document_store.py b/packages/syft/src/syft/store/mongo_document_store.py deleted file mode 100644 index 7a4bbcdb6ca..00000000000 --- a/packages/syft/src/syft/store/mongo_document_store.py +++ /dev/null @@ -1,962 +0,0 @@ -# stdlib -from collections.abc import Callable -from typing import Any -from typing import Set # noqa: UP035 - -# third party -from pydantic import Field -from pymongo import ASCENDING -from pymongo.collection import Collection as MongoCollection -from typing_extensions import Self - -# relative -from ..serde.deserialize import _deserialize -from ..serde.serializable import serializable -from ..serde.serialize import _serialize -from ..server.credentials import SyftVerifyKey -from ..service.action.action_permissions import ActionObjectEXECUTE -from ..service.action.action_permissions import ActionObjectOWNER -from ..service.action.action_permissions import ActionObjectPermission -from ..service.action.action_permissions import ActionObjectREAD -from ..service.action.action_permissions import ActionObjectWRITE -from ..service.action.action_permissions import ActionPermission -from ..service.action.action_permissions import StoragePermission -from ..service.context import AuthedServiceContext -from ..service.response import SyftSuccess -from ..types.errors import SyftException -from ..types.result import as_result -from ..types.syft_object import SYFT_OBJECT_VERSION_1 -from ..types.syft_object import StorableObjectType -from ..types.syft_object import SyftBaseObject -from ..types.syft_object import SyftObject -from ..types.transforms import TransformContext -from ..types.transforms import transform -from ..types.transforms import transform_method -from ..types.uid import UID -from .document_store import DocumentStore -from .document_store import PartitionKey -from .document_store import PartitionSettings -from .document_store import QueryKey -from .document_store import QueryKeys -from .document_store import StoreConfig -from .document_store import StorePartition -from .document_store_errors import NotFoundException -from .kv_document_store import KeyValueBackingStore -from .locks import LockingConfig -from .locks import NoLockingConfig -from .mongo_client import MongoClient -from .mongo_client import MongoStoreClientConfig - - -@serializable() -class MongoDict(SyftBaseObject): - __canonical_name__ = "MongoDict" - __version__ = SYFT_OBJECT_VERSION_1 - - keys: list[Any] - values: list[Any] - - @property - def dict(self) -> dict[Any, Any]: - return dict(zip(self.keys, self.values)) - - @classmethod - def from_dict(cls, input: dict) -> Self: - return cls(keys=list(input.keys()), values=list(input.values())) - - def __repr__(self) -> str: - return self.dict.__repr__() - - -class MongoBsonObject(StorableObjectType, dict): - pass - - -def _repr_debug_(value: Any) -> str: - if hasattr(value, "_repr_debug_"): - return value._repr_debug_() - return repr(value) - - -def to_mongo(context: TransformContext) -> TransformContext: - output = {} - if context.obj: - unique_keys_dict = context.obj._syft_unique_keys_dict() - search_keys_dict = context.obj._syft_searchable_keys_dict() - all_dict = unique_keys_dict - all_dict.update(search_keys_dict) - for k in all_dict: - value = getattr(context.obj, k, "") - # if the value is a method, store its value - if callable(value): - output[k] = value() - else: - output[k] = value - - output["__canonical_name__"] = context.obj.__canonical_name__ - output["__version__"] = context.obj.__version__ - output["__blob__"] = _serialize(context.obj, to_bytes=True) - output["__arepr__"] = _repr_debug_(context.obj) # a comes first in alphabet - - if context.output and "id" in context.output: - output["_id"] = context.output["id"] - - context.output = output - - return context - - -@transform(SyftObject, MongoBsonObject) -def syft_obj_to_mongo() -> list[Callable]: - return [to_mongo] - - -@transform_method(MongoBsonObject, SyftObject) -def from_mongo( - storage_obj: dict, context: TransformContext | None = None -) -> SyftObject: - return _deserialize(storage_obj["__blob__"], from_bytes=True) - - -@serializable(attrs=["storage_type"], canonical_name="MongoStorePartition", version=1) -class MongoStorePartition(StorePartition): - """Mongo StorePartition - - Parameters: - `settings`: PartitionSettings - PySyft specific settings, used for partitioning and indexing. - `store_config`: MongoStoreConfig - Mongo specific configuration - """ - - storage_type: type[StorableObjectType] = MongoBsonObject - - @as_result(SyftException) - def init_store(self) -> bool: - super().init_store().unwrap() - client = MongoClient(config=self.store_config.client_config) - self._collection = client.with_collection( - collection_settings=self.settings, store_config=self.store_config - ).unwrap() - self._permissions = client.with_collection_permissions( - collection_settings=self.settings, store_config=self.store_config - ).unwrap() - self._storage_permissions = client.with_collection_storage_permissions( - collection_settings=self.settings, store_config=self.store_config - ).unwrap() - return self._create_update_index().unwrap() - - # Potentially thread-unsafe methods. - # - # CAUTION: - # * Don't use self.lock here. - # * Do not call the public thread-safe methods here(with locking). - # These methods are called from the public thread-safe API, and will hang the process. - - @as_result(SyftException) - def _create_update_index(self) -> bool: - """Create or update mongo database indexes""" - collection: MongoCollection = self.collection.unwrap() - - def check_index_keys( - current_keys: list[tuple[str, int]], new_index_keys: list[tuple[str, int]] - ) -> bool: - current_keys.sort() - new_index_keys.sort() - return current_keys == new_index_keys - - syft_obj = self.settings.object_type - - unique_attrs = getattr(syft_obj, "__attr_unique__", []) - object_name = syft_obj.__canonical_name__ - - new_index_keys = [(attr, ASCENDING) for attr in unique_attrs] - - try: - current_indexes = collection.index_information() - except BaseException as e: - raise SyftException.from_exception(e) - index_name = f"{object_name}_index_name" - - current_index_keys = current_indexes.get(index_name, None) - - if current_index_keys is not None: - keys_same = check_index_keys(current_index_keys["key"], new_index_keys) - if keys_same: - return True - - # Drop current index, since incompatible with current object - try: - collection.drop_index(index_or_name=index_name) - except Exception: - raise SyftException( - public_message=( - f"Failed to drop index for object: {object_name}" - f" with index keys: {current_index_keys}" - ) - ) - - # If no new indexes, then skip index creation - if len(new_index_keys) == 0: - return True - - try: - collection.create_index(new_index_keys, unique=True, name=index_name) - except Exception: - raise SyftException( - public_message=f"Failed to create index for {object_name} with index keys: {new_index_keys}" - ) - - return True - - @property - @as_result(SyftException) - def collection(self) -> MongoCollection: - if not hasattr(self, "_collection"): - self.init_store().unwrap() - return self._collection - - @property - @as_result(SyftException) - def permissions(self) -> MongoCollection: - if not hasattr(self, "_permissions"): - self.init_store().unwrap() - return self._permissions - - @property - @as_result(SyftException) - def storage_permissions(self) -> MongoCollection: - if not hasattr(self, "_storage_permissions"): - self.init_store().unwrap() - return self._storage_permissions - - @as_result(SyftException) - def set(self, *args: Any, **kwargs: Any) -> SyftObject: - return self._set(*args, **kwargs).unwrap() - - @as_result(SyftException) - def _set( - self, - credentials: SyftVerifyKey, - obj: SyftObject, - add_permissions: list[ActionObjectPermission] | None = None, - add_storage_permission: bool = True, - ignore_duplicates: bool = False, - ) -> SyftObject: - # TODO: Refactor this function since now it's doing both set and - # update at the same time - write_permission = ActionObjectWRITE(uid=obj.id, credentials=credentials) - can_write: bool = self.has_permission(write_permission) - - store_query_key: QueryKey = self.settings.store_key.with_obj(obj) - collection: MongoCollection = self.collection.unwrap() - - store_key_exists = ( - collection.find_one(store_query_key.as_dict_mongo) is not None - ) - if (not store_key_exists) and (not self.item_keys_exist(obj, collection)): - # attempt to claim ownership for writing - can_write = self.take_ownership( - uid=obj.id, credentials=credentials - ).unwrap() - elif not ignore_duplicates: - unique_query_keys: QueryKeys = self.settings.unique_keys.with_obj(obj) - keys = ", ".join(f"`{key.key}`" for key in unique_query_keys.all) - raise SyftException( - public_message=f"Duplication Key Error for {obj}.\nThe fields that should be unique are {keys}." - ) - else: - # we are not throwing an error, because we are ignoring duplicates - # we are also not writing though - return obj - - if not can_write: - raise SyftException( - public_message=f"No permission to write object with id {obj.id}" - ) - - storage_obj = obj.to(self.storage_type) - - collection.insert_one(storage_obj) - - # adding permissions - read_permission = ActionObjectPermission( - uid=obj.id, - credentials=credentials, - permission=ActionPermission.READ, - ) - self.add_permission(read_permission) - - if add_permissions is not None: - self.add_permissions(add_permissions) - - if add_storage_permission: - self.add_storage_permission( - StoragePermission( - uid=obj.id, - server_uid=self.server_uid, - ) - ) - - return obj - - def item_keys_exist(self, obj: SyftObject, collection: MongoCollection) -> bool: - qks: QueryKeys = self.settings.unique_keys.with_obj(obj) - query = {"$or": [{k: v} for k, v in qks.as_dict_mongo.items()]} - res = collection.find_one(query) - return res is not None - - @as_result(SyftException) - def _update( - self, - credentials: SyftVerifyKey, - qk: QueryKey, - obj: SyftObject, - has_permission: bool = False, - overwrite: bool = False, - allow_missing_keys: bool = False, - ) -> SyftObject: - collection: MongoCollection = self.collection.unwrap() - - # TODO: optimize the update. The ID should not be overwritten, - # but the qk doesn't necessarily have to include the `id` field either. - - prev_obj = self._get_all_from_store(credentials, QueryKeys(qks=[qk])).unwrap() - if len(prev_obj) == 0: - raise SyftException( - public_message=f"Failed to update missing values for query key: {qk} for type {type(obj)}" - ) - - prev_obj = prev_obj[0] - if has_permission or self.has_permission( - ActionObjectWRITE(uid=prev_obj.id, credentials=credentials) - ): - for key, value in obj.to_dict(exclude_empty=True).items(): - # we don't want to overwrite Mongo's "id_" or Syft's "id" on update - if key == "id": - # protected field - continue - - # Overwrite the value if the key is already present - setattr(prev_obj, key, value) - - # Create the Mongo object - storage_obj = prev_obj.to(self.storage_type) - - try: - collection.update_one( - filter=qk.as_dict_mongo, update={"$set": storage_obj} - ) - except Exception: - raise SyftException(f"Failed to update obj: {obj} with qk: {qk}") - - return prev_obj - else: - raise SyftException(f"Failed to update obj {obj}, you have no permission") - - @as_result(SyftException) - def _find_index_or_search_keys( - self, - credentials: SyftVerifyKey, - index_qks: QueryKeys, - search_qks: QueryKeys, - order_by: PartitionKey | None = None, - ) -> list[SyftObject]: - # TODO: pass index as hint to find method - qks = QueryKeys(qks=(list(index_qks.all) + list(search_qks.all))) - return self._get_all_from_store( - credentials=credentials, qks=qks, order_by=order_by - ).unwrap() - - @property - def data(self) -> dict: - values: list = self._all(credentials=None, has_permission=True).unwrap() - return {v.id: v for v in values} - - @as_result(SyftException) - def _get( - self, - uid: UID, - credentials: SyftVerifyKey, - has_permission: bool | None = False, - ) -> SyftObject: - qks = QueryKeys.from_dict({"id": uid}) - res = self._get_all_from_store( - credentials, qks, order_by=None, has_permission=has_permission - ).unwrap() - if len(res) == 0: - raise NotFoundException - else: - return res[0] - - @as_result(SyftException) - def _get_all_from_store( - self, - credentials: SyftVerifyKey, - qks: QueryKeys, - order_by: PartitionKey | None = None, - has_permission: bool | None = False, - ) -> list[SyftObject]: - collection = self.collection.unwrap() - - if order_by is not None: - storage_objs = collection.find(filter=qks.as_dict_mongo).sort(order_by.key) - else: - _default_key = "_id" - storage_objs = collection.find(filter=qks.as_dict_mongo).sort(_default_key) - - syft_objs = [] - for storage_obj in storage_objs: - obj = self.storage_type(storage_obj) - transform_context = TransformContext(output={}, obj=obj) - - syft_obj = obj.to(self.settings.object_type, transform_context) - if has_permission or self.has_permission( - ActionObjectREAD(uid=syft_obj.id, credentials=credentials) - ): - syft_objs.append(syft_obj) - - return syft_objs - - @as_result(SyftException) - def _delete( - self, credentials: SyftVerifyKey, qk: QueryKey, has_permission: bool = False - ) -> SyftSuccess: - if not ( - has_permission - or self.has_permission( - ActionObjectWRITE(uid=qk.value, credentials=credentials) - ) - ): - raise SyftException( - public_message=f"You don't have permission to delete object with qk: {qk}" - ) - - collection = self.collection.unwrap() - collection_permissions: MongoCollection = self.permissions.unwrap() - - qks = QueryKeys(qks=qk) - # delete the object - result = collection.delete_one(filter=qks.as_dict_mongo) - # delete the object's permission - result_permission = collection_permissions.delete_one(filter=qks.as_dict_mongo) - if result.deleted_count == 1 and result_permission.deleted_count == 1: - return SyftSuccess(message="Object and its permission are deleted") - elif result.deleted_count == 0: - raise SyftException(public_message=f"Failed to delete object with qk: {qk}") - else: - raise SyftException( - public_message=f"Object with qk: {qk} was deleted, but failed to delete its corresponding permission" - ) - - def has_permission(self, permission: ActionObjectPermission) -> bool: - """Check if the permission is inside the permission collection""" - collection_permissions_status = self.permissions - if collection_permissions_status.is_err(): - return False - collection_permissions: MongoCollection = collection_permissions_status.ok() - - permissions: dict | None = collection_permissions.find_one( - {"_id": permission.uid} - ) - - if permissions is None: - return False - - if ( - permission.credentials - and self.root_verify_key.verify == permission.credentials.verify - ): - return True - - if ( - permission.credentials - and self.has_admin_permissions is not None - and self.has_admin_permissions(permission.credentials) - ): - return True - - if permission.permission_string in permissions["permissions"]: - return True - - # check ALL_READ permission - if ( - permission.permission == ActionPermission.READ - and ActionObjectPermission( - permission.uid, ActionPermission.ALL_READ - ).permission_string - in permissions["permissions"] - ): - return True - - return False - - @as_result(SyftException) - def _get_permissions_for_uid(self, uid: UID) -> Set[str]: # noqa: UP006 - collection_permissions = self.permissions.unwrap() - permissions: dict | None = collection_permissions.find_one({"_id": uid}) - if permissions is None: - raise SyftException( - public_message=f"Permissions for object with UID {uid} not found!" - ) - return set(permissions["permissions"]) - - @as_result(SyftException) - def get_all_permissions(self) -> dict[UID, Set[str]]: # noqa: UP006 - # Returns a dictionary of all permissions {object_uid: {*permissions}} - collection_permissions: MongoCollection = self.permissions.unwrap() - permissions = collection_permissions.find({}) - permissions_dict = {} - for permission in permissions: - permissions_dict[permission["_id"]] = permission["permissions"] - return permissions_dict - - def add_permission(self, permission: ActionObjectPermission) -> None: - collection_permissions = self.permissions.unwrap() - - # find the permissions for the given permission.uid - # e.g. permissions = {"_id": "7b88fdef6bff42a8991d294c3d66f757", - # "permissions": set(["permission_str_1", "permission_str_2"]}} - permissions: dict | None = collection_permissions.find_one( - {"_id": permission.uid} - ) - if permissions is None: - # Permission doesn't exist, add a new one - collection_permissions.insert_one( - { - "_id": permission.uid, - "permissions": {permission.permission_string}, - } - ) - else: - # update the permissions with the new permission string - permission_strings: set = permissions["permissions"] - permission_strings.add(permission.permission_string) - collection_permissions.update_one( - {"_id": permission.uid}, {"$set": {"permissions": permission_strings}} - ) - - def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: - for permission in permissions: - self.add_permission(permission) - - def remove_permission(self, permission: ActionObjectPermission) -> None: - collection_permissions = self.permissions.unwrap() - permissions: dict | None = collection_permissions.find_one( - {"_id": permission.uid} - ) - if permissions is None: - raise SyftException( - public_message=f"permission with UID {permission.uid} not found!" - ) - permissions_strings: set = permissions["permissions"] - if permission.permission_string in permissions_strings: - permissions_strings.remove(permission.permission_string) - if len(permissions_strings) > 0: - collection_permissions.update_one( - {"_id": permission.uid}, - {"$set": {"permissions": permissions_strings}}, - ) - else: - collection_permissions.delete_one({"_id": permission.uid}) - else: - raise SyftException( - public_message=f"the permission {permission.permission_string} does not exist!" - ) - - def add_storage_permission(self, storage_permission: StoragePermission) -> None: - storage_permissions_collection: MongoCollection = ( - self.storage_permissions.unwrap() - ) - storage_permissions: dict | None = storage_permissions_collection.find_one( - {"_id": storage_permission.uid} - ) - if storage_permissions is None: - # Permission doesn't exist, add a new one - storage_permissions_collection.insert_one( - { - "_id": storage_permission.uid, - "server_uids": {storage_permission.server_uid}, - } - ) - else: - # update the permissions with the new permission string - server_uids: set = storage_permissions["server_uids"] - server_uids.add(storage_permission.server_uid) - storage_permissions_collection.update_one( - {"_id": storage_permission.uid}, - {"$set": {"server_uids": server_uids}}, - ) - - def add_storage_permissions(self, permissions: list[StoragePermission]) -> None: - for permission in permissions: - self.add_storage_permission(permission) - - def has_storage_permission(self, permission: StoragePermission) -> bool: # type: ignore - """Check if the storage_permission is inside the storage_permission collection""" - storage_permissions_collection: MongoCollection = ( - self.storage_permissions.unwrap() - ) - storage_permissions: dict | None = storage_permissions_collection.find_one( - {"_id": permission.uid} - ) - if storage_permissions is None or "server_uids" not in storage_permissions: - return False - return permission.server_uid in storage_permissions["server_uids"] - - def remove_storage_permission(self, storage_permission: StoragePermission) -> None: - storage_permissions_collection = self.storage_permissions.unwrap() - storage_permissions: dict | None = storage_permissions_collection.find_one( - {"_id": storage_permission.uid} - ) - if storage_permissions is None: - raise SyftException( - public_message=f"storage permission with UID {storage_permission.uid} not found!" - ) - server_uids: set = storage_permissions["server_uids"] - if storage_permission.server_uid in server_uids: - server_uids.remove(storage_permission.server_uid) - storage_permissions_collection.update_one( - {"_id": storage_permission.uid}, - {"$set": {"server_uids": server_uids}}, - ) - else: - raise SyftException( - public_message=( - f"the server_uid {storage_permission.server_uid} does not exist in the storage permission!" - ) - ) - - def _get_storage_permissions_for_uid(self, uid: UID) -> Set[UID]: # noqa: UP006 - storage_permissions_collection: MongoCollection = ( - self.storage_permissions.unwrap() - ) - storage_permissions: dict | None = storage_permissions_collection.find_one( - {"_id": uid} - ) - if storage_permissions is None: - raise SyftException( - public_message=f"Storage permissions for object with UID {uid} not found!" - ) - return set(storage_permissions["server_uids"]) - - @as_result(SyftException) - def get_all_storage_permissions( - self, - ) -> dict[UID, Set[UID]]: # noqa: UP006 - # Returns a dictionary of all storage permissions {object_uid: {*server_uids}} - storage_permissions_collection: MongoCollection = ( - self.storage_permissions.unwrap() - ) - storage_permissions = storage_permissions_collection.find({}) - storage_permissions_dict = {} - for storage_permission in storage_permissions: - storage_permissions_dict[storage_permission["_id"]] = storage_permission[ - "server_uids" - ] - return storage_permissions_dict - - @as_result(SyftException) - def take_ownership(self, uid: UID, credentials: SyftVerifyKey) -> bool: - collection_permissions: MongoCollection = self.permissions.unwrap() - collection: MongoCollection = self.collection.unwrap() - data: list[UID] | None = collection.find_one({"_id": uid}) - permissions: list[UID] | None = collection_permissions.find_one({"_id": uid}) - - if permissions is not None or data is not None: - raise SyftException(public_message=f"UID: {uid} already owned.") - - # first person using this UID can claim ownership - self.add_permissions( - [ - ActionObjectOWNER(uid=uid, credentials=credentials), - ActionObjectWRITE(uid=uid, credentials=credentials), - ActionObjectREAD(uid=uid, credentials=credentials), - ActionObjectEXECUTE(uid=uid, credentials=credentials), - ] - ) - - return True - - @as_result(SyftException) - def _all( - self, - credentials: SyftVerifyKey, - order_by: PartitionKey | None = None, - has_permission: bool | None = False, - ) -> list[SyftObject]: - qks = QueryKeys(qks=()) - return self._get_all_from_store( - credentials=credentials, - qks=qks, - order_by=order_by, - has_permission=has_permission, - ).unwrap() - - def __len__(self) -> int: - collection_status = self.collection - if collection_status.is_err(): - return 0 - collection: MongoCollection = collection_status.ok() - return collection.count_documents(filter={}) - - @as_result(SyftException) - def _migrate_data( - self, to_klass: SyftObject, context: AuthedServiceContext, has_permission: bool - ) -> bool: - credentials = context.credentials - has_permission = (credentials == self.root_verify_key) or has_permission - collection: MongoCollection = self.collection.unwrap() - - if has_permission: - storage_objs = collection.find({}) - for storage_obj in storage_objs: - obj = self.storage_type(storage_obj) - transform_context = TransformContext(output={}, obj=obj) - value = obj.to(self.settings.object_type, transform_context) - key = obj.get("_id") - try: - migrated_value = value.migrate_to(to_klass.__version__, context) - except Exception: - raise SyftException( - public_message=f"Failed to migrate data to {to_klass} for qk: {key}" - ) - qk = self.settings.store_key.with_obj(key) - self._update( - credentials, - qk=qk, - obj=migrated_value, - has_permission=has_permission, - ).unwrap() - return True - raise SyftException( - public_message="You don't have permissions to migrate data." - ) - - -@serializable(canonical_name="MongoDocumentStore", version=1) -class MongoDocumentStore(DocumentStore): - """Mongo Document Store - - Parameters: - `store_config`: MongoStoreConfig - Mongo specific configuration, including connection configuration, database name, or client class type. - """ - - partition_type = MongoStorePartition - - -@serializable( - attrs=["index_name", "settings", "store_config"], - canonical_name="MongoBackingStore", - version=1, -) -class MongoBackingStore(KeyValueBackingStore): - """ - Core logic for the MongoDB key-value store - - Parameters: - `index_name`: str - Index name (can be either 'data' or 'permissions') - `settings`: PartitionSettings - Syft specific settings - `store_config`: StoreConfig - Connection Configuration - `ddtype`: Type - Optional and should be None - Used to make a consistent interface with SQLiteBackingStore - """ - - def __init__( - self, - index_name: str, - settings: PartitionSettings, - store_config: StoreConfig, - ddtype: type | None = None, - ) -> None: - self.index_name = index_name - self.settings = settings - self.store_config = store_config - self.client: MongoClient - self.ddtype = ddtype - self.init_client() - - @as_result(SyftException) - def init_client(self) -> None: - self.client = MongoClient(config=self.store_config.client_config) - self._collection: MongoCollection = self.client.with_collection( - collection_settings=self.settings, - store_config=self.store_config, - collection_name=f"{self.settings.name}_{self.index_name}", - ).unwrap() - - @property - @as_result(SyftException) - def collection(self) -> MongoCollection: - if not hasattr(self, "_collection"): - self.init_client().unwrap() - return self._collection - - def _exist(self, key: UID) -> bool: - collection: MongoCollection = self.collection.unwrap() - result: dict | None = collection.find_one({"_id": key}) - if result is not None: - return True - return False - - def _set(self, key: UID, value: Any) -> None: - if self._exist(key): - self._update(key, value) - else: - collection: MongoCollection = self.collection.unwrap() - try: - bson_data = { - "_id": key, - f"{key}": _serialize(value, to_bytes=True), - "_repr_debug_": _repr_debug_(value), - } - collection.insert_one(bson_data) - except Exception: - raise SyftException(public_message="Cannot insert data.") - - def _update(self, key: UID, value: Any) -> None: - collection: MongoCollection = self.collection.unwrap() - try: - collection.update_one( - {"_id": key}, - { - "$set": { - f"{key}": _serialize(value, to_bytes=True), - "_repr_debug_": _repr_debug_(value), - } - }, - ) - except Exception as e: - raise SyftException( - public_message=f"Failed to update obj: {key} with value: {value}. Error: {e}" - ) - - def __setitem__(self, key: Any, value: Any) -> None: - self._set(key, value) - - def _get(self, key: UID) -> Any: - collection: MongoCollection = self.collection.unwrap() - result: dict | None = collection.find_one({"_id": key}) - if result is not None: - return _deserialize(result[f"{key}"], from_bytes=True) - else: - raise KeyError(f"{key} does not exist") - - def __getitem__(self, key: Any) -> Self: - try: - return self._get(key) - except KeyError as e: - if self.ddtype is not None: - return self.ddtype() - raise e - - def _len(self) -> int: - collection: MongoCollection = self.collection.unwrap() - return collection.count_documents(filter={}) - - def __len__(self) -> int: - return self._len() - - def _delete(self, key: UID) -> SyftSuccess: - collection: MongoCollection = self.collection.unwrap() - result = collection.delete_one({"_id": key}) - if result.deleted_count != 1: - raise SyftException(public_message=f"{key} does not exist") - return SyftSuccess(message="Deleted") - - def __delitem__(self, key: str) -> None: - self._delete(key) - - def _delete_all(self) -> None: - collection: MongoCollection = self.collection.unwrap() - collection.delete_many({}) - - def clear(self) -> None: - self._delete_all() - - def _get_all(self) -> Any: - collection_status = self.collection - if collection_status.is_err(): - return collection_status - collection: MongoCollection = collection_status.ok() - result = collection.find() - keys, values = [], [] - for row in result: - keys.append(row["_id"]) - values.append(_deserialize(row[f"{row['_id']}"], from_bytes=True)) - return dict(zip(keys, values)) - - def keys(self) -> Any: - return self._get_all().keys() - - def values(self) -> Any: - return self._get_all().values() - - def items(self) -> Any: - return self._get_all().items() - - def pop(self, key: Any) -> Self: - value = self._get(key) - self._delete(key) - return value - - def __contains__(self, key: Any) -> bool: - return self._exist(key) - - def __iter__(self) -> Any: - return iter(self.keys()) - - def __repr__(self) -> str: - return repr(self._get_all()) - - def copy(self) -> Self: - # 🟡 TODO - raise NotImplementedError - - def update(self, *args: Any, **kwargs: Any) -> None: - """ - Inserts the specified items to the dictionary. - """ - # 🟡 TODO - raise NotImplementedError - - def __del__(self) -> None: - """ - Close the mongo client connection: - - Cleanup client resources and disconnect from MongoDB - - End all server sessions created by this client - - Close all sockets in the connection pools and stop the monitor threads - """ - self.client.close() - - -@serializable() -class MongoStoreConfig(StoreConfig): - __canonical_name__ = "MongoStoreConfig" - """Mongo Store configuration - - Parameters: - `client_config`: MongoStoreClientConfig - Mongo connection details: hostname, port, user, password etc. - `store_type`: Type[DocumentStore] - The type of the DocumentStore. Default: MongoDocumentStore - `db_name`: str - Database name - locking_config: LockingConfig - The config used for store locking. Available options: - * NoLockingConfig: no locking, ideal for single-thread stores. - * ThreadingLockingConfig: threading-based locking, ideal for same-process in-memory stores. - Defaults to NoLockingConfig. - """ - - client_config: MongoStoreClientConfig - store_type: type[DocumentStore] = MongoDocumentStore - db_name: str = "app" - backing_store: type[KeyValueBackingStore] = MongoBackingStore - # TODO: should use a distributed lock, with RedisLockingConfig - locking_config: LockingConfig = Field(default_factory=NoLockingConfig) diff --git a/scripts/dev_tools.sh b/scripts/dev_tools.sh index 20a74b597e0..c23b56c05b9 100755 --- a/scripts/dev_tools.sh +++ b/scripts/dev_tools.sh @@ -23,15 +23,13 @@ function docker_list_exposed_ports() { if [[ -z "$1" ]]; then # list db, redis, rabbitmq, and seaweedfs ports - docker_list_exposed_ports "db\|seaweedfs\|mongo" + docker_list_exposed_ports "db\|seaweedfs" else PORT=$1 if docker ps | grep ":${PORT}" | grep -q 'redis'; then ${command} redis://127.0.0.1:${PORT} elif docker ps | grep ":${PORT}" | grep -q 'postgres'; then ${command} postgresql://postgres:changethis@127.0.0.1:${PORT}/app - elif docker ps | grep ":${PORT}" | grep -q 'mongo'; then - ${command} mongodb://root:example@127.0.0.1:${PORT} else ${command} http://localhost:${PORT} fi From 6df49c7d8e9bfd04d9f823b4a094a94a77303d89 Mon Sep 17 00:00:00 2001 From: dk Date: Wed, 11 Sep 2024 14:36:08 +0700 Subject: [PATCH 131/197] [clean] remove mongomock in tests and mongo-related packages --- .pre-commit-config.yaml | 18 +- packages/syft/setup.cfg | 2 - packages/syft/tests/mongomock/__init__.py | 138 - packages/syft/tests/mongomock/__init__.pyi | 30 - packages/syft/tests/mongomock/__version__.py | 15 - packages/syft/tests/mongomock/aggregate.py | 1811 ------------ .../syft/tests/mongomock/codec_options.py | 135 - packages/syft/tests/mongomock/collection.py | 2596 ----------------- .../syft/tests/mongomock/command_cursor.py | 37 - packages/syft/tests/mongomock/database.py | 301 -- packages/syft/tests/mongomock/filtering.py | 601 ---- packages/syft/tests/mongomock/gridfs.py | 68 - packages/syft/tests/mongomock/helpers.py | 474 --- packages/syft/tests/mongomock/mongo_client.py | 222 -- .../syft/tests/mongomock/not_implemented.py | 36 - packages/syft/tests/mongomock/object_id.py | 26 - packages/syft/tests/mongomock/patch.py | 120 - packages/syft/tests/mongomock/py.typed | 0 packages/syft/tests/mongomock/read_concern.py | 21 - .../syft/tests/mongomock/read_preferences.py | 42 - packages/syft/tests/mongomock/results.py | 117 - packages/syft/tests/mongomock/store.py | 191 -- packages/syft/tests/mongomock/thread.py | 94 - .../syft/tests/mongomock/write_concern.py | 45 - packages/syftcli/manifest.yml | 2 +- scripts/reset_mongo.sh | 18 - 26 files changed, 7 insertions(+), 7153 deletions(-) delete mode 100644 packages/syft/tests/mongomock/__init__.py delete mode 100644 packages/syft/tests/mongomock/__init__.pyi delete mode 100644 packages/syft/tests/mongomock/__version__.py delete mode 100644 packages/syft/tests/mongomock/aggregate.py delete mode 100644 packages/syft/tests/mongomock/codec_options.py delete mode 100644 packages/syft/tests/mongomock/collection.py delete mode 100644 packages/syft/tests/mongomock/command_cursor.py delete mode 100644 packages/syft/tests/mongomock/database.py delete mode 100644 packages/syft/tests/mongomock/filtering.py delete mode 100644 packages/syft/tests/mongomock/gridfs.py delete mode 100644 packages/syft/tests/mongomock/helpers.py delete mode 100644 packages/syft/tests/mongomock/mongo_client.py delete mode 100644 packages/syft/tests/mongomock/not_implemented.py delete mode 100644 packages/syft/tests/mongomock/object_id.py delete mode 100644 packages/syft/tests/mongomock/patch.py delete mode 100644 packages/syft/tests/mongomock/py.typed delete mode 100644 packages/syft/tests/mongomock/read_concern.py delete mode 100644 packages/syft/tests/mongomock/read_preferences.py delete mode 100644 packages/syft/tests/mongomock/results.py delete mode 100644 packages/syft/tests/mongomock/store.py delete mode 100644 packages/syft/tests/mongomock/thread.py delete mode 100644 packages/syft/tests/mongomock/write_concern.py delete mode 100755 scripts/reset_mongo.sh diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3487d8d0915..e1c50cd3b96 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,41 +3,36 @@ repos: rev: v4.5.0 hooks: - id: check-ast - exclude: ^(packages/syft/tests/mongomock) always_run: true - id: trailing-whitespace always_run: true - exclude: ^(docs/|.+\.md|.bumpversion.cfg|packages/syft/tests/mongomock) + exclude: ^(docs/|.+\.md|.bumpversion.cfg) - id: check-docstring-first always_run: true - exclude: ^(packages/syft/tests/mongomock) - id: check-json always_run: true - exclude: ^(packages/grid/frontend/|packages/syft/tests/mongomock|.vscode) + exclude: ^(packages/grid/frontend/|.vscode) - id: check-added-large-files always_run: true exclude: ^(packages/grid/backend/wheels/.*|docs/img/header.png|docs/img/terminalizer.gif) - id: check-yaml always_run: true - exclude: ^(packages/grid/k8s/rendered/|packages/grid/helm/|packages/syft/tests/mongomock) + exclude: ^(packages/grid/k8s/rendered/|packages/grid/helm/) - id: check-merge-conflict always_run: true args: ["--assume-in-merge"] - id: check-executables-have-shebangs always_run: true - exclude: ^(packages/syft/tests/mongomock) - id: debug-statements always_run: true - exclude: ^(packages/syft/tests/mongomock) - id: name-tests-test always_run: true - exclude: ^(.*/tests/utils/)|^(.*fixtures.py|packages/syft/tests/mongomock)|^(tests/scenarios/bigquery/helpers) + exclude: ^(.*/tests/utils/)|^(.*fixtures.py)|^(tests/scenarios/bigquery/helpers) - id: requirements-txt-fixer always_run: true - exclude: "packages/syft/tests/mongomock" - id: mixed-line-ending args: ["--fix=lf"] - exclude: '\.bat|\.csv|\.ps1$|packages/syft/tests/mongomock' + exclude: '\.bat|\.csv|\.ps1$' - repo: https://github.com/MarcoGorelli/absolufy-imports # This repository has been archived by the owner on Aug 15, 2023. It is now read-only. rev: v0.3.1 @@ -88,7 +83,6 @@ repos: hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix, --show-fixes] - exclude: packages/syft/tests/mongomock types_or: [python, pyi, jupyter] - id: ruff-format types_or: [python, pyi, jupyter] @@ -178,7 +172,7 @@ repos: rev: "v3.0.0-alpha.9-for-vscode" hooks: - id: prettier - exclude: ^(packages/grid/helm|packages/grid/frontend/pnpm-lock.yaml|packages/syft/tests/mongomock|.vscode) + exclude: ^(packages/grid/helm|packages/grid/frontend/pnpm-lock.yaml|.vscode) # - repo: meta # hooks: diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 4cc1ae70619..6375b6c7d8a 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -35,7 +35,6 @@ syft = pycapnp==2.0.0 pydantic[email]==2.6.0 pydantic-settings==2.2.1 - pymongo==4.6.3 pynacl==1.5.0 pyzmq>=23.2.1,<=25.1.1 requests==2.32.3 @@ -113,7 +112,6 @@ telemetry = opentelemetry-instrumentation==0.48b0 opentelemetry-instrumentation-requests==0.48b0 opentelemetry-instrumentation-fastapi==0.48b0 - opentelemetry-instrumentation-pymongo==0.48b0 opentelemetry-instrumentation-botocore==0.48b0 opentelemetry-instrumentation-logging==0.48b0 ; opentelemetry-instrumentation-asyncio==0.48b0 diff --git a/packages/syft/tests/mongomock/__init__.py b/packages/syft/tests/mongomock/__init__.py deleted file mode 100644 index 6ce7670902b..00000000000 --- a/packages/syft/tests/mongomock/__init__.py +++ /dev/null @@ -1,138 +0,0 @@ -# stdlib -import os - -try: - # third party - from pymongo.errors import PyMongoError -except ImportError: - - class PyMongoError(Exception): - pass - - -try: - # third party - from pymongo.errors import OperationFailure -except ImportError: - - class OperationFailure(PyMongoError): - def __init__(self, message, code=None, details=None): - super(OperationFailure, self).__init__() - self._message = message - self._code = code - self._details = details - - code = property(lambda self: self._code) - details = property(lambda self: self._details) - - def __str__(self): - return self._message - - -try: - # third party - from pymongo.errors import WriteError -except ImportError: - - class WriteError(OperationFailure): - pass - - -try: - # third party - from pymongo.errors import DuplicateKeyError -except ImportError: - - class DuplicateKeyError(WriteError): - pass - - -try: - # third party - from pymongo.errors import BulkWriteError -except ImportError: - - class BulkWriteError(OperationFailure): - def __init__(self, results): - super(BulkWriteError, self).__init__( - "batch op errors occurred", 65, results - ) - - -try: - # third party - from pymongo.errors import CollectionInvalid -except ImportError: - - class CollectionInvalid(PyMongoError): - pass - - -try: - # third party - from pymongo.errors import InvalidName -except ImportError: - - class InvalidName(PyMongoError): - pass - - -try: - # third party - from pymongo.errors import InvalidOperation -except ImportError: - - class InvalidOperation(PyMongoError): - pass - - -try: - # third party - from pymongo.errors import ConfigurationError -except ImportError: - - class ConfigurationError(PyMongoError): - pass - - -try: - # third party - from pymongo.errors import InvalidURI -except ImportError: - - class InvalidURI(ConfigurationError): - pass - - -from .helpers import ObjectId, utcnow # noqa - - -__all__ = [ - "Database", - "DuplicateKeyError", - "Collection", - "CollectionInvalid", - "InvalidName", - "MongoClient", - "ObjectId", - "OperationFailure", - "WriteConcern", - "ignore_feature", - "patch", - "warn_on_feature", - "SERVER_VERSION", -] - -# relative -from .collection import Collection -from .database import Database -from .mongo_client import MongoClient -from .not_implemented import ignore_feature -from .not_implemented import warn_on_feature -from .patch import patch -from .write_concern import WriteConcern - -# The version of the server faked by mongomock. Callers may patch it before creating connections to -# update the behavior of mongomock. -# Keep the default version in sync with docker-compose.yml and travis.yml. -SERVER_VERSION = os.getenv("MONGODB", "5.0.5") diff --git a/packages/syft/tests/mongomock/__init__.pyi b/packages/syft/tests/mongomock/__init__.pyi deleted file mode 100644 index b7ba5e4c03c..00000000000 --- a/packages/syft/tests/mongomock/__init__.pyi +++ /dev/null @@ -1,30 +0,0 @@ -# stdlib -from typing import Any -from typing import Callable -from typing import Literal -from typing import Sequence -from typing import Tuple -from typing import Union -from unittest import mock - -# third party -from bson.objectid import ObjectId -from pymongo import MongoClient -from pymongo.collection import Collection -from pymongo.database import Database -from pymongo.errors import CollectionInvalid -from pymongo.errors import DuplicateKeyError -from pymongo.errors import InvalidName -from pymongo.errors import OperationFailure - -def patch( - servers: Union[str, Tuple[str, int], Sequence[Union[str, Tuple[str, int]]]] = ..., - on_new: Literal["error", "create", "timeout", "pymongo"] = ..., -) -> mock._patch: ... - -_FeatureName = Literal["collation", "session"] - -def ignore_feature(feature: _FeatureName) -> None: ... -def warn_on_feature(feature: _FeatureName) -> None: ... - -SERVER_VERSION: str = ... diff --git a/packages/syft/tests/mongomock/__version__.py b/packages/syft/tests/mongomock/__version__.py deleted file mode 100644 index 14863a3db29..00000000000 --- a/packages/syft/tests/mongomock/__version__.py +++ /dev/null @@ -1,15 +0,0 @@ -# stdlib -from platform import python_version_tuple - -python_version = python_version_tuple() - -if (int(python_version[0]), int(python_version[1])) >= (3, 8): - # stdlib - from importlib.metadata import version - - __version__ = version("mongomock") -else: - # third party - import pkg_resources - - __version__ = pkg_resources.get_distribution("mongomock").version diff --git a/packages/syft/tests/mongomock/aggregate.py b/packages/syft/tests/mongomock/aggregate.py deleted file mode 100644 index 243720b690f..00000000000 --- a/packages/syft/tests/mongomock/aggregate.py +++ /dev/null @@ -1,1811 +0,0 @@ -"""Module to handle the operations within the aggregate pipeline.""" - -# stdlib -import bisect -import collections -import copy -import datetime -import decimal -import functools -import itertools -import math -import numbers -import random -import re -import sys -import warnings - -# third party -from packaging import version -import pytz - -# relative -from . import OperationFailure -from . import command_cursor -from . import filtering -from . import helpers - -try: - # third party - from bson import Regex - from bson import decimal128 - from bson.errors import InvalidDocument - - decimal_support = True - _RE_TYPES = (helpers.RE_TYPE, Regex) -except ImportError: - InvalidDocument = OperationFailure - decimal_support = False - _RE_TYPES = helpers.RE_TYPE - -_random = random.Random() - - -group_operators = [ - "$addToSet", - "$avg", - "$first", - "$last", - "$max", - "$mergeObjects", - "$min", - "$push", - "$stdDevPop", - "$stdDevSamp", - "$sum", -] -unary_arithmetic_operators = { - "$abs", - "$ceil", - "$exp", - "$floor", - "$ln", - "$log10", - "$sqrt", - "$trunc", -} -binary_arithmetic_operators = { - "$divide", - "$log", - "$mod", - "$pow", - "$subtract", -} -arithmetic_operators = ( - unary_arithmetic_operators - | binary_arithmetic_operators - | { - "$add", - "$multiply", - } -) -project_operators = [ - "$max", - "$min", - "$avg", - "$sum", - "$stdDevPop", - "$stdDevSamp", - "$arrayElemAt", - "$first", - "$last", -] -control_flow_operators = [ - "$switch", -] -projection_operators = [ - "$let", - "$literal", -] -date_operators = [ - "$dateFromString", - "$dateToString", - "$dateFromParts", - "$dayOfMonth", - "$dayOfWeek", - "$dayOfYear", - "$hour", - "$isoDayOfWeek", - "$isoWeek", - "$isoWeekYear", - "$millisecond", - "$minute", - "$month", - "$second", - "$week", - "$year", -] -conditional_operators = ["$cond", "$ifNull"] -array_operators = [ - "$concatArrays", - "$filter", - "$indexOfArray", - "$map", - "$range", - "$reduce", - "$reverseArray", - "$size", - "$slice", - "$zip", -] -object_operators = [ - "$mergeObjects", -] -text_search_operators = ["$meta"] -string_operators = [ - "$concat", - "$indexOfBytes", - "$indexOfCP", - "$regexMatch", - "$split", - "$strcasecmp", - "$strLenBytes", - "$strLenCP", - "$substr", - "$substrBytes", - "$substrCP", - "$toLower", - "$toUpper", - "$trim", -] -comparison_operators = [ - "$cmp", - "$eq", - "$ne", -] + list(filtering.SORTING_OPERATOR_MAP.keys()) -boolean_operators = ["$and", "$or", "$not"] -set_operators = [ - "$in", - "$setEquals", - "$setIntersection", - "$setDifference", - "$setUnion", - "$setIsSubset", - "$anyElementTrue", - "$allElementsTrue", -] - -type_convertion_operators = [ - "$convert", - "$toString", - "$toInt", - "$toDecimal", - "$toLong", - "$arrayToObject", - "$objectToArray", -] -type_operators = [ - "$isNumber", - "$isArray", -] - - -def _avg_operation(values): - values_list = list(v for v in values if isinstance(v, numbers.Number)) - if not values_list: - return None - return sum(values_list) / float(len(list(values_list))) - - -def _group_operation(values, operator): - values_list = list(v for v in values if v is not None) - if not values_list: - return None - return operator(values_list) - - -def _sum_operation(values): - values_list = list() - if decimal_support: - for v in values: - if isinstance(v, numbers.Number): - values_list.append(v) - elif isinstance(v, decimal128.Decimal128): - values_list.append(v.to_decimal()) - else: - values_list = list(v for v in values if isinstance(v, numbers.Number)) - sum_value = sum(values_list) - return ( - decimal128.Decimal128(sum_value) - if isinstance(sum_value, decimal.Decimal) - else sum_value - ) - - -def _merge_objects_operation(values): - merged_doc = dict() - for v in values: - if isinstance(v, dict): - merged_doc.update(v) - return merged_doc - - -_GROUPING_OPERATOR_MAP = { - "$sum": _sum_operation, - "$avg": _avg_operation, - "$mergeObjects": _merge_objects_operation, - "$min": lambda values: _group_operation(values, min), - "$max": lambda values: _group_operation(values, max), - "$first": lambda values: values[0] if values else None, - "$last": lambda values: values[-1] if values else None, -} - - -class _Parser(object): - """Helper to parse expressions within the aggregate pipeline.""" - - def __init__(self, doc_dict, user_vars=None, ignore_missing_keys=False): - self._doc_dict = doc_dict - self._ignore_missing_keys = ignore_missing_keys - self._user_vars = user_vars or {} - - def parse(self, expression): - """Parse a MongoDB expression.""" - if not isinstance(expression, dict): - # May raise a KeyError despite the ignore missing key. - return self._parse_basic_expression(expression) - - if len(expression) > 1 and any(key.startswith("$") for key in expression): - raise OperationFailure( - "an expression specification must contain exactly one field, " - "the name of the expression. Found %d fields in %s" - % (len(expression), expression) - ) - - value_dict = {} - for k, v in expression.items(): - if k in arithmetic_operators: - return self._handle_arithmetic_operator(k, v) - if k in project_operators: - return self._handle_project_operator(k, v) - if k in projection_operators: - return self._handle_projection_operator(k, v) - if k in comparison_operators: - return self._handle_comparison_operator(k, v) - if k in date_operators: - return self._handle_date_operator(k, v) - if k in array_operators: - return self._handle_array_operator(k, v) - if k in conditional_operators: - return self._handle_conditional_operator(k, v) - if k in control_flow_operators: - return self._handle_control_flow_operator(k, v) - if k in set_operators: - return self._handle_set_operator(k, v) - if k in string_operators: - return self._handle_string_operator(k, v) - if k in type_convertion_operators: - return self._handle_type_convertion_operator(k, v) - if k in type_operators: - return self._handle_type_operator(k, v) - if k in boolean_operators: - return self._handle_boolean_operator(k, v) - if k in text_search_operators + projection_operators + object_operators: - raise NotImplementedError( - "'%s' is a valid operation but it is not supported by Mongomock yet." - % k - ) - if k.startswith("$"): - raise OperationFailure("Unrecognized expression '%s'" % k) - try: - value = self.parse(v) - except KeyError: - if self._ignore_missing_keys: - continue - raise - value_dict[k] = value - - return value_dict - - def parse_many(self, values): - for value in values: - try: - yield self.parse(value) - except KeyError: - if self._ignore_missing_keys: - yield None - else: - raise - - def _parse_to_bool(self, expression): - """Parse a MongoDB expression and then convert it to bool""" - # handles converting `undefined` (in form of KeyError) to False - try: - return helpers.mongodb_to_bool(self.parse(expression)) - except KeyError: - return False - - def _parse_or_None(self, expression): - try: - return self.parse(expression) - except KeyError: - return None - - def _parse_basic_expression(self, expression): - if isinstance(expression, str) and expression.startswith("$"): - if expression.startswith("$$"): - return helpers.get_value_by_dot( - dict( - { - "ROOT": self._doc_dict, - "CURRENT": self._doc_dict, - }, - **self._user_vars, - ), - expression[2:], - can_generate_array=True, - ) - return helpers.get_value_by_dot( - self._doc_dict, expression[1:], can_generate_array=True - ) - return expression - - def _handle_boolean_operator(self, operator, values): - if operator == "$and": - return all([self._parse_to_bool(value) for value in values]) - if operator == "$or": - return any(self._parse_to_bool(value) for value in values) - if operator == "$not": - return not self._parse_to_bool(values) - # This should never happen: it is only a safe fallback if something went wrong. - raise NotImplementedError( # pragma: no cover - "Although '%s' is a valid boolean operator for the " - "aggregation pipeline, it is currently not implemented" - " in Mongomock." % operator - ) - - def _handle_arithmetic_operator(self, operator, values): - if operator in unary_arithmetic_operators: - try: - number = self.parse(values) - except KeyError: - return None - if number is None: - return None - if not isinstance(number, numbers.Number): - raise OperationFailure( - "Parameter to %s must evaluate to a number, got '%s'" - % (operator, type(number)) - ) - - if operator == "$abs": - return abs(number) - if operator == "$ceil": - return math.ceil(number) - if operator == "$exp": - return math.exp(number) - if operator == "$floor": - return math.floor(number) - if operator == "$ln": - return math.log(number) - if operator == "$log10": - return math.log10(number) - if operator == "$sqrt": - return math.sqrt(number) - if operator == "$trunc": - return math.trunc(number) - - if operator in binary_arithmetic_operators: - if not isinstance(values, (tuple, list)): - raise OperationFailure( - "Parameter to %s must evaluate to a list, got '%s'" - % (operator, type(values)) - ) - - if len(values) != 2: - raise OperationFailure("%s must have only 2 parameters" % operator) - number_0, number_1 = self.parse_many(values) - if number_0 is None or number_1 is None: - return None - - if operator == "$divide": - return number_0 / number_1 - if operator == "$log": - return math.log(number_0, number_1) - if operator == "$mod": - return math.fmod(number_0, number_1) - if operator == "$pow": - return math.pow(number_0, number_1) - if operator == "$subtract": - if isinstance(number_0, datetime.datetime) and isinstance( - number_1, (int, float) - ): - number_1 = datetime.timedelta(milliseconds=number_1) - res = number_0 - number_1 - if isinstance(res, datetime.timedelta): - return round(res.total_seconds() * 1000) - return res - - assert isinstance(values, (tuple, list)), ( - "Parameter to %s must evaluate to a list, got '%s'" - % ( - operator, - type(values), - ) - ) - - parsed_values = list(self.parse_many(values)) - assert parsed_values, "%s must have at least one parameter" % operator - for value in parsed_values: - if value is None: - return None - assert isinstance(value, numbers.Number), "%s only uses numbers" % operator - if operator == "$add": - return sum(parsed_values) - if operator == "$multiply": - return functools.reduce(lambda x, y: x * y, parsed_values) - - # This should never happen: it is only a safe fallback if something went wrong. - raise NotImplementedError( # pragma: no cover - "Although '%s' is a valid aritmetic operator for the aggregation " - "pipeline, it is currently not implemented in Mongomock." % operator - ) - - def _handle_project_operator(self, operator, values): - if operator in _GROUPING_OPERATOR_MAP: - values = ( - self.parse(values) - if isinstance(values, str) - else self.parse_many(values) - ) - return _GROUPING_OPERATOR_MAP[operator](values) - if operator == "$arrayElemAt": - key, value = values - array = self.parse(key) - index = self.parse(value) - try: - return array[index] - except IndexError as error: - raise KeyError("Array have length less than index value") from error - - raise NotImplementedError( - "Although '%s' is a valid project operator for the " - "aggregation pipeline, it is currently not implemented " - "in Mongomock." % operator - ) - - def _handle_projection_operator(self, operator, value): - if operator == "$literal": - return value - if operator == "$let": - if not isinstance(value, dict): - raise InvalidDocument("$let only supports an object as its argument") - for field in ("vars", "in"): - if field not in value: - raise OperationFailure( - "Missing '{}' parameter to $let".format(field) - ) - if not isinstance(value["vars"], dict): - raise OperationFailure("invalid parameter: expected an object (vars)") - user_vars = { - var_key: self.parse(var_value) - for var_key, var_value in value["vars"].items() - } - return _Parser( - self._doc_dict, - dict(self._user_vars, **user_vars), - ignore_missing_keys=self._ignore_missing_keys, - ).parse(value["in"]) - raise NotImplementedError( - "Although '%s' is a valid project operator for the " - "aggregation pipeline, it is currently not implemented " - "in Mongomock." % operator - ) - - def _handle_comparison_operator(self, operator, values): - assert len(values) == 2, "Comparison requires two expressions" - a = self.parse(values[0]) - b = self.parse(values[1]) - if operator == "$eq": - return a == b - if operator == "$ne": - return a != b - if operator in filtering.SORTING_OPERATOR_MAP: - return filtering.bson_compare( - filtering.SORTING_OPERATOR_MAP[operator], a, b - ) - raise NotImplementedError( - "Although '%s' is a valid comparison operator for the " - "aggregation pipeline, it is currently not implemented " - " in Mongomock." % operator - ) - - def _handle_string_operator(self, operator, values): - if operator == "$toLower": - parsed = self.parse(values) - return str(parsed).lower() if parsed is not None else "" - if operator == "$toUpper": - parsed = self.parse(values) - return str(parsed).upper() if parsed is not None else "" - if operator == "$concat": - parsed_list = list(self.parse_many(values)) - return ( - None if None in parsed_list else "".join([str(x) for x in parsed_list]) - ) - if operator == "$split": - if len(values) != 2: - raise OperationFailure("split must have 2 items") - try: - string = self.parse(values[0]) - delimiter = self.parse(values[1]) - except KeyError: - return None - - if string is None or delimiter is None: - return None - if not isinstance(string, str): - raise TypeError("split first argument must evaluate to string") - if not isinstance(delimiter, str): - raise TypeError("split second argument must evaluate to string") - return string.split(delimiter) - if operator == "$substr": - if len(values) != 3: - raise OperationFailure("substr must have 3 items") - string = str(self.parse(values[0])) - first = self.parse(values[1]) - length = self.parse(values[2]) - if string is None: - return "" - if first < 0: - warnings.warn( - "Negative starting point given to $substr is accepted only until " - "MongoDB 3.7. This behavior will change in the future." - ) - return "" - if length < 0: - warnings.warn( - "Negative length given to $substr is accepted only until " - "MongoDB 3.7. This behavior will change in the future." - ) - second = len(string) if length < 0 else first + length - return string[first:second] - if operator == "$strcasecmp": - if len(values) != 2: - raise OperationFailure("strcasecmp must have 2 items") - a, b = str(self.parse(values[0])), str(self.parse(values[1])) - return 0 if a == b else -1 if a < b else 1 - if operator == "$regexMatch": - if not isinstance(values, dict): - raise OperationFailure( - "$regexMatch expects an object of named arguments but found: %s" - % type(values) - ) - for field in ("input", "regex"): - if field not in values: - raise OperationFailure( - "$regexMatch requires '%s' parameter" % field - ) - unknown_args = set(values) - {"input", "regex", "options"} - if unknown_args: - raise OperationFailure( - "$regexMatch found an unknown argument: %s" % list(unknown_args)[0] - ) - - try: - input_value = self.parse(values["input"]) - except KeyError: - return False - if not isinstance(input_value, str): - raise OperationFailure("$regexMatch needs 'input' to be of type string") - - try: - regex_val = self.parse(values["regex"]) - except KeyError: - return False - options = None - for option in values.get("options", ""): - if option not in "imxs": - raise OperationFailure( - "$regexMatch invalid flag in regex options: %s" % option - ) - re_option = getattr(re, option.upper()) - if options is None: - options = re_option - else: - options |= re_option - if isinstance(regex_val, str): - if options is None: - regex = re.compile(regex_val) - else: - regex = re.compile(regex_val, options) - elif "options" in values and regex_val.flags: - raise OperationFailure( - "$regexMatch: regex option(s) specified in both 'regex' and 'option' fields" - ) - elif isinstance(regex_val, helpers.RE_TYPE): - if options and not regex_val.flags: - regex = re.compile(regex_val.pattern, options) - elif regex_val.flags & ~(re.I | re.M | re.X | re.S): - raise OperationFailure( - "$regexMatch invalid flag in regex options: %s" - % regex_val.flags - ) - else: - regex = regex_val - elif isinstance(regex_val, _RE_TYPES): - # bson.Regex - if regex_val.flags & ~(re.I | re.M | re.X | re.S): - raise OperationFailure( - "$regexMatch invalid flag in regex options: %s" - % regex_val.flags - ) - regex = re.compile(regex_val.pattern, regex_val.flags or options) - else: - raise OperationFailure( - "$regexMatch needs 'regex' to be of type string or regex" - ) - - return bool(regex.search(input_value)) - - # This should never happen: it is only a safe fallback if something went wrong. - raise NotImplementedError( # pragma: no cover - "Although '%s' is a valid string operator for the aggregation " - "pipeline, it is currently not implemented in Mongomock." % operator - ) - - def _handle_date_operator(self, operator, values): - if isinstance(values, dict) and values.keys() == {"date", "timezone"}: - value = self.parse(values["date"]) - target_tz = pytz.timezone(values["timezone"]) - out_value = value.replace(tzinfo=pytz.utc).astimezone(target_tz) - else: - out_value = self.parse(values) - - if operator == "$dayOfYear": - return out_value.timetuple().tm_yday - if operator == "$dayOfMonth": - return out_value.day - if operator == "$dayOfWeek": - return (out_value.isoweekday() % 7) + 1 - if operator == "$year": - return out_value.year - if operator == "$month": - return out_value.month - if operator == "$week": - return int(out_value.strftime("%U")) - if operator == "$hour": - return out_value.hour - if operator == "$minute": - return out_value.minute - if operator == "$second": - return out_value.second - if operator == "$millisecond": - return int(out_value.microsecond / 1000) - if operator == "$dateToString": - if not isinstance(values, dict): - raise OperationFailure( - "$dateToString operator must correspond a dict" - 'that has "format" and "date" field.' - ) - if not isinstance(values, dict) or not {"format", "date"} <= set(values): - raise OperationFailure( - "$dateToString operator must correspond a dict" - 'that has "format" and "date" field.' - ) - if "%L" in out_value["format"]: - raise NotImplementedError( - "Although %L is a valid date format for the " - "$dateToString operator, it is currently not implemented " - " in Mongomock." - ) - if "onNull" in values: - raise NotImplementedError( - "Although onNull is a valid field for the " - "$dateToString operator, it is currently not implemented " - " in Mongomock." - ) - if "timezone" in values.keys(): - raise NotImplementedError( - "Although timezone is a valid field for the " - "$dateToString operator, it is currently not implemented " - " in Mongomock." - ) - return out_value["date"].strftime(out_value["format"]) - if operator == "$dateFromParts": - if not isinstance(out_value, dict): - raise OperationFailure( - f"{operator} operator must correspond a dict " - 'that has "year" or "isoWeekYear" field.' - ) - if len(set(out_value) & {"year", "isoWeekYear"}) != 1: - raise OperationFailure( - f"{operator} operator must correspond a dict " - 'that has "year" or "isoWeekYear" field.' - ) - for field in ("isoWeekYear", "isoWeek", "isoDayOfWeek", "timezone"): - if field in out_value: - raise NotImplementedError( - f"Although {field} is a valid field for the " - f"{operator} operator, it is currently not implemented " - "in Mongomock." - ) - - year = out_value["year"] - month = out_value.get("month", 1) or 1 - day = out_value.get("day", 1) or 1 - hour = out_value.get("hour", 0) or 0 - minute = out_value.get("minute", 0) or 0 - second = out_value.get("second", 0) or 0 - millisecond = out_value.get("millisecond", 0) or 0 - - return datetime.datetime( - year=year, - month=month, - day=day, - hour=hour, - minute=minute, - second=second, - microsecond=millisecond, - ) - - raise NotImplementedError( - "Although '%s' is a valid date operator for the " - "aggregation pipeline, it is currently not implemented " - " in Mongomock." % operator - ) - - def _handle_array_operator(self, operator, value): - if operator == "$concatArrays": - if not isinstance(value, (list, tuple)): - value = [value] - - parsed_list = list(self.parse_many(value)) - for parsed_item in parsed_list: - if parsed_item is not None and not isinstance( - parsed_item, (list, tuple) - ): - raise OperationFailure( - "$concatArrays only supports arrays, not {}".format( - type(parsed_item) - ) - ) - - return ( - None - if None in parsed_list - else list(itertools.chain.from_iterable(parsed_list)) - ) - - if operator == "$map": - if not isinstance(value, dict): - raise OperationFailure("$map only supports an object as its argument") - - # NOTE: while the two validations below could be achieved with - # one-liner set operations (e.g. set(value) - {'input', 'as', - # 'in'}), we prefer the iteration-based approaches in order to - # mimic MongoDB's behavior regarding the order of evaluation. For - # example, MongoDB complains about 'input' parameter missing before - # 'in'. - for k in ("input", "in"): - if k not in value: - raise OperationFailure("Missing '%s' parameter to $map" % k) - - for k in value: - if k not in {"input", "as", "in"}: - raise OperationFailure("Unrecognized parameter to $map: %s" % k) - - input_array = self._parse_or_None(value["input"]) - - if input_array is None or input_array is None: - return None - - if not isinstance(input_array, (list, tuple)): - raise OperationFailure( - "input to $map must be an array not %s" % type(input_array) - ) - - fieldname = value.get("as", "this") - in_expr = value["in"] - return [ - _Parser( - self._doc_dict, - dict(self._user_vars, **{fieldname: item}), - ignore_missing_keys=self._ignore_missing_keys, - ).parse(in_expr) - for item in input_array - ] - - if operator == "$size": - if isinstance(value, list): - if len(value) != 1: - raise OperationFailure( - "Expression $size takes exactly 1 arguments. " - "%d were passed in." % len(value) - ) - value = value[0] - array_value = self._parse_or_None(value) - if not isinstance(array_value, (list, tuple)): - raise OperationFailure( - "The argument to $size must be an array, but was of type: %s" - % ("missing" if array_value is None else type(array_value)) - ) - return len(array_value) - - if operator == "$filter": - if not isinstance(value, dict): - raise OperationFailure( - "$filter only supports an object as its argument" - ) - extra_params = set(value) - {"input", "cond", "as"} - if extra_params: - raise OperationFailure( - "Unrecognized parameter to $filter: %s" % extra_params.pop() - ) - missing_params = {"input", "cond"} - set(value) - if missing_params: - raise OperationFailure( - "Missing '%s' parameter to $filter" % missing_params.pop() - ) - - input_array = self.parse(value["input"]) - fieldname = value.get("as", "this") - cond = value["cond"] - return [ - item - for item in input_array - if _Parser( - self._doc_dict, - dict(self._user_vars, **{fieldname: item}), - ignore_missing_keys=self._ignore_missing_keys, - ).parse(cond) - ] - if operator == "$slice": - if not isinstance(value, list): - raise OperationFailure("$slice only supports a list as its argument") - if len(value) < 2 or len(value) > 3: - raise OperationFailure( - "Expression $slice takes at least 2 arguments, and at most " - "3, but {} were passed in".format(len(value)) - ) - array_value = self.parse(value[0]) - if not isinstance(array_value, list): - raise OperationFailure( - "First argument to $slice must be an array, but is of type: {}".format( - type(array_value) - ) - ) - for num, v in zip(("Second", "Third"), value[1:]): - if not isinstance(v, int): - raise OperationFailure( - "{} argument to $slice must be numeric, but is of type: {}".format( - num, type(v) - ) - ) - if len(value) > 2 and value[2] <= 0: - raise OperationFailure( - "Third argument to $slice must be " "positive: {}".format(value[2]) - ) - - start = value[1] - if start < 0: - if len(value) > 2: - stop = len(array_value) + start + value[2] - else: - stop = None - elif len(value) > 2: - stop = start + value[2] - else: - stop = start - start = 0 - return array_value[start:stop] - - raise NotImplementedError( - "Although '%s' is a valid array operator for the " - "aggregation pipeline, it is currently not implemented " - "in Mongomock." % operator - ) - - def _handle_type_convertion_operator(self, operator, values): - if operator == "$toString": - try: - parsed = self.parse(values) - except KeyError: - return None - if isinstance(parsed, bool): - return str(parsed).lower() - if isinstance(parsed, datetime.datetime): - return parsed.isoformat()[:-3] + "Z" - return str(parsed) - - if operator == "$toInt": - try: - parsed = self.parse(values) - except KeyError: - return None - if decimal_support: - if isinstance(parsed, decimal128.Decimal128): - return int(parsed.to_decimal()) - return int(parsed) - raise NotImplementedError( - "You need to import the pymongo library to support decimal128 type." - ) - - if operator == "$toLong": - try: - parsed = self.parse(values) - except KeyError: - return None - if decimal_support: - if isinstance(parsed, decimal128.Decimal128): - return int(parsed.to_decimal()) - return int(parsed) - raise NotImplementedError( - "You need to import the pymongo library to support decimal128 type." - ) - - # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/toDecimal/ - if operator == "$toDecimal": - if not decimal_support: - raise NotImplementedError( - "You need to import the pymongo library to support decimal128 type." - ) - try: - parsed = self.parse(values) - except KeyError: - return None - if isinstance(parsed, bool): - parsed = "1" if parsed is True else "0" - decimal_value = decimal128.Decimal128(parsed) - elif isinstance(parsed, int): - decimal_value = decimal128.Decimal128(str(parsed)) - elif isinstance(parsed, float): - exp = decimal.Decimal(".00000000000000") - decimal_value = decimal.Decimal(str(parsed)).quantize(exp) - decimal_value = decimal128.Decimal128(decimal_value) - elif isinstance(parsed, decimal128.Decimal128): - decimal_value = parsed - elif isinstance(parsed, str): - try: - decimal_value = decimal128.Decimal128(parsed) - except decimal.InvalidOperation as err: - raise OperationFailure( - "Failed to parse number '%s' in $convert with no onError value:" - "Failed to parse string to decimal" % parsed - ) from err - elif isinstance(parsed, datetime.datetime): - epoch = datetime.datetime.utcfromtimestamp(0) - string_micro_seconds = str( - (parsed - epoch).total_seconds() * 1000 - ).split(".", 1)[0] - decimal_value = decimal128.Decimal128(string_micro_seconds) - else: - raise TypeError("'%s' type is not supported" % type(parsed)) - return decimal_value - - # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/arrayToObject/ - if operator == "$arrayToObject": - try: - parsed = self.parse(values) - except KeyError: - return None - - if parsed is None: - return None - - if not isinstance(parsed, (list, tuple)): - raise OperationFailure( - "$arrayToObject requires an array input, found: {}".format( - type(parsed) - ) - ) - - if all(isinstance(x, dict) and set(x.keys()) == {"k", "v"} for x in parsed): - return {d["k"]: d["v"] for d in parsed} - - if all(isinstance(x, (list, tuple)) and len(x) == 2 for x in parsed): - return dict(parsed) - - raise OperationFailure( - "arrays used with $arrayToObject must contain documents " - "with k and v fields or two-element arrays" - ) - - # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/objectToArray/ - if operator == "$objectToArray": - try: - parsed = self.parse(values) - except KeyError: - return None - - if parsed is None: - return None - - if not isinstance(parsed, (dict, collections.OrderedDict)): - raise OperationFailure( - "$objectToArray requires an object input, found: {}".format( - type(parsed) - ) - ) - - if len(parsed) > 1 and sys.version_info < (3, 6): - raise NotImplementedError( - "Although '%s' is a valid type conversion, it is not implemented for Python 2 " - "and Python 3.5 in Mongomock yet." % operator - ) - - return [{"k": k, "v": v} for k, v in parsed.items()] - - raise NotImplementedError( - "Although '%s' is a valid type conversion operator for the " - "aggregation pipeline, it is currently not implemented " - "in Mongomock." % operator - ) - - def _handle_type_operator(self, operator, values): - # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/isNumber/ - if operator == "$isNumber": - try: - parsed = self.parse(values) - except KeyError: - return False - return ( - False - if isinstance(parsed, bool) - else isinstance(parsed, numbers.Number) - ) - - # Document: https://docs.mongodb.com/manual/reference/operator/aggregation/isArray/ - if operator == "$isArray": - try: - parsed = self.parse(values) - except KeyError: - return False - return isinstance(parsed, (tuple, list)) - - raise NotImplementedError( # pragma: no cover - "Although '%s' is a valid type operator for the aggregation pipeline, it is currently " - "not implemented in Mongomock." % operator - ) - - def _handle_conditional_operator(self, operator, values): - # relative - from . import SERVER_VERSION - - if operator == "$ifNull": - fields = values[:-1] - if len(fields) > 1 and version.parse(SERVER_VERSION) <= version.parse( - "4.4" - ): - raise OperationFailure( - "$ifNull supports only one input expression " - " in MongoDB v4.4 and lower" - ) - fallback = values[-1] - for field in fields: - try: - out_value = self.parse(field) - if out_value is not None: - return out_value - except KeyError: - pass - return self.parse(fallback) - if operator == "$cond": - if isinstance(values, list): - condition, true_case, false_case = values - elif isinstance(values, dict): - condition = values["if"] - true_case = values["then"] - false_case = values["else"] - condition_value = self._parse_to_bool(condition) - expression = true_case if condition_value else false_case - return self.parse(expression) - # This should never happen: it is only a safe fallback if something went wrong. - raise NotImplementedError( # pragma: no cover - "Although '%s' is a valid conditional operator for the " - "aggregation pipeline, it is currently not implemented " - " in Mongomock." % operator - ) - - def _handle_control_flow_operator(self, operator, values): - if operator == "$switch": - if not isinstance(values, dict): - raise OperationFailure( - "$switch requires an object as an argument, " - "found: %s" % type(values) - ) - - branches = values.get("branches", []) - if not isinstance(branches, (list, tuple)): - raise OperationFailure( - "$switch expected an array for 'branches', " - "found: %s" % type(branches) - ) - if not branches: - raise OperationFailure("$switch requires at least one branch.") - - for branch in branches: - if not isinstance(branch, dict): - raise OperationFailure( - "$switch expected each branch to be an object, " - "found: %s" % type(branch) - ) - if "case" not in branch: - raise OperationFailure( - "$switch requires each branch have a 'case' expression" - ) - if "then" not in branch: - raise OperationFailure( - "$switch requires each branch have a 'then' expression." - ) - - for branch in branches: - if self._parse_to_bool(branch["case"]): - return self.parse(branch["then"]) - - if "default" not in values: - raise OperationFailure( - "$switch could not find a matching branch for an input, " - "and no default was specified." - ) - return self.parse(values["default"]) - - # This should never happen: it is only a safe fallback if something went wrong. - raise NotImplementedError( # pragma: no cover - "Although '%s' is a valid control flow operator for the " - "aggregation pipeline, it is currently not implemented " - "in Mongomock." % operator - ) - - def _handle_set_operator(self, operator, values): - if operator == "$in": - expression, array = values - return self.parse(expression) in self.parse(array) - if operator == "$setUnion": - result = [] - for set_value in values: - for value in self.parse(set_value): - if value not in result: - result.append(value) - return result - if operator == "$setEquals": - set_values = [set(self.parse(value)) for value in values] - for set1, set2 in itertools.combinations(set_values, 2): - if set1 != set2: - return False - return True - raise NotImplementedError( - "Although '%s' is a valid set operator for the aggregation " - "pipeline, it is currently not implemented in Mongomock." % operator - ) - - -def _parse_expression(expression, doc_dict, ignore_missing_keys=False): - """Parse an expression. - - Args: - expression: an Aggregate Expression, see - https://docs.mongodb.com/manual/meta/aggregation-quick-reference/#aggregation-expressions. - doc_dict: the document on which to evaluate the expression. - ignore_missing_keys: if True, missing keys evaluated by the expression are ignored silently - if it is possible. - """ - return _Parser(doc_dict, ignore_missing_keys=ignore_missing_keys).parse(expression) - - -filtering.register_parse_expression(_parse_expression) - - -def _accumulate_group(output_fields, group_list): - doc_dict = {} - for field, value in output_fields.items(): - if field == "_id": - continue - for operator, key in value.items(): - values = [] - for doc in group_list: - try: - values.append(_parse_expression(key, doc)) - except KeyError: - continue - if operator in _GROUPING_OPERATOR_MAP: - doc_dict[field] = _GROUPING_OPERATOR_MAP[operator](values) - elif operator == "$addToSet": - value = [] - val_it = (val or None for val in values) - # Don't use set in case elt in not hashable (like dicts). - for elt in val_it: - if elt not in value: - value.append(elt) - doc_dict[field] = value - elif operator == "$push": - if field not in doc_dict: - doc_dict[field] = values - else: - doc_dict[field].extend(values) - elif operator in group_operators: - raise NotImplementedError( - "Although %s is a valid group operator for the " - "aggregation pipeline, it is currently not implemented " - "in Mongomock." % operator - ) - else: - raise NotImplementedError( - "%s is not a valid group operator for the aggregation " - "pipeline. See http://docs.mongodb.org/manual/meta/" - "aggregation-quick-reference/ for a complete list of " - "valid operators." % operator - ) - return doc_dict - - -def _fix_sort_key(key_getter): - def fixed_getter(doc): - key = key_getter(doc) - # Convert dictionaries to make sorted() work in Python 3. - if isinstance(key, dict): - return [(k, v) for (k, v) in sorted(key.items())] - return key - - return fixed_getter - - -def _handle_lookup_stage(in_collection, database, options): - for operator in ("let", "pipeline"): - if operator in options: - raise NotImplementedError( - "Although '%s' is a valid lookup operator for the " - "aggregation pipeline, it is currently not " - "implemented in Mongomock." % operator - ) - for operator in ("from", "localField", "foreignField", "as"): - if operator not in options: - raise OperationFailure("Must specify '%s' field for a $lookup" % operator) - if not isinstance(options[operator], str): - raise OperationFailure("Arguments to $lookup must be strings") - if operator in ("as", "localField", "foreignField") and options[ - operator - ].startswith("$"): - raise OperationFailure("FieldPath field names may not start with '$'") - if operator == "as" and "." in options[operator]: - raise NotImplementedError( - "Although '.' is valid in the 'as' " - "parameters for the lookup stage of the aggregation " - "pipeline, it is currently not implemented in Mongomock." - ) - - foreign_name = options["from"] - local_field = options["localField"] - foreign_field = options["foreignField"] - local_name = options["as"] - foreign_collection = database.get_collection(foreign_name) - for doc in in_collection: - try: - query = helpers.get_value_by_dot(doc, local_field) - except KeyError: - query = None - if isinstance(query, list): - query = {"$in": query} - matches = foreign_collection.find({foreign_field: query}) - doc[local_name] = [foreign_doc for foreign_doc in matches] - - return in_collection - - -def _recursive_get(match, nested_fields): - head = match.get(nested_fields[0]) - remaining_fields = nested_fields[1:] - if not remaining_fields: - # Final/last field reached. - yield head - return - # More fields to go, must be list, tuple, or dict. - if isinstance(head, (list, tuple)): - for m in head: - # Yield from _recursive_get(m, remaining_fields). - for answer in _recursive_get(m, remaining_fields): - yield answer - elif isinstance(head, dict): - # Yield from _recursive_get(head, remaining_fields). - for answer in _recursive_get(head, remaining_fields): - yield answer - - -def _handle_graph_lookup_stage(in_collection, database, options): - if not isinstance(options.get("maxDepth", 0), int): - raise OperationFailure("Argument 'maxDepth' to $graphLookup must be a number") - if not isinstance(options.get("restrictSearchWithMatch", {}), dict): - raise OperationFailure( - "Argument 'restrictSearchWithMatch' to $graphLookup must be a Dictionary" - ) - if not isinstance(options.get("depthField", ""), str): - raise OperationFailure("Argument 'depthField' to $graphlookup must be a string") - if "startWith" not in options: - raise OperationFailure("Must specify 'startWith' field for a $graphLookup") - for operator in ("as", "connectFromField", "connectToField", "from"): - if operator not in options: - raise OperationFailure( - "Must specify '%s' field for a $graphLookup" % operator - ) - if not isinstance(options[operator], str): - raise OperationFailure( - "Argument '%s' to $graphLookup must be string" % operator - ) - if options[operator].startswith("$"): - raise OperationFailure("FieldPath field names may not start with '$'") - if operator == "as" and "." in options[operator]: - raise NotImplementedError( - "Although '.' is valid in the '%s' " - "parameter for the $graphLookup stage of the aggregation " - "pipeline, it is currently not implemented in Mongomock." % operator - ) - - foreign_name = options["from"] - start_with = options["startWith"] - connect_from_field = options["connectFromField"] - connect_to_field = options["connectToField"] - local_name = options["as"] - max_depth = options.get("maxDepth", None) - depth_field = options.get("depthField", None) - restrict_search_with_match = options.get("restrictSearchWithMatch", {}) - foreign_collection = database.get_collection(foreign_name) - out_doc = copy.deepcopy(in_collection) # TODO(pascal): speed the deep copy - - def _find_matches_for_depth(query): - if isinstance(query, list): - query = {"$in": query} - matches = foreign_collection.find({connect_to_field: query}) - new_matches = [] - for new_match in matches: - if ( - filtering.filter_applies(restrict_search_with_match, new_match) - and new_match["_id"] not in found_items - ): - if depth_field is not None: - new_match = collections.OrderedDict( - new_match, **{depth_field: depth} - ) - new_matches.append(new_match) - found_items.add(new_match["_id"]) - return new_matches - - for doc in out_doc: - found_items = set() - depth = 0 - try: - result = _parse_expression(start_with, doc) - except KeyError: - continue - origin_matches = doc[local_name] = _find_matches_for_depth(result) - while origin_matches and (max_depth is None or depth < max_depth): - depth += 1 - newly_discovered_matches = [] - for match in origin_matches: - nested_fields = connect_from_field.split(".") - for match_target in _recursive_get(match, nested_fields): - newly_discovered_matches += _find_matches_for_depth(match_target) - doc[local_name] += newly_discovered_matches - origin_matches = newly_discovered_matches - return out_doc - - -def _handle_group_stage(in_collection, unused_database, options): - grouped_collection = [] - _id = options["_id"] - if _id: - - def _key_getter(doc): - try: - return _parse_expression(_id, doc, ignore_missing_keys=True) - except KeyError: - return None - - def _sort_key_getter(doc): - return filtering.BsonComparable(_key_getter(doc)) - - # Sort the collection only for the itertools.groupby. - # $group does not order its output document. - sorted_collection = sorted(in_collection, key=_sort_key_getter) - grouped = itertools.groupby(sorted_collection, _key_getter) - else: - grouped = [(None, in_collection)] - - for doc_id, group in grouped: - group_list = [x for x in group] - doc_dict = _accumulate_group(options, group_list) - doc_dict["_id"] = doc_id - grouped_collection.append(doc_dict) - - return grouped_collection - - -def _handle_bucket_stage(in_collection, unused_database, options): - unknown_options = set(options) - {"groupBy", "boundaries", "output", "default"} - if unknown_options: - raise OperationFailure( - "Unrecognized option to $bucket: %s." % unknown_options.pop() - ) - if "groupBy" not in options or "boundaries" not in options: - raise OperationFailure( - "$bucket requires 'groupBy' and 'boundaries' to be specified." - ) - group_by = options["groupBy"] - boundaries = options["boundaries"] - if not isinstance(boundaries, list): - raise OperationFailure( - "The $bucket 'boundaries' field must be an array, but found type: %s" - % type(boundaries) - ) - if len(boundaries) < 2: - raise OperationFailure( - "The $bucket 'boundaries' field must have at least 2 values, but " - "found %d value(s)." % len(boundaries) - ) - if sorted(boundaries) != boundaries: - raise OperationFailure( - "The 'boundaries' option to $bucket must be sorted in ascending order" - ) - output_fields = options.get("output", {"count": {"$sum": 1}}) - default_value = options.get("default", None) - try: - is_default_last = default_value >= boundaries[-1] - except TypeError: - is_default_last = True - - def _get_default_bucket(): - try: - return options["default"] - except KeyError as err: - raise OperationFailure( - "$bucket could not find a matching branch for " - "an input, and no default was specified." - ) from err - - def _get_bucket_id(doc): - """Get the bucket ID for a document. - - Note that it actually returns a tuple with the first - param being a sort key to sort the default bucket even - if it's not the same type as the boundaries. - """ - try: - value = _parse_expression(group_by, doc) - except KeyError: - return (is_default_last, _get_default_bucket()) - index = bisect.bisect_right(boundaries, value) - if index and index < len(boundaries): - return (False, boundaries[index - 1]) - return (is_default_last, _get_default_bucket()) - - in_collection = ((_get_bucket_id(doc), doc) for doc in in_collection) - out_collection = sorted(in_collection, key=lambda kv: kv[0]) - grouped = itertools.groupby(out_collection, lambda kv: kv[0]) - - out_collection = [] - for (unused_key, doc_id), group in grouped: - group_list = [kv[1] for kv in group] - doc_dict = _accumulate_group(output_fields, group_list) - doc_dict["_id"] = doc_id - out_collection.append(doc_dict) - return out_collection - - -def _handle_sample_stage(in_collection, unused_database, options): - if not isinstance(options, dict): - raise OperationFailure("the $sample stage specification must be an object") - size = options.pop("size", None) - if size is None: - raise OperationFailure("$sample stage must specify a size") - if options: - raise OperationFailure( - "unrecognized option to $sample: %s" % set(options).pop() - ) - shuffled = list(in_collection) - _random.shuffle(shuffled) - return shuffled[:size] - - -def _handle_sort_stage(in_collection, unused_database, options): - sort_array = reversed([{x: y} for x, y in options.items()]) - sorted_collection = in_collection - for sort_pair in sort_array: - for sortKey, sortDirection in sort_pair.items(): - sorted_collection = sorted( - sorted_collection, - key=lambda x: filtering.resolve_sort_key(sortKey, x), - reverse=sortDirection < 0, - ) - return sorted_collection - - -def _handle_unwind_stage(in_collection, unused_database, options): - if not isinstance(options, dict): - options = {"path": options} - path = options["path"] - if not isinstance(path, str) or path[0] != "$": - raise ValueError( - "$unwind failed: exception: field path references must be prefixed " - "with a '$' '%s'" % path - ) - path = path[1:] - should_preserve_null_and_empty = options.get("preserveNullAndEmptyArrays") - include_array_index = options.get("includeArrayIndex") - unwound_collection = [] - for doc in in_collection: - try: - array_value = helpers.get_value_by_dot(doc, path) - except KeyError: - if should_preserve_null_and_empty: - unwound_collection.append(doc) - continue - if array_value is None: - if should_preserve_null_and_empty: - unwound_collection.append(doc) - continue - if array_value == []: - if should_preserve_null_and_empty: - new_doc = copy.deepcopy(doc) - # We just ran a get_value_by_dot so we know the value exists. - helpers.delete_value_by_dot(new_doc, path) - unwound_collection.append(new_doc) - continue - if isinstance(array_value, list): - iter_array = enumerate(array_value) - else: - iter_array = [(None, array_value)] - for index, field_item in iter_array: - new_doc = copy.deepcopy(doc) - new_doc = helpers.set_value_by_dot(new_doc, path, field_item) - if include_array_index: - new_doc = helpers.set_value_by_dot(new_doc, include_array_index, index) - unwound_collection.append(new_doc) - - return unwound_collection - - -# TODO(pascal): Combine with the equivalent function in collection but check -# what are the allowed overriding. -def _combine_projection_spec(filter_list, original_filter, prefix=""): - """Re-format a projection fields spec into a nested dictionary. - - e.g: ['a', 'b.c', 'b.d'] => {'a': 1, 'b': {'c': 1, 'd': 1}} - """ - if not isinstance(filter_list, list): - return filter_list - - filter_dict = collections.OrderedDict() - - for key in filter_list: - field, separator, subkey = key.partition(".") - if not separator: - if isinstance(filter_dict.get(field), list): - other_key = field + "." + filter_dict[field][0] - raise OperationFailure( - "Invalid $project :: caused by :: specification contains two conflicting paths." - " Cannot specify both %s and %s: %s" - % (repr(prefix + field), repr(prefix + other_key), original_filter) - ) - filter_dict[field] = 1 - continue - if not isinstance(filter_dict.get(field, []), list): - raise OperationFailure( - "Invalid $project :: caused by :: specification contains two conflicting paths." - " Cannot specify both %s and %s: %s" - % (repr(prefix + field), repr(prefix + key), original_filter) - ) - filter_dict[field] = filter_dict.get(field, []) + [subkey] - - return collections.OrderedDict( - (k, _combine_projection_spec(v, original_filter, prefix="%s%s." % (prefix, k))) - for k, v in filter_dict.items() - ) - - -def _project_by_spec(doc, proj_spec, is_include): - output = {} - for key, value in doc.items(): - if key not in proj_spec: - if not is_include: - output[key] = value - continue - - if not isinstance(proj_spec[key], dict): - if is_include: - output[key] = value - continue - - if isinstance(value, dict): - output[key] = _project_by_spec(value, proj_spec[key], is_include) - elif isinstance(value, list): - output[key] = [ - _project_by_spec(array_value, proj_spec[key], is_include) - for array_value in value - if isinstance(array_value, dict) - ] - elif not is_include: - output[key] = value - - return output - - -def _handle_replace_root_stage(in_collection, unused_database, options): - if "newRoot" not in options: - raise OperationFailure( - "Parameter 'newRoot' is missing for $replaceRoot operation." - ) - new_root = options["newRoot"] - out_collection = [] - for doc in in_collection: - try: - new_doc = _parse_expression(new_root, doc, ignore_missing_keys=True) - except KeyError: - new_doc = None - if not isinstance(new_doc, dict): - raise OperationFailure( - "'newRoot' expression must evaluate to an object, but resulting value was: {}".format( - new_doc - ) - ) - out_collection.append(new_doc) - return out_collection - - -def _handle_project_stage(in_collection, unused_database, options): - filter_list = [] - method = None - include_id = options.get("_id") - # Compute new values for each field, except inclusion/exclusions that are - # handled in one final step. - new_fields_collection = None - for field, value in options.items(): - if method is None and (field != "_id" or value): - method = "include" if value else "exclude" - elif method == "include" and not value and field != "_id": - raise OperationFailure( - "Bad projection specification, cannot exclude fields " - "other than '_id' in an inclusion projection: %s" % options - ) - elif method == "exclude" and value: - raise OperationFailure( - "Bad projection specification, cannot include fields " - "or add computed fields during an exclusion projection: %s" % options - ) - if value in (0, 1, True, False): - if field != "_id": - filter_list.append(field) - continue - if not new_fields_collection: - new_fields_collection = [{} for unused_doc in in_collection] - - for in_doc, out_doc in zip(in_collection, new_fields_collection): - try: - out_doc[field] = _parse_expression( - value, in_doc, ignore_missing_keys=True - ) - except KeyError: - # Ignore missing key. - pass - if (method == "include") == (include_id is not False and include_id != 0): - filter_list.append("_id") - - if not filter_list: - return new_fields_collection - - # Final steps: include or exclude fields and merge with newly created fields. - projection_spec = _combine_projection_spec(filter_list, original_filter=options) - out_collection = [ - _project_by_spec(doc, projection_spec, is_include=(method == "include")) - for doc in in_collection - ] - if new_fields_collection: - return [dict(a, **b) for a, b in zip(out_collection, new_fields_collection)] - return out_collection - - -def _handle_add_fields_stage(in_collection, unused_database, options): - if not options: - raise OperationFailure( - "Invalid $addFields :: caused by :: specification must have at least one field" - ) - out_collection = [dict(doc) for doc in in_collection] - for field, value in options.items(): - for in_doc, out_doc in zip(in_collection, out_collection): - try: - out_value = _parse_expression(value, in_doc, ignore_missing_keys=True) - except KeyError: - continue - parts = field.split(".") - for subfield in parts[:-1]: - out_doc[subfield] = out_doc.get(subfield, {}) - if not isinstance(out_doc[subfield], dict): - out_doc[subfield] = {} - out_doc = out_doc[subfield] - out_doc[parts[-1]] = out_value - return out_collection - - -def _handle_out_stage(in_collection, database, options): - # TODO(MetrodataTeam): should leave the origin collection unchanged - out_collection = database.get_collection(options) - if out_collection.find_one(): - out_collection.drop() - if in_collection: - out_collection.insert_many(in_collection) - return in_collection - - -def _handle_count_stage(in_collection, database, options): - if not isinstance(options, str) or options == "": - raise OperationFailure("the count field must be a non-empty string") - elif options.startswith("$"): - raise OperationFailure("the count field cannot be a $-prefixed path") - elif "." in options: - raise OperationFailure("the count field cannot contain '.'") - return [{options: len(in_collection)}] - - -def _handle_facet_stage(in_collection, database, options): - out_collection_by_pipeline = {} - for pipeline_title, pipeline in options.items(): - out_collection_by_pipeline[pipeline_title] = list( - process_pipeline(in_collection, database, pipeline, None) - ) - return [out_collection_by_pipeline] - - -def _handle_match_stage(in_collection, database, options): - spec = helpers.patch_datetime_awareness_in_document(options) - return [ - doc - for doc in in_collection - if filtering.filter_applies( - spec, helpers.patch_datetime_awareness_in_document(doc) - ) - ] - - -_PIPELINE_HANDLERS = { - "$addFields": _handle_add_fields_stage, - "$bucket": _handle_bucket_stage, - "$bucketAuto": None, - "$collStats": None, - "$count": _handle_count_stage, - "$currentOp": None, - "$facet": _handle_facet_stage, - "$geoNear": None, - "$graphLookup": _handle_graph_lookup_stage, - "$group": _handle_group_stage, - "$indexStats": None, - "$limit": lambda c, d, o: c[:o], - "$listLocalSessions": None, - "$listSessions": None, - "$lookup": _handle_lookup_stage, - "$match": _handle_match_stage, - "$merge": None, - "$out": _handle_out_stage, - "$planCacheStats": None, - "$project": _handle_project_stage, - "$redact": None, - "$replaceRoot": _handle_replace_root_stage, - "$replaceWith": None, - "$sample": _handle_sample_stage, - "$set": _handle_add_fields_stage, - "$skip": lambda c, d, o: c[o:], - "$sort": _handle_sort_stage, - "$sortByCount": None, - "$unset": None, - "$unwind": _handle_unwind_stage, -} - - -def process_pipeline(collection, database, pipeline, session): - if session: - raise NotImplementedError("Mongomock does not handle sessions yet") - - for stage in pipeline: - for operator, options in stage.items(): - try: - handler = _PIPELINE_HANDLERS[operator] - except KeyError as err: - raise NotImplementedError( - "%s is not a valid operator for the aggregation pipeline. " - "See http://docs.mongodb.org/manual/meta/aggregation-quick-reference/ " - "for a complete list of valid operators." % operator - ) from err - if not handler: - raise NotImplementedError( - "Although '%s' is a valid operator for the aggregation pipeline, it is " - "currently not implemented in Mongomock." % operator - ) - collection = handler(collection, database, options) - - return command_cursor.CommandCursor(collection) diff --git a/packages/syft/tests/mongomock/codec_options.py b/packages/syft/tests/mongomock/codec_options.py deleted file mode 100644 index e71eb41d672..00000000000 --- a/packages/syft/tests/mongomock/codec_options.py +++ /dev/null @@ -1,135 +0,0 @@ -"""Tools for specifying BSON codec options.""" - -# stdlib -import collections - -# third party -from packaging import version - -# relative -from . import helpers - -try: - # third party - from bson import codec_options - from pymongo.common import _UUID_REPRESENTATIONS -except ImportError: - codec_options = None - _UUID_REPRESENTATIONS = None - - -class TypeRegistry(object): - pass - - -_FIELDS = ( - "document_class", - "tz_aware", - "uuid_representation", - "unicode_decode_error_handler", - "tzinfo", -) - -if codec_options and helpers.PYMONGO_VERSION >= version.parse("3.8"): - _DEFAULT_TYPE_REGISTRY = codec_options.TypeRegistry() - _FIELDS = _FIELDS + ("type_registry",) -else: - _DEFAULT_TYPE_REGISTRY = TypeRegistry() - -if codec_options and helpers.PYMONGO_VERSION >= version.parse("4.3.0"): - _DATETIME_CONVERSION_VALUES = codec_options.DatetimeConversion._value2member_map_ - _DATETIME_CONVERSION_DEFAULT_VALUE = codec_options.DatetimeConversion.DATETIME - _FIELDS = _FIELDS + ("datetime_conversion",) -else: - _DATETIME_CONVERSION_VALUES = () - _DATETIME_CONVERSION_DEFAULT_VALUE = None - -# New default in Pymongo v4: -# https://pymongo.readthedocs.io/en/stable/examples/uuid.html#unspecified -if helpers.PYMONGO_VERSION >= version.parse("4.0"): - _DEFAULT_UUID_REPRESENTATION = 0 -else: - _DEFAULT_UUID_REPRESENTATION = 3 - - -class CodecOptions(collections.namedtuple("CodecOptions", _FIELDS)): - def __new__( - cls, - document_class=dict, - tz_aware=False, - uuid_representation=None, - unicode_decode_error_handler="strict", - tzinfo=None, - type_registry=None, - datetime_conversion=_DATETIME_CONVERSION_DEFAULT_VALUE, - ): - if document_class != dict: - raise NotImplementedError( - "Mongomock does not implement custom document_class yet: %r" - % document_class - ) - - if not isinstance(tz_aware, bool): - raise TypeError("tz_aware must be True or False") - - if uuid_representation is None: - uuid_representation = _DEFAULT_UUID_REPRESENTATION - - if unicode_decode_error_handler not in ("strict", None): - raise NotImplementedError( - "Mongomock does not handle custom unicode_decode_error_handler yet" - ) - - if tzinfo: - raise NotImplementedError("Mongomock does not handle custom tzinfo yet") - - values = ( - document_class, - tz_aware, - uuid_representation, - unicode_decode_error_handler, - tzinfo, - ) - - if "type_registry" in _FIELDS: - if not type_registry: - type_registry = _DEFAULT_TYPE_REGISTRY - values = values + (type_registry,) - - if "datetime_conversion" in _FIELDS: - if ( - datetime_conversion - and datetime_conversion not in _DATETIME_CONVERSION_VALUES - ): - raise TypeError( - "datetime_conversion must be member of DatetimeConversion" - ) - values = values + (datetime_conversion,) - - return tuple.__new__(cls, values) - - def with_options(self, **kwargs): - opts = self._asdict() - opts.update(kwargs) - return CodecOptions(**opts) - - def to_pymongo(self): - if not codec_options: - return None - - uuid_representation = self.uuid_representation - if _UUID_REPRESENTATIONS and isinstance(self.uuid_representation, str): - uuid_representation = _UUID_REPRESENTATIONS[uuid_representation] - - return codec_options.CodecOptions( - uuid_representation=uuid_representation, - unicode_decode_error_handler=self.unicode_decode_error_handler, - type_registry=self.type_registry, - ) - - -def is_supported(custom_codec_options): - if not custom_codec_options: - return None - - return CodecOptions(**custom_codec_options._asdict()) diff --git a/packages/syft/tests/mongomock/collection.py b/packages/syft/tests/mongomock/collection.py deleted file mode 100644 index 8a677300355..00000000000 --- a/packages/syft/tests/mongomock/collection.py +++ /dev/null @@ -1,2596 +0,0 @@ -# future -from __future__ import division - -# stdlib -import collections -from collections import OrderedDict -from collections.abc import Iterable -from collections.abc import Mapping -from collections.abc import MutableMapping -import copy -import functools -import itertools -import json -import math -import time -import warnings - -# third party -from packaging import version - -try: - # third party - from bson import BSON - from bson import SON - from bson import json_util - from bson.codec_options import CodecOptions - from bson.errors import InvalidDocument -except ImportError: - json_utils = SON = BSON = None - CodecOptions = None -try: - # third party - import execjs -except ImportError: - execjs = None - -try: - # third party - from pymongo import ReadPreference - from pymongo import ReturnDocument - from pymongo.operations import IndexModel - - _READ_PREFERENCE_PRIMARY = ReadPreference.PRIMARY -except ImportError: - - class IndexModel(object): - pass - - class ReturnDocument(object): - BEFORE = False - AFTER = True - - from .read_preferences import PRIMARY as _READ_PREFERENCE_PRIMARY - -# relative -from . import BulkWriteError -from . import ConfigurationError -from . import DuplicateKeyError -from . import InvalidOperation -from . import ObjectId -from . import OperationFailure -from . import WriteError -from . import aggregate -from . import codec_options as mongomock_codec_options -from . import filtering -from . import helpers -from . import utcnow -from .filtering import filter_applies -from .not_implemented import raise_for_feature as raise_not_implemented -from .results import BulkWriteResult -from .results import DeleteResult -from .results import InsertManyResult -from .results import InsertOneResult -from .results import UpdateResult -from .write_concern import WriteConcern - -try: - # third party - from pymongo.read_concern import ReadConcern -except ImportError: - # relative - from .read_concern import ReadConcern - -_KwargOption = collections.namedtuple("KwargOption", ["typename", "default", "attrs"]) - -_WITH_OPTIONS_KWARGS = { - "read_preference": _KwargOption( - "pymongo.read_preference.ReadPreference", - _READ_PREFERENCE_PRIMARY, - ("document", "mode", "mongos_mode", "max_staleness"), - ), - "write_concern": _KwargOption( - "pymongo.write_concern.WriteConcern", - WriteConcern(), - ("acknowledged", "document"), - ), -} - - -def _bson_encode(document, codec_options): - if CodecOptions: - if isinstance(codec_options, mongomock_codec_options.CodecOptions): - codec_options = codec_options.to_pymongo() - if isinstance(codec_options, CodecOptions): - BSON.encode(document, check_keys=True, codec_options=codec_options) - else: - BSON.encode(document, check_keys=True) - - -def validate_is_mapping(option, value): - if not isinstance(value, Mapping): - raise TypeError( - "%s must be an instance of dict, bson.son.SON, or " - "other type that inherits from " - "collections.Mapping" % (option,) - ) - - -def validate_is_mutable_mapping(option, value): - if not isinstance(value, MutableMapping): - raise TypeError( - "%s must be an instance of dict, bson.son.SON, or " - "other type that inherits from " - "collections.MutableMapping" % (option,) - ) - - -def validate_ok_for_replace(replacement): - validate_is_mapping("replacement", replacement) - if replacement: - first = next(iter(replacement)) - if first.startswith("$"): - raise ValueError("replacement can not include $ operators") - - -def validate_ok_for_update(update): - validate_is_mapping("update", update) - if not update: - raise ValueError("update only works with $ operators") - first = next(iter(update)) - if not first.startswith("$"): - raise ValueError("update only works with $ operators") - - -def validate_write_concern_params(**params): - if params: - WriteConcern(**params) - - -class BulkWriteOperation(object): - def __init__(self, builder, selector, is_upsert=False): - self.builder = builder - self.selector = selector - self.is_upsert = is_upsert - - def upsert(self): - assert not self.is_upsert - return BulkWriteOperation(self.builder, self.selector, is_upsert=True) - - def register_remove_op(self, multi, hint=None): - collection = self.builder.collection - selector = self.selector - - def exec_remove(): - if multi: - op_result = collection.delete_many(selector, hint=hint).raw_result - else: - op_result = collection.delete_one(selector, hint=hint).raw_result - if op_result.get("ok"): - return {"nRemoved": op_result.get("n")} - err = op_result.get("err") - if err: - return {"writeErrors": [err]} - return {} - - self.builder.executors.append(exec_remove) - - def remove(self): - assert not self.is_upsert - self.register_remove_op(multi=True) - - def remove_one( - self, - ): - assert not self.is_upsert - self.register_remove_op(multi=False) - - def register_update_op(self, document, multi, **extra_args): - if not extra_args.get("remove"): - validate_ok_for_update(document) - - collection = self.builder.collection - selector = self.selector - - def exec_update(): - result = collection._update( - spec=selector, - document=document, - multi=multi, - upsert=self.is_upsert, - **extra_args, - ) - ret_val = {} - if result.get("upserted"): - ret_val["upserted"] = result.get("upserted") - ret_val["nUpserted"] = result.get("n") - else: - matched = result.get("n") - if matched is not None: - ret_val["nMatched"] = matched - modified = result.get("nModified") - if modified is not None: - ret_val["nModified"] = modified - if result.get("err"): - ret_val["err"] = result.get("err") - return ret_val - - self.builder.executors.append(exec_update) - - def update(self, document, hint=None): - self.register_update_op(document, multi=True, hint=hint) - - def update_one(self, document, hint=None): - self.register_update_op(document, multi=False, hint=hint) - - def replace_one(self, document, hint=None): - self.register_update_op(document, multi=False, remove=True, hint=hint) - - -def _combine_projection_spec(projection_fields_spec): - """Re-format a projection fields spec into a nested dictionary. - - e.g: {'a': 1, 'b.c': 1, 'b.d': 1} => {'a': 1, 'b': {'c': 1, 'd': 1}} - """ - - tmp_spec = OrderedDict() - for f, v in projection_fields_spec.items(): - if "." not in f: - if isinstance(tmp_spec.get(f), dict): - if not v: - raise NotImplementedError( - "Mongomock does not support overriding excluding projection: %s" - % projection_fields_spec - ) - raise OperationFailure("Path collision at %s" % f) - tmp_spec[f] = v - else: - split_field = f.split(".", 1) - base_field, new_field = tuple(split_field) - if not isinstance(tmp_spec.get(base_field), dict): - if base_field in tmp_spec: - raise OperationFailure( - "Path collision at %s remaining portion %s" % (f, new_field) - ) - tmp_spec[base_field] = OrderedDict() - tmp_spec[base_field][new_field] = v - - combined_spec = OrderedDict() - for f, v in tmp_spec.items(): - if isinstance(v, dict): - combined_spec[f] = _combine_projection_spec(v) - else: - combined_spec[f] = v - - return combined_spec - - -def _project_by_spec(doc, combined_projection_spec, is_include, container): - if "$" in combined_projection_spec: - if is_include: - raise NotImplementedError( - "Positional projection is not implemented in mongomock" - ) - raise OperationFailure( - "Cannot exclude array elements with the positional operator" - ) - - doc_copy = container() - - for key, val in doc.items(): - spec = combined_projection_spec.get(key, None) - if isinstance(spec, dict): - if isinstance(val, (list, tuple)): - doc_copy[key] = [ - _project_by_spec(sub_doc, spec, is_include, container) - for sub_doc in val - ] - elif isinstance(val, dict): - doc_copy[key] = _project_by_spec(val, spec, is_include, container) - elif (is_include and spec is not None) or (not is_include and spec is None): - doc_copy[key] = _copy_field(val, container) - - return doc_copy - - -def _copy_field(obj, container): - if isinstance(obj, list): - new = [] - for item in obj: - new.append(_copy_field(item, container)) - return new - if isinstance(obj, dict): - new = container() - for key, value in obj.items(): - new[key] = _copy_field(value, container) - return new - return copy.copy(obj) - - -def _recursive_key_check_null_character(data): - for key, value in data.items(): - if "\0" in key: - raise InvalidDocument( - f"Field names cannot contain the null character (found: {key})" - ) - if isinstance(value, Mapping): - _recursive_key_check_null_character(value) - - -def _validate_data_fields(data): - _recursive_key_check_null_character(data) - for key in data.keys(): - if key.startswith("$"): - raise InvalidDocument( - f'Top-level field names cannot start with the "$" sign ' - f"(found: {key})" - ) - - -class BulkOperationBuilder(object): - def __init__(self, collection, ordered=False, bypass_document_validation=False): - self.collection = collection - self.ordered = ordered - self.results = {} - self.executors = [] - self.done = False - self._insert_returns_nModified = True - self._update_returns_nModified = True - self._bypass_document_validation = bypass_document_validation - - def find(self, selector): - return BulkWriteOperation(self, selector) - - def insert(self, doc): - def exec_insert(): - self.collection.insert_one( - doc, bypass_document_validation=self._bypass_document_validation - ) - return {"nInserted": 1} - - self.executors.append(exec_insert) - - def __aggregate_operation_result(self, total_result, key, value): - agg_val = total_result.get(key) - assert agg_val is not None, ( - "Unknow operation result %s=%s" " (unrecognized key)" % (key, value) - ) - if isinstance(agg_val, int): - total_result[key] += value - elif isinstance(agg_val, list): - if key == "upserted": - new_element = {"index": len(agg_val), "_id": value} - agg_val.append(new_element) - else: - agg_val.append(value) - else: - assert False, ( - "Fixme: missed aggreation rule for type: %s for" - " key {%s=%s}" - % ( - type(agg_val), - key, - agg_val, - ) - ) - - def _set_nModified_policy(self, insert, update): - self._insert_returns_nModified = insert - self._update_returns_nModified = update - - def execute(self, write_concern=None): - if not self.executors: - raise InvalidOperation("Bulk operation empty!") - if self.done: - raise InvalidOperation("Bulk operation already executed!") - self.done = True - result = { - "nModified": 0, - "nUpserted": 0, - "nMatched": 0, - "writeErrors": [], - "upserted": [], - "writeConcernErrors": [], - "nRemoved": 0, - "nInserted": 0, - } - - has_update = False - has_insert = False - broken_nModified_info = False - for index, execute_func in enumerate(self.executors): - exec_name = execute_func.__name__ - try: - op_result = execute_func() - except WriteError as error: - result["writeErrors"].append( - { - "index": index, - "code": error.code, - "errmsg": str(error), - } - ) - if self.ordered: - break - continue - for key, value in op_result.items(): - self.__aggregate_operation_result(result, key, value) - if exec_name == "exec_update": - has_update = True - if "nModified" not in op_result: - broken_nModified_info = True - has_insert |= exec_name == "exec_insert" - - if broken_nModified_info: - result.pop("nModified") - elif has_insert and self._insert_returns_nModified: - pass - elif has_update and self._update_returns_nModified: - pass - elif self._update_returns_nModified and self._insert_returns_nModified: - pass - else: - result.pop("nModified") - - if result.get("writeErrors"): - raise BulkWriteError(result) - - return result - - def add_insert(self, doc): - self.insert(doc) - - def add_update( - self, - selector, - doc, - multi=False, - upsert=False, - collation=None, - array_filters=None, - hint=None, - ): - if array_filters: - raise_not_implemented( - "array_filters", "Array filters are not implemented in mongomock yet." - ) - write_operation = BulkWriteOperation(self, selector, is_upsert=upsert) - write_operation.register_update_op(doc, multi, hint=hint) - - def add_replace(self, selector, doc, upsert, collation=None, hint=None): - write_operation = BulkWriteOperation(self, selector, is_upsert=upsert) - write_operation.replace_one(doc, hint=hint) - - def add_delete(self, selector, just_one, collation=None, hint=None): - write_operation = BulkWriteOperation(self, selector, is_upsert=False) - write_operation.register_remove_op(not just_one, hint=hint) - - -class Collection(object): - def __init__( - self, - database, - name, - _db_store, - write_concern=None, - read_concern=None, - read_preference=None, - codec_options=None, - ): - self.database = database - self._name = name - self._db_store = _db_store - self._write_concern = write_concern or WriteConcern() - if read_concern and not isinstance(read_concern, ReadConcern): - raise TypeError( - "read_concern must be an instance of pymongo.read_concern.ReadConcern" - ) - self._read_concern = read_concern or ReadConcern() - self._read_preference = read_preference or _READ_PREFERENCE_PRIMARY - self._codec_options = codec_options or mongomock_codec_options.CodecOptions() - - def __repr__(self): - return "Collection({0}, '{1}')".format(self.database, self.name) - - def __getitem__(self, name): - return self.database[self.name + "." + name] - - def __getattr__(self, attr): - if attr.startswith("_"): - raise AttributeError( - "%s has no attribute '%s'. To access the %s.%s collection, use database['%s.%s']." - % (self.__class__.__name__, attr, self.name, attr, self.name, attr) - ) - return self.__getitem__(attr) - - def __call__(self, *args, **kwargs): - name = self._name if "." not in self._name else self._name.split(".")[-1] - raise TypeError( - "'Collection' object is not callable. If you meant to call the '%s' method on a " - "'Collection' object it is failing because no such method exists." % name - ) - - def __eq__(self, other): - if isinstance(other, self.__class__): - return self.database == other.database and self.name == other.name - return NotImplemented - - if helpers.PYMONGO_VERSION >= version.parse("3.12"): - - def __hash__(self): - return hash((self.database, self.name)) - - @property - def full_name(self): - return "{0}.{1}".format(self.database.name, self._name) - - @property - def name(self): - return self._name - - @property - def write_concern(self): - return self._write_concern - - @property - def read_concern(self): - return self._read_concern - - @property - def read_preference(self): - return self._read_preference - - @property - def codec_options(self): - return self._codec_options - - def initialize_unordered_bulk_op(self, bypass_document_validation=False): - return BulkOperationBuilder( - self, ordered=False, bypass_document_validation=bypass_document_validation - ) - - def initialize_ordered_bulk_op(self, bypass_document_validation=False): - return BulkOperationBuilder( - self, ordered=True, bypass_document_validation=bypass_document_validation - ) - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def insert( - self, - data, - manipulate=True, - check_keys=True, - continue_on_error=False, - **kwargs, - ): - warnings.warn( - "insert is deprecated. Use insert_one or insert_many " "instead.", - DeprecationWarning, - stacklevel=2, - ) - validate_write_concern_params(**kwargs) - return self._insert(data) - - def insert_one(self, document, bypass_document_validation=False, session=None): - if not bypass_document_validation: - validate_is_mutable_mapping("document", document) - return InsertOneResult(self._insert(document, session), acknowledged=True) - - def insert_many( - self, documents, ordered=True, bypass_document_validation=False, session=None - ): - if not isinstance(documents, Iterable) or not documents: - raise TypeError("documents must be a non-empty list") - documents = list(documents) - if not bypass_document_validation: - for document in documents: - validate_is_mutable_mapping("document", document) - return InsertManyResult( - self._insert(documents, session, ordered=ordered), acknowledged=True - ) - - @property - def _store(self): - return self._db_store[self._name] - - def _insert(self, data, session=None, ordered=True): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - if not isinstance(data, Mapping): - results = [] - write_errors = [] - num_inserted = 0 - for index, item in enumerate(data): - try: - results.append(self._insert(item)) - except WriteError as error: - write_errors.append( - { - "index": index, - "code": error.code, - "errmsg": str(error), - "op": item, - } - ) - if ordered: - break - else: - continue - num_inserted += 1 - if write_errors: - raise BulkWriteError( - { - "writeErrors": write_errors, - "nInserted": num_inserted, - } - ) - return results - - if not all(isinstance(k, str) for k in data): - raise ValueError("Document keys must be strings") - - if BSON: - # bson validation - check_keys = helpers.PYMONGO_VERSION < version.parse("3.6") - if not check_keys: - _validate_data_fields(data) - - _bson_encode(data, self._codec_options) - - # Like pymongo, we should fill the _id in the inserted dict (odd behavior, - # but we need to stick to it), so we must patch in-place the data dict - if "_id" not in data: - data["_id"] = ObjectId() - - object_id = data["_id"] - if isinstance(object_id, dict): - object_id = helpers.hashdict(object_id) - if object_id in self._store: - raise DuplicateKeyError("E11000 Duplicate Key Error", 11000) - - data = helpers.patch_datetime_awareness_in_document(data) - - self._store[object_id] = data - try: - self._ensure_uniques(data) - except DuplicateKeyError: - # Rollback - del self._store[object_id] - raise - return data["_id"] - - def _ensure_uniques(self, new_data): - # Note we consider new_data is already inserted in db - for index in self._store.indexes.values(): - if not index.get("unique"): - continue - unique = index.get("key") - is_sparse = index.get("sparse") - partial_filter_expression = index.get("partialFilterExpression") - find_kwargs = {} - for key, _ in unique: - try: - find_kwargs[key] = helpers.get_value_by_dot(new_data, key) - except KeyError: - find_kwargs[key] = None - if is_sparse and set(find_kwargs.values()) == {None}: - continue - if partial_filter_expression is not None: - find_kwargs = {"$and": [partial_filter_expression, find_kwargs]} - answer_count = len(list(self._iter_documents(find_kwargs))) - if answer_count > 1: - raise DuplicateKeyError("E11000 Duplicate Key Error", 11000) - - def _internalize_dict(self, d): - return {k: copy.deepcopy(v) for k, v in d.items()} - - def _has_key(self, doc, key): - key_parts = key.split(".") - sub_doc = doc - for part in key_parts: - if part not in sub_doc: - return False - sub_doc = sub_doc[part] - return True - - def update_one( - self, - filter, - update, - upsert=False, - bypass_document_validation=False, - collation=None, - array_filters=None, - hint=None, - session=None, - let=None, - ): - if not bypass_document_validation: - validate_ok_for_update(update) - return UpdateResult( - self._update( - filter, - update, - upsert=upsert, - hint=hint, - session=session, - collation=collation, - array_filters=array_filters, - let=let, - ), - acknowledged=True, - ) - - def update_many( - self, - filter, - update, - upsert=False, - array_filters=None, - bypass_document_validation=False, - collation=None, - hint=None, - session=None, - let=None, - ): - if not bypass_document_validation: - validate_ok_for_update(update) - return UpdateResult( - self._update( - filter, - update, - upsert=upsert, - multi=True, - hint=hint, - session=session, - collation=collation, - array_filters=array_filters, - let=let, - ), - acknowledged=True, - ) - - def replace_one( - self, - filter, - replacement, - upsert=False, - bypass_document_validation=False, - session=None, - hint=None, - ): - if not bypass_document_validation: - validate_ok_for_replace(replacement) - return UpdateResult( - self._update( - filter, replacement, upsert=upsert, hint=hint, session=session - ), - acknowledged=True, - ) - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def update( - self, - spec, - document, - upsert=False, - manipulate=False, - multi=False, - check_keys=False, - **kwargs, - ): - warnings.warn( - "update is deprecated. Use replace_one, update_one or " - "update_many instead.", - DeprecationWarning, - stacklevel=2, - ) - return self._update( - spec, document, upsert, manipulate, multi, check_keys, **kwargs - ) - - def _update( - self, - spec, - document, - upsert=False, - manipulate=False, - multi=False, - check_keys=False, - hint=None, - session=None, - collation=None, - let=None, - array_filters=None, - **kwargs, - ): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - if hint: - raise NotImplementedError( - "The hint argument of update is valid but has not been implemented in " - "mongomock yet" - ) - if collation: - raise_not_implemented( - "collation", - "The collation argument of update is valid but has not been implemented in " - "mongomock yet", - ) - if array_filters: - raise_not_implemented( - "array_filters", "Array filters are not implemented in mongomock yet." - ) - if let: - raise_not_implemented( - "let", - "The let argument of update is valid but has not been implemented in mongomock " - "yet", - ) - spec = helpers.patch_datetime_awareness_in_document(spec) - document = helpers.patch_datetime_awareness_in_document(document) - validate_is_mapping("spec", spec) - validate_is_mapping("document", document) - - if self.database.client.server_info()["versionArray"] < [5]: - for operator in _updaters: - if not document.get(operator, True): - raise WriteError( - "'%s' is empty. You must specify a field like so: {%s: {: ...}}" - % (operator, operator), - ) - - updated_existing = False - upserted_id = None - num_updated = 0 - num_matched = 0 - for existing_document in itertools.chain(self._iter_documents(spec), [None]): - # we need was_insert for the setOnInsert update operation - was_insert = False - # the sentinel document means we should do an upsert - if existing_document is None: - if not upsert or num_matched: - continue - # For upsert operation we have first to create a fake existing_document, - # update it like a regular one, then finally insert it - if spec.get("_id") is not None: - _id = spec["_id"] - elif document.get("_id") is not None: - _id = document["_id"] - else: - _id = ObjectId() - to_insert = dict(spec, _id=_id) - to_insert = self._expand_dots(to_insert) - to_insert, _ = self._discard_operators(to_insert) - existing_document = to_insert - was_insert = True - else: - original_document_snapshot = copy.deepcopy(existing_document) - updated_existing = True - num_matched += 1 - first = True - subdocument = None - for k, v in document.items(): - if k in _updaters: - updater = _updaters[k] - subdocument = ( - self._update_document_fields_with_positional_awareness( - existing_document, v, spec, updater, subdocument - ) - ) - - elif k == "$rename": - for src, dst in v.items(): - if "." in src or "." in dst: - raise NotImplementedError( - "Using the $rename operator with dots is a valid MongoDB " - "operation, but it is not yet supported by mongomock" - ) - if self._has_key(existing_document, src): - existing_document[dst] = existing_document.pop(src) - - elif k == "$setOnInsert": - if not was_insert: - continue - subdocument = ( - self._update_document_fields_with_positional_awareness( - existing_document, v, spec, _set_updater, subdocument - ) - ) - - elif k == "$currentDate": - subdocument = ( - self._update_document_fields_with_positional_awareness( - existing_document, - v, - spec, - _current_date_updater, - subdocument, - ) - ) - - elif k == "$addToSet": - for field, value in v.items(): - nested_field_list = field.rsplit(".") - if len(nested_field_list) == 1: - if field not in existing_document: - existing_document[field] = [] - # document should be a list append to it - if isinstance(value, dict): - if "$each" in value: - # append the list to the field - existing_document[field] += [ - obj - for obj in list(value["$each"]) - if obj not in existing_document[field] - ] - continue - if value not in existing_document[field]: - existing_document[field].append(value) - continue - # push to array in a nested attribute - else: - # create nested attributes if they do not exist - subdocument = existing_document - for field_part in nested_field_list[:-1]: - if field_part == "$": - break - if field_part not in subdocument: - subdocument[field_part] = {} - - subdocument = subdocument[field_part] - - # get subdocument with $ oprator support - subdocument, _ = self._get_subdocument( - existing_document, spec, nested_field_list - ) - - # we're pushing a list - push_results = [] - if nested_field_list[-1] in subdocument: - # if the list exists, then use that list - push_results = subdocument[nested_field_list[-1]] - - if isinstance(value, dict) and "$each" in value: - push_results += [ - obj - for obj in list(value["$each"]) - if obj not in push_results - ] - elif value not in push_results: - push_results.append(value) - - subdocument[nested_field_list[-1]] = push_results - elif k == "$pull": - for field, value in v.items(): - nested_field_list = field.rsplit(".") - # nested fields includes a positional element - # need to find that element - if "$" in nested_field_list: - if not subdocument: - subdocument, _ = self._get_subdocument( - existing_document, spec, nested_field_list - ) - - # value should be a dictionary since we're pulling - pull_results = [] - # and the last subdoc should be an array - for obj in subdocument[nested_field_list[-1]]: - if isinstance(obj, dict): - for pull_key, pull_value in value.items(): - if obj[pull_key] != pull_value: - pull_results.append(obj) - continue - if obj != value: - pull_results.append(obj) - - # cannot write to doc directly as it doesn't save to - # existing_document - subdocument[nested_field_list[-1]] = pull_results - else: - arr = existing_document - for field_part in nested_field_list: - if field_part not in arr: - break - arr = arr[field_part] - if not isinstance(arr, list): - continue - - arr_copy = copy.deepcopy(arr) - if isinstance(value, dict): - for obj in arr_copy: - try: - is_matching = filter_applies(value, obj) - except OperationFailure: - is_matching = False - if is_matching: - arr.remove(obj) - continue - - if filter_applies({"field": value}, {"field": obj}): - arr.remove(obj) - else: - for obj in arr_copy: - if value == obj: - arr.remove(obj) - elif k == "$pullAll": - for field, value in v.items(): - nested_field_list = field.rsplit(".") - if len(nested_field_list) == 1: - if field in existing_document: - arr = existing_document[field] - existing_document[field] = [ - obj for obj in arr if obj not in value - ] - continue - else: - subdocument, _ = self._get_subdocument( - existing_document, spec, nested_field_list - ) - - if nested_field_list[-1] in subdocument: - arr = subdocument[nested_field_list[-1]] - subdocument[nested_field_list[-1]] = [ - obj for obj in arr if obj not in value - ] - elif k == "$push": - for field, value in v.items(): - # Find the place where to push. - nested_field_list = field.rsplit(".") - subdocument, field = self._get_subdocument( - existing_document, spec, nested_field_list - ) - - # Push the new element or elements. - if isinstance(subdocument, dict) and field not in subdocument: - subdocument[field] = [] - push_results = subdocument[field] - if isinstance(value, dict) and "$each" in value: - if "$position" in value: - push_results = ( - push_results[0 : value["$position"]] - + list(value["$each"]) - + push_results[value["$position"] :] - ) - else: - push_results += list(value["$each"]) - - if "$sort" in value: - sort_spec = value["$sort"] - if isinstance(sort_spec, dict): - sort_key = set(sort_spec.keys()).pop() - push_results = sorted( - push_results, - key=lambda d: helpers.get_value_by_dot( - d, sort_key - ), - reverse=set(sort_spec.values()).pop() < 0, - ) - else: - push_results = sorted( - push_results, reverse=sort_spec < 0 - ) - - if "$slice" in value: - slice_value = value["$slice"] - if slice_value < 0: - push_results = push_results[slice_value:] - elif slice_value == 0: - push_results = [] - else: - push_results = push_results[:slice_value] - - unused_modifiers = set(value.keys()) - { - "$each", - "$slice", - "$position", - "$sort", - } - if unused_modifiers: - raise WriteError( - "Unrecognized clause in $push: " - + unused_modifiers.pop() - ) - else: - push_results.append(value) - subdocument[field] = push_results - else: - if first: - # replace entire document - for key in document.keys(): - if key.startswith("$"): - # can't mix modifiers with non-modifiers in - # update - raise ValueError( - "field names cannot start with $ [{}]".format(k) - ) - _id = spec.get("_id", existing_document.get("_id")) - existing_document.clear() - if _id is not None: - existing_document["_id"] = _id - if BSON: - # bson validation - check_keys = helpers.PYMONGO_VERSION < version.parse("3.6") - if not check_keys: - _validate_data_fields(document) - _bson_encode(document, self.codec_options) - existing_document.update(self._internalize_dict(document)) - if existing_document["_id"] != _id: - raise OperationFailure( - "The _id field cannot be changed from {0} to {1}".format( - existing_document["_id"], _id - ) - ) - break - else: - # can't mix modifiers with non-modifiers in update - raise ValueError("Invalid modifier specified: {}".format(k)) - first = False - # if empty document comes - if not document: - _id = spec.get("_id", existing_document.get("_id")) - existing_document.clear() - if _id: - existing_document["_id"] = _id - - if was_insert: - upserted_id = self._insert(existing_document) - num_updated += 1 - elif existing_document != original_document_snapshot: - # Document has been modified in-place. - - # Make sure the ID was not change. - if original_document_snapshot.get("_id") != existing_document.get( - "_id" - ): - # Rollback. - self._store[original_document_snapshot["_id"]] = ( - original_document_snapshot - ) - raise WriteError( - "After applying the update, the (immutable) field '_id' was found to have " - "been altered to _id: {}".format(existing_document.get("_id")) - ) - - # Make sure it still respect the unique indexes and, if not, to - # revert modifications - try: - self._ensure_uniques(existing_document) - num_updated += 1 - except DuplicateKeyError: - # Rollback. - self._store[original_document_snapshot["_id"]] = ( - original_document_snapshot - ) - raise - - if not multi: - break - - return { - "connectionId": self.database.client._id, - "err": None, - "n": num_matched, - "nModified": num_updated if updated_existing else 0, - "ok": 1, - "upserted": upserted_id, - "updatedExisting": updated_existing, - } - - def _get_subdocument(self, existing_document, spec, nested_field_list): - """This method retrieves the subdocument of the existing_document.nested_field_list. - - It uses the spec to filter through the items. It will continue to grab nested documents - until it can go no further. It will then return the subdocument that was last saved. - '$' is the positional operator, so we use the $elemMatch in the spec to find the right - subdocument in the array. - """ - # Current document in view. - doc = existing_document - # Previous document in view. - parent_doc = existing_document - # Current spec in view. - subspec = spec - # Whether spec is following the document. - is_following_spec = True - # Walk down the dictionary. - for index, subfield in enumerate(nested_field_list): - if subfield == "$": - if not is_following_spec: - raise WriteError( - "The positional operator did not find the match needed from the query" - ) - # Positional element should have the equivalent elemMatch in the query. - subspec = subspec["$elemMatch"] - is_following_spec = False - # Iterate through. - for spec_index, item in enumerate(doc): - if filter_applies(subspec, item): - subfield = spec_index - break - else: - raise WriteError( - "The positional operator did not find the match needed from the query" - ) - - parent_doc = doc - if isinstance(parent_doc, list): - subfield = int(subfield) - if is_following_spec and (subfield < 0 or subfield >= len(subspec)): - is_following_spec = False - - if index == len(nested_field_list) - 1: - return parent_doc, subfield - - if not isinstance(parent_doc, list): - if subfield not in parent_doc: - parent_doc[subfield] = {} - if is_following_spec and subfield not in subspec: - is_following_spec = False - - doc = parent_doc[subfield] - if is_following_spec: - subspec = subspec[subfield] - - def _expand_dots(self, doc): - expanded = {} - paths = {} - for k, v in doc.items(): - - def _raise_incompatible(subkey): - raise WriteError( - "cannot infer query fields to set, both paths '%s' and '%s' are matched" - % (k, paths[subkey]) - ) - - if k in paths: - _raise_incompatible(k) - - key_parts = k.split(".") - sub_expanded = expanded - - paths[k] = k - for i, key_part in enumerate(key_parts[:-1]): - if key_part not in sub_expanded: - sub_expanded[key_part] = {} - sub_expanded = sub_expanded[key_part] - key = ".".join(key_parts[: i + 1]) - if not isinstance(sub_expanded, dict): - _raise_incompatible(key) - paths[key] = k - sub_expanded[key_parts[-1]] = v - return expanded - - def _discard_operators(self, doc): - if not doc or not isinstance(doc, dict): - return doc, False - new_doc = OrderedDict() - for k, v in doc.items(): - if k == "$eq": - return v, False - if k.startswith("$"): - continue - new_v, discarded = self._discard_operators(v) - if not discarded: - new_doc[k] = new_v - return new_doc, not bool(new_doc) - - def find( - self, - filter=None, - projection=None, - skip=0, - limit=0, - no_cursor_timeout=False, - cursor_type=None, - sort=None, - allow_partial_results=False, - oplog_replay=False, - modifiers=None, - batch_size=0, - manipulate=True, - collation=None, - session=None, - max_time_ms=None, - allow_disk_use=False, - **kwargs, - ): - spec = filter - if spec is None: - spec = {} - validate_is_mapping("filter", spec) - for kwarg, value in kwargs.items(): - if value: - raise OperationFailure("Unrecognized field '%s'" % kwarg) - return ( - Cursor(self, spec, sort, projection, skip, limit, collation=collation) - .max_time_ms(max_time_ms) - .allow_disk_use(allow_disk_use) - ) - - def _get_dataset(self, spec, sort, fields, as_class): - dataset = self._iter_documents(spec) - if sort: - for sort_key, sort_direction in reversed(sort): - if sort_key == "$natural": - if sort_direction < 0: - dataset = iter(reversed(list(dataset))) - continue - if sort_key.startswith("$"): - raise NotImplementedError( - "Sorting by {} is not implemented in mongomock yet".format( - sort_key - ) - ) - dataset = iter( - sorted( - dataset, - key=lambda x: filtering.resolve_sort_key(sort_key, x), - reverse=sort_direction < 0, - ) - ) - for document in dataset: - yield self._copy_only_fields(document, fields, as_class) - - def _extract_projection_operators(self, fields): - """Removes and returns fields with projection operators.""" - result = {} - allowed_projection_operators = {"$elemMatch", "$slice"} - for key, value in fields.items(): - if isinstance(value, dict): - for op in value: - if op not in allowed_projection_operators: - raise ValueError("Unsupported projection option: {}".format(op)) - result[key] = value - - for key in result: - del fields[key] - - return result - - def _apply_projection_operators(self, ops, doc, doc_copy): - """Applies projection operators to copied document.""" - for field, op in ops.items(): - if field not in doc_copy: - if field in doc: - # field was not copied yet (since we are in include mode) - doc_copy[field] = doc[field] - else: - # field doesn't exist in original document, no work to do - continue - - if "$slice" in op: - if not isinstance(doc_copy[field], list): - raise OperationFailure( - "Unsupported type {} for slicing operation: {}".format( - type(doc_copy[field]), op - ) - ) - op_value = op["$slice"] - slice_ = None - if isinstance(op_value, list): - if len(op_value) != 2: - raise OperationFailure( - "Unsupported slice format {} for slicing operation: {}".format( - op_value, op - ) - ) - skip, limit = op_value - if skip < 0: - skip = len(doc_copy[field]) + skip - last = min(skip + limit, len(doc_copy[field])) - slice_ = slice(skip, last) - elif isinstance(op_value, int): - count = op_value - start = 0 - end = len(doc_copy[field]) - if count < 0: - start = max(0, len(doc_copy[field]) + count) - else: - end = min(count, len(doc_copy[field])) - slice_ = slice(start, end) - - if slice_: - doc_copy[field] = doc_copy[field][slice_] - else: - raise OperationFailure( - "Unsupported slice value {} for slicing operation: {}".format( - op_value, op - ) - ) - - if "$elemMatch" in op: - if isinstance(doc_copy[field], list): - # find the first item that matches - matched = False - for item in doc_copy[field]: - if filter_applies(op["$elemMatch"], item): - matched = True - doc_copy[field] = [item] - break - - # None have matched - if not matched: - del doc_copy[field] - - else: - # remove the field since there is None to iterate - del doc_copy[field] - - def _copy_only_fields(self, doc, fields, container): - """Copy only the specified fields.""" - - # https://pymongo.readthedocs.io/en/stable/migrate-to-pymongo4.html#collection-find-returns-entire-document-with-empty-projection - if ( - fields is None - or not fields - and helpers.PYMONGO_VERSION >= version.parse("4.0") - ): - return _copy_field(doc, container) - - if not fields: - fields = {"_id": 1} - if not isinstance(fields, dict): - fields = helpers.fields_list_to_dict(fields) - - # we can pass in something like {'_id':0, 'field':1}, so pull the id - # value out and hang on to it until later - id_value = fields.pop("_id", 1) - - # filter out fields with projection operators, we will take care of them later - projection_operators = self._extract_projection_operators(fields) - - # other than the _id field, all fields must be either includes or - # excludes, this can evaluate to 0 - if len(set(list(fields.values()))) > 1: - raise ValueError("You cannot currently mix including and excluding fields.") - - # if we have novalues passed in, make a doc_copy based on the - # id_value - if not fields: - if id_value == 1: - doc_copy = container() - else: - doc_copy = _copy_field(doc, container) - else: - doc_copy = _project_by_spec( - doc, - _combine_projection_spec(fields), - is_include=list(fields.values())[0], - container=container, - ) - - # set the _id value if we requested it, otherwise remove it - if id_value == 0: - doc_copy.pop("_id", None) - else: - if "_id" in doc: - doc_copy["_id"] = doc["_id"] - - fields["_id"] = id_value # put _id back in fields - - # time to apply the projection operators and put back their fields - self._apply_projection_operators(projection_operators, doc, doc_copy) - for field, op in projection_operators.items(): - fields[field] = op - return doc_copy - - def _update_document_fields(self, doc, fields, updater): - """Implements the $set behavior on an existing document""" - for k, v in fields.items(): - self._update_document_single_field(doc, k, v, updater) - - def _update_document_fields_positional( - self, doc, fields, spec, updater, subdocument=None - ): - """Implements the $set behavior on an existing document""" - for k, v in fields.items(): - if "$" in k: - field_name_parts = k.split(".") - if not subdocument: - current_doc = doc - subspec = spec - for part in field_name_parts[:-1]: - if part == "$": - subspec_dollar = subspec.get("$elemMatch", subspec) - for item in current_doc: - if filter_applies(subspec_dollar, item): - current_doc = item - break - continue - - new_spec = {} - for el in subspec: - if el.startswith(part): - if len(el.split(".")) > 1: - new_spec[".".join(el.split(".")[1:])] = subspec[el] - else: - new_spec = subspec[el] - subspec = new_spec - current_doc = current_doc[part] - - subdocument = current_doc - if field_name_parts[-1] == "$" and isinstance(subdocument, list): - for i, doc in enumerate(subdocument): - subspec_dollar = subspec.get("$elemMatch", subspec) - if filter_applies(subspec_dollar, doc): - subdocument[i] = v - break - continue - - updater(subdocument, field_name_parts[-1], v) - continue - # otherwise, we handle it the standard way - self._update_document_single_field(doc, k, v, updater) - - return subdocument - - def _update_document_fields_with_positional_awareness( - self, existing_document, v, spec, updater, subdocument - ): - positional = any("$" in key for key in v.keys()) - - if positional: - return self._update_document_fields_positional( - existing_document, v, spec, updater, subdocument - ) - self._update_document_fields(existing_document, v, updater) - return subdocument - - def _update_document_single_field(self, doc, field_name, field_value, updater): - field_name_parts = field_name.split(".") - for part in field_name_parts[:-1]: - if isinstance(doc, list): - try: - if part == "$": - doc = doc[0] - else: - doc = doc[int(part)] - continue - except ValueError: - pass - elif isinstance(doc, dict): - if updater is _unset_updater and part not in doc: - # If the parent doesn't exists, so does it child. - return - doc = doc.setdefault(part, {}) - else: - return - field_name = field_name_parts[-1] - updater(doc, field_name, field_value, codec_options=self._codec_options) - - def _iter_documents(self, filter): - # Validate the filter even if no documents can be returned. - if self._store.is_empty: - filter_applies(filter, {}) - - return ( - document - for document in list(self._store.documents) - if filter_applies(filter, document) - ) - - def find_one(self, filter=None, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg - # Allow calling find_one with a non-dict argument that gets used as - # the id for the query. - if filter is None: - filter = {} - if not isinstance(filter, Mapping): - filter = {"_id": filter} - - try: - return next(self.find(filter, *args, **kwargs)) - except StopIteration: - return None - - def find_one_and_delete(self, filter, projection=None, sort=None, **kwargs): - kwargs["remove"] = True - validate_is_mapping("filter", filter) - return self._find_and_modify(filter, projection, sort=sort, **kwargs) - - def find_one_and_replace( - self, - filter, - replacement, - projection=None, - sort=None, - upsert=False, - return_document=ReturnDocument.BEFORE, - **kwargs, - ): - validate_is_mapping("filter", filter) - validate_ok_for_replace(replacement) - return self._find_and_modify( - filter, projection, replacement, upsert, sort, return_document, **kwargs - ) - - def find_one_and_update( - self, - filter, - update, - projection=None, - sort=None, - upsert=False, - return_document=ReturnDocument.BEFORE, - **kwargs, - ): - validate_is_mapping("filter", filter) - validate_ok_for_update(update) - return self._find_and_modify( - filter, projection, update, upsert, sort, return_document, **kwargs - ) - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def find_and_modify( - self, - query={}, - update=None, - upsert=False, - sort=None, - full_response=False, - manipulate=False, - fields=None, - **kwargs, - ): - warnings.warn( - "find_and_modify is deprecated, use find_one_and_delete" - ", find_one_and_replace, or find_one_and_update instead", - DeprecationWarning, - stacklevel=2, - ) - if "projection" in kwargs: - raise TypeError( - "find_and_modify() got an unexpected keyword argument 'projection'" - ) - return self._find_and_modify( - query, - update=update, - upsert=upsert, - sort=sort, - projection=fields, - **kwargs, - ) - - def _find_and_modify( - self, - query, - projection=None, - update=None, - upsert=False, - sort=None, - return_document=ReturnDocument.BEFORE, - session=None, - **kwargs, - ): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - remove = kwargs.get("remove", False) - if kwargs.get("new", False) and remove: - # message from mongodb - raise OperationFailure("remove and returnNew can't co-exist") - - if not (remove or update): - raise ValueError("Must either update or remove") - - if remove and update: - raise ValueError("Can't do both update and remove") - - old = self.find_one(query, projection=projection, sort=sort) - if not old and not upsert: - return - - if old and "_id" in old: - query = {"_id": old["_id"]} - - if remove: - self.delete_one(query) - else: - updated = self._update(query, update, upsert) - if updated["upserted"]: - query = {"_id": updated["upserted"]} - - if return_document is ReturnDocument.AFTER or kwargs.get("new"): - return self.find_one(query, projection) - return old - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def save(self, to_save, manipulate=True, check_keys=True, **kwargs): - warnings.warn( - "save is deprecated. Use insert_one or replace_one " "instead", - DeprecationWarning, - stacklevel=2, - ) - validate_is_mutable_mapping("to_save", to_save) - validate_write_concern_params(**kwargs) - - if "_id" not in to_save: - return self.insert(to_save) - self._update( - {"_id": to_save["_id"]}, - to_save, - True, - manipulate, - check_keys=True, - **kwargs, - ) - return to_save.get("_id", None) - - def delete_one(self, filter, collation=None, hint=None, session=None): - validate_is_mapping("filter", filter) - return DeleteResult( - self._delete(filter, collation=collation, hint=hint, session=session), True - ) - - def delete_many(self, filter, collation=None, hint=None, session=None): - validate_is_mapping("filter", filter) - return DeleteResult( - self._delete( - filter, collation=collation, hint=hint, multi=True, session=session - ), - True, - ) - - def _delete(self, filter, collation=None, hint=None, multi=False, session=None): - if hint: - raise NotImplementedError( - "The hint argument of delete is valid but has not been implemented in " - "mongomock yet" - ) - if collation: - raise_not_implemented( - "collation", - "The collation argument of delete is valid but has not been " - "implemented in mongomock yet", - ) - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - filter = helpers.patch_datetime_awareness_in_document(filter) - if filter is None: - filter = {} - if not isinstance(filter, Mapping): - filter = {"_id": filter} - to_delete = list(self.find(filter)) - deleted_count = 0 - for doc in to_delete: - doc_id = doc["_id"] - if isinstance(doc_id, dict): - doc_id = helpers.hashdict(doc_id) - del self._store[doc_id] - deleted_count += 1 - if not multi: - break - - return { - "connectionId": self.database.client._id, - "n": deleted_count, - "ok": 1.0, - "err": None, - } - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def remove(self, spec_or_id=None, multi=True, **kwargs): - warnings.warn( - "remove is deprecated. Use delete_one or delete_many " "instead.", - DeprecationWarning, - stacklevel=2, - ) - validate_write_concern_params(**kwargs) - return self._delete(spec_or_id, multi=multi) - - def count(self, filter=None, **kwargs): - warnings.warn( - "count is deprecated. Use estimated_document_count or " - "count_documents instead. Please note that $where must be replaced " - "by $expr, $near must be replaced by $geoWithin with $center, and " - "$nearSphere must be replaced by $geoWithin with $centerSphere", - DeprecationWarning, - stacklevel=2, - ) - if kwargs.pop("session", None): - raise_not_implemented( - "session", "Mongomock does not handle sessions yet" - ) - if filter is None: - return len(self._store) - spec = helpers.patch_datetime_awareness_in_document(filter) - return len(list(self._iter_documents(spec))) - - def count_documents(self, filter, **kwargs): - if kwargs.pop("collation", None): - raise_not_implemented( - "collation", - "The collation argument of count_documents is valid but has not been " - "implemented in mongomock yet", - ) - if kwargs.pop("session", None): - raise_not_implemented("session", "Mongomock does not handle sessions yet") - skip = kwargs.pop("skip", 0) - if "limit" in kwargs: - limit = kwargs.pop("limit") - if not isinstance(limit, (int, float)): - raise OperationFailure("the limit must be specified as a number") - if limit <= 0: - raise OperationFailure("the limit must be positive") - limit = math.floor(limit) - else: - limit = None - unknown_kwargs = set(kwargs) - {"maxTimeMS", "hint"} - if unknown_kwargs: - raise OperationFailure("unrecognized field '%s'" % unknown_kwargs.pop()) - - spec = helpers.patch_datetime_awareness_in_document(filter) - doc_num = len(list(self._iter_documents(spec))) - count = max(doc_num - skip, 0) - return count if limit is None else min(count, limit) - - def estimated_document_count(self, **kwargs): - if kwargs.pop("session", None): - raise ConfigurationError( - "estimated_document_count does not support sessions" - ) - unknown_kwargs = set(kwargs) - {"limit", "maxTimeMS", "hint"} - if self.database.client.server_info()["versionArray"] < [5]: - unknown_kwargs.discard("skip") - if unknown_kwargs: - raise OperationFailure( - "BSON field 'count.%s' is an unknown field." % list(unknown_kwargs)[0] - ) - return self.count_documents({}, **kwargs) - - def drop(self, session=None): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - self.database.drop_collection(self.name) - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def ensure_index(self, key_or_list, cache_for=300, **kwargs): - return self.create_index(key_or_list, cache_for, **kwargs) - - def create_index(self, key_or_list, cache_for=300, session=None, **kwargs): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - index_list = helpers.create_index_list(key_or_list) - is_unique = kwargs.pop("unique", False) - is_sparse = kwargs.pop("sparse", False) - - index_name = kwargs.pop("name", helpers.gen_index_name(index_list)) - index_dict = {"key": index_list} - if is_sparse: - index_dict["sparse"] = True - if is_unique: - index_dict["unique"] = True - if "expireAfterSeconds" in kwargs and kwargs["expireAfterSeconds"] is not None: - index_dict["expireAfterSeconds"] = kwargs.pop("expireAfterSeconds") - if ( - "partialFilterExpression" in kwargs - and kwargs["partialFilterExpression"] is not None - ): - index_dict["partialFilterExpression"] = kwargs.pop( - "partialFilterExpression" - ) - - existing_index = self._store.indexes.get(index_name) - if existing_index and index_dict != existing_index: - raise OperationFailure( - "Index with name: %s already exists with different options" % index_name - ) - - # Check that documents already verify the uniquess of this new index. - if is_unique: - indexed = set() - indexed_list = [] - documents_gen = self._store.documents - for doc in documents_gen: - index = [] - for key, unused_order in index_list: - try: - index.append(helpers.get_value_by_dot(doc, key)) - except KeyError: - if is_sparse: - continue - index.append(None) - if is_sparse and not index: - continue - index = tuple(index) - try: - if index in indexed: - # Need to throw this inside the generator so it can clean the locks - documents_gen.throw( - DuplicateKeyError("E11000 Duplicate Key Error", 11000), - None, - None, - ) - indexed.add(index) - except TypeError as err: - # index is not hashable. - if index in indexed_list: - documents_gen.throw( - DuplicateKeyError("E11000 Duplicate Key Error", 11000), - None, - err, - ) - indexed_list.append(index) - - self._store.create_index(index_name, index_dict) - - return index_name - - def create_indexes(self, indexes, session=None): - for index in indexes: - if not isinstance(index, IndexModel): - raise TypeError( - "%s is not an instance of pymongo.operations.IndexModel" % index - ) - - return [ - self.create_index( - index.document["key"].items(), - session=session, - expireAfterSeconds=index.document.get("expireAfterSeconds"), - unique=index.document.get("unique", False), - sparse=index.document.get("sparse", False), - name=index.document.get("name"), - ) - for index in indexes - ] - - def drop_index(self, index_or_name, session=None): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - if isinstance(index_or_name, list): - name = helpers.gen_index_name(index_or_name) - else: - name = index_or_name - try: - self._store.drop_index(name) - except KeyError as err: - raise OperationFailure("index not found with name [%s]" % name) from err - - def drop_indexes(self, session=None): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - self._store.indexes = {} - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def reindex(self, session=None): - if session: - raise_not_implemented( - "session", "Mongomock does not handle sessions yet" - ) - - def _list_all_indexes(self): - if not self._store.is_created: - return - yield "_id_", {"key": [("_id", 1)]} - for name, information in self._store.indexes.items(): - yield name, information - - def list_indexes(self, session=None): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - for name, information in self._list_all_indexes(): - yield dict(information, key=dict(information["key"]), name=name, v=2) - - def index_information(self, session=None): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - return {name: dict(index, v=2) for name, index in self._list_all_indexes()} - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def map_reduce( - self, - map_func, - reduce_func, - out, - full_response=False, - query=None, - limit=0, - session=None, - ): - if execjs is None: - raise NotImplementedError( - "PyExecJS is required in order to run Map-Reduce. " - "Use 'pip install pyexecjs pymongo' to support Map-Reduce mock." - ) - if session: - raise_not_implemented( - "session", "Mongomock does not handle sessions yet" - ) - if limit == 0: - limit = None - start_time = time.perf_counter() - out_collection = None - reduced_rows = None - full_dict = { - "counts": {"input": 0, "reduce": 0, "emit": 0, "output": 0}, - "timeMillis": 0, - "ok": 1.0, - "result": None, - } - map_ctx = execjs.compile( - """ - function doMap(fnc, docList) { - var mappedDict = {}; - function emit(key, val) { - if (key['$oid']) { - mapped_key = '$oid' + key['$oid']; - } - else { - mapped_key = key; - } - if(!mappedDict[mapped_key]) { - mappedDict[mapped_key] = []; - } - mappedDict[mapped_key].push(val); - } - mapper = eval('('+fnc+')'); - var mappedList = new Array(); - for(var i=0; i 1: - full_dict["counts"]["reduce"] += 1 - full_dict["counts"]["output"] = len(reduced_rows) - if isinstance(out, (str, bytes)): - out_collection = getattr(self.database, out) - out_collection.drop() - out_collection.insert(reduced_rows) - ret_val = out_collection - full_dict["result"] = out - elif isinstance(out, SON) and out.get("replace") and out.get("db"): - # Must be of the format SON([('replace','results'),('db','outdb')]) - out_db = getattr(self.database._client, out["db"]) - out_collection = getattr(out_db, out["replace"]) - out_collection.insert(reduced_rows) - ret_val = out_collection - full_dict["result"] = {"db": out["db"], "collection": out["replace"]} - elif isinstance(out, dict) and out.get("inline"): - ret_val = reduced_rows - full_dict["result"] = reduced_rows - else: - raise TypeError("'out' must be an instance of string, dict or bson.SON") - time_millis = (time.perf_counter() - start_time) * 1000 - full_dict["timeMillis"] = int(round(time_millis)) - if full_response: - ret_val = full_dict - return ret_val - - def inline_map_reduce( - self, - map_func, - reduce_func, - full_response=False, - query=None, - limit=0, - session=None, - ): - return self.map_reduce( - map_func, - reduce_func, - {"inline": 1}, - full_response, - query, - limit, - session=session, - ) - - def distinct(self, key, filter=None, session=None): - if session: - raise_not_implemented("session", "Mongomock does not handle sessions yet") - return self.find(filter).distinct(key) - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def group(self, key, condition, initial, reduce, finalize=None): - if helpers.PYMONGO_VERSION >= version.parse("3.6"): - raise OperationFailure("no such command: 'group'") - if execjs is None: - raise NotImplementedError( - "PyExecJS is required in order to use group. " - "Use 'pip install pyexecjs pymongo' to support group mock." - ) - reduce_ctx = execjs.compile( - """ - function doReduce(fnc, docList) { - reducer = eval('('+fnc+')'); - for(var i=0, l=docList.length; i 0: - doc += [None] * len_diff - doc[field_index] = value - - -def _unset_updater(doc, field_name, value, codec_options=None): - if isinstance(doc, dict): - doc.pop(field_name, None) - - -def _inc_updater(doc, field_name, value, codec_options=None): - if isinstance(doc, dict): - doc[field_name] = doc.get(field_name, 0) + value - - if isinstance(doc, list): - field_index = int(field_name) - if field_index < 0: - raise WriteError("Negative index provided") - try: - doc[field_index] += value - except IndexError: - len_diff = field_index - (len(doc) - 1) - doc += [None] * len_diff - doc[field_index] = value - - -def _max_updater(doc, field_name, value, codec_options=None): - if isinstance(doc, dict): - doc[field_name] = max(doc.get(field_name, value), value) - - -def _min_updater(doc, field_name, value, codec_options=None): - if isinstance(doc, dict): - doc[field_name] = min(doc.get(field_name, value), value) - - -def _pop_updater(doc, field_name, value, codec_options=None): - if value not in {1, -1}: - raise WriteError("$pop expects 1 or -1, found: " + str(value)) - - if isinstance(doc, dict): - if isinstance(doc[field_name], (tuple, list)): - doc[field_name] = list(doc[field_name]) - _pop_from_list(doc[field_name], value) - return - raise WriteError("Path contains element of non-array type") - - if isinstance(doc, list): - field_index = int(field_name) - if field_index < 0: - raise WriteError("Negative index provided") - if field_index >= len(doc): - return - _pop_from_list(doc[field_index], value) - - -def _pop_from_list(list_instance, mongo_pop_value, codec_options=None): - if not list_instance: - return - - if mongo_pop_value == 1: - list_instance.pop() - elif mongo_pop_value == -1: - list_instance.pop(0) - - -def _current_date_updater(doc, field_name, value, codec_options=None): - if isinstance(doc, dict): - if value == {"$type": "timestamp"}: - # TODO(juannyg): get_current_timestamp should also be using helpers utcnow, - # as it currently using time.time internally - doc[field_name] = helpers.get_current_timestamp() - else: - doc[field_name] = utcnow() - - -_updaters = { - "$set": _set_updater, - "$unset": _unset_updater, - "$inc": _inc_updater, - "$max": _max_updater, - "$min": _min_updater, - "$pop": _pop_updater, -} diff --git a/packages/syft/tests/mongomock/command_cursor.py b/packages/syft/tests/mongomock/command_cursor.py deleted file mode 100644 index 025bb836e24..00000000000 --- a/packages/syft/tests/mongomock/command_cursor.py +++ /dev/null @@ -1,37 +0,0 @@ -class CommandCursor(object): - def __init__(self, collection, curser_info=None, address=None, retrieved=0): - self._collection = iter(collection) - self._id = None - self._address = address - self._data = {} - self._retrieved = retrieved - self._batch_size = 0 - self._killed = self._id == 0 - - @property - def address(self): - return self._address - - def close(self): - pass - - def batch_size(self, batch_size): - return self - - @property - def alive(self): - return True - - def __iter__(self): - return self - - def next(self): - return next(self._collection) - - __next__ = next - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - return diff --git a/packages/syft/tests/mongomock/database.py b/packages/syft/tests/mongomock/database.py deleted file mode 100644 index 3b1a7e59f70..00000000000 --- a/packages/syft/tests/mongomock/database.py +++ /dev/null @@ -1,301 +0,0 @@ -# stdlib -import warnings - -# third party -from packaging import version - -# relative -from . import CollectionInvalid -from . import InvalidName -from . import OperationFailure -from . import codec_options as mongomock_codec_options -from . import helpers -from . import read_preferences -from . import store -from .collection import Collection -from .filtering import filter_applies - -try: - # third party - from pymongo import ReadPreference - - _READ_PREFERENCE_PRIMARY = ReadPreference.PRIMARY -except ImportError: - _READ_PREFERENCE_PRIMARY = read_preferences.PRIMARY - -try: - # third party - from pymongo.read_concern import ReadConcern -except ImportError: - # relative - from .read_concern import ReadConcern - -_LIST_COLLECTION_FILTER_ALLOWED_OPERATORS = frozenset(["$regex", "$eq", "$ne"]) - - -def _verify_list_collection_supported_op(keys): - if set(keys) - _LIST_COLLECTION_FILTER_ALLOWED_OPERATORS: - raise NotImplementedError( - "list collection names filter operator {0} is not implemented yet in mongomock " - "allowed operators are {1}".format( - keys, _LIST_COLLECTION_FILTER_ALLOWED_OPERATORS - ) - ) - - -class Database(object): - def __init__( - self, - client, - name, - _store, - read_preference=None, - codec_options=None, - read_concern=None, - ): - self.name = name - self._client = client - self._collection_accesses = {} - self._store = _store or store.DatabaseStore() - self._read_preference = read_preference or _READ_PREFERENCE_PRIMARY - mongomock_codec_options.is_supported(codec_options) - self._codec_options = codec_options or mongomock_codec_options.CodecOptions() - if read_concern and not isinstance(read_concern, ReadConcern): - raise TypeError( - "read_concern must be an instance of pymongo.read_concern.ReadConcern" - ) - self._read_concern = read_concern or ReadConcern() - - def __getitem__(self, coll_name): - return self.get_collection(coll_name) - - def __getattr__(self, attr): - if attr.startswith("_"): - raise AttributeError( - "%s has no attribute '%s'. To access the %s collection, use database['%s']." - % (self.__class__.__name__, attr, attr, attr) - ) - return self[attr] - - def __repr__(self): - return "Database({0}, '{1}')".format(self._client, self.name) - - def __eq__(self, other): - if isinstance(other, self.__class__): - return self._client == other._client and self.name == other.name - return NotImplemented - - if helpers.PYMONGO_VERSION >= version.parse("3.12"): - - def __hash__(self): - return hash((self._client, self.name)) - - @property - def client(self): - return self._client - - @property - def read_preference(self): - return self._read_preference - - @property - def codec_options(self): - return self._codec_options - - @property - def read_concern(self): - return self._read_concern - - def _get_created_collections(self): - return self._store.list_created_collection_names() - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def collection_names(self, include_system_collections=True, session=None): - warnings.warn( - "collection_names is deprecated. Use list_collection_names instead." - ) - if include_system_collections: - return list(self._get_created_collections()) - return self.list_collection_names(session=session) - - def list_collections(self, filter=None, session=None, nameOnly=False): - raise NotImplementedError( - "list_collections is a valid method of Database but has not been implemented in " - "mongomock yet." - ) - - def list_collection_names(self, filter=None, session=None): - """filter: only name field type with eq,ne or regex operator - - session: not supported - for supported operator please see _LIST_COLLECTION_FILTER_ALLOWED_OPERATORS - """ - field_name = "name" - - if session: - raise NotImplementedError("Mongomock does not handle sessions yet") - - if filter: - if not filter.get("name"): - raise NotImplementedError( - "list collection {0} might be valid but is not " - "implemented yet in mongomock".format(filter) - ) - - filter = ( - {field_name: {"$eq": filter.get(field_name)}} - if isinstance(filter.get(field_name), str) - else filter - ) - - _verify_list_collection_supported_op(filter.get(field_name).keys()) - - return [ - name - for name in list(self._store._collections) - if filter_applies(filter, {field_name: name}) - and not name.startswith("system.") - ] - - return [ - name - for name in self._get_created_collections() - if not name.startswith("system.") - ] - - def get_collection( - self, - name, - codec_options=None, - read_preference=None, - write_concern=None, - read_concern=None, - ): - if read_preference is not None: - read_preferences.ensure_read_preference_type( - "read_preference", read_preference - ) - mongomock_codec_options.is_supported(codec_options) - try: - return self._collection_accesses[name].with_options( - codec_options=codec_options or self._codec_options, - read_preference=read_preference or self.read_preference, - read_concern=read_concern, - write_concern=write_concern, - ) - except KeyError: - self._ensure_valid_collection_name(name) - collection = self._collection_accesses[name] = Collection( - self, - name=name, - read_concern=read_concern, - write_concern=write_concern, - read_preference=read_preference or self.read_preference, - codec_options=codec_options or self._codec_options, - _db_store=self._store, - ) - return collection - - def drop_collection(self, name_or_collection, session=None): - if session: - raise NotImplementedError("Mongomock does not handle sessions yet") - if isinstance(name_or_collection, Collection): - name_or_collection._store.drop() - else: - self._store[name_or_collection].drop() - - def _ensure_valid_collection_name(self, name): - # These are the same checks that are done in pymongo. - if not isinstance(name, str): - raise TypeError("name must be an instance of str") - if not name or ".." in name: - raise InvalidName("collection names cannot be empty") - if name[0] == "." or name[-1] == ".": - raise InvalidName("collection names must not start or end with '.'") - if "$" in name: - raise InvalidName("collection names must not contain '$'") - if "\x00" in name: - raise InvalidName("collection names must not contain the null character") - - def create_collection(self, name, **kwargs): - self._ensure_valid_collection_name(name) - if name in self.list_collection_names(): - raise CollectionInvalid("collection %s already exists" % name) - - if kwargs: - raise NotImplementedError("Special options not supported") - - self._store.create_collection(name) - return self[name] - - def rename_collection(self, name, new_name, dropTarget=False): - """Changes the name of an existing collection.""" - self._ensure_valid_collection_name(new_name) - - # Reference for server implementation: - # https://docs.mongodb.com/manual/reference/command/renameCollection/ - if not self._store[name].is_created: - raise OperationFailure( - 'The collection "{0}" does not exist.'.format(name), 10026 - ) - if new_name in self._store: - if dropTarget: - self.drop_collection(new_name) - else: - raise OperationFailure( - 'The target collection "{0}" already exists'.format(new_name), 10027 - ) - self._store.rename(name, new_name) - return {"ok": 1} - - def dereference(self, dbref, session=None): - if session: - raise NotImplementedError("Mongomock does not handle sessions yet") - - if not hasattr(dbref, "collection") or not hasattr(dbref, "id"): - raise TypeError("cannot dereference a %s" % type(dbref)) - if dbref.database is not None and dbref.database != self.name: - raise ValueError( - "trying to dereference a DBRef that points to " - "another database (%r not %r)" % (dbref.database, self.name) - ) - return self[dbref.collection].find_one({"_id": dbref.id}) - - def command(self, command, **unused_kwargs): - if isinstance(command, str): - command = {command: 1} - if "ping" in command: - return {"ok": 1.0} - # TODO(pascal): Differentiate NotImplementedError for valid commands - # and OperationFailure if the command is not valid. - raise NotImplementedError( - "command is a valid Database method but is not implemented in Mongomock yet" - ) - - def with_options( - self, - codec_options=None, - read_preference=None, - write_concern=None, - read_concern=None, - ): - mongomock_codec_options.is_supported(codec_options) - - if write_concern: - raise NotImplementedError( - "write_concern is a valid parameter for with_options but is not implemented yet in " - "mongomock" - ) - - if read_preference is None or read_preference == self._read_preference: - return self - - return Database( - self._client, - self.name, - self._store, - read_preference=read_preference or self._read_preference, - codec_options=codec_options or self._codec_options, - read_concern=read_concern or self._read_concern, - ) diff --git a/packages/syft/tests/mongomock/filtering.py b/packages/syft/tests/mongomock/filtering.py deleted file mode 100644 index 345b94c7b88..00000000000 --- a/packages/syft/tests/mongomock/filtering.py +++ /dev/null @@ -1,601 +0,0 @@ -# stdlib -from datetime import datetime -import itertools -import numbers -import operator -import re -import uuid - -# relative -from . import OperationFailure -from .helpers import ObjectId -from .helpers import RE_TYPE - -try: - # stdlib - from types import NoneType -except ImportError: - NoneType = type(None) - -try: - # third party - from bson import DBRef - from bson import Regex - - _RE_TYPES = (RE_TYPE, Regex) -except ImportError: - DBRef = None - _RE_TYPES = (RE_TYPE,) - -try: - # third party - from bson.decimal128 import Decimal128 -except ImportError: - Decimal128 = None - -_TOP_LEVEL_OPERATORS = {"$expr", "$text", "$where", "$jsonSchema"} - - -_NOT_IMPLEMENTED_OPERATORS = { - "$bitsAllClear", - "$bitsAllSet", - "$bitsAnyClear", - "$bitsAnySet", - "$geoIntersects", - "$geoWithin", - "$maxDistance", - "$minDistance", - "$near", - "$nearSphere", -} - - -def filter_applies(search_filter, document): - """Applies given filter - - This function implements MongoDB's matching strategy over documents in the find() method - and other related scenarios (like $elemMatch) - """ - return _filterer_inst.apply(search_filter, document) - - -class _Filterer(object): - """An object to help applying a filter, using the MongoDB query language.""" - - # This is populated using register_parse_expression further down. - parse_expression = [] - - def __init__(self): - self._operator_map = dict( - { - "$eq": _list_expand(operator_eq), - "$ne": _list_expand( - lambda dv, sv: not operator_eq(dv, sv), negative=True - ), - "$all": self._all_op, - "$in": _in_op, - "$nin": lambda dv, sv: not _in_op(dv, sv), - "$exists": lambda dv, sv: bool(sv) == (dv is not None), - "$regex": _not_None_and(_regex), - "$elemMatch": self._elem_match_op, - "$size": _size_op, - "$type": _type_op, - }, - **{ - key: _not_None_and(_list_expand(_compare_objects(op))) - for key, op in SORTING_OPERATOR_MAP.items() - }, - ) - - def apply(self, search_filter, document): - if not isinstance(search_filter, dict): - raise OperationFailure( - "the match filter must be an expression in an object" - ) - - for key, search in search_filter.items(): - # Top level operators. - if key == "$comment": - continue - if key in LOGICAL_OPERATOR_MAP: - if not search: - raise OperationFailure( - "BadValue $and/$or/$nor must be a nonempty array" - ) - if not LOGICAL_OPERATOR_MAP[key](document, search, self.apply): - return False - continue - if key == "$expr": - parse_expression = self.parse_expression[0] - if not parse_expression(search, document, ignore_missing_keys=True): - return False - continue - if key in _TOP_LEVEL_OPERATORS: - raise NotImplementedError( - "The {} operator is not implemented in mongomock yet".format(key) - ) - if key.startswith("$"): - raise OperationFailure("unknown top level operator: " + key) - - is_match = False - - is_checking_negative_match = isinstance(search, dict) and { - "$ne", - "$nin", - } & set(search.keys()) - is_checking_positive_match = not isinstance(search, dict) or ( - set(search.keys()) - {"$ne", "$nin"} - ) - has_candidates = False - - if search == {"$exists": False} and not iter_key_candidates(key, document): - continue - - if isinstance(search, dict) and "$all" in search: - if not self._all_op(iter_key_candidates(key, document), search["$all"]): - return False - # if there are no query operators then continue - if len(search) == 1: - continue - - for doc_val in iter_key_candidates(key, document): - has_candidates |= doc_val is not None - is_ops_filter = ( - search - and isinstance(search, dict) - and all(key.startswith("$") for key in search.keys()) - ) - if is_ops_filter: - if "$options" in search and "$regex" in search: - search = _combine_regex_options(search) - unknown_operators = set(search) - set(self._operator_map) - {"$not"} - if unknown_operators: - not_implemented_operators = ( - unknown_operators & _NOT_IMPLEMENTED_OPERATORS - ) - if not_implemented_operators: - raise NotImplementedError( - "'%s' is a valid operation but it is not supported by Mongomock " - "yet." % list(not_implemented_operators)[0] - ) - raise OperationFailure( - "unknown operator: " + list(unknown_operators)[0] - ) - is_match = ( - all( - operator_string in self._operator_map - and self._operator_map[operator_string](doc_val, search_val) - or operator_string == "$not" - and self._not_op(document, key, search_val) - for operator_string, search_val in search.items() - ) - and search - ) - elif isinstance(search, _RE_TYPES) and isinstance(doc_val, (str, list)): - is_match = _regex(doc_val, search) - elif key in LOGICAL_OPERATOR_MAP: - if not search: - raise OperationFailure( - "BadValue $and/$or/$nor must be a nonempty array" - ) - is_match = LOGICAL_OPERATOR_MAP[key](document, search, self.apply) - elif isinstance(doc_val, (list, tuple)): - is_match = search in doc_val or search == doc_val - if isinstance(search, ObjectId): - is_match |= str(search) in doc_val - else: - is_match = (doc_val == search) or ( - search is None and doc_val is None - ) - - # When checking negative match, all the elements should match. - if is_checking_negative_match and not is_match: - return False - - # If not checking negative matches, the first match is enouh for this criteria. - if is_match and not is_checking_negative_match: - break - - if not is_match and (has_candidates or is_checking_positive_match): - return False - - return True - - def _not_op(self, d, k, s): - if isinstance(s, dict): - for key in s.keys(): - if key not in self._operator_map and key not in LOGICAL_OPERATOR_MAP: - raise OperationFailure("unknown operator: %s" % key) - elif isinstance(s, _RE_TYPES): - pass - else: - raise OperationFailure("$not needs a regex or a document") - return not self.apply({k: s}, d) - - def _elem_match_op(self, doc_val, query): - if not isinstance(doc_val, list): - return False - if not isinstance(query, dict): - raise OperationFailure("$elemMatch needs an Object") - for item in doc_val: - try: - if self.apply(query, item): - return True - except OperationFailure: - if self.apply({"field": query}, {"field": item}): - return True - return False - - def _all_op(self, doc_val, search_val): - if isinstance(doc_val, list) and doc_val and isinstance(doc_val[0], list): - doc_val = list(itertools.chain.from_iterable(doc_val)) - dv = _force_list(doc_val) - matches = [] - for x in search_val: - if isinstance(x, dict) and "$elemMatch" in x: - matches.append(self._elem_match_op(doc_val, x["$elemMatch"])) - else: - matches.append(x in dv) - return all(matches) - - -def iter_key_candidates(key, doc): - """Get possible subdocuments or lists that are referred to by the key in question - - Returns the appropriate nested value if the key includes dot notation. - """ - if not key: - return [doc] - - if doc is None: - return () - - if isinstance(doc, list): - return _iter_key_candidates_sublist(key, doc) - - if not isinstance(doc, dict): - return () - - key_parts = key.split(".") - if len(key_parts) == 1: - return [doc.get(key, None)] - - sub_key = ".".join(key_parts[1:]) - sub_doc = doc.get(key_parts[0], {}) - return iter_key_candidates(sub_key, sub_doc) - - -def _iter_key_candidates_sublist(key, doc): - """Iterates of candidates - - :param doc: a list to be searched for candidates for our key - :param key: the string key to be matched - """ - key_parts = key.split(".") - sub_key = key_parts.pop(0) - key_remainder = ".".join(key_parts) - try: - sub_key_int = int(sub_key) - except ValueError: - sub_key_int = None - - if sub_key_int is None: - # subkey is not an integer... - ret = [] - for sub_doc in doc: - if isinstance(sub_doc, dict): - if sub_key in sub_doc: - ret.extend(iter_key_candidates(key_remainder, sub_doc[sub_key])) - else: - ret.append(None) - return ret - - # subkey is an index - if sub_key_int >= len(doc): - return () # dead end - sub_doc = doc[sub_key_int] - if key_parts: - return iter_key_candidates(".".join(key_parts), sub_doc) - return [sub_doc] - - -def _force_list(v): - return v if isinstance(v, (list, tuple)) else [v] - - -def _in_op(doc_val, search_val): - if not isinstance(search_val, (list, tuple)): - raise OperationFailure("$in needs an array") - if doc_val is None and None in search_val: - return True - doc_val = _force_list(doc_val) - is_regex_list = [isinstance(x, _RE_TYPES) for x in search_val] - if not any(is_regex_list): - return any(x in search_val for x in doc_val) - for x, is_regex in zip(search_val, is_regex_list): - if (is_regex and _regex(doc_val, x)) or (x in doc_val): - return True - return False - - -def _not_None_and(f): - """wrap an operator to return False if the first arg is None""" - return lambda v, l: v is not None and f(v, l) - - -def _compare_objects(op): - """Wrap an operator to also compare objects following BSON comparison. - - See https://docs.mongodb.com/manual/reference/bson-type-comparison-order/#objects - """ - - def _wrapped(a, b): - # Do not compare uncomparable types, see Type Bracketing: - # https://docs.mongodb.com/manual/reference/method/db.collection.find/#type-bracketing - return bson_compare(op, a, b, can_compare_types=False) - - return _wrapped - - -def bson_compare(op, a, b, can_compare_types=True): - """Compare two elements using BSON comparison. - - Args: - op: the basic operation to compare (e.g. operator.lt, operator.ge). - a: the first operand - b: the second operand - can_compare_types: if True, according to BSON's definition order - between types is used, otherwise always return False when types are - different. - """ - a_type = _get_compare_type(a) - b_type = _get_compare_type(b) - if a_type != b_type: - return can_compare_types and op(a_type, b_type) - - # Compare DBRefs as dicts - if type(a).__name__ == "DBRef" and hasattr(a, "as_doc"): - a = a.as_doc() - if type(b).__name__ == "DBRef" and hasattr(b, "as_doc"): - b = b.as_doc() - - if isinstance(a, dict): - # MongoDb server compares the type before comparing the keys - # https://github.com/mongodb/mongo/blob/f10f214/src/mongo/bson/bsonelement.cpp#L516 - # even though the documentation does not say anything about that. - a = [(_get_compare_type(v), k, v) for k, v in a.items()] - b = [(_get_compare_type(v), k, v) for k, v in b.items()] - - if isinstance(a, (tuple, list)): - for item_a, item_b in zip(a, b): - if item_a != item_b: - return bson_compare(op, item_a, item_b) - return bson_compare(op, len(a), len(b)) - - if isinstance(a, NoneType): - return op(0, 0) - - # bson handles bytes as binary in python3+: - # https://api.mongodb.com/python/current/api/bson/index.html - if isinstance(a, bytes): - # Performs the same operation as described by: - # https://docs.mongodb.com/manual/reference/bson-type-comparison-order/#bindata - if len(a) != len(b): - return op(len(a), len(b)) - # bytes is always treated as subtype 0 by the bson library - return op(a, b) - - -def _get_compare_type(val): - """Get a number representing the base type of the value used for comparison. - - See https://docs.mongodb.com/manual/reference/bson-type-comparison-order/ - also https://github.com/mongodb/mongo/blob/46b28bb/src/mongo/bson/bsontypes.h#L175 - for canonical values. - """ - if isinstance(val, NoneType): - return 5 - if isinstance(val, bool): - return 40 - if isinstance(val, numbers.Number): - return 10 - if isinstance(val, str): - return 15 - if isinstance(val, dict): - return 20 - if isinstance(val, (tuple, list)): - return 25 - if isinstance(val, uuid.UUID): - return 30 - if isinstance(val, bytes): - return 30 - if isinstance(val, ObjectId): - return 35 - if isinstance(val, datetime): - return 45 - if isinstance(val, _RE_TYPES): - return 50 - if DBRef and isinstance(val, DBRef): - # According to the C++ code, this should be 55 but apparently sending a DBRef through - # pymongo is stored as a dict. - return 20 - return 0 - - -def _regex(doc_val, regex): - if not (isinstance(doc_val, (str, list)) or isinstance(doc_val, RE_TYPE)): - return False - if isinstance(regex, str): - regex = re.compile(regex) - if not isinstance(regex, RE_TYPE): - # bson.Regex - regex = regex.try_compile() - return any( - regex.search(item) for item in _force_list(doc_val) if isinstance(item, str) - ) - - -def _size_op(doc_val, search_val): - if isinstance(doc_val, (list, tuple, dict)): - return search_val == len(doc_val) - return search_val == 1 if doc_val and doc_val is not None else 0 - - -def _list_expand(f, negative=False): - def func(doc_val, search_val): - if isinstance(doc_val, (list, tuple)) and not isinstance( - search_val, (list, tuple) - ): - if negative: - return all(f(val, search_val) for val in doc_val) - return any(f(val, search_val) for val in doc_val) - return f(doc_val, search_val) - - return func - - -def _type_op(doc_val, search_val, in_array=False): - if search_val not in TYPE_MAP: - raise OperationFailure("%r is not a valid $type" % search_val) - elif TYPE_MAP[search_val] is None: - raise NotImplementedError( - "%s is a valid $type but not implemented" % search_val - ) - if TYPE_MAP[search_val](doc_val): - return True - if isinstance(doc_val, (list, tuple)) and not in_array: - return any(_type_op(val, search_val, in_array=True) for val in doc_val) - return False - - -def _combine_regex_options(search): - if not isinstance(search["$options"], str): - raise OperationFailure("$options has to be a string") - - options = None - for option in search["$options"]: - if option not in "imxs": - continue - re_option = getattr(re, option.upper()) - if options is None: - options = re_option - else: - options |= re_option - - search_copy = dict(search) - del search_copy["$options"] - - if options is None: - return search_copy - - if isinstance(search["$regex"], _RE_TYPES): - if isinstance(search["$regex"], RE_TYPE): - search_copy["$regex"] = re.compile( - search["$regex"].pattern, search["$regex"].flags | options - ) - else: - # bson.Regex - regex = search["$regex"] - search_copy["$regex"] = regex.__class__( - regex.pattern, regex.flags | options - ) - else: - search_copy["$regex"] = re.compile(search["$regex"], options) - return search_copy - - -def operator_eq(doc_val, search_val): - if doc_val is None and search_val is None: - return True - return operator.eq(doc_val, search_val) - - -SORTING_OPERATOR_MAP = { - "$gt": operator.gt, - "$gte": operator.ge, - "$lt": operator.lt, - "$lte": operator.le, -} - - -LOGICAL_OPERATOR_MAP = { - "$or": lambda d, subq, filter_func: any(filter_func(q, d) for q in subq), - "$and": lambda d, subq, filter_func: all(filter_func(q, d) for q in subq), - "$nor": lambda d, subq, filter_func: all(not filter_func(q, d) for q in subq), - "$not": lambda d, subq, filter_func: (not filter_func(q, d) for q in subq), -} - - -TYPE_MAP = { - "double": lambda v: isinstance(v, float), - "string": lambda v: isinstance(v, str), - "object": lambda v: isinstance(v, dict), - "array": lambda v: isinstance(v, list), - "binData": lambda v: isinstance(v, bytes), - "undefined": None, - "objectId": lambda v: isinstance(v, ObjectId), - "bool": lambda v: isinstance(v, bool), - "date": lambda v: isinstance(v, datetime), - "null": None, - "regex": None, - "dbPointer": None, - "javascript": None, - "symbol": None, - "javascriptWithScope": None, - "int": lambda v: ( - isinstance(v, int) and not isinstance(v, bool) and v.bit_length() <= 32 - ), - "timestamp": None, - "long": lambda v: ( - isinstance(v, int) and not isinstance(v, bool) and v.bit_length() > 32 - ), - "decimal": (lambda v: isinstance(v, Decimal128)) if Decimal128 else None, - "number": lambda v: ( - # pylint: disable-next=isinstance-second-argument-not-valid-type - isinstance(v, (int, float) + ((Decimal128,) if Decimal128 else ())) - and not isinstance(v, bool) - ), - "minKey": None, - "maxKey": None, -} - - -def resolve_key(key, doc): - return next(iter(iter_key_candidates(key, doc)), None) - - -def resolve_sort_key(key, doc): - value = resolve_key(key, doc) - # see http://docs.mongodb.org/manual/reference/method/cursor.sort/#ascending-descending-sort - if value is None: - return 1, BsonComparable(None) - - # List or tuples are sorted solely by their first value. - if isinstance(value, (tuple, list)): - if not value: - return 0, BsonComparable(None) - return 1, BsonComparable(value[0]) - - return 1, BsonComparable(value) - - -class BsonComparable(object): - """Wraps a value in an BSON like object that can be compared one to another.""" - - def __init__(self, obj): - self.obj = obj - - def __lt__(self, other): - return bson_compare(operator.lt, self.obj, other.obj) - - -_filterer_inst = _Filterer() - - -# Developer note: to avoid a cross-modules dependency (filtering requires aggregation, that requires -# filtering), the aggregation module needs to register its parse_expression function here. -def register_parse_expression(parse_expression): - """Register the parse_expression function from the aggregate module.""" - - del _Filterer.parse_expression[:] - _Filterer.parse_expression.append(parse_expression) diff --git a/packages/syft/tests/mongomock/gridfs.py b/packages/syft/tests/mongomock/gridfs.py deleted file mode 100644 index 13a59999855..00000000000 --- a/packages/syft/tests/mongomock/gridfs.py +++ /dev/null @@ -1,68 +0,0 @@ -# stdlib -from unittest import mock - -# relative -from . import Collection as MongoMockCollection -from . import Database as MongoMockDatabase -from ..collection import Cursor as MongoMockCursor - -try: - # third party - from gridfs.grid_file import GridOut as PyMongoGridOut - from gridfs.grid_file import GridOutCursor as PyMongoGridOutCursor - from pymongo.collection import Collection as PyMongoCollection - from pymongo.database import Database as PyMongoDatabase - - _HAVE_PYMONGO = True -except ImportError: - _HAVE_PYMONGO = False - - -# This is a copy of GridOutCursor but with a different base. Note that we -# need both classes as one might want to access both mongomock and real -# MongoDb. -class _MongoMockGridOutCursor(MongoMockCursor): - def __init__(self, collection, *args, **kwargs): - self.__root_collection = collection - super(_MongoMockGridOutCursor, self).__init__(collection.files, *args, **kwargs) - - def next(self): - next_file = super(_MongoMockGridOutCursor, self).next() - return PyMongoGridOut( - self.__root_collection, file_document=next_file, session=self.session - ) - - __next__ = next - - def add_option(self, *args, **kwargs): - raise NotImplementedError() - - def remove_option(self, *args, **kwargs): - raise NotImplementedError() - - def _clone_base(self, session): - return _MongoMockGridOutCursor(self.__root_collection, session=session) - - -def _create_grid_out_cursor(collection, *args, **kwargs): - if isinstance(collection, MongoMockCollection): - return _MongoMockGridOutCursor(collection, *args, **kwargs) - return PyMongoGridOutCursor(collection, *args, **kwargs) - - -def enable_gridfs_integration(): - """This function enables the use of mongomock Database's and Collection's inside gridfs - - Gridfs library use `isinstance` to make sure the passed elements - are valid `pymongo.Database/Collection` so we monkey patch those types in the gridfs modules - (luckily in the modules they are used, they are only used with isinstance). - """ - - if not _HAVE_PYMONGO: - raise NotImplementedError("gridfs mocking requires pymongo to work") - - mock.patch("gridfs.Database", (PyMongoDatabase, MongoMockDatabase)).start() - mock.patch( - "gridfs.grid_file.Collection", (PyMongoCollection, MongoMockCollection) - ).start() - mock.patch("gridfs.GridOutCursor", _create_grid_out_cursor).start() diff --git a/packages/syft/tests/mongomock/helpers.py b/packages/syft/tests/mongomock/helpers.py deleted file mode 100644 index 13f6892cae5..00000000000 --- a/packages/syft/tests/mongomock/helpers.py +++ /dev/null @@ -1,474 +0,0 @@ -# stdlib -from collections import OrderedDict -from collections import abc -from datetime import datetime -from datetime import timedelta -from datetime import tzinfo -import re -import time -from urllib.parse import unquote_plus -import warnings - -# third party -from packaging import version - -# relative -from . import InvalidURI - -# Get ObjectId from bson if available or import a crafted one. This is not used -# in this module but is made available for callers of this module. -try: - # third party - from bson import ObjectId # pylint: disable=unused-import - from bson import Timestamp - from pymongo import version as pymongo_version - - PYMONGO_VERSION = version.parse(pymongo_version) - HAVE_PYMONGO = True -except ImportError: - from .object_id import ObjectId # noqa - - Timestamp = None - # Default Pymongo version if not present. - PYMONGO_VERSION = version.parse("4.0") - HAVE_PYMONGO = False - -# Cache the RegExp pattern type. -RE_TYPE = type(re.compile("")) -_HOST_MATCH = re.compile(r"^([^@]+@)?([^:]+|\[[^\]]+\])(:([^:]+))?$") -_SIMPLE_HOST_MATCH = re.compile(r"^([^:]+|\[[^\]]+\])(:([^:]+))?$") - -try: - # third party - from bson.tz_util import utc -except ImportError: - - class _FixedOffset(tzinfo): - def __init__(self, offset, name): - self.__offset = timedelta(minutes=offset) - self.__name = name - - def __getinitargs__(self): - return self.__offset, self.__name - - def utcoffset(self, dt): - return self.__offset - - def tzname(self, dt): - return self.__name - - def dst(self, dt): - return timedelta(0) - - utc = _FixedOffset(0, "UTC") - - -ASCENDING = 1 -DESCENDING = -1 - - -def utcnow(): - """Simple wrapper for datetime.utcnow - - This provides a centralized definition of "now" in the mongomock realm, - allowing users to transform the value of "now" to the future or the past, - based on their testing needs. For example: - - ```python - def test_x(self): - with mock.patch("mongomock.utcnow") as mm_utc: - mm_utc = datetime.utcnow() + timedelta(hours=100) - # Test some things "100 hours" in the future - ``` - """ - return datetime.utcnow() - - -def print_deprecation_warning(old_param_name, new_param_name): - warnings.warn( - "'%s' has been deprecated to be in line with pymongo implementation, a new parameter '%s' " - "should be used instead. the old parameter will be kept for backward compatibility " - "purposes." % (old_param_name, new_param_name), - DeprecationWarning, - ) - - -def create_index_list(key_or_list, direction=None): - """Helper to generate a list of (key, direction) pairs. - - It takes such a list, or a single key, or a single key and direction. - """ - if isinstance(key_or_list, str): - return [(key_or_list, direction or ASCENDING)] - if not isinstance(key_or_list, (list, tuple, abc.Iterable)): - raise TypeError( - "if no direction is specified, " "key_or_list must be an instance of list" - ) - return key_or_list - - -def gen_index_name(index_list): - """Generate an index name based on the list of keys with directions.""" - - return "_".join(["%s_%s" % item for item in index_list]) - - -class hashdict(dict): - """hashable dict implementation, suitable for use as a key into other dicts. - - >>> h1 = hashdict({'apples': 1, 'bananas':2}) - >>> h2 = hashdict({'bananas': 3, 'mangoes': 5}) - >>> h1+h2 - hashdict(apples=1, bananas=3, mangoes=5) - >>> d1 = {} - >>> d1[h1] = 'salad' - >>> d1[h1] - 'salad' - >>> d1[h2] - Traceback (most recent call last): - ... - KeyError: hashdict(bananas=3, mangoes=5) - - based on answers from - http://stackoverflow.com/questions/1151658/python-hashable-dicts - """ - - def __key(self): - return frozenset( - ( - k, - ( - hashdict(v) - if isinstance(v, dict) - else tuple(v) - if isinstance(v, list) - else v - ), - ) - for k, v in self.items() - ) - - def __repr__(self): - return "{0}({1})".format( - self.__class__.__name__, - ", ".join( - "{0}={1}".format(str(i[0]), repr(i[1])) for i in sorted(self.__key()) - ), - ) - - def __hash__(self): - return hash(self.__key()) - - def __setitem__(self, key, value): - raise TypeError( - "{0} does not support item assignment".format(self.__class__.__name__) - ) - - def __delitem__(self, key): - raise TypeError( - "{0} does not support item assignment".format(self.__class__.__name__) - ) - - def clear(self): - raise TypeError( - "{0} does not support item assignment".format(self.__class__.__name__) - ) - - def pop(self, *args, **kwargs): - raise TypeError( - "{0} does not support item assignment".format(self.__class__.__name__) - ) - - def popitem(self, *args, **kwargs): - raise TypeError( - "{0} does not support item assignment".format(self.__class__.__name__) - ) - - def setdefault(self, *args, **kwargs): - raise TypeError( - "{0} does not support item assignment".format(self.__class__.__name__) - ) - - def update(self, *args, **kwargs): - raise TypeError( - "{0} does not support item assignment".format(self.__class__.__name__) - ) - - def __add__(self, right): - result = hashdict(self) - dict.update(result, right) - return result - - -def fields_list_to_dict(fields): - """Takes a list of field names and returns a matching dictionary. - - ['a', 'b'] becomes {'a': 1, 'b': 1} - - and - - ['a.b.c', 'd', 'a.c'] becomes {'a.b.c': 1, 'd': 1, 'a.c': 1} - """ - as_dict = {} - for field in fields: - if not isinstance(field, str): - raise TypeError( - "fields must be a list of key names, each an instance of str" - ) - as_dict[field] = 1 - return as_dict - - -def parse_uri(uri, default_port=27017, warn=False): - """A simplified version of pymongo.uri_parser.parse_uri. - - Returns a dict with: - - nodelist, a tuple of (host, port) - - database the name of the database or None if no database is provided in the URI. - - An invalid MongoDB connection URI may raise an InvalidURI exception, - however, the URI is not fully parsed and some invalid URIs may not result - in an exception. - - 'mongodb://host1/database' becomes 'host1', 27017, 'database' - - and - - 'mongodb://host1' becomes 'host1', 27017, None - """ - SCHEME = "mongodb://" - - if not uri.startswith(SCHEME): - raise InvalidURI("Invalid URI scheme: URI " "must begin with '%s'" % (SCHEME,)) - - scheme_free = uri[len(SCHEME) :] - - if not scheme_free: - raise InvalidURI("Must provide at least one hostname or IP.") - - dbase = None - - # Check for unix domain sockets in the uri - if ".sock" in scheme_free: - host_part, _, path_part = scheme_free.rpartition("/") - if not host_part: - host_part = path_part - path_part = "" - if "/" in host_part: - raise InvalidURI( - "Any '/' in a unix domain socket must be" " URL encoded: %s" % host_part - ) - path_part = unquote_plus(path_part) - else: - host_part, _, path_part = scheme_free.partition("/") - - if not path_part and "?" in host_part: - raise InvalidURI("A '/' is required between " "the host list and any options.") - - nodelist = [] - if "," in host_part: - hosts = host_part.split(",") - else: - hosts = [host_part] - for host in hosts: - match = _HOST_MATCH.match(host) - if not match: - raise ValueError( - "Reserved characters such as ':' must be escaped according RFC " - "2396. An IPv6 address literal must be enclosed in '[' and ']' " - "according to RFC 2732." - ) - host = match.group(2) - if host.startswith("[") and host.endswith("]"): - host = host[1:-1] - - port = match.group(4) - if port: - try: - port = int(port) - if port < 0 or port > 65535: - raise ValueError() - except ValueError as err: - raise ValueError( - "Port must be an integer between 0 and 65535:", port - ) from err - else: - port = default_port - - nodelist.append((host, port)) - - if path_part and path_part[0] != "?": - dbase, _, _ = path_part.partition("?") - if "." in dbase: - dbase, _ = dbase.split(".", 1) - - if dbase is not None: - dbase = unquote_plus(dbase) - - return {"nodelist": tuple(nodelist), "database": dbase} - - -def split_hosts(hosts, default_port=27017): - """Split the entity into a list of tuples of host and port.""" - - nodelist = [] - for entity in hosts.split(","): - port = default_port - if entity.endswith(".sock"): - port = None - - match = _SIMPLE_HOST_MATCH.match(entity) - if not match: - raise ValueError( - "Reserved characters such as ':' must be escaped according RFC " - "2396. An IPv6 address literal must be enclosed in '[' and ']' " - "according to RFC 2732." - ) - host = match.group(1) - if host.startswith("[") and host.endswith("]"): - host = host[1:-1] - - if match.group(3): - try: - port = int(match.group(3)) - if port < 0 or port > 65535: - raise ValueError() - except ValueError as err: - raise ValueError( - "Port must be an integer between 0 and 65535:", port - ) from err - - nodelist.append((host, port)) - - return nodelist - - -_LAST_TIMESTAMP_INC = [] - - -def get_current_timestamp(): - """Get the current timestamp as a bson Timestamp object.""" - if not Timestamp: - raise NotImplementedError( - "timestamp is not supported. Import pymongo to use it." - ) - now = int(time.time()) - if _LAST_TIMESTAMP_INC and _LAST_TIMESTAMP_INC[0] == now: - _LAST_TIMESTAMP_INC[1] += 1 - else: - del _LAST_TIMESTAMP_INC[:] - _LAST_TIMESTAMP_INC.extend([now, 1]) - return Timestamp(now, _LAST_TIMESTAMP_INC[1]) - - -def patch_datetime_awareness_in_document(value): - # MongoDB is supposed to stock everything as timezone naive utc date - # Hence we have to convert incoming datetimes to avoid errors while - # mixing tz aware and naive. - # On top of that, MongoDB date precision is up to millisecond, where Python - # datetime use microsecond, so we must lower the precision to mimic mongo. - for best_type in (OrderedDict, dict): - if isinstance(value, best_type): - return best_type( - (k, patch_datetime_awareness_in_document(v)) for k, v in value.items() - ) - if isinstance(value, (tuple, list)): - return [patch_datetime_awareness_in_document(item) for item in value] - if isinstance(value, datetime): - mongo_us = (value.microsecond // 1000) * 1000 - if value.tzinfo: - return (value - value.utcoffset()).replace( - tzinfo=None, microsecond=mongo_us - ) - return value.replace(microsecond=mongo_us) - if Timestamp and isinstance(value, Timestamp) and not value.time and not value.inc: - return get_current_timestamp() - return value - - -def make_datetime_timezone_aware_in_document(value): - # MongoClient support tz_aware=True parameter to return timezone-aware - # datetime objects. Given the date is stored internally without timezone - # information, all returned datetime have utc as timezone. - if isinstance(value, dict): - return { - k: make_datetime_timezone_aware_in_document(v) for k, v in value.items() - } - if isinstance(value, (tuple, list)): - return [make_datetime_timezone_aware_in_document(item) for item in value] - if isinstance(value, datetime): - return value.replace(tzinfo=utc) - return value - - -def get_value_by_dot(doc, key, can_generate_array=False): - """Get dictionary value using dotted key""" - result = doc - key_items = key.split(".") - for key_index, key_item in enumerate(key_items): - if isinstance(result, dict): - result = result[key_item] - - elif isinstance(result, (list, tuple)): - try: - int_key = int(key_item) - except ValueError as err: - if not can_generate_array: - raise KeyError(key_index) from err - remaining_key = ".".join(key_items[key_index:]) - return [get_value_by_dot(subdoc, remaining_key) for subdoc in result] - - try: - result = result[int_key] - except (ValueError, IndexError) as err: - raise KeyError(key_index) from err - - else: - raise KeyError(key_index) - - return result - - -def set_value_by_dot(doc, key, value): - """Set dictionary value using dotted key""" - try: - parent_key, child_key = key.rsplit(".", 1) - parent = get_value_by_dot(doc, parent_key) - except ValueError: - child_key = key - parent = doc - - if isinstance(parent, dict): - parent[child_key] = value - elif isinstance(parent, (list, tuple)): - try: - parent[int(child_key)] = value - except (ValueError, IndexError) as err: - raise KeyError() from err - else: - raise KeyError() - - return doc - - -def delete_value_by_dot(doc, key): - """Delete dictionary value using dotted key. - - This function assumes that the value exists. - """ - try: - parent_key, child_key = key.rsplit(".", 1) - parent = get_value_by_dot(doc, parent_key) - except ValueError: - child_key = key - parent = doc - - del parent[child_key] - - return doc - - -def mongodb_to_bool(value): - """Converts any value to bool the way MongoDB does it""" - - return value not in [False, None, 0] diff --git a/packages/syft/tests/mongomock/mongo_client.py b/packages/syft/tests/mongomock/mongo_client.py deleted file mode 100644 index 560a7ce0f11..00000000000 --- a/packages/syft/tests/mongomock/mongo_client.py +++ /dev/null @@ -1,222 +0,0 @@ -# stdlib -import itertools -import warnings - -# third party -from packaging import version - -# relative -from . import ConfigurationError -from . import codec_options as mongomock_codec_options -from . import helpers -from . import read_preferences -from .database import Database -from .store import ServerStore - -try: - # third party - from pymongo import ReadPreference - from pymongo.uri_parser import parse_uri - from pymongo.uri_parser import split_hosts - - _READ_PREFERENCE_PRIMARY = ReadPreference.PRIMARY -except ImportError: - # relative - from .helpers import parse_uri - from .helpers import split_hosts - - _READ_PREFERENCE_PRIMARY = read_preferences.PRIMARY - - -def _convert_version_to_list(version_str): - pieces = [int(part) for part in version_str.split(".")] - return pieces + [0] * (4 - len(pieces)) - - -class MongoClient(object): - HOST = "localhost" - PORT = 27017 - _CONNECTION_ID = itertools.count() - - def __init__( - self, - host=None, - port=None, - document_class=dict, - tz_aware=False, - connect=True, - _store=None, - read_preference=None, - uuidRepresentation=None, - type_registry=None, - **kwargs, - ): - if host: - self.host = host[0] if isinstance(host, (list, tuple)) else host - else: - self.host = self.HOST - self.port = port or self.PORT - - self._tz_aware = tz_aware - self._codec_options = mongomock_codec_options.CodecOptions( - tz_aware=tz_aware, - uuid_representation=uuidRepresentation, - type_registry=type_registry, - ) - self._database_accesses = {} - self._store = _store or ServerStore() - self._id = next(self._CONNECTION_ID) - self._document_class = document_class - if read_preference is not None: - read_preferences.ensure_read_preference_type( - "read_preference", read_preference - ) - self._read_preference = read_preference or _READ_PREFERENCE_PRIMARY - - dbase = None - - if "://" in self.host: - res = parse_uri(self.host, default_port=self.port, warn=True) - self.host, self.port = res["nodelist"][0] - dbase = res["database"] - else: - self.host, self.port = split_hosts(self.host, default_port=self.port)[0] - - self.__default_database_name = dbase - # relative - from . import SERVER_VERSION - - self._server_version = SERVER_VERSION - - def __getitem__(self, db_name): - return self.get_database(db_name) - - def __getattr__(self, attr): - return self[attr] - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - def __repr__(self): - return "mongomock.MongoClient('{0}', {1})".format(self.host, self.port) - - def __eq__(self, other): - if isinstance(other, self.__class__): - return self.address == other.address - return NotImplemented - - if helpers.PYMONGO_VERSION >= version.parse("3.12"): - - def __hash__(self): - return hash(self.address) - - def close(self): - pass - - @property - def is_mongos(self): - return True - - @property - def is_primary(self): - return True - - @property - def address(self): - return self.host, self.port - - @property - def read_preference(self): - return self._read_preference - - @property - def codec_options(self): - return self._codec_options - - def server_info(self): - return { - "version": self._server_version, - "sysInfo": "Mock", - "versionArray": _convert_version_to_list(self._server_version), - "bits": 64, - "debug": False, - "maxBsonObjectSize": 16777216, - "ok": 1, - } - - if helpers.PYMONGO_VERSION < version.parse("4.0"): - - def database_names(self): - warnings.warn( - "database_names is deprecated. Use list_database_names instead." - ) - return self.list_database_names() - - def list_database_names(self): - return self._store.list_created_database_names() - - def drop_database(self, name_or_db): - def drop_collections_for_db(_db): - db_store = self._store[_db.name] - for col_name in db_store.list_created_collection_names(): - _db.drop_collection(col_name) - - if isinstance(name_or_db, Database): - db = next(db for db in self._database_accesses.values() if db is name_or_db) - if db: - drop_collections_for_db(db) - - elif name_or_db in self._store: - db = self.get_database(name_or_db) - drop_collections_for_db(db) - - def get_database( - self, - name=None, - codec_options=None, - read_preference=None, - write_concern=None, - read_concern=None, - ): - if name is None: - db = self.get_default_database( - codec_options=codec_options, - read_preference=read_preference, - write_concern=write_concern, - read_concern=read_concern, - ) - else: - db = self._database_accesses.get(name) - if db is None: - db_store = self._store[name] - db = self._database_accesses[name] = Database( - self, - name, - read_preference=read_preference or self.read_preference, - codec_options=codec_options or self._codec_options, - _store=db_store, - read_concern=read_concern, - ) - return db - - def get_default_database(self, default=None, **kwargs): - name = self.__default_database_name - name = name if name is not None else default - if name is None: - raise ConfigurationError("No default database name defined or provided.") - - return self.get_database(name=name, **kwargs) - - def alive(self): - """The original MongoConnection.alive method checks the status of the server. - - In our case as we mock the actual server, we should always return True. - """ - return True - - def start_session(self, causal_consistency=True, default_transaction_options=None): - """Start a logical session.""" - raise NotImplementedError("Mongomock does not support sessions yet") diff --git a/packages/syft/tests/mongomock/not_implemented.py b/packages/syft/tests/mongomock/not_implemented.py deleted file mode 100644 index 990b89a411e..00000000000 --- a/packages/syft/tests/mongomock/not_implemented.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Module to handle features that are not implemented yet.""" - -_IGNORED_FEATURES = { - "array_filters": False, - "collation": False, - "let": False, - "session": False, -} - - -def _ensure_ignorable_feature(feature): - if feature not in _IGNORED_FEATURES: - raise KeyError( - "%s is not an error that can be ignored: maybe it has been implemented in Mongomock. " - "Here is the list of features that can be ignored: %s" - % (feature, _IGNORED_FEATURES.keys()) - ) - - -def ignore_feature(feature): - """Ignore a feature instead of raising a NotImplementedError.""" - _ensure_ignorable_feature(feature) - _IGNORED_FEATURES[feature] = True - - -def warn_on_feature(feature): - """Rasie a NotImplementedError the next times a feature is used.""" - _ensure_ignorable_feature(feature) - _IGNORED_FEATURES[feature] = False - - -def raise_for_feature(feature, reason): - _ensure_ignorable_feature(feature) - if _IGNORED_FEATURES[feature]: - return False - raise NotImplementedError(reason) diff --git a/packages/syft/tests/mongomock/object_id.py b/packages/syft/tests/mongomock/object_id.py deleted file mode 100644 index 281e8b02663..00000000000 --- a/packages/syft/tests/mongomock/object_id.py +++ /dev/null @@ -1,26 +0,0 @@ -# stdlib -import uuid - - -class ObjectId(object): - def __init__(self, id=None): - super(ObjectId, self).__init__() - if id is None: - self._id = uuid.uuid1() - else: - self._id = uuid.UUID(id) - - def __eq__(self, other): - return isinstance(other, ObjectId) and other._id == self._id - - def __ne__(self, other): - return not self == other - - def __hash__(self): - return hash(self._id) - - def __repr__(self): - return "ObjectId({0})".format(self._id) - - def __str__(self): - return str(self._id) diff --git a/packages/syft/tests/mongomock/patch.py b/packages/syft/tests/mongomock/patch.py deleted file mode 100644 index 8db36497e75..00000000000 --- a/packages/syft/tests/mongomock/patch.py +++ /dev/null @@ -1,120 +0,0 @@ -# stdlib -import time - -# relative -from .mongo_client import MongoClient - -try: - # stdlib - from unittest import mock - - _IMPORT_MOCK_ERROR = None -except ImportError: - try: - # third party - import mock - - _IMPORT_MOCK_ERROR = None - except ImportError as error: - _IMPORT_MOCK_ERROR = error - -try: - # third party - import pymongo - from pymongo.uri_parser import parse_uri - from pymongo.uri_parser import split_hosts - - _IMPORT_PYMONGO_ERROR = None -except ImportError as error: - # relative - from .helpers import parse_uri - from .helpers import split_hosts - - _IMPORT_PYMONGO_ERROR = error - - -def _parse_any_host(host, default_port=27017): - if isinstance(host, tuple): - return _parse_any_host(host[0], host[1]) - if "://" in host: - return parse_uri(host, warn=True)["nodelist"] - return split_hosts(host, default_port=default_port) - - -def patch(servers="localhost", on_new="error"): - """Patch pymongo.MongoClient. - - This will patch the class MongoClient and use mongomock to mock MongoDB - servers. It keeps a consistant state of servers across multiple clients so - you can do: - - ``` - client = pymongo.MongoClient(host='localhost', port=27017) - client.db.coll.insert_one({'name': 'Pascal'}) - - other_client = pymongo.MongoClient('mongodb://localhost:27017') - client.db.coll.find_one() - ``` - - The data is persisted as long as the patch lives. - - Args: - on_new: Behavior when accessing a new server (not in servers): - 'create': mock a new empty server, accept any client connection. - 'error': raise a ValueError immediately when trying to access. - 'timeout': behave as pymongo when a server does not exist, raise an - error after a timeout. - 'pymongo': use an actual pymongo client. - servers: a list of server that are avaiable. - """ - - if _IMPORT_MOCK_ERROR: - raise _IMPORT_MOCK_ERROR # pylint: disable=raising-bad-type - - if _IMPORT_PYMONGO_ERROR: - PyMongoClient = None - else: - PyMongoClient = pymongo.MongoClient - - persisted_clients = {} - parsed_servers = set() - for server in servers if isinstance(servers, (list, tuple)) else [servers]: - parsed_servers.update(_parse_any_host(server)) - - def _create_persistent_client(*args, **kwargs): - if _IMPORT_PYMONGO_ERROR: - raise _IMPORT_PYMONGO_ERROR # pylint: disable=raising-bad-type - - client = MongoClient(*args, **kwargs) - - try: - persisted_client = persisted_clients[client.address] - client._store = persisted_client._store - return client - except KeyError: - pass - - if client.address in parsed_servers or on_new == "create": - persisted_clients[client.address] = client - return client - - if on_new == "timeout": - # TODO(pcorpet): Only wait when trying to access the server's data. - time.sleep(kwargs.get("serverSelectionTimeoutMS", 30000)) - raise pymongo.errors.ServerSelectionTimeoutError( - "%s:%d: [Errno 111] Connection refused" % client.address - ) - - if on_new == "pymongo": - return PyMongoClient(*args, **kwargs) - - raise ValueError( - "MongoDB server %s:%d does not exist.\n" % client.address - + "%s" % parsed_servers - ) - - class _PersistentClient: - def __new__(cls, *args, **kwargs): - return _create_persistent_client(*args, **kwargs) - - return mock.patch("pymongo.MongoClient", _PersistentClient) diff --git a/packages/syft/tests/mongomock/py.typed b/packages/syft/tests/mongomock/py.typed deleted file mode 100644 index e69de29bb2d..00000000000 diff --git a/packages/syft/tests/mongomock/read_concern.py b/packages/syft/tests/mongomock/read_concern.py deleted file mode 100644 index 229e0f78bb4..00000000000 --- a/packages/syft/tests/mongomock/read_concern.py +++ /dev/null @@ -1,21 +0,0 @@ -class ReadConcern(object): - def __init__(self, level=None): - self._document = {} - - if level is not None: - self._document["level"] = level - - @property - def level(self): - return self._document.get("level") - - @property - def ok_for_legacy(self): - return True - - @property - def document(self): - return self._document.copy() - - def __eq__(self, other): - return other.document == self.document diff --git a/packages/syft/tests/mongomock/read_preferences.py b/packages/syft/tests/mongomock/read_preferences.py deleted file mode 100644 index a9349e6c576..00000000000 --- a/packages/syft/tests/mongomock/read_preferences.py +++ /dev/null @@ -1,42 +0,0 @@ -class _Primary(object): - @property - def mongos_mode(self): - return "primary" - - @property - def mode(self): - return 0 - - @property - def name(self): - return "Primary" - - @property - def document(self): - return {"mode": "primary"} - - @property - def tag_sets(self): - return [{}] - - @property - def max_staleness(self): - return -1 - - @property - def min_wire_version(self): - return 0 - - -def ensure_read_preference_type(key, value): - """Raise a TypeError if the value is not a type compatible for ReadPreference.""" - for attr in ("document", "mode", "mongos_mode", "max_staleness"): - if not hasattr(value, attr): - raise TypeError( - "{} must be an instance of {}".format( - key, "pymongo.read_preference.ReadPreference" - ) - ) - - -PRIMARY = _Primary() diff --git a/packages/syft/tests/mongomock/results.py b/packages/syft/tests/mongomock/results.py deleted file mode 100644 index 07633a6e82e..00000000000 --- a/packages/syft/tests/mongomock/results.py +++ /dev/null @@ -1,117 +0,0 @@ -try: - # third party - from pymongo.results import BulkWriteResult - from pymongo.results import DeleteResult - from pymongo.results import InsertManyResult - from pymongo.results import InsertOneResult - from pymongo.results import UpdateResult -except ImportError: - - class _WriteResult(object): - def __init__(self, acknowledged=True): - self.__acknowledged = acknowledged - - @property - def acknowledged(self): - return self.__acknowledged - - class InsertOneResult(_WriteResult): - __slots__ = ("__inserted_id", "__acknowledged") - - def __init__(self, inserted_id, acknowledged=True): - self.__inserted_id = inserted_id - super(InsertOneResult, self).__init__(acknowledged) - - @property - def inserted_id(self): - return self.__inserted_id - - class InsertManyResult(_WriteResult): - __slots__ = ("__inserted_ids", "__acknowledged") - - def __init__(self, inserted_ids, acknowledged=True): - self.__inserted_ids = inserted_ids - super(InsertManyResult, self).__init__(acknowledged) - - @property - def inserted_ids(self): - return self.__inserted_ids - - class UpdateResult(_WriteResult): - __slots__ = ("__raw_result", "__acknowledged") - - def __init__(self, raw_result, acknowledged=True): - self.__raw_result = raw_result - super(UpdateResult, self).__init__(acknowledged) - - @property - def raw_result(self): - return self.__raw_result - - @property - def matched_count(self): - if self.upserted_id is not None: - return 0 - return self.__raw_result.get("n", 0) - - @property - def modified_count(self): - return self.__raw_result.get("nModified") - - @property - def upserted_id(self): - return self.__raw_result.get("upserted") - - class DeleteResult(_WriteResult): - __slots__ = ("__raw_result", "__acknowledged") - - def __init__(self, raw_result, acknowledged=True): - self.__raw_result = raw_result - super(DeleteResult, self).__init__(acknowledged) - - @property - def raw_result(self): - return self.__raw_result - - @property - def deleted_count(self): - return self.__raw_result.get("n", 0) - - class BulkWriteResult(_WriteResult): - __slots__ = ("__bulk_api_result", "__acknowledged") - - def __init__(self, bulk_api_result, acknowledged): - self.__bulk_api_result = bulk_api_result - super(BulkWriteResult, self).__init__(acknowledged) - - @property - def bulk_api_result(self): - return self.__bulk_api_result - - @property - def inserted_count(self): - return self.__bulk_api_result.get("nInserted") - - @property - def matched_count(self): - return self.__bulk_api_result.get("nMatched") - - @property - def modified_count(self): - return self.__bulk_api_result.get("nModified") - - @property - def deleted_count(self): - return self.__bulk_api_result.get("nRemoved") - - @property - def upserted_count(self): - return self.__bulk_api_result.get("nUpserted") - - @property - def upserted_ids(self): - if self.__bulk_api_result: - return dict( - (upsert["index"], upsert["_id"]) - for upsert in self.bulk_api_result["upserted"] - ) diff --git a/packages/syft/tests/mongomock/store.py b/packages/syft/tests/mongomock/store.py deleted file mode 100644 index 9cef7206329..00000000000 --- a/packages/syft/tests/mongomock/store.py +++ /dev/null @@ -1,191 +0,0 @@ -# stdlib -import collections -import datetime -import functools - -# relative -from .helpers import utcnow -from .thread import RWLock - - -class ServerStore(object): - """Object holding the data for a whole server (many databases).""" - - def __init__(self): - self._databases = {} - - def __getitem__(self, db_name): - try: - return self._databases[db_name] - except KeyError: - db = self._databases[db_name] = DatabaseStore() - return db - - def __contains__(self, db_name): - return self[db_name].is_created - - def list_created_database_names(self): - return [name for name, db in self._databases.items() if db.is_created] - - -class DatabaseStore(object): - """Object holding the data for a database (many collections).""" - - def __init__(self): - self._collections = {} - - def __getitem__(self, col_name): - try: - return self._collections[col_name] - except KeyError: - col = self._collections[col_name] = CollectionStore(col_name) - return col - - def __contains__(self, col_name): - return self[col_name].is_created - - def list_created_collection_names(self): - return [name for name, col in self._collections.items() if col.is_created] - - def create_collection(self, name): - col = self[name] - col.create() - return col - - def rename(self, name, new_name): - col = self._collections.pop(name, CollectionStore(new_name)) - col.name = new_name - self._collections[new_name] = col - - @property - def is_created(self): - return any(col.is_created for col in self._collections.values()) - - -class CollectionStore(object): - """Object holding the data for a collection.""" - - def __init__(self, name): - self._documents = collections.OrderedDict() - self.indexes = {} - self._is_force_created = False - self.name = name - self._ttl_indexes = {} - - # 694 - Lock for safely iterating and mutating OrderedDicts - self._rwlock = RWLock() - - def create(self): - self._is_force_created = True - - @property - def is_created(self): - return self._documents or self.indexes or self._is_force_created - - def drop(self): - self._documents = collections.OrderedDict() - self.indexes = {} - self._ttl_indexes = {} - self._is_force_created = False - - def create_index(self, index_name, index_dict): - self.indexes[index_name] = index_dict - if index_dict.get("expireAfterSeconds") is not None: - self._ttl_indexes[index_name] = index_dict - - def drop_index(self, index_name): - self._remove_expired_documents() - - # The main index object should raise a KeyError, but the - # TTL indexes have no meaning to the outside. - del self.indexes[index_name] - self._ttl_indexes.pop(index_name, None) - - @property - def is_empty(self): - self._remove_expired_documents() - return not self._documents - - def __contains__(self, key): - self._remove_expired_documents() - with self._rwlock.reader(): - return key in self._documents - - def __getitem__(self, key): - self._remove_expired_documents() - with self._rwlock.reader(): - return self._documents[key] - - def __setitem__(self, key, val): - with self._rwlock.writer(): - self._documents[key] = val - - def __delitem__(self, key): - with self._rwlock.writer(): - del self._documents[key] - - def __len__(self): - self._remove_expired_documents() - with self._rwlock.reader(): - return len(self._documents) - - @property - def documents(self): - self._remove_expired_documents() - with self._rwlock.reader(): - for doc in self._documents.values(): - yield doc - - def _remove_expired_documents(self): - for index in self._ttl_indexes.values(): - self._expire_documents(index) - - def _expire_documents(self, index): - # TODO(juannyg): use a caching mechanism to avoid re-expiring the documents if - # we just did and no document was added / updated - - # Ignore non-integer values - try: - expiry = int(index["expireAfterSeconds"]) - except ValueError: - return - - # Ignore commpound keys - if len(index["key"]) > 1: - return - - # "key" structure = list of (field name, direction) tuples - ttl_field_name = next(iter(index["key"]))[0] - ttl_now = utcnow() - - with self._rwlock.reader(): - expired_ids = [ - doc["_id"] - for doc in self._documents.values() - if self._value_meets_expiry(doc.get(ttl_field_name), expiry, ttl_now) - ] - - for exp_id in expired_ids: - del self[exp_id] - - def _value_meets_expiry(self, val, expiry, ttl_now): - val_to_compare = _get_min_datetime_from_value(val) - try: - return (ttl_now - val_to_compare).total_seconds() >= expiry - except TypeError: - return False - - -def _get_min_datetime_from_value(val): - if not val: - return datetime.datetime.max - if isinstance(val, list): - return functools.reduce(_min_dt, [datetime.datetime.max] + val) - return val - - -def _min_dt(dt1, dt2): - try: - return dt1 if dt1 < dt2 else dt2 - except TypeError: - return dt1 diff --git a/packages/syft/tests/mongomock/thread.py b/packages/syft/tests/mongomock/thread.py deleted file mode 100644 index ff673e44309..00000000000 --- a/packages/syft/tests/mongomock/thread.py +++ /dev/null @@ -1,94 +0,0 @@ -# stdlib -from contextlib import contextmanager -import threading - - -class RWLock: - """Lock enabling multiple readers but only 1 exclusive writer - - Source: https://cutt.ly/Ij70qaq - """ - - def __init__(self): - self._read_switch = _LightSwitch() - self._write_switch = _LightSwitch() - self._no_readers = threading.Lock() - self._no_writers = threading.Lock() - self._readers_queue = threading.RLock() - - @contextmanager - def reader(self): - self._reader_acquire() - try: - yield - except Exception: # pylint: disable=W0706 - raise - finally: - self._reader_release() - - @contextmanager - def writer(self): - self._writer_acquire() - try: - yield - except Exception: # pylint: disable=W0706 - raise - finally: - self._writer_release() - - def _reader_acquire(self): - """Readers should block whenever a writer has acquired""" - self._readers_queue.acquire() - self._no_readers.acquire() - self._read_switch.acquire(self._no_writers) - self._no_readers.release() - self._readers_queue.release() - - def _reader_release(self): - self._read_switch.release(self._no_writers) - - def _writer_acquire(self): - """Acquire the writer lock. - - Only the first writer will lock the readtry and then - all subsequent writers can simply use the resource as - it gets freed by the previous writer. The very last writer must - release the readtry semaphore, thus opening the gate for readers - to try reading. - - No reader can engage in the entry section if the readtry semaphore - has been set by a writer previously - """ - self._write_switch.acquire(self._no_readers) - self._no_writers.acquire() - - def _writer_release(self): - self._no_writers.release() - self._write_switch.release(self._no_readers) - - -class _LightSwitch: - """An auxiliary "light switch"-like object - - The first thread turns on the "switch", the last one turns it off. - - Source: https://cutt.ly/Ij70qaq - """ - - def __init__(self): - self._counter = 0 - self._mutex = threading.RLock() - - def acquire(self, lock): - self._mutex.acquire() - self._counter += 1 - if self._counter == 1: - lock.acquire() - self._mutex.release() - - def release(self, lock): - self._mutex.acquire() - self._counter -= 1 - if self._counter == 0: - lock.release() - self._mutex.release() diff --git a/packages/syft/tests/mongomock/write_concern.py b/packages/syft/tests/mongomock/write_concern.py deleted file mode 100644 index 93760445647..00000000000 --- a/packages/syft/tests/mongomock/write_concern.py +++ /dev/null @@ -1,45 +0,0 @@ -def _with_default_values(document): - if "w" in document: - return document - return dict(document, w=1) - - -class WriteConcern(object): - def __init__(self, w=None, wtimeout=None, j=None, fsync=None): - self._document = {} - if w is not None: - self._document["w"] = w - if wtimeout is not None: - self._document["wtimeout"] = wtimeout - if j is not None: - self._document["j"] = j - if fsync is not None: - self._document["fsync"] = fsync - - def __eq__(self, other): - try: - return _with_default_values(other.document) == _with_default_values( - self.document - ) - except AttributeError: - return NotImplemented - - def __ne__(self, other): - try: - return _with_default_values(other.document) != _with_default_values( - self.document - ) - except AttributeError: - return NotImplemented - - @property - def acknowledged(self): - return True - - @property - def document(self): - return self._document.copy() - - @property - def is_server_default(self): - return not self._document diff --git a/packages/syftcli/manifest.yml b/packages/syftcli/manifest.yml index 6dce61b13d5..bfc5224ce48 100644 --- a/packages/syftcli/manifest.yml +++ b/packages/syftcli/manifest.yml @@ -6,7 +6,7 @@ dockerTag: 0.9.2-beta.2 images: - docker.io/openmined/syft-frontend:0.9.2-beta.2 - docker.io/openmined/syft-backend:0.9.2-beta.2 - - docker.io/library/mongo:7.0.4 + - docker.io/library/postgres:13 - docker.io/traefik:v2.11.0 configFiles: diff --git a/scripts/reset_mongo.sh b/scripts/reset_mongo.sh deleted file mode 100755 index ac1641f68e4..00000000000 --- a/scripts/reset_mongo.sh +++ /dev/null @@ -1,18 +0,0 @@ -#!/bin/bash - -# WARNING: this will drop the app database in all your mongo dbs -echo $1 - -if [ -z $1 ]; then - MONGO_CONTAINER_NAME=$(docker ps --format '{{.Names}}' | grep -m 1 mongo) -else - MONGO_CONTAINER_NAME=$1 -fi - -DROPCMD="<&1 \ No newline at end of file From fee08de28916387111a161bfeb7020e407a65808 Mon Sep 17 00:00:00 2001 From: dk Date: Wed, 11 Sep 2024 14:52:24 +0700 Subject: [PATCH 132/197] [clean] remove mongo related api reference - remove ununsed scripts --- .isort.cfg | 2 +- .../api_reference/syft.store.mongo_client.rst | 31 -------------- .../api_reference/syft.store.mongo_codecs.rst | 35 ---------------- .../syft.store.mongo_document_store.rst | 40 ------------------- docs/source/api_reference/syft.store.rst | 3 -- scripts/reset_network.sh | 19 --------- 6 files changed, 1 insertion(+), 129 deletions(-) delete mode 100644 docs/source/api_reference/syft.store.mongo_client.rst delete mode 100644 docs/source/api_reference/syft.store.mongo_codecs.rst delete mode 100644 docs/source/api_reference/syft.store.mongo_document_store.rst delete mode 100755 scripts/reset_network.sh diff --git a/.isort.cfg b/.isort.cfg index aeb09bb8f36..26309a07039 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -20,4 +20,4 @@ import_heading_localfolder=relative ignore_comments=False force_grid_wrap=True honor_noqa=True -skip_glob=packages/syft/src/syft/__init__.py,packages/grid/data/*,packages/syft/tests/mongomock/* \ No newline at end of file +skip_glob=packages/syft/src/syft/__init__.py,packages/grid/data/* \ No newline at end of file diff --git a/docs/source/api_reference/syft.store.mongo_client.rst b/docs/source/api_reference/syft.store.mongo_client.rst deleted file mode 100644 index a21d43700aa..00000000000 --- a/docs/source/api_reference/syft.store.mongo_client.rst +++ /dev/null @@ -1,31 +0,0 @@ -syft.store.mongo\_client -======================== - -.. automodule:: syft.store.mongo_client - - - - - - - - - - - - .. rubric:: Classes - - .. autosummary:: - - MongoClient - MongoClientCache - MongoStoreClientConfig - - - - - - - - - diff --git a/docs/source/api_reference/syft.store.mongo_codecs.rst b/docs/source/api_reference/syft.store.mongo_codecs.rst deleted file mode 100644 index 1d91b779e95..00000000000 --- a/docs/source/api_reference/syft.store.mongo_codecs.rst +++ /dev/null @@ -1,35 +0,0 @@ -syft.store.mongo\_codecs -======================== - -.. automodule:: syft.store.mongo_codecs - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - fallback_syft_encoder - - - - - - .. rubric:: Classes - - .. autosummary:: - - SyftMongoBinaryDecoder - - - - - - - - - diff --git a/docs/source/api_reference/syft.store.mongo_document_store.rst b/docs/source/api_reference/syft.store.mongo_document_store.rst deleted file mode 100644 index 30fdb6bc6ca..00000000000 --- a/docs/source/api_reference/syft.store.mongo_document_store.rst +++ /dev/null @@ -1,40 +0,0 @@ -syft.store.mongo\_document\_store -================================= - -.. automodule:: syft.store.mongo_document_store - - - - - - - - .. rubric:: Functions - - .. autosummary:: - - from_mongo - syft_obj_to_mongo - to_mongo - - - - - - .. rubric:: Classes - - .. autosummary:: - - MongoBsonObject - MongoDocumentStore - MongoStoreConfig - MongoStorePartition - - - - - - - - - diff --git a/docs/source/api_reference/syft.store.rst b/docs/source/api_reference/syft.store.rst index b21cf230488..e83e8699025 100644 --- a/docs/source/api_reference/syft.store.rst +++ b/docs/source/api_reference/syft.store.rst @@ -32,8 +32,5 @@ syft.store.kv_document_store syft.store.linked_obj syft.store.locks - syft.store.mongo_client - syft.store.mongo_codecs - syft.store.mongo_document_store syft.store.sqlite_document_store diff --git a/scripts/reset_network.sh b/scripts/reset_network.sh deleted file mode 100755 index ce5f863ff14..00000000000 --- a/scripts/reset_network.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash - -MONGO_CONTAINER_NAME=$(docker ps --format '{{.Names}}' | grep -m 1 mongo) -DROPCMD="<&1 - -# flush the worker queue -. ${BASH_SOURCE%/*}/flush_queue.sh - -# reset docker service to clear out weird network issues -sudo service docker restart - -# make sure all containers start -. ${BASH_SOURCE%/*}/../packages/grid/scripts/containers.sh From 3eabf45b253b94967ca2aa3d194cc36a6c578040 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 11 Sep 2024 11:06:42 +0200 Subject: [PATCH 133/197] fix lint --- packages/syft/src/syft/server/server.py | 5 ++- .../syft/src/syft/server/service_registry.py | 2 +- .../src/syft/service/action/action_service.py | 2 + .../data_subject/data_subject_service.py | 18 -------- .../syft/service/network/network_service.py | 43 ++++++------------- .../syft/src/syft/service/network/utils.py | 2 +- .../src/syft/service/queue/queue_service.py | 20 +-------- .../syft/service/request/request_service.py | 6 +-- packages/syft/src/syft/service/service.py | 2 + .../src/syft/service/user/user_service.py | 4 +- .../syft/service/worker/worker_image_stash.py | 6 +++ .../syft/service/worker/worker_pool_stash.py | 10 ++++- .../src/syft/service/worker/worker_stash.py | 6 +++ packages/syft/src/syft/store/linked_obj.py | 2 +- .../tests/syft/users/user_service_test.py | 22 ++++++++++ 15 files changed, 69 insertions(+), 81 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index e952c0f4633..bdfc73a6ad4 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -51,6 +51,7 @@ from ..service.metadata.server_metadata import ServerMetadata from ..service.network.utils import PeerHealthCheckTask from ..service.notifier.notifier_service import NotifierService +from ..service.output.output_service import OutputStash from ..service.queue.base_queue import AbstractMessageHandler from ..service.queue.base_queue import QueueConsumer from ..service.queue.base_queue import QueueProducer @@ -903,8 +904,8 @@ def job_stash(self) -> JobStash: return self.services.job.stash @property - def output_stash(self) -> JobStash: - return self.get_service("outputservice").stash + def output_stash(self) -> OutputStash: + return self.services.output.stash @property def worker_stash(self) -> WorkerStash: diff --git a/packages/syft/src/syft/server/service_registry.py b/packages/syft/src/syft/server/service_registry.py index 0689185f4cc..23052804385 100644 --- a/packages/syft/src/syft/server/service_registry.py +++ b/packages/syft/src/syft/server/service_registry.py @@ -119,7 +119,7 @@ def get_service_classes( def _construct_services(cls, server: "Server") -> dict[str, AbstractService]: service_dict = {} for field_name, service_cls in cls.get_service_classes().items(): - service = service_cls(store=server.db) + service = service_cls(store=server.db) # type: ignore service_dict[field_name] = service return service_dict diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index ddca7e03ebd..fa5a1ee8a8a 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -56,6 +56,8 @@ @serializable(canonical_name="ActionService", version=1) class ActionService(AbstractService): + stash: ActionObjectStash + def __init__(self, store: DocumentStore) -> None: # TODO remove self.store, use self.stash instead self.store = ActionObjectStash(store) diff --git a/packages/syft/src/syft/service/data_subject/data_subject_service.py b/packages/syft/src/syft/service/data_subject/data_subject_service.py index cfa9f611afc..331c22a9d49 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_service.py @@ -28,24 +28,6 @@ def get_by_name(self, credentials: SyftVerifyKey, name: str) -> DataSubject: filters={"name": name}, ).unwrap() - @as_result(StashException) - def update( - self, - credentials: SyftVerifyKey, - data_subject: DataSubject, - has_permission: bool = False, - ) -> DataSubject: - self.check_type(data_subject, DataSubject).unwrap() - return ( - super() - .update( - credentials=credentials, - obj=data_subject, - has_permission=has_permission, - ) - .unwrap() - ) - @serializable(canonical_name="DataSubjectService", version=1) class DataSubjectService(AbstractService): diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 28f28db75a8..9301e702890 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -17,8 +17,6 @@ from ...service.settings.settings import ServerSettings from ...store.db.stash import ObjectStash from ...store.document_store import DocumentStore -from ...store.document_store import PartitionKey -from ...store.document_store import QueryKeys from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException @@ -59,10 +57,6 @@ logger = logging.getLogger(__name__) -VerifyKeyPartitionKey = PartitionKey(key="verify_key", type_=SyftVerifyKey) -ServerTypePartitionKey = PartitionKey(key="server_type", type_=ServerType) -OrderByNamePartitionKey = PartitionKey(key="name", type_=str) - REVERSE_TUNNEL_ENABLED = "REVERSE_TUNNEL_ENABLED" @@ -91,20 +85,6 @@ def get_by_name(self, credentials: SyftVerifyKey, name: str) -> ServerPeer: e, public_message=f"ServerPeer with {name} not found" ) - @as_result(StashException) - def update( - self, - credentials: SyftVerifyKey, - peer_update: ServerPeerUpdate, - has_permission: bool = False, - ) -> ServerPeer: - self.check_type(peer_update, ServerPeerUpdate).unwrap() - return ( - super() - .update(credentials, peer_update, has_permission=has_permission) - .unwrap() - ) - @as_result(StashException) def create_or_update_peer( self, credentials: SyftVerifyKey, peer: ServerPeer @@ -140,18 +120,18 @@ def create_or_update_peer( def get_by_verify_key( self, credentials: SyftVerifyKey, verify_key: SyftVerifyKey ) -> ServerPeer: - qks = QueryKeys(qks=[VerifyKeyPartitionKey.with_obj(verify_key)]) - return self.query_one(credentials, qks).unwrap( - private_message=f"ServerPeer with {verify_key} not found" - ) + return self.get_one( + credentials=credentials, + filters={"verify_key": verify_key}, + ).unwrap() @as_result(StashException) def get_by_server_type( self, credentials: SyftVerifyKey, server_type: ServerType ) -> list[ServerPeer]: - qks = QueryKeys(qks=[ServerTypePartitionKey.with_obj(server_type)]) - return self.query_all( - credentials=credentials, qks=qks, order_by=OrderByNamePartitionKey + return self.get( + credentials=credentials, + filters={"server_type": server_type}, ).unwrap() @@ -399,7 +379,8 @@ def get_all_peers(self, context: AuthedServiceContext) -> list[ServerPeer]: """Get all Peers""" return self.stash.get_all( credentials=context.server.verify_key, - order_by=OrderByNamePartitionKey, + order_by="name", + sort_order="asc", ).unwrap() @service_method( @@ -439,7 +420,7 @@ def update_peer( peer = self.stash.update( credentials=context.server.verify_key, - peer_update=peer_update, + obj=peer_update, ).unwrap() self.set_reverse_tunnel_config(context=context, remote_server_peer=peer) @@ -599,7 +580,7 @@ def add_route( ) self.stash.update( credentials=context.server.verify_key, - peer_update=peer_update, + obj=peer_update, ).unwrap() return SyftSuccess( @@ -727,7 +708,7 @@ def delete_route( id=remote_server_peer.id, server_routes=remote_server_peer.server_routes ) self.stash.update( - credentials=context.server.verify_key, peer_update=peer_update + credentials=context.server.verify_key, obj=peer_update ).unwrap() return SyftSuccess(message=return_message) diff --git a/packages/syft/src/syft/service/network/utils.py b/packages/syft/src/syft/service/network/utils.py index 280e836f17b..655d40b9b3e 100644 --- a/packages/syft/src/syft/service/network/utils.py +++ b/packages/syft/src/syft/service/network/utils.py @@ -88,7 +88,7 @@ def peer_route_heathcheck(self, context: AuthedServiceContext) -> None: result = network_stash.update( credentials=context.server.verify_key, - peer_update=peer_update, + obj=peer_update, has_permission=True, ) diff --git a/packages/syft/src/syft/service/queue/queue_service.py b/packages/syft/src/syft/service/queue/queue_service.py index c898893ee35..d47f21052a8 100644 --- a/packages/syft/src/syft/service/queue/queue_service.py +++ b/packages/syft/src/syft/service/queue/queue_service.py @@ -2,30 +2,14 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore -from ...types.uid import UID -from ..context import AuthedServiceContext +from ...store.db.sqlite_db import DBManager from ..service import AbstractService -from ..service import service_method -from ..user.user_roles import DATA_SCIENTIST_ROLE_LEVEL -from .queue_stash import QueueItem from .queue_stash import QueueStash @serializable(canonical_name="QueueService", version=1) class QueueService(AbstractService): - store: DocumentStore stash: QueueStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = QueueStash(store=store) - - @service_method( - path="queue.get_subjobs", - name="get_subjobs", - roles=DATA_SCIENTIST_ROLE_LEVEL, - ) - def get_subjobs(self, context: AuthedServiceContext, uid: UID) -> list[QueueItem]: - # FIX: There is no get_by_parent_id in QueueStash - return self.stash.get_by_parent_id(context.credentials, uid=uid) diff --git a/packages/syft/src/syft/service/request/request_service.py b/packages/syft/src/syft/service/request/request_service.py index 6343007f6ca..37ceeeae02a 100644 --- a/packages/syft/src/syft/service/request/request_service.py +++ b/packages/syft/src/syft/service/request/request_service.py @@ -4,7 +4,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...store.linked_obj import LinkedObject from ...types.errors import SyftException from ...types.result import as_result @@ -37,11 +37,9 @@ @serializable(canonical_name="RequestService", version=1) class RequestService(AbstractService): - store: DocumentStore stash: RequestStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = RequestStash(store=store) @service_method(path="request.submit", name="submit", roles=GUEST_ROLE_LEVEL) diff --git a/packages/syft/src/syft/service/service.py b/packages/syft/src/syft/service/service.py index 784eca2e340..49749711853 100644 --- a/packages/syft/src/syft/service/service.py +++ b/packages/syft/src/syft/service/service.py @@ -37,6 +37,7 @@ from ..serde.signature import signature_remove_context from ..serde.signature import signature_remove_self from ..server.credentials import SyftVerifyKey +from ..store.db.stash import ObjectStash from ..store.document_store import DocumentStore from ..store.linked_obj import LinkedObject from ..types.errors import SyftException @@ -71,6 +72,7 @@ class AbstractService: server: AbstractServer server_uid: UID store_type: type = DocumentStore + stash: ObjectStash @as_result(SyftException) def resolve_link( diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index d71cda9c5f7..9d8f54605bd 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -557,9 +557,7 @@ def delete(self, context: AuthedServiceContext, uid: UID) -> UID: ) # TODO: Remove notifications for the deleted user - self.stash.delete_by_uid( - credentials=context.credentials, uid=uid, has_permission=True - ).unwrap() + self.stash.delete_by_uid(credentials=context.credentials, uid=uid).unwrap() return uid diff --git a/packages/syft/src/syft/service/worker/worker_image_stash.py b/packages/syft/src/syft/service/worker/worker_image_stash.py index 211f6072831..aa7131252bd 100644 --- a/packages/syft/src/syft/service/worker/worker_image_stash.py +++ b/packages/syft/src/syft/service/worker/worker_image_stash.py @@ -2,12 +2,16 @@ # third party +# third party +import sqlalchemy as sa + # relative from ...custom_worker.config import DockerWorkerConfig from ...custom_worker.config import WorkerConfig from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...store.db.stash import ObjectStash +from ...store.db.stash import with_session from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException @@ -20,6 +24,7 @@ @serializable(canonical_name="SyftWorkerImageSQLStash", version=1) class SyftWorkerImageStash(ObjectStash[SyftWorkerImage]): @as_result(SyftException, StashException, NotFoundException) + @with_session def set( self, credentials: SyftVerifyKey, @@ -27,6 +32,7 @@ def set( add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, ignore_duplicates: bool = False, + session: sa.Session = None, ) -> SyftWorkerImage: # By default syft images have all read permission add_permissions = [] if add_permissions is None else add_permissions diff --git a/packages/syft/src/syft/service/worker/worker_pool_stash.py b/packages/syft/src/syft/service/worker/worker_pool_stash.py index f3490ce10a7..e9dc4d7dfde 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_stash.py +++ b/packages/syft/src/syft/service/worker/worker_pool_stash.py @@ -2,10 +2,14 @@ # third party +# third party +import sqlalchemy as sa + # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...store.db.stash import ObjectStash +from ...store.db.stash import with_session from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result @@ -29,6 +33,7 @@ def get_by_name(self, credentials: SyftVerifyKey, pool_name: str) -> WorkerPool: ) @as_result(StashException) + @with_session def set( self, credentials: SyftVerifyKey, @@ -36,6 +41,7 @@ def set( add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, ignore_duplicates: bool = False, + session: sa.Session = None, ) -> WorkerPool: # By default all worker pools have all read permission add_permissions = [] if add_permissions is None else add_permissions @@ -58,7 +64,7 @@ def set( def get_by_image_uid( self, credentials: SyftVerifyKey, image_uid: UID ) -> list[WorkerPool]: - return self.get_by_fields( + return self.get_all( credentials=credentials, - fields={"image_id": image_uid.no_dash}, + filters={"image_id": image_uid.no_dash}, ).unwrap() diff --git a/packages/syft/src/syft/service/worker/worker_stash.py b/packages/syft/src/syft/service/worker/worker_stash.py index 01202eb9ba2..e911258707b 100644 --- a/packages/syft/src/syft/service/worker/worker_stash.py +++ b/packages/syft/src/syft/service/worker/worker_stash.py @@ -2,10 +2,14 @@ # third party +# third party +import sqlalchemy as sa + # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...store.db.stash import ObjectStash +from ...store.db.stash import with_session from ...store.document_store import PartitionKey from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException @@ -22,6 +26,7 @@ @serializable(canonical_name="WorkerSQLStash", version=1) class WorkerStash(ObjectStash[SyftWorker]): @as_result(StashException) + @with_session def set( self, credentials: SyftVerifyKey, @@ -29,6 +34,7 @@ def set( add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, ignore_duplicates: bool = False, + session: sa.Session = None, ) -> SyftWorker: # By default all worker pools have all read permission add_permissions = [] if add_permissions is None else add_permissions diff --git a/packages/syft/src/syft/store/linked_obj.py b/packages/syft/src/syft/store/linked_obj.py index 5d9c29c9d9d..d3e40372842 100644 --- a/packages/syft/src/syft/store/linked_obj.py +++ b/packages/syft/src/syft/store/linked_obj.py @@ -76,7 +76,7 @@ def update_with_context( raise SyftException(public_message=f"context {context}'s server is None") service = context.server.get_service(self.service_type) if hasattr(service, "stash"): - result = service.stash.update(credentials, obj) + result = service.stash.update(credentials, obj).unwrap() else: raise SyftException( public_message=f"service {service} does not have a stash" diff --git a/packages/syft/tests/syft/users/user_service_test.py b/packages/syft/tests/syft/users/user_service_test.py index 95b05bb6be7..0452bb21ede 100644 --- a/packages/syft/tests/syft/users/user_service_test.py +++ b/packages/syft/tests/syft/users/user_service_test.py @@ -9,6 +9,7 @@ from pytest import MonkeyPatch # syft absolute +import syft as sy from syft import orchestra from syft.client.client import SyftClient from syft.server.credentials import SyftVerifyKey @@ -793,3 +794,24 @@ def test_reset_password(): ) guest_client.reset_password(token=temp_token, new_password="Password123") server.login(email="new_syft_user@openmined.org", password="Password123") + + +def test_root_cannot_be_deleted(): + server = orchestra.launch(name="datasite-test", reset=True) + datasite_client = server.login(email="info@openmined.org", password="changethis") + + new_admin_email = "admin@openmined.org" + new_admin_pass = "changethis2" + datasite_client.register( + name="second admin", + email=new_admin_email, + password=new_admin_pass, + password_verify=new_admin_pass, + ) + # update role + new_user_id = datasite_client.users.search(email=new_admin_email)[0].id + datasite_client.users.update(uid=new_user_id, role="admin") + + new_admin_client = server.login(email=new_admin_email, password=new_admin_pass) + with sy.raises(sy.SyftException): + new_admin_client.users.delete(datasite_client.account.id) From 895e3cd57affcf8fb5e11f2e6f4f2e0be3f38a5c Mon Sep 17 00:00:00 2001 From: khoaguin Date: Wed, 11 Sep 2024 16:31:13 +0700 Subject: [PATCH 134/197] [script] `reset_k8s.sh` works with postgres --- scripts/reset_k8s.sh | 53 +++++++++++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 13 deletions(-) diff --git a/scripts/reset_k8s.sh b/scripts/reset_k8s.sh index d0d245be6f2..033cb24ed31 100755 --- a/scripts/reset_k8s.sh +++ b/scripts/reset_k8s.sh @@ -1,22 +1,49 @@ #!/bin/bash -# WARNING: this will drop the 'app' database in your mongo-0 instance in the syft namespace echo $1 -# Dropping the database on mongo-0 -if [ -z $1 ]; then - MONGO_POD_NAME="mongo-0" -else - MONGO_POD_NAME=$1 -fi +# Default pod name +DEFAULT_POD_NAME="postgres-0" + +# Use the provided pod name or the default +POSTGRES_POD_NAME=${1:-$DEFAULT_POD_NAME} + +# SQL commands to reset all tables +RESET_COMMAND=" +DO \$\$ +DECLARE + r RECORD; +BEGIN + -- Disable all triggers + SET session_replication_role = 'replica'; + + -- Truncate all tables in the current schema + FOR r IN (SELECT tablename FROM pg_tables WHERE schemaname = current_schema()) LOOP + EXECUTE 'TRUNCATE TABLE ' || quote_ident(r.tablename) || ' CASCADE'; + END LOOP; + + -- Re-enable all triggers + SET session_replication_role = 'origin'; +END \$\$; + +-- Reset all sequences +DO \$\$ +DECLARE + r RECORD; +BEGIN + FOR r IN (SELECT sequence_name FROM information_schema.sequences WHERE sequence_schema = current_schema()) LOOP + EXECUTE 'ALTER SEQUENCE ' || quote_ident(r.sequence_name) || ' RESTART WITH 1'; + END LOOP; +END \$\$; +" -DROPCMD="<&1 +echo "All tables in $POSTGRES_POD_NAME have been reset." # Resetting the backend pod BACKEND_POD=$(kubectl get pods -n syft -o jsonpath="{.items[*].metadata.name}" | tr ' ' '\n' | grep -E ".*backend.*") From 3272d9bfbd9f25029d7e4f5ba06186e16cf230a5 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 11 Sep 2024 12:02:54 +0200 Subject: [PATCH 135/197] fix session --- packages/syft/src/syft/service/worker/worker_image_stash.py | 5 +++-- packages/syft/src/syft/service/worker/worker_pool_stash.py | 5 +++-- packages/syft/src/syft/service/worker/worker_stash.py | 5 +++-- packages/syft/src/syft/store/db/sqlite_db.py | 1 + 4 files changed, 10 insertions(+), 6 deletions(-) diff --git a/packages/syft/src/syft/service/worker/worker_image_stash.py b/packages/syft/src/syft/service/worker/worker_image_stash.py index aa7131252bd..dc220905839 100644 --- a/packages/syft/src/syft/service/worker/worker_image_stash.py +++ b/packages/syft/src/syft/service/worker/worker_image_stash.py @@ -3,7 +3,7 @@ # third party # third party -import sqlalchemy as sa +from sqlalchemy.orm import Session # relative from ...custom_worker.config import DockerWorkerConfig @@ -32,7 +32,7 @@ def set( add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, ignore_duplicates: bool = False, - session: sa.Session = None, + session: Session = None, ) -> SyftWorkerImage: # By default syft images have all read permission add_permissions = [] if add_permissions is None else add_permissions @@ -57,6 +57,7 @@ def set( add_permissions=add_permissions, add_storage_permission=add_storage_permission, ignore_duplicates=ignore_duplicates, + session=session, ) .unwrap() ) diff --git a/packages/syft/src/syft/service/worker/worker_pool_stash.py b/packages/syft/src/syft/service/worker/worker_pool_stash.py index e9dc4d7dfde..699ee19a47a 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_stash.py +++ b/packages/syft/src/syft/service/worker/worker_pool_stash.py @@ -3,7 +3,7 @@ # third party # third party -import sqlalchemy as sa +from sqlalchemy.orm import Session # relative from ...serde.serializable import serializable @@ -41,7 +41,7 @@ def set( add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, ignore_duplicates: bool = False, - session: sa.Session = None, + session: Session = None, ) -> WorkerPool: # By default all worker pools have all read permission add_permissions = [] if add_permissions is None else add_permissions @@ -56,6 +56,7 @@ def set( add_permissions=add_permissions, add_storage_permission=add_storage_permission, ignore_duplicates=ignore_duplicates, + session=session, ) .unwrap() ) diff --git a/packages/syft/src/syft/service/worker/worker_stash.py b/packages/syft/src/syft/service/worker/worker_stash.py index e911258707b..48a192ecd19 100644 --- a/packages/syft/src/syft/service/worker/worker_stash.py +++ b/packages/syft/src/syft/service/worker/worker_stash.py @@ -3,7 +3,7 @@ # third party # third party -import sqlalchemy as sa +from sqlalchemy.orm import Session # relative from ...serde.serializable import serializable @@ -34,7 +34,7 @@ def set( add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, ignore_duplicates: bool = False, - session: sa.Session = None, + session: Session = None, ) -> SyftWorker: # By default all worker pools have all read permission add_permissions = [] if add_permissions is None else add_permissions @@ -49,6 +49,7 @@ def set( add_permissions=add_permissions, ignore_duplicates=ignore_duplicates, add_storage_permission=add_storage_permission, + session=session, ) .unwrap() ) diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index 674fc58c224..28043c89001 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -89,6 +89,7 @@ def update_settings(self) -> None: if self.engine.dialect.name == "sqlite": connection.execute(sa.text("PRAGMA journal_mode = WAL")) connection.execute(sa.text("PRAGMA busy_timeout = 5000")) + # TODO check connection.execute(sa.text("PRAGMA temp_store = 2")) connection.execute(sa.text("PRAGMA synchronous = 1")) From aea98771d0f29a926b134c2031a86f78e74390d9 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Wed, 11 Sep 2024 13:03:13 +0200 Subject: [PATCH 136/197] replace DocumentStore with DBManager --- packages/syft/src/syft/server/server.py | 4 +-- .../src/syft/service/action/action_service.py | 34 +++++++++---------- .../src/syft/service/action/action_store.py | 4 +-- .../syft/src/syft/service/api/api_service.py | 6 ++-- .../attestation/attestation_service.py | 6 ++-- .../src/syft/service/blob_storage/service.py | 6 ++-- .../src/syft/service/code/status_service.py | 6 ++-- .../syft/service/code/user_code_service.py | 6 ++-- .../code_history/code_history_service.py | 6 ++-- .../data_subject_member_service.py | 6 ++-- .../data_subject/data_subject_service.py | 6 ++-- .../syft/service/dataset/dataset_service.py | 6 ++-- .../syft/service/enclave/enclave_service.py | 8 ++--- .../syft/src/syft/service/job/job_service.py | 6 ++-- .../syft/src/syft/service/log/log_service.py | 6 ++-- .../syft/service/metadata/metadata_service.py | 6 ++-- .../service/migration/migration_service.py | 6 ++-- .../syft/service/network/network_service.py | 6 ++-- .../notification/notification_service.py | 6 ++-- .../syft/service/notifier/notifier_service.py | 6 ++-- .../src/syft/service/output/output_service.py | 8 ++--- .../syft/src/syft/service/policy/policy.py | 2 +- .../src/syft/service/policy/policy_service.py | 6 ++-- .../syft/service/project/project_service.py | 6 ++-- packages/syft/src/syft/service/queue/queue.py | 2 +- .../syft/src/syft/service/request/request.py | 2 +- .../syft/service/settings/settings_service.py | 6 ++-- .../src/syft/service/sync/sync_service.py | 10 +++--- .../src/syft/service/user/user_service.py | 8 ++--- .../service/worker/image_registry_service.py | 6 ++-- .../service/worker/worker_image_service.py | 6 ++-- .../syft/service/worker/worker_image_stash.py | 4 +-- .../service/worker/worker_pool_service.py | 6 ++-- .../syft/service/worker/worker_pool_stash.py | 4 +-- .../src/syft/service/worker/worker_service.py | 6 ++-- .../src/syft/service/worker/worker_stash.py | 4 +-- .../syft/src/syft/store/document_store.py | 1 - packages/syft/tests/syft/action_test.py | 6 ++-- .../migrations/protocol_communication_test.py | 6 ++-- .../syft/service/action/action_object_test.py | 2 +- .../service/action/action_service_test.py | 2 +- 41 files changed, 98 insertions(+), 151 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index bdfc73a6ad4..0802f1cddc6 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -420,7 +420,7 @@ def __init__( self.services: ServiceRegistry = ServiceRegistry.for_server(self) self.db.init_tables() # self.services.user.stash.init_root_user() - self.action_store = self.services.action.store + self.action_store = self.services.action.stash create_admin_new( name=root_username, @@ -1436,7 +1436,7 @@ def add_queueitem_to_queue( result_obj.syft_server_location = self.id result_obj.syft_client_verify_key = credentials - if not self.services.action.store.exists( + if not self.services.action.stash.exists( credentials=credentials, uid=action.result_id ): self.services.action.set_result_to_store( diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index fa5a1ee8a8a..38b53cf977d 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -9,7 +9,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.datetime import DateTime @@ -58,10 +58,8 @@ class ActionService(AbstractService): stash: ActionObjectStash - def __init__(self, store: DocumentStore) -> None: - # TODO remove self.store, use self.stash instead - self.store = ActionObjectStash(store) - self.stash = self.store + def __init__(self, store: DBManager) -> None: + self.stash = ActionObjectStash(store) @service_method(path="action.np_array", name="np_array") def np_array(self, context: AuthedServiceContext, data: Any) -> Any: @@ -181,7 +179,7 @@ def _set( or has_result_read_permission ) - self.store.set_or_update( + self.stash.set_or_update( uid=action_object.id, credentials=context.credentials, syft_object=action_object, @@ -240,7 +238,7 @@ def resolve_links( ) -> ActionObject: """Get an object from the action store""" # If user has permission to get the object / object exists - result = self.store.get(uid=uid, credentials=context.credentials).unwrap() + result = self.stash.get(uid=uid, credentials=context.credentials).unwrap() # If it's not a leaf if result.is_link: @@ -274,7 +272,7 @@ def _get( resolve_nested: bool = True, ) -> ActionObject | TwinObject: """Get an object from the action store""" - obj = self.store.get( + obj = self.stash.get( uid=uid, credentials=context.credentials, has_permission=has_permission ).unwrap() @@ -317,7 +315,7 @@ def get_pointer( self, context: AuthedServiceContext, uid: UID ) -> ActionObjectPointer: """Get a pointer from the action store""" - obj = self.store.get_pointer( + obj = self.stash.get_pointer( uid=uid, credentials=context.credentials, server_uid=context.server.id ).unwrap() @@ -331,7 +329,7 @@ def get_pointer( @service_method(path="action.get_mock", name="get_mock", roles=GUEST_ROLE_LEVEL) def get_mock(self, context: AuthedServiceContext, uid: UID) -> SyftObject: """Get a pointer from the action store""" - return self.store.get_mock(credentials=context.credentials, uid=uid).unwrap() + return self.stash.get_mock(credentials=context.credentials, uid=uid).unwrap() @service_method( path="action.has_storage_permission", @@ -339,12 +337,12 @@ def get_mock(self, context: AuthedServiceContext, uid: UID) -> SyftObject: roles=GUEST_ROLE_LEVEL, ) def has_storage_permission(self, context: AuthedServiceContext, uid: UID) -> bool: - return self.store.has_storage_permission( + return self.stash.has_storage_permission( StoragePermission(uid=uid, server_uid=context.server.id) ) def has_read_permission(self, context: AuthedServiceContext, uid: UID) -> bool: - return self.store.has_permissions( + return self.stash.has_permissions( [ActionObjectREAD(uid=uid, credentials=context.credentials)] ) @@ -563,7 +561,7 @@ def blob_permission( if len(output_readers) > 0: store_permissions = [store_permission(x) for x in output_readers] - self.store.add_permissions(store_permissions) + self.stash.add_permissions(store_permissions) if result_blob_id is not None: blob_permissions = [blob_permission(x) for x in output_readers] @@ -885,12 +883,12 @@ def has_read_permission_for_action_result( ActionObjectREAD(uid=_id, credentials=context.credentials) for _id in action_obj_ids ] - return self.store.has_permissions(permissions) + return self.stash.has_permissions(permissions) @service_method(path="action.exists", name="exists", roles=GUEST_ROLE_LEVEL) def exists(self, context: AuthedServiceContext, obj_id: UID) -> bool: """Checks if the given object id exists in the Action Store""" - return self.store.exists(context.credentials, obj_id) + return self.stash.exists(context.credentials, obj_id) @service_method( path="action.delete", @@ -901,7 +899,7 @@ def exists(self, context: AuthedServiceContext, obj_id: UID) -> bool: def delete( self, context: AuthedServiceContext, uid: UID, soft_delete: bool = False ) -> SyftSuccess: - obj = self.store.get(uid=uid, credentials=context.credentials).unwrap() + obj = self.stash.get(uid=uid, credentials=context.credentials).unwrap() return_msg = [] @@ -957,7 +955,7 @@ def _delete_from_action_store( soft_delete: bool = False, ) -> SyftSuccess: if soft_delete: - obj = self.store.get(uid=uid, credentials=context.credentials).unwrap() + obj = self.stash.get(uid=uid, credentials=context.credentials).unwrap() if isinstance(obj, TwinObject): self._soft_delete_action_obj( @@ -969,7 +967,7 @@ def _delete_from_action_store( if isinstance(obj, ActionObject): self._soft_delete_action_obj(context=context, action_obj=obj).unwrap() else: - self.store.delete_by_uid(credentials=context.credentials, uid=uid).unwrap() + self.stash.delete_by_uid(credentials=context.credentials, uid=uid).unwrap() return SyftSuccess(message=f"Action object with uid '{uid}' deleted.") diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index c90492ed4b8..bd5fcf45e10 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -67,14 +67,14 @@ def get_pointer( if has_permissions: if isinstance(obj, TwinObject): return obj.private.syft_point_to(server_uid) - return obj.syft_point_to(server_uid) + return obj.syft_point_to(server_uid) # type: ignore # if its a twin with a mock anyone can have this if isinstance(obj, TwinObject): return obj.mock.syft_point_to(server_uid) # finally worst case you get ActionDataEmpty so you can still trace - return obj.as_empty().syft_point_to(server_uid) + return obj.as_empty().syft_point_to(server_uid) # type: ignore @as_result(SyftException, StashException) def set_or_update( # type: ignore diff --git a/packages/syft/src/syft/service/api/api_service.py b/packages/syft/src/syft/service/api/api_service.py index a8c443a6271..cce0d53698a 100644 --- a/packages/syft/src/syft/service/api/api_service.py +++ b/packages/syft/src/syft/service/api/api_service.py @@ -10,7 +10,7 @@ from ...serde.serializable import serializable from ...service.action.action_endpoint import CustomEndpointActionObject from ...service.action.action_object import ActionObject -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException @@ -37,11 +37,9 @@ @serializable(canonical_name="APIService", version=1) class APIService(AbstractService): - store: DocumentStore stash: TwinAPIEndpointStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = TwinAPIEndpointStash(store=store) @service_method( diff --git a/packages/syft/src/syft/service/attestation/attestation_service.py b/packages/syft/src/syft/service/attestation/attestation_service.py index debb838c4b1..1794ef6b03a 100644 --- a/packages/syft/src/syft/service/attestation/attestation_service.py +++ b/packages/syft/src/syft/service/attestation/attestation_service.py @@ -6,7 +6,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...types.errors import SyftException from ...types.result import as_result from ...util.util import str_to_bool @@ -24,8 +24,8 @@ class AttestationService(AbstractService): """This service is responsible for getting all sorts of attestations for any client.""" - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: + pass @as_result(SyftException) def perform_request( diff --git a/packages/syft/src/syft/service/blob_storage/service.py b/packages/syft/src/syft/service/blob_storage/service.py index ec335433edc..f8069566268 100644 --- a/packages/syft/src/syft/service/blob_storage/service.py +++ b/packages/syft/src/syft/service/blob_storage/service.py @@ -11,7 +11,7 @@ from ...store.blob_storage import BlobRetrieval from ...store.blob_storage.on_disk import OnDiskBlobDeposit from ...store.blob_storage.seaweedfs import SeaweedFSBlobDeposit -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...types.blob_storage import AzureSecureFilePathLocation from ...types.blob_storage import BlobFileType from ...types.blob_storage import BlobStorageEntry @@ -37,12 +37,10 @@ @serializable(canonical_name="BlobStorageService", version=1) class BlobStorageService(AbstractService): - store: DocumentStore stash: BlobStorageStash remote_profile_stash: RemoteProfileStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = BlobStorageStash(store=store) self.remote_profile_stash = RemoteProfileStash(store=store) diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index 1b977b69e3a..f6709ebab3a 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -4,8 +4,8 @@ # relative from ...serde.serializable import serializable +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store import PartitionSettings from ...types.uid import UID from ...util.trace_decorator import instrument @@ -30,11 +30,9 @@ class StatusStash(ObjectStash[UserCodeStatusCollection]): @serializable(canonical_name="UserCodeStatusService", version=1) class UserCodeStatusService(AbstractService): - store: DocumentStore stash: StatusStash - def __init__(self, store: DocumentStore): - self.store = store + def __init__(self, store: DBManager): self.stash = StatusStash(store=store) @service_method(path="code_status.create", name="create", roles=ADMIN_ROLE_LEVEL) diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 6230ef179df..29fa9093f14 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -5,7 +5,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject @@ -60,11 +60,9 @@ class IsExecutionAllowedEnum(str, Enum): @serializable(canonical_name="UserCodeService", version=1) class UserCodeService(AbstractService): - store: DocumentStore stash: UserCodeStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = UserCodeStash(store=store) @service_method( diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index 0943978759f..fa7da4976fa 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -3,7 +3,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...types.uid import UID from ..code.user_code import SubmitUserCode @@ -24,11 +24,9 @@ @serializable(canonical_name="CodeHistoryService", version=1) class CodeHistoryService(AbstractService): - store: DocumentStore stash: CodeHistoryStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = CodeHistoryStash(store=store) @service_method( diff --git a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py index a7aec8e6e44..ec0241b54e0 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py @@ -3,8 +3,8 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store_errors import StashException from ...types.result import as_result from ..context import AuthedServiceContext @@ -38,11 +38,9 @@ def get_all_for_child( @serializable(canonical_name="DataSubjectMemberService", version=1) class DataSubjectMemberService(AbstractService): - store: DocumentStore stash: DataSubjectMemberStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = DataSubjectMemberStash(store=store) def add( diff --git a/packages/syft/src/syft/service/data_subject/data_subject_service.py b/packages/syft/src/syft/service/data_subject/data_subject_service.py index 331c22a9d49..8c67cd84849 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_service.py @@ -5,8 +5,8 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store_errors import StashException from ...types.result import as_result from ..context import AuthedServiceContext @@ -31,11 +31,9 @@ def get_by_name(self, credentials: SyftVerifyKey, name: str) -> DataSubject: @serializable(canonical_name="DataSubjectService", version=1) class DataSubjectService(AbstractService): - store: DocumentStore stash: DataSubjectStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = DataSubjectStash(store=store) @service_method(path="data_subject.add", name="add_data_subject") diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py index de6fdf9cb90..93ab6f8ad5b 100644 --- a/packages/syft/src/syft/service/dataset/dataset_service.py +++ b/packages/syft/src/syft/service/dataset/dataset_service.py @@ -5,7 +5,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...types.dicttuple import DictTuple from ...types.uid import UID from ..action.action_permissions import ActionObjectPermission @@ -68,11 +68,9 @@ def _paginate_dataset_collection( @serializable(canonical_name="DatasetService", version=1) class DatasetService(AbstractService): - store: DocumentStore stash: DatasetStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = DatasetStash(store=store) @service_method( diff --git a/packages/syft/src/syft/service/enclave/enclave_service.py b/packages/syft/src/syft/service/enclave/enclave_service.py index 2f88c60e123..76d8c455e52 100644 --- a/packages/syft/src/syft/service/enclave/enclave_service.py +++ b/packages/syft/src/syft/service/enclave/enclave_service.py @@ -2,13 +2,11 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ..service import AbstractService @serializable(canonical_name="EnclaveService", version=1) class EnclaveService(AbstractService): - store: DocumentStore - - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: + pass diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index 1ec9babac67..3b31f168f17 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -6,7 +6,7 @@ # relative from ...serde.serializable import serializable from ...server.worker_settings import WorkerSettings -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...types.errors import SyftException from ...types.uid import UID from ..action.action_object import ActionObject @@ -40,11 +40,9 @@ def wait_until(predicate: Callable[[], bool], timeout: int = 10) -> SyftSuccess: @serializable(canonical_name="JobService", version=1) class JobService(AbstractService): - store: DocumentStore stash: JobStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = JobStash(store=store) @service_method( diff --git a/packages/syft/src/syft/service/log/log_service.py b/packages/syft/src/syft/service/log/log_service.py index d3529b0906f..9d25679b895 100644 --- a/packages/syft/src/syft/service/log/log_service.py +++ b/packages/syft/src/syft/service/log/log_service.py @@ -1,6 +1,6 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...types.uid import UID from ..action.action_permissions import StoragePermission from ..context import AuthedServiceContext @@ -16,11 +16,9 @@ @serializable(canonical_name="LogService", version=1) class LogService(AbstractService): - store: DocumentStore stash: LogStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = LogStash(store=store) @service_method(path="log.add", name="add", roles=DATA_SCIENTIST_ROLE_LEVEL) diff --git a/packages/syft/src/syft/service/metadata/metadata_service.py b/packages/syft/src/syft/service/metadata/metadata_service.py index 70453d9b084..603091ca4ee 100644 --- a/packages/syft/src/syft/service/metadata/metadata_service.py +++ b/packages/syft/src/syft/service/metadata/metadata_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ..context import AuthedServiceContext from ..service import AbstractService from ..service import service_method @@ -12,8 +12,8 @@ @serializable(canonical_name="MetadataService", version=1) class MetadataService(AbstractService): - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: + pass @service_method( path="metadata.get_metadata", name="get_metadata", roles=GUEST_ROLE_LEVEL diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index f150f5e72ed..581c3561c92 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -6,8 +6,8 @@ # relative from ...serde.serializable import serializable +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store_errors import NotFoundException from ...types.blob_storage import BlobStorageEntry from ...types.errors import SyftException @@ -34,11 +34,9 @@ @serializable(canonical_name="MigrationService", version=1) class MigrationService(AbstractService): - store: DocumentStore stash: SyftMigrationStateStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = SyftMigrationStateStash(store=store) @service_method(path="migration", name="get_version") diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 9301e702890..d14fadb25ed 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -15,8 +15,8 @@ from ...server.credentials import SyftVerifyKey from ...server.worker_settings import WorkerSettings from ...service.settings.settings import ServerSettings +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException @@ -137,11 +137,9 @@ def get_by_server_type( @serializable(canonical_name="NetworkService", version=1) class NetworkService(AbstractService): - store: DocumentStore stash: NetworkStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = NetworkStash(store=store) if reverse_tunnel_enabled(): self.rtunnel_service = ReverseTunnelService() diff --git a/packages/syft/src/syft/service/notification/notification_service.py b/packages/syft/src/syft/service/notification/notification_service.py index 15b2b9725d8..9da8ae4a935 100644 --- a/packages/syft/src/syft/service/notification/notification_service.py +++ b/packages/syft/src/syft/service/notification/notification_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import StashException from ...types.errors import SyftException from ...types.result import as_result @@ -28,11 +28,9 @@ @serializable(canonical_name="NotificationService", version=1) class NotificationService(AbstractService): - store: DocumentStore stash: NotificationStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = NotificationStash(store=store) @service_method(path="notifications.send", name="send") diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index 20b8f26ed92..34a964b4a03 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -8,7 +8,7 @@ # relative from ...abstract_server import AbstractServer from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException @@ -34,11 +34,9 @@ class RateLimitException(SyftException): @serializable(canonical_name="NotifierService", version=1) class NotifierService(AbstractService): - store: DocumentStore stash: NotifierStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = NotifierStash(store=store) @as_result(StashException) diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index e0111b22429..7f59cf4e1e8 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -7,8 +7,8 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store import PartitionKey from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject @@ -213,11 +213,9 @@ def get_by_output_policy_id( @serializable(canonical_name="OutputService", version=1) class OutputService(AbstractService): - store: DocumentStore stash: OutputStash - def __init__(self, store: DocumentStore): - self.store = store + def __init__(self, store: DBManager): self.stash = OutputStash(store=store) @service_method( @@ -286,7 +284,7 @@ def has_output_read_permissions( ActionObjectREAD(uid=_id.id, credentials=user_verify_key) for _id in result_ids ] - if context.server.services.action.store.has_permissions(permissions): + if context.server.services.action.stash.has_permissions(permissions): return True return False diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index 3b1f33c0a08..db57609af93 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -335,7 +335,7 @@ class UserOwned(PolicyRule): def is_owned( self, context: AuthedServiceContext, action_object: ActionObject ) -> bool: - action_store = context.server.services.action.store + action_store = context.server.services.action.stash return action_store.has_permission( ActionObjectPermission( action_object.id, ActionPermission.OWNER, context.credentials diff --git a/packages/syft/src/syft/service/policy/policy_service.py b/packages/syft/src/syft/service/policy/policy_service.py index fe32c226dbf..da8759a3540 100644 --- a/packages/syft/src/syft/service/policy/policy_service.py +++ b/packages/syft/src/syft/service/policy/policy_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...types.uid import UID from ..context import AuthedServiceContext from ..response import SyftSuccess @@ -16,11 +16,9 @@ @serializable(canonical_name="PolicyService", version=1) class PolicyService(AbstractService): - store: DocumentStore stash: UserPolicyStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = UserPolicyStash(store=store) @service_method(path="policy.get_all", name="get_all") diff --git a/packages/syft/src/syft/service/project/project_service.py b/packages/syft/src/syft/service/project/project_service.py index 3a2543e38a1..b04ec27bf8b 100644 --- a/packages/syft/src/syft/service/project/project_service.py +++ b/packages/syft/src/syft/service/project/project_service.py @@ -4,7 +4,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject @@ -32,11 +32,9 @@ @serializable(canonical_name="ProjectService", version=1) class ProjectService(AbstractService): - store: DocumentStore stash: ProjectStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = ProjectStash(store=store) @as_result(SyftException) diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index 25c6b6094eb..76fdb1d4c42 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -233,7 +233,7 @@ def handle_message_multiprocessing( public_message=f"Job {queue_item.job_id} not found!" ) - job_item.server_uid = worker.id + job_item.server_uid = worker.id # type: ignore[assignment] job_item.result = result job_item.resolved = True job_item.status = job_status diff --git a/packages/syft/src/syft/service/request/request.py b/packages/syft/src/syft/service/request/request.py index 39e3a632a4a..674c66a019a 100644 --- a/packages/syft/src/syft/service/request/request.py +++ b/packages/syft/src/syft/service/request/request.py @@ -112,7 +112,7 @@ class ActionStoreChange(Change): @as_result(SyftException) def _run(self, context: ChangeContext, apply: bool) -> SyftSuccess: - action_store = context.server.services.action.store + action_store = context.server.services.action.stash # can we ever have a lineage ID in the store? obj_uid = self.linked_obj.object_uid diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py index b067644342f..a0adbbf8575 100644 --- a/packages/syft/src/syft/service/settings/settings_service.py +++ b/packages/syft/src/syft/service/settings/settings_service.py @@ -8,7 +8,7 @@ # relative from ...abstract_server import ServerSideType from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.sqlite_document_store import SQLiteStoreConfig @@ -48,11 +48,9 @@ @serializable(canonical_name="SettingsService", version=1) class SettingsService(AbstractService): - store: DocumentStore stash: SettingsStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = SettingsStash(store=store) @service_method(path="settings.get", name="get") diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index c9e3764e46d..740f55f02af 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -6,8 +6,8 @@ # relative from ...client.api import ServerIdentity from ...serde.serializable import serializable +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash -from ...store.document_store import DocumentStore from ...store.document_store import NewBaseStash from ...store.document_store_errors import NotFoundException from ...store.linked_obj import LinkedObject @@ -40,7 +40,7 @@ def get_store(context: AuthedServiceContext, item: SyncableSyftObject) -> ObjectStash: if isinstance(item, ActionObject): service = context.server.services.action # type: ignore - return service.store # type: ignore + return service.stash # type: ignore service = context.server.get_service(TYPE_TO_SERVICE[type(item)]) # type: ignore return service.stash @@ -48,11 +48,9 @@ def get_store(context: AuthedServiceContext, item: SyncableSyftObject) -> Object @instrument @serializable(canonical_name="SyncService", version=1) class SyncService(AbstractService): - store: DocumentStore stash: SyncStash - def __init__(self, store: DocumentStore): - self.store = store + def __init__(self, store: DBManager): self.stash = SyncStash(store=store) def add_actionobject_read_permissions( @@ -61,7 +59,7 @@ def add_actionobject_read_permissions( action_object: ActionObject, new_permissions: list[ActionObjectPermission], ) -> None: - store_to = context.server.services.action.store # type: ignore + store_to = context.server.services.action.stash # type: ignore for permission in new_permissions: if permission.permission == ActionPermission.READ: store_to.add_permission(permission) diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index 9d8f54605bd..34d5071fa7a 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -11,7 +11,7 @@ from ...serde.serializable import serializable from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject @@ -81,11 +81,9 @@ def _paginate( @serializable(canonical_name="UserService", version=1) class UserService(AbstractService): - store: DocumentStore stash: UserStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = UserStash(store=store) @as_result(StashException) @@ -522,7 +520,7 @@ def update( ).unwrap() if user.role == ServiceRole.ADMIN: - settings_stash = SettingsStash(store=self.store) + settings_stash = SettingsStash(store=self.stash.db) settings = settings_stash.get_all( context.credentials, limit=1, sort_order="desc" ).unwrap() diff --git a/packages/syft/src/syft/service/worker/image_registry_service.py b/packages/syft/src/syft/service/worker/image_registry_service.py index 7310decd85d..e9e80f21892 100644 --- a/packages/syft/src/syft/service/worker/image_registry_service.py +++ b/packages/syft/src/syft/service/worker/image_registry_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...types.errors import SyftException from ...types.uid import UID from ..context import AuthedServiceContext @@ -20,11 +20,9 @@ @serializable(canonical_name="SyftImageRegistryService", version=1) class SyftImageRegistryService(AbstractService): - store: DocumentStore stash: SyftImageRegistryStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = SyftImageRegistryStash(store=store) @service_method( diff --git a/packages/syft/src/syft/service/worker/worker_image_service.py b/packages/syft/src/syft/service/worker/worker_image_service.py index 37622a36d71..ab9cf5a250f 100644 --- a/packages/syft/src/syft/service/worker/worker_image_service.py +++ b/packages/syft/src/syft/service/worker/worker_image_service.py @@ -10,7 +10,7 @@ from ...custom_worker.config import WorkerConfig from ...custom_worker.k8s import IN_KUBERNETES from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...types.datetime import DateTime from ...types.dicttuple import DictTuple from ...types.errors import SyftException @@ -31,11 +31,9 @@ @serializable(canonical_name="SyftWorkerImageService", version=1) class SyftWorkerImageService(AbstractService): - store: DocumentStore stash: SyftWorkerImageStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = SyftWorkerImageStash(store=store) @service_method( diff --git a/packages/syft/src/syft/service/worker/worker_image_stash.py b/packages/syft/src/syft/service/worker/worker_image_stash.py index aa7131252bd..f3bc56c8177 100644 --- a/packages/syft/src/syft/service/worker/worker_image_stash.py +++ b/packages/syft/src/syft/service/worker/worker_image_stash.py @@ -3,7 +3,7 @@ # third party # third party -import sqlalchemy as sa +from sqlalchemy.orm import Session # relative from ...custom_worker.config import DockerWorkerConfig @@ -32,7 +32,7 @@ def set( add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, ignore_duplicates: bool = False, - session: sa.Session = None, + session: Session = None, ) -> SyftWorkerImage: # By default syft images have all read permission add_permissions = [] if add_permissions is None else add_permissions diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index 4ceced2bf26..ae46958997d 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -12,7 +12,7 @@ from ...custom_worker.k8s import IN_KUBERNETES from ...custom_worker.runner_k8s import KubernetesRunner from ...serde.serializable import serializable -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject @@ -52,11 +52,9 @@ @serializable(canonical_name="SyftWorkerPoolService", version=1) class SyftWorkerPoolService(AbstractService): - store: DocumentStore stash: SyftWorkerPoolStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = SyftWorkerPoolStash(store=store) self.image_stash = SyftWorkerImageStash(store=store) diff --git a/packages/syft/src/syft/service/worker/worker_pool_stash.py b/packages/syft/src/syft/service/worker/worker_pool_stash.py index e9dc4d7dfde..d78afde7ee5 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_stash.py +++ b/packages/syft/src/syft/service/worker/worker_pool_stash.py @@ -3,7 +3,7 @@ # third party # third party -import sqlalchemy as sa +from sqlalchemy.orm import Session # relative from ...serde.serializable import serializable @@ -41,7 +41,7 @@ def set( add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, ignore_duplicates: bool = False, - session: sa.Session = None, + session: Session = None, ) -> WorkerPool: # By default all worker pools have all read permission add_permissions = [] if add_permissions is None else add_permissions diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 625a88a46b4..28f443cb5fa 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -13,7 +13,7 @@ from ...custom_worker.runner_k8s import KubernetesRunner from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import DocumentStore +from ...store.db.sqlite_db import DBManager from ...store.document_store import SyftSuccess from ...store.document_store_errors import StashException from ...types.errors import SyftException @@ -39,11 +39,9 @@ @serializable(canonical_name="WorkerService", version=1) class WorkerService(AbstractService): - store: DocumentStore stash: WorkerStash - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = WorkerStash(store=store) @service_method( diff --git a/packages/syft/src/syft/service/worker/worker_stash.py b/packages/syft/src/syft/service/worker/worker_stash.py index e911258707b..038a4bfa35a 100644 --- a/packages/syft/src/syft/service/worker/worker_stash.py +++ b/packages/syft/src/syft/service/worker/worker_stash.py @@ -3,7 +3,7 @@ # third party # third party -import sqlalchemy as sa +from sqlalchemy.orm import Session # relative from ...serde.serializable import serializable @@ -34,7 +34,7 @@ def set( add_permissions: list[ActionObjectPermission] | None = None, add_storage_permission: bool = True, ignore_duplicates: bool = False, - session: sa.Session = None, + session: Session = None, ) -> SyftWorker: # By default all worker pools have all read permission add_permissions = [] if add_permissions is None else add_permissions diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index 9c9dbcb55d0..3ef1f6247b7 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -680,7 +680,6 @@ class NewBaseStash: partition: StorePartition def __init__(self, store: DocumentStore) -> None: - self.store = store self.partition = store.partition(type(self).settings) @as_result(StashException) diff --git a/packages/syft/tests/syft/action_test.py b/packages/syft/tests/syft/action_test.py index a2e172663ca..bb9ffb0bad2 100644 --- a/packages/syft/tests/syft/action_test.py +++ b/packages/syft/tests/syft/action_test.py @@ -23,7 +23,7 @@ def test_actionobject_method(worker): root_datasite_client = worker.root_client assert root_datasite_client.settings.enable_eager_execution(enable=True) - action_store = worker.services.action.store + action_store = worker.services.action.stash obj = ActionObject.from_obj("abc") pointer = obj.send(root_datasite_client) assert len(action_store._data) == 1 @@ -75,7 +75,7 @@ def test_lib_function_action(worker): assert isinstance(res, ActionObject) assert all(res == np.array([0, 0, 0])) - assert len(worker.services.action.store._data) > 0 + assert len(worker.services.action.stash._data) > 0 def test_call_lib_function_action2(worker): @@ -90,7 +90,7 @@ def test_lib_class_init_action(worker): assert isinstance(res, ActionObject) assert res == np.float32(4.0) - assert len(worker.services.action.store._data) > 0 + assert len(worker.services.action.stash._data) > 0 def test_call_lib_wo_permission(worker): diff --git a/packages/syft/tests/syft/migrations/protocol_communication_test.py b/packages/syft/tests/syft/migrations/protocol_communication_test.py index 059c729d921..da8ca1fd37a 100644 --- a/packages/syft/tests/syft/migrations/protocol_communication_test.py +++ b/packages/syft/tests/syft/migrations/protocol_communication_test.py @@ -20,6 +20,7 @@ from syft.service.service import ServiceConfigRegistry from syft.service.service import service_method from syft.service.user.user_roles import GUEST_ROLE_LEVEL +from syft.store.db.sqlite_db import DBManager from syft.store.document_store import DocumentStore from syft.store.document_store import NewBaseStash from syft.store.document_store import PartitionSettings @@ -85,7 +86,7 @@ class SyftMockObjectStash(NewBaseStash): object_type=syft_object, ) - def __init__(self, store: DocumentStore) -> None: + def __init__(self, store: DBManager) -> None: super().__init__(store=store) return SyftMockObjectStash @@ -103,8 +104,7 @@ class SyftMockObjectService(AbstractService): stash: stash_klass # type: ignore __module__: str = "syft.test" - def __init__(self, store: DocumentStore) -> None: - self.store = store + def __init__(self, store: DBManager) -> None: self.stash = stash_klass(store=store) @service_method( diff --git a/packages/syft/tests/syft/service/action/action_object_test.py b/packages/syft/tests/syft/service/action/action_object_test.py index 3a990bd3e5d..a57a61e1509 100644 --- a/packages/syft/tests/syft/service/action/action_object_test.py +++ b/packages/syft/tests/syft/service/action/action_object_test.py @@ -506,7 +506,7 @@ def test_actionobject_syft_get_path(testcase): def test_actionobject_syft_send_get(worker, testcase): root_datasite_client = worker.root_client root_datasite_client._fetch_api(root_datasite_client.credentials) - action_store = worker.services.action.store + action_store = worker.services.action.stash orig_obj = testcase obj = helper_make_action_obj(orig_obj) diff --git a/packages/syft/tests/syft/service/action/action_service_test.py b/packages/syft/tests/syft/service/action/action_service_test.py index a2c6d3f4b4e..e97362a6340 100644 --- a/packages/syft/tests/syft/service/action/action_service_test.py +++ b/packages/syft/tests/syft/service/action/action_service_test.py @@ -22,6 +22,6 @@ def test_action_service_sanity(worker): obj = ActionObject.from_obj("abc") pointer = obj.send(root_datasite_client) - assert len(service.store._data) == 1 + assert len(service.stash._data) == 1 res = pointer.capitalize() assert res[0] == "A" From 6d0b4b9f0bc449749e870c254e33dfe37234432a Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Wed, 11 Sep 2024 13:29:52 +0200 Subject: [PATCH 137/197] fix env variables in module scope, fix root admin already exists --- packages/syft/src/syft/server/server.py | 25 +++++++++++++++---- .../syft/src/syft/service/user/user_stash.py | 10 ++++++++ 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index bdfc73a6ad4..c4825879bdc 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -422,7 +422,7 @@ def __init__( # self.services.user.stash.init_root_user() self.action_store = self.services.action.store - create_admin_new( + create_root_admin( name=root_username, email=root_email, password=root_password, # nosec @@ -1700,17 +1700,32 @@ def create_initial_settings(self, admin_email: str) -> ServerSettings: ).unwrap() -def create_admin_new( +def create_root_admin( name: str, email: str, password: str, server: Server, ) -> User | None: + """ + If no root admin exists: + - all exists checks on the user stash will fail, as we cannot get the role for the admin to check if it exists + - result: a new admin is always created + + If a root admin exists with a different email: + - cause: DEFAULT_USER_EMAIL env variable is set to a different email than the root admin in the db + - verify_key_exists will return True + - result: no new admin is created, as the server already has a root admin + """ user_stash = server.services.user.stash - user_exists = user_stash.email_exists(email=email).unwrap() - if user_exists: - logger.debug("Admin not created, admin already exists") + email_exists = user_stash.email_exists(email=email).unwrap() + if email_exists: + logger.debug("Admin not created, a user with this email already exists") + return None + + verify_key_exists = user_stash.verify_key_exists(server.verify_key).unwrap() + if verify_key_exists: + logger.debug("Admin not created, this server already has a root admin") return None create_user = UserCreate( diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index 1a4a97fa3a1..6b51bf73afb 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -64,6 +64,16 @@ def email_exists(self, email: str) -> bool: except NotFoundException: return False + @as_result(StashException) + def verify_key_exists(self, verify_key: SyftVerifyKey) -> bool: + try: + self.get_by_verify_key( + credentials=self.admin_verify_key(), verify_key=verify_key + ).unwrap() + return True + except NotFoundException: + return False + @as_result(StashException, NotFoundException) def get_by_role(self, credentials: SyftVerifyKey, role: ServiceRole) -> User: try: From d5a950a68a8092834712fc5cb7356ce617242435 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Wed, 11 Sep 2024 21:52:37 +0530 Subject: [PATCH 138/197] add a connection pool class add a query executor class --- .../syft/store/postgres_pool_connection.py | 116 ++++++++++++++++++ .../src/syft/store/postgres_query_executor.py | 82 +++++++++++++ 2 files changed, 198 insertions(+) create mode 100644 packages/syft/src/syft/store/postgres_pool_connection.py create mode 100644 packages/syft/src/syft/store/postgres_query_executor.py diff --git a/packages/syft/src/syft/store/postgres_pool_connection.py b/packages/syft/src/syft/store/postgres_pool_connection.py new file mode 100644 index 00000000000..41b462b262b --- /dev/null +++ b/packages/syft/src/syft/store/postgres_pool_connection.py @@ -0,0 +1,116 @@ +# stdlib +from collections.abc import Generator +from contextlib import contextmanager +import logging +import time + +# third party +from psycopg import Connection +from psycopg_pool import ConnectionPool +from psycopg_pool import PoolTimeout + +# relative +from ..types.errors import SyftException + +logger = logging.getLogger(__name__) + + +MIN_DB_POOL_SIZE = 1 +MAX_DB_POOL_SIZE = 10 +DEFAULT_POOL_CONN_TIMEOUT = 30 +CONN_RETRY_INTERVAL = 1 + + +class PostgresPoolConnection: + def __init__( + self, + client_config: dict, + min_size: int = MIN_DB_POOL_SIZE, + max_size: int = MAX_DB_POOL_SIZE, + timeout: int = DEFAULT_POOL_CONN_TIMEOUT, + retry_interval: int = CONN_RETRY_INTERVAL, + pool_kwargs: dict | None = None, + ) -> None: + connect_kwargs = self._connection_kwargs_from_config(client_config) + + # https://www.psycopg.org/psycopg3/docs/advanced/prepare.html#using-prepared-statements-with-pgbouncer + # This should default to None to allow the connection pool to manage the prepare threshold + connect_kwargs["prepare_threshold"] = None + + self.pool = ConnectionPool( + kwargs=connect_kwargs, + open=False, + check=ConnectionPool.check_connection, + min_size=min_size, + max_size=max_size, + **pool_kwargs, + ) + logger.info( + f"Connection pool created with min_size={self.min_size} and max_size={self.max_size}" + ) + logger.info(f"Connected to {self.store_config.client_config.dbname}") + logger.info(f"PostgreSQL Pool connection: {self.pool.get_stats()}") + self.timeout = timeout + self.retry_interval = retry_interval + + @contextmanager + def get_connection(self) -> Generator[Connection, None, None]: + """Provide a connection from the pool, waiting if necessary until one is available.""" + conn = None + start_time = time.time() + + try: + while True: + try: + conn = self.pool.getconn(timeout=self.retry_interval) + if conn: + yield conn # Return the connection object to be used in the context + break + except PoolTimeout as e: + elapsed_time = time.time() - start_time + if elapsed_time >= self.timeout: + message = f"Could not get a connection from database pool within {self.timeout} seconds." + raise SyftException.from_exception( + e, + public_message=message, + ) + logger.warning( + f"Connection not available, retrying... ({elapsed_time:.2f} seconds elapsed)" + ) + time.sleep(self.retry_interval) + + except Exception as e: + logger.error(f"Error getting connection from pool: {e}") + yield None + finally: + if conn: + self.pool.putconn(conn) + + def release_connection(self, conn: Connection) -> None: + """Release a connection back to the pool.""" + try: + if conn.closed or conn.broken: + self.pool.putconn(conn, close=True) + logger.info("Broken connection closed and removed from pool.") + else: + self.pool.putconn(conn) + logger.info("Connection released back to pool.") + except Exception as e: + logger.error(f"Error releasing connection: {e}") + + def _connection_kwargs_from_config(self, config: dict) -> dict: + return { + "dbname": config.get("dbname"), + "user": config.get("user"), + "password": config.get("password"), + "host": config.get("host"), + "port": config.get("port"), + } + + def close_all_connections(self) -> None: + """Close all connections in the pool and shut down the pool.""" + try: + self.pool.close() + logger.info("All connections closed and pool shut down.") + except Exception as e: + logger.error(f"Error closing connection pool: {e}") diff --git a/packages/syft/src/syft/store/postgres_query_executor.py b/packages/syft/src/syft/store/postgres_query_executor.py new file mode 100644 index 00000000000..dea9bb849b0 --- /dev/null +++ b/packages/syft/src/syft/store/postgres_query_executor.py @@ -0,0 +1,82 @@ +# stdlib +import logging +from typing import Any + +# third party +import psycopg +from psycopg import Cursor + +# relative +from .postgres_pool_connection import PostgresPoolConnection + +logger = logging.getLogger(__name__) + + +MAX_QUERY_RETRIES = 3 +QUERY_RETRY_DELAY = 5 + + +class PostgresQueryExecutor: + def __init__( + self, + connection_pool: PostgresPoolConnection, + retries: int = MAX_QUERY_RETRIES, + retry_delay: int = QUERY_RETRY_DELAY, + ) -> None: + self.connection_pool = connection_pool + self.retries = retries + self.retry_delay = retry_delay + + def execute_query(self, query: str, args: list[Any] | None = None) -> Cursor | None: + """ + Execute a query on the database using a context-managed connection. + Handles `InFailedSqlTransaction` errors by rolling back the transaction. + Returns a cursor object after execution for further handling by the caller. + + :param query: SQL query to execute. + :param params: Query parameters (optional). + :return: Cursor object or None if an error occurs. + """ + attempt = 0 + while attempt < self.retries: + try: + # Using the context manager for the connection + with self.connection_pool.get_connection() as conn: + if conn is None: + return None + + cur = conn.cursor() + + # Check if connection is in failed state (i.e., in a failed transaction) + if conn.status == psycopg.extensions.STATUS_IN_FAILED_TRANSACTION: + logger.warning( + "Transaction is in a failed state. Rolling back." + ) + conn.rollback() + + cur.execute(query, args) + + conn.commit() + + return cur # Return the cursor object + + except psycopg.errors.InFailedSqlTransaction as e: + logger.error(f"Transaction failed and is in an invalid state: {e}") + if conn and not conn.closed: + conn.rollback() # Roll back the transaction + attempt += 1 # Retry the query after rollback + + except (psycopg.OperationalError, psycopg.errors.AdminShutdown) as e: + logger.error( + f"Server error or termination: {e}. Retrying ({attempt + 1}/{self.retries})..." + ) + attempt += 1 + + except Exception as e: + logger.error(f"Error executing query: {e}") + if conn and not conn.closed: + conn.rollback() # Roll back on any general error + return None + + logger.error(f"Query failed after {self.retries} attempts.") + return None From fd6ef15c0355042550b68e585d76c835147907a6 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 12 Sep 2024 09:01:36 +0200 Subject: [PATCH 139/197] add protocol --- packages/syft/src/syft/protocol/protocol_version.json | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index d7f1a607404..a8d210436e3 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -135,7 +135,7 @@ "Notification": { "1": { "version": 1, - "hash": "af4cb232bff390c431e399975f048b34da7e940ace8b23b940a3b398c91c5326", + "hash": "cbd07fe9c549c5f05d0880d1417f18a5bfaa71eba65c63ff9cb761e274fffc54", "action": "add" } }, @@ -218,7 +218,7 @@ "SyftWorkerImage": { "1": { "version": 1, - "hash": "44da7badfbe573d5403d3ab78c077f17dbefc560b81fdf927b671815be047441", + "hash": "3771910757bf7534369ceb3e6cf1fc2e3a5438d8d094a67b0b548d501a0ec63f", "action": "add" } }, From 90d1ee46568eb21d9b15238bd34f114f388c906d Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 12 Sep 2024 09:18:09 +0200 Subject: [PATCH 140/197] split db manager class --- packages/syft/src/syft/store/db/sqlite_db.py | 59 +++++++++++++------- 1 file changed, 38 insertions(+), 21 deletions(-) diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index 28043c89001..b845ccc1754 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -1,6 +1,8 @@ # stdlib from pathlib import Path import tempfile +from typing import Generic +from typing import TypeVar import uuid # third party @@ -51,10 +53,13 @@ def connection_string(self) -> str: return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" -class DBManager: +ConfigT = TypeVar("ConfigT", bound=DBConfig) + + +class DBManager(Generic[ConfigT]): def __init__( self, - config: SQLiteDBConfig, + config: ConfigT, server_uid: UID, root_verify_key: SyftVerifyKey, ) -> None: @@ -74,25 +79,6 @@ def __init__( def update_settings(self) -> None: pass - def init_tables(self) -> None: - pass - - def reset(self) -> None: - pass - - -class SQLiteDBManager(DBManager): - def update_settings(self) -> None: - # TODO split SQLite / PostgresDBManager - connection = self.engine.connect() - - if self.engine.dialect.name == "sqlite": - connection.execute(sa.text("PRAGMA journal_mode = WAL")) - connection.execute(sa.text("PRAGMA busy_timeout = 5000")) - # TODO check - connection.execute(sa.text("PRAGMA temp_store = 2")) - connection.execute(sa.text("PRAGMA synchronous = 1")) - def init_tables(self) -> None: if self.config.reset: # drop all tables that we know about @@ -104,6 +90,15 @@ def reset(self) -> None: Base.metadata.drop_all(bind=self.engine) Base.metadata.create_all(self.engine) + +class SQLiteDBManager(DBManager[SQLiteDBConfig]): + def update_settings(self) -> None: + connection = self.engine.connect() + connection.execute(sa.text("PRAGMA journal_mode = WAL")) + connection.execute(sa.text("PRAGMA busy_timeout = 5000")) + connection.execute(sa.text("PRAGMA temp_store = 2")) + connection.execute(sa.text("PRAGMA synchronous = 1")) + @classmethod def random( cls, @@ -120,3 +115,25 @@ def random( server_uid=server_uid, root_verify_key=root_verify_key, ) + + +class PostgresDBManager(DBManager[PostgresDBConfig]): + @classmethod + def random( + cls, + *, + config: PostgresDBConfig | None = None, + server_uid: UID | None = None, + root_verify_key: SyftVerifyKey | None = None, + ) -> "PostgresDBManager": + if config is None: + raise ValueError("Cannot create a postgres db without a config") + + root_verify_key = root_verify_key or SyftSigningKey.generate().verify_key + server_uid = server_uid or UID() + config = config or PostgresDBConfig() + return PostgresDBManager( + config=config, + server_uid=server_uid, + root_verify_key=root_verify_key, + ) From 1aa70c50121f81f77f0247d0646cd0cb80d954d0 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 12 Sep 2024 09:24:42 +0200 Subject: [PATCH 141/197] rename db_sqlite --- packages/syft/src/syft/server/server.py | 6 ++-- .../syft/src/syft/server/worker_settings.py | 2 +- .../src/syft/service/action/action_service.py | 2 +- .../syft/src/syft/service/api/api_service.py | 2 +- .../attestation/attestation_service.py | 2 +- .../src/syft/service/blob_storage/service.py | 2 +- .../src/syft/service/code/status_service.py | 2 +- .../syft/service/code/user_code_service.py | 2 +- .../code_history/code_history_service.py | 2 +- .../data_subject_member_service.py | 2 +- .../data_subject/data_subject_service.py | 2 +- .../syft/service/dataset/dataset_service.py | 2 +- .../syft/service/enclave/enclave_service.py | 2 +- .../syft/src/syft/service/job/job_service.py | 2 +- .../syft/src/syft/service/log/log_service.py | 2 +- .../syft/service/metadata/metadata_service.py | 2 +- .../service/migration/migration_service.py | 2 +- .../syft/service/network/network_service.py | 2 +- .../notification/notification_service.py | 2 +- .../syft/service/notifier/notifier_service.py | 2 +- .../src/syft/service/output/output_service.py | 2 +- .../src/syft/service/policy/policy_service.py | 2 +- .../syft/service/project/project_service.py | 2 +- .../src/syft/service/queue/queue_service.py | 2 +- .../syft/service/request/request_service.py | 2 +- .../syft/service/settings/settings_service.py | 2 +- .../src/syft/service/sync/sync_service.py | 2 +- .../syft/src/syft/service/sync/sync_stash.py | 2 +- .../src/syft/service/user/user_service.py | 2 +- .../service/worker/image_registry_service.py | 2 +- .../service/worker/worker_image_service.py | 2 +- .../service/worker/worker_pool_service.py | 2 +- .../src/syft/service/worker/worker_service.py | 2 +- .../src/syft/store/db/{sqlite_db.py => db.py} | 0 packages/syft/src/syft/store/db/stash.py | 4 +-- packages/syft/src/syft/store/db/utils.py | 35 ------------------- .../migrations/protocol_communication_test.py | 2 +- .../tests/syft/stores/action_store_test.py | 2 +- .../syft/tests/syft/stores/base_stash_test.py | 4 +-- .../tests/syft/stores/store_fixtures_test.py | 4 +-- packages/syft/tests/syft/worker_test.py | 2 +- 41 files changed, 44 insertions(+), 79 deletions(-) rename packages/syft/src/syft/store/db/{sqlite_db.py => db.py} (100%) delete mode 100644 packages/syft/src/syft/store/db/utils.py diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 23cfffbcfba..2743fce3fc7 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -87,9 +87,9 @@ from ..store.blob_storage.on_disk import OnDiskBlobStorageClientConfig from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig from ..store.blob_storage.seaweedfs import SeaweedFSBlobDeposit -from ..store.db.sqlite_db import DBConfig -from ..store.db.sqlite_db import SQLiteDBConfig -from ..store.db.sqlite_db import SQLiteDBManager +from ..store.db.db import DBConfig +from ..store.db.db import SQLiteDBConfig +from ..store.db.db import SQLiteDBManager from ..store.db.stash import ObjectStash from ..store.document_store import StoreConfig from ..store.document_store_errors import NotFoundException diff --git a/packages/syft/src/syft/server/worker_settings.py b/packages/syft/src/syft/server/worker_settings.py index c76e5617d0d..f06359ec95d 100644 --- a/packages/syft/src/syft/server/worker_settings.py +++ b/packages/syft/src/syft/server/worker_settings.py @@ -13,7 +13,7 @@ from ..server.credentials import SyftSigningKey from ..service.queue.base_queue import QueueConfig from ..store.blob_storage import BlobStorageConfig -from ..store.db.sqlite_db import DBConfig +from ..store.db.db import DBConfig from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.syft_object import SyftObject from ..types.uid import UID diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 38b53cf977d..bd871a18164 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -9,7 +9,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.datetime import DateTime diff --git a/packages/syft/src/syft/service/api/api_service.py b/packages/syft/src/syft/service/api/api_service.py index cce0d53698a..9df52905439 100644 --- a/packages/syft/src/syft/service/api/api_service.py +++ b/packages/syft/src/syft/service/api/api_service.py @@ -10,7 +10,7 @@ from ...serde.serializable import serializable from ...service.action.action_endpoint import CustomEndpointActionObject from ...service.action.action_object import ActionObject -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException diff --git a/packages/syft/src/syft/service/attestation/attestation_service.py b/packages/syft/src/syft/service/attestation/attestation_service.py index 1794ef6b03a..93110f72d74 100644 --- a/packages/syft/src/syft/service/attestation/attestation_service.py +++ b/packages/syft/src/syft/service/attestation/attestation_service.py @@ -6,7 +6,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...types.errors import SyftException from ...types.result import as_result from ...util.util import str_to_bool diff --git a/packages/syft/src/syft/service/blob_storage/service.py b/packages/syft/src/syft/service/blob_storage/service.py index f8069566268..78da77d7eab 100644 --- a/packages/syft/src/syft/service/blob_storage/service.py +++ b/packages/syft/src/syft/service/blob_storage/service.py @@ -11,7 +11,7 @@ from ...store.blob_storage import BlobRetrieval from ...store.blob_storage.on_disk import OnDiskBlobDeposit from ...store.blob_storage.seaweedfs import SeaweedFSBlobDeposit -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...types.blob_storage import AzureSecureFilePathLocation from ...types.blob_storage import BlobFileType from ...types.blob_storage import BlobStorageEntry diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index f6709ebab3a..a1fa9300bdf 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -4,7 +4,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store import PartitionSettings from ...types.uid import UID diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 29fa9093f14..79490223f45 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -5,7 +5,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index fa7da4976fa..e6cea5a4a21 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -3,7 +3,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...types.uid import UID from ..code.user_code import SubmitUserCode diff --git a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py index ec0241b54e0..e7e482a4337 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py @@ -3,7 +3,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store_errors import StashException from ...types.result import as_result diff --git a/packages/syft/src/syft/service/data_subject/data_subject_service.py b/packages/syft/src/syft/service/data_subject/data_subject_service.py index 8c67cd84849..ecde100edf5 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_service.py @@ -5,7 +5,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store_errors import StashException from ...types.result import as_result diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py index 93ab6f8ad5b..a25a49ee60d 100644 --- a/packages/syft/src/syft/service/dataset/dataset_service.py +++ b/packages/syft/src/syft/service/dataset/dataset_service.py @@ -5,7 +5,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...types.dicttuple import DictTuple from ...types.uid import UID from ..action.action_permissions import ActionObjectPermission diff --git a/packages/syft/src/syft/service/enclave/enclave_service.py b/packages/syft/src/syft/service/enclave/enclave_service.py index 76d8c455e52..064c8806f91 100644 --- a/packages/syft/src/syft/service/enclave/enclave_service.py +++ b/packages/syft/src/syft/service/enclave/enclave_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ..service import AbstractService diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index 3b31f168f17..5cff02ecb74 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -6,7 +6,7 @@ # relative from ...serde.serializable import serializable from ...server.worker_settings import WorkerSettings -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...types.errors import SyftException from ...types.uid import UID from ..action.action_object import ActionObject diff --git a/packages/syft/src/syft/service/log/log_service.py b/packages/syft/src/syft/service/log/log_service.py index 9d25679b895..d4b96a0deed 100644 --- a/packages/syft/src/syft/service/log/log_service.py +++ b/packages/syft/src/syft/service/log/log_service.py @@ -1,6 +1,6 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...types.uid import UID from ..action.action_permissions import StoragePermission from ..context import AuthedServiceContext diff --git a/packages/syft/src/syft/service/metadata/metadata_service.py b/packages/syft/src/syft/service/metadata/metadata_service.py index 603091ca4ee..b7b450b037b 100644 --- a/packages/syft/src/syft/service/metadata/metadata_service.py +++ b/packages/syft/src/syft/service/metadata/metadata_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ..context import AuthedServiceContext from ..service import AbstractService from ..service import service_method diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 581c3561c92..68ac808473f 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -6,7 +6,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...types.blob_storage import BlobStorageEntry diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index d14fadb25ed..1f4d656201c 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -15,7 +15,7 @@ from ...server.credentials import SyftVerifyKey from ...server.worker_settings import WorkerSettings from ...service.settings.settings import ServerSettings -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException diff --git a/packages/syft/src/syft/service/notification/notification_service.py b/packages/syft/src/syft/service/notification/notification_service.py index 9da8ae4a935..7d6cb83e2f1 100644 --- a/packages/syft/src/syft/service/notification/notification_service.py +++ b/packages/syft/src/syft/service/notification/notification_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.document_store_errors import StashException from ...types.errors import SyftException from ...types.result import as_result diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index 34a964b4a03..adae5484b05 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -8,7 +8,7 @@ # relative from ...abstract_server import AbstractServer from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index 7f59cf4e1e8..3d80a8e4148 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -7,7 +7,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey from ...store.document_store_errors import StashException diff --git a/packages/syft/src/syft/service/policy/policy_service.py b/packages/syft/src/syft/service/policy/policy_service.py index da8759a3540..1e5f430d109 100644 --- a/packages/syft/src/syft/service/policy/policy_service.py +++ b/packages/syft/src/syft/service/policy/policy_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...types.uid import UID from ..context import AuthedServiceContext from ..response import SyftSuccess diff --git a/packages/syft/src/syft/service/project/project_service.py b/packages/syft/src/syft/service/project/project_service.py index b04ec27bf8b..2df2fd42e1c 100644 --- a/packages/syft/src/syft/service/project/project_service.py +++ b/packages/syft/src/syft/service/project/project_service.py @@ -4,7 +4,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject diff --git a/packages/syft/src/syft/service/queue/queue_service.py b/packages/syft/src/syft/service/queue/queue_service.py index d47f21052a8..b98f344745d 100644 --- a/packages/syft/src/syft/service/queue/queue_service.py +++ b/packages/syft/src/syft/service/queue/queue_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ..service import AbstractService from .queue_stash import QueueStash diff --git a/packages/syft/src/syft/service/request/request_service.py b/packages/syft/src/syft/service/request/request_service.py index 37ceeeae02a..4acb027f1fd 100644 --- a/packages/syft/src/syft/service/request/request_service.py +++ b/packages/syft/src/syft/service/request/request_service.py @@ -4,7 +4,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.linked_obj import LinkedObject from ...types.errors import SyftException from ...types.result import as_result diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py index a0adbbf8575..10890350e2d 100644 --- a/packages/syft/src/syft/service/settings/settings_service.py +++ b/packages/syft/src/syft/service/settings/settings_service.py @@ -8,7 +8,7 @@ # relative from ...abstract_server import ServerSideType from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.sqlite_document_store import SQLiteStoreConfig diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index d044edf3b96..bc7326bd974 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -6,7 +6,7 @@ # relative from ...client.api import ServerIdentity from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store import NewBaseStash from ...store.document_store_errors import NotFoundException diff --git a/packages/syft/src/syft/service/sync/sync_stash.py b/packages/syft/src/syft/service/sync/sync_stash.py index 4090e62dec1..d6824a3ac89 100644 --- a/packages/syft/src/syft/service/sync/sync_stash.py +++ b/packages/syft/src/syft/service/sync/sync_stash.py @@ -1,7 +1,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store import PartitionSettings from ...store.document_store_errors import StashException diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index 34d5071fa7a..54b959085e6 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -11,7 +11,7 @@ from ...serde.serializable import serializable from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject diff --git a/packages/syft/src/syft/service/worker/image_registry_service.py b/packages/syft/src/syft/service/worker/image_registry_service.py index e9e80f21892..83a30bb670b 100644 --- a/packages/syft/src/syft/service/worker/image_registry_service.py +++ b/packages/syft/src/syft/service/worker/image_registry_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...types.errors import SyftException from ...types.uid import UID from ..context import AuthedServiceContext diff --git a/packages/syft/src/syft/service/worker/worker_image_service.py b/packages/syft/src/syft/service/worker/worker_image_service.py index ab9cf5a250f..a5f05f94dac 100644 --- a/packages/syft/src/syft/service/worker/worker_image_service.py +++ b/packages/syft/src/syft/service/worker/worker_image_service.py @@ -10,7 +10,7 @@ from ...custom_worker.config import WorkerConfig from ...custom_worker.k8s import IN_KUBERNETES from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...types.datetime import DateTime from ...types.dicttuple import DictTuple from ...types.errors import SyftException diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index ae46958997d..55b103ba369 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -12,7 +12,7 @@ from ...custom_worker.k8s import IN_KUBERNETES from ...custom_worker.runner_k8s import KubernetesRunner from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 28f443cb5fa..300c0b6ed3d 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -13,7 +13,7 @@ from ...custom_worker.runner_k8s import KubernetesRunner from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.db import DBManager from ...store.document_store import SyftSuccess from ...store.document_store_errors import StashException from ...types.errors import SyftException diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/db.py similarity index 100% rename from packages/syft/src/syft/store/db/sqlite_db.py rename to packages/syft/src/syft/store/db/db.py diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 7f5510d92c2..0fa6618c59d 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -38,11 +38,11 @@ from ...types.uid import UID from ..document_store_errors import NotFoundException from ..document_store_errors import StashException +from .db import DBManager +from .db import SQLiteDBManager from .query import Query from .schema import Base from .schema import create_table -from .sqlite_db import DBManager -from .sqlite_db import SQLiteDBManager StashT = TypeVar("StashT", bound=SyftObject) T = TypeVar("T") diff --git a/packages/syft/src/syft/store/db/utils.py b/packages/syft/src/syft/store/db/utils.py deleted file mode 100644 index 870186fa66c..00000000000 --- a/packages/syft/src/syft/store/db/utils.py +++ /dev/null @@ -1,35 +0,0 @@ -# stdlib -from enum import Enum -import json -from typing import Any -from uuid import UUID - -# relative -from ...serde.json_serde import Json -from ...types.uid import UID - - -def _default_dumps(val: Any) -> Json: # type: ignore - if isinstance(val, UID): - return str(val.no_dash) - elif isinstance(val, UUID): - return val.hex - elif issubclass(type(val), Enum): - return val.name - elif val is None: - return None - return str(val) - - -def _default_loads(val: Any) -> Any: # type: ignore - if "UID" in val: - return UID(val) - return val - - -def dumps(d: Any) -> str: - return json.dumps(d, default=_default_dumps) - - -def loads(d: str) -> Any: - return json.loads(d, object_hook=_default_loads) diff --git a/packages/syft/tests/syft/migrations/protocol_communication_test.py b/packages/syft/tests/syft/migrations/protocol_communication_test.py index da8ca1fd37a..64c670d5ea3 100644 --- a/packages/syft/tests/syft/migrations/protocol_communication_test.py +++ b/packages/syft/tests/syft/migrations/protocol_communication_test.py @@ -20,7 +20,7 @@ from syft.service.service import ServiceConfigRegistry from syft.service.service import service_method from syft.service.user.user_roles import GUEST_ROLE_LEVEL -from syft.store.db.sqlite_db import DBManager +from syft.store.db.db import DBManager from syft.store.document_store import DocumentStore from syft.store.document_store import NewBaseStash from syft.store.document_store import PartitionSettings diff --git a/packages/syft/tests/syft/stores/action_store_test.py b/packages/syft/tests/syft/stores/action_store_test.py index d40fa15e8ca..5c2fe63be0d 100644 --- a/packages/syft/tests/syft/stores/action_store_test.py +++ b/packages/syft/tests/syft/stores/action_store_test.py @@ -16,7 +16,7 @@ from syft.service.user.user import User from syft.service.user.user_roles import ServiceRole from syft.service.user.user_stash import UserStash -from syft.store.db.sqlite_db import DBManager +from syft.store.db.db import DBManager from syft.types.uid import UID # relative diff --git a/packages/syft/tests/syft/stores/base_stash_test.py b/packages/syft/tests/syft/stores/base_stash_test.py index 451db6fbd74..3a629999388 100644 --- a/packages/syft/tests/syft/stores/base_stash_test.py +++ b/packages/syft/tests/syft/stores/base_stash_test.py @@ -12,8 +12,8 @@ # syft absolute from syft.serde.serializable import serializable -from syft.store.db.sqlite_db import SQLiteDBConfig -from syft.store.db.sqlite_db import SQLiteDBManager +from syft.store.db.db import SQLiteDBConfig +from syft.store.db.db import SQLiteDBManager from syft.store.db.stash import ObjectStash from syft.store.document_store import PartitionKey from syft.store.document_store_errors import NotFoundException diff --git a/packages/syft/tests/syft/stores/store_fixtures_test.py b/packages/syft/tests/syft/stores/store_fixtures_test.py index 7225be989b7..c226c89fe26 100644 --- a/packages/syft/tests/syft/stores/store_fixtures_test.py +++ b/packages/syft/tests/syft/stores/store_fixtures_test.py @@ -9,8 +9,8 @@ from syft.service.user.user import UserCreate from syft.service.user.user_roles import ServiceRole from syft.service.user.user_stash import UserStash -from syft.store.db.sqlite_db import SQLiteDBConfig -from syft.store.db.sqlite_db import SQLiteDBManager +from syft.store.db.db import SQLiteDBConfig +from syft.store.db.db import SQLiteDBManager from syft.store.document_store import DocumentStore from syft.types.uid import UID diff --git a/packages/syft/tests/syft/worker_test.py b/packages/syft/tests/syft/worker_test.py index 816ce02b909..92faf0d36b7 100644 --- a/packages/syft/tests/syft/worker_test.py +++ b/packages/syft/tests/syft/worker_test.py @@ -24,7 +24,7 @@ from syft.service.user.user import UserCreate from syft.service.user.user import UserView from syft.service.user.user_stash import UserStash -from syft.store.db.sqlite_db import SQLiteDBManager +from syft.store.db.db import SQLiteDBManager from syft.types.errors import SyftException from syft.types.result import Ok From 5ff95b9d12d270a1a6b2b93ca7c025e431057ecd Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 12 Sep 2024 09:25:27 +0200 Subject: [PATCH 142/197] add partial updates --- packages/syft/src/syft/store/db/stash.py | 32 +++++++++++++++++++++--- packages/syft/tests/syft/network_test.py | 31 +++++++++++++++++++++++ 2 files changed, 60 insertions(+), 3 deletions(-) create mode 100644 packages/syft/tests/syft/network_test.py diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 7f5510d92c2..edccac2e35d 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -10,6 +10,7 @@ from typing import get_args # third party +from pydantic import ValidationError import sqlalchemy as sa from sqlalchemy import Row from sqlalchemy import Table @@ -34,6 +35,7 @@ from ...service.user.user_roles import ServiceRole from ...types.errors import SyftException from ...types.result import as_result +from ...types.syft_metaclass import Empty from ...types.syft_object import SyftObject from ...types.uid import UID from ..document_store_errors import NotFoundException @@ -402,7 +404,26 @@ def set( session.commit() return self.get_by_uid(credentials, uid, session=session).unwrap() - @as_result(StashException, NotFoundException) + @as_result(ValidationError, AttributeError) + def apply_partial_update( + self, original_obj: StashT, update_obj: SyftObject + ) -> StashT: + for key, value in update_obj.__dict__.items(): + if value is Empty: + continue + + if key in original_obj.__dict__: + setattr(original_obj, key, value) + else: + raise AttributeError( + f"{type(update_obj).__name__}.{key} not found in {type(original_obj).__name__}" + ) + + # validate the new fields + self.object_type.model_validate(original_obj) + return original_obj + + @as_result(StashException, NotFoundException, AttributeError, ValidationError) @with_session def update( self, @@ -419,8 +440,13 @@ def update( - To fix, we either need db-supported computed fields, or know in our ORM which fields should be re-computed. """ - if not self.allow_any_type: - self.check_type(obj, self.object_type).unwrap() + if not isinstance(obj, self.object_type): + original_obj = self.get_by_uid( + credentials, obj.id, session=session + ).unwrap() + obj = self.apply_partial_update( + original_obj=original_obj, update_obj=obj + ).unwrap() # TODO has_permission is not used if not self.is_unique(obj): diff --git a/packages/syft/tests/syft/network_test.py b/packages/syft/tests/syft/network_test.py new file mode 100644 index 00000000000..3bb4b5e84e3 --- /dev/null +++ b/packages/syft/tests/syft/network_test.py @@ -0,0 +1,31 @@ +# syft absolute +from syft.abstract_server import ServerType +from syft.server.credentials import SyftSigningKey +from syft.service.network.network_service import NetworkStash +from syft.service.network.server_peer import ServerPeer +from syft.service.network.server_peer import ServerPeerUpdate +from syft.types.uid import UID + + +def test_add_route() -> None: + uid = UID() + peer = ServerPeer( + id=uid, + name="test", + verify_key=SyftSigningKey.generate().verify_key, + server_type=ServerType.DATASITE, + admin_email="info@openmined.org", + ) + network_stash = NetworkStash.random() + + network_stash.set( + credentials=network_stash.db.root_verify_key, + obj=peer, + ).unwrap() + peer_update = ServerPeerUpdate(id=uid, name="new name") + peer = network_stash.update( + credentials=network_stash.db.root_verify_key, + obj=peer_update, + ).unwrap() + + assert peer.name == "new name" From 3b77c22d1e3669653a0c749128051a47a8febfb4 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 12 Sep 2024 10:21:05 +0200 Subject: [PATCH 143/197] root admin delete check --- packages/syft/src/syft/service/api/api.py | 4 +- .../syft/src/syft/service/api/api_service.py | 4 +- .../service/notification/email_templates.py | 4 +- .../syft/service/notifier/notifier_service.py | 2 +- .../syft/service/request/request_service.py | 2 +- .../src/syft/service/user/user_service.py | 57 +++++++++++-------- .../syft/src/syft/service/user/user_stash.py | 11 ++-- .../tests/syft/users/user_service_test.py | 19 +------ 8 files changed, 47 insertions(+), 56 deletions(-) diff --git a/packages/syft/src/syft/service/api/api.py b/packages/syft/src/syft/service/api/api.py index 44567e8d8ef..1cebbd1e0fc 100644 --- a/packages/syft/src/syft/service/api/api.py +++ b/packages/syft/src/syft/service/api/api.py @@ -575,7 +575,7 @@ def exec_code( api_service = context.server.get_service("apiservice") api_service.stash.upsert( - context.server.services.user.admin_verify_key(), self + context.server.services.user.root_verify_key, self ).unwrap() print = original_print # type: ignore @@ -650,7 +650,7 @@ def code_string(context: TransformContext) -> TransformContext: ) context.server = cast(AbstractServer, context.server) - admin_key = context.server.services.user.admin_verify_key() + admin_key = context.server.services.user.root_verify_key # If endpoint exists **AND** (has visible access **OR** the user is admin) if endpoint_type is not None and ( diff --git a/packages/syft/src/syft/service/api/api_service.py b/packages/syft/src/syft/service/api/api_service.py index 9df52905439..8df26f0ac15 100644 --- a/packages/syft/src/syft/service/api/api_service.py +++ b/packages/syft/src/syft/service/api/api_service.py @@ -261,7 +261,7 @@ def api_endpoints( context: AuthedServiceContext, ) -> list[TwinAPIEndpointView]: """Retrieves a list of available API endpoints view available to the user.""" - admin_key = context.server.services.user.admin_verify_key() + admin_key = context.server.services.user.root_verify_key all_api_endpoints = self.stash.get_all(admin_key).unwrap() api_endpoint_view = [ @@ -585,7 +585,7 @@ def execute_server_side_endpoint_mock_by_id( def get_endpoint_by_uid( self, context: AuthedServiceContext, uid: UID ) -> TwinAPIEndpoint: - admin_key = context.server.services.user.admin_verify_key() + admin_key = context.server.services.user.root_verify_key return self.stash.get_by_uid(admin_key, uid).unwrap() @as_result(StashException) diff --git a/packages/syft/src/syft/service/notification/email_templates.py b/packages/syft/src/syft/service/notification/email_templates.py index 2ebc0908a88..c53da5fabef 100644 --- a/packages/syft/src/syft/service/notification/email_templates.py +++ b/packages/syft/src/syft/service/notification/email_templates.py @@ -133,7 +133,7 @@ def email_title(notification: "Notification", context: AuthedServiceContext) -> @staticmethod def email_body(notification: "Notification", context: AuthedServiceContext) -> str: user_service = context.server.services.user - admin_verify_key = user_service.admin_verify_key() + admin_verify_key = user_service.root_verify_key user = user_service.stash.get_by_verify_key( credentials=admin_verify_key, verify_key=notification.to_user_verify_key ).unwrap() @@ -224,7 +224,7 @@ def email_title(notification: "Notification", context: AuthedServiceContext) -> @staticmethod def email_body(notification: "Notification", context: AuthedServiceContext) -> str: user_service = context.server.services.user - admin_verify_key = user_service.admin_verify_key() + admin_verify_key = user_service.root_verify_key admin = user_service.get_by_verify_key(admin_verify_key).unwrap() admin_name = admin.name diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index adae5484b05..8f788d09431 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -322,7 +322,7 @@ def set_email_rate_limit( def dispatch_notification( self, context: AuthedServiceContext, notification: Notification ) -> SyftSuccess: - admin_key = context.server.services.user.admin_verify_key() + admin_key = context.server.services.user.root_verify_key # Silently fail on notification not delivered try: diff --git a/packages/syft/src/syft/service/request/request_service.py b/packages/syft/src/syft/service/request/request_service.py index 4acb027f1fd..ed94c185689 100644 --- a/packages/syft/src/syft/service/request/request_service.py +++ b/packages/syft/src/syft/service/request/request_service.py @@ -57,7 +57,7 @@ def submit( request, ).unwrap() - root_verify_key = context.server.services.user.admin_verify_key() + root_verify_key = context.server.services.user.root_verify_key if send_message: message_subject = f"Result to request {str(request.id)[:4]}...{str(request.id)[-3:]}\ diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index 54b959085e6..6ffee6756c9 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -130,7 +130,7 @@ def forgot_password( "If the email is valid, we sent a password " + "reset token to your email or a password request to the admin." ) - root_key = self.admin_verify_key() + root_key = self.root_verify_key root_context = AuthedServiceContext(server=context.server, credentials=root_key) @@ -241,7 +241,7 @@ def reset_password( self, context: UnauthedServiceContext, token: str, new_password: str ) -> SyftSuccess: """Resets a certain user password using a temporary token.""" - root_key = self.admin_verify_key() + root_key = self.root_verify_key root_context = AuthedServiceContext(server=context.server, credentials=root_key) try: @@ -348,7 +348,7 @@ def get_index( def signing_key_for_verify_key(self, verify_key: SyftVerifyKey) -> UserPrivateKey: user = self.stash.get_by_verify_key( - credentials=self.stash.admin_verify_key(), verify_key=verify_key + credentials=self.stash.root_verify_key, verify_key=verify_key ).unwrap() return user.to(UserPrivateKey) @@ -537,27 +537,36 @@ def update( @service_method(path="user.delete", name="delete", roles=GUEST_ROLE_LEVEL) def delete(self, context: AuthedServiceContext, uid: UID) -> UID: - user = self.stash.get_by_uid(credentials=context.credentials, uid=uid).unwrap() + user_to_delete = self.stash.get_by_uid( + credentials=context.credentials, uid=uid + ).unwrap() - if ( + # Cannot delete root user + if user_to_delete.verify_key == self.root_verify_key: + raise UserPermissionError( + private_message=f"User {context.credentials} attempted to delete root user." + ) + + # - Admins can delete any user + # - Data Owners can delete Data Scientists and Guests + has_delete_permissions = ( context.role == ServiceRole.ADMIN or context.role == ServiceRole.DATA_OWNER - and user.role - in [ - ServiceRole.GUEST, - ServiceRole.DATA_SCIENTIST, - ] - ): - pass - else: + and user_to_delete.role in [ServiceRole.GUEST, ServiceRole.DATA_SCIENTIST] + ) + + if not has_delete_permissions: raise UserPermissionError( - f"User {context.credentials} ({context.role}) tried to delete user {uid} ({user.role})" + private_message=( + f"User {context.credentials} ({context.role}) tried to delete user " + f"{uid} ({user_to_delete.role})" + ) ) # TODO: Remove notifications for the deleted user - self.stash.delete_by_uid(credentials=context.credentials, uid=uid).unwrap() - - return uid + return self.stash.delete_by_uid( + credentials=context.credentials, uid=uid + ).unwrap() def exchange_credentials(self, context: UnauthedServiceContext) -> SyftSuccess: """Verify user @@ -568,7 +577,7 @@ def exchange_credentials(self, context: UnauthedServiceContext) -> SyftSuccess: raise SyftException(public_message="Invalid login credentials") user = self.stash.get_by_email( - credentials=self.admin_verify_key(), email=context.login_credentials.email + credentials=self.root_verify_key, email=context.login_credentials.email ).unwrap() if check_pwd(context.login_credentials.password, user.hashed_password): @@ -587,9 +596,9 @@ def exchange_credentials(self, context: UnauthedServiceContext) -> SyftSuccess: return SyftSuccess(message="Login successful.", value=user.to(UserPrivateKey)) - def admin_verify_key(self) -> SyftVerifyKey: - # TODO: Remove passthrough method? - return self.stash.admin_verify_key() + @property + def root_verify_key(self) -> SyftVerifyKey: + return self.stash.root_verify_key def register( self, context: ServerServiceContext, new_user: UserCreate @@ -630,7 +639,7 @@ def register( success_message = f"User '{user.name}' successfully registered!" # Notification Step - root_key = self.admin_verify_key() + root_key = self.root_verify_key root_context = AuthedServiceContext(server=context.server, credentials=root_key) link = None @@ -657,7 +666,7 @@ def register( @as_result(StashException) def user_verify_key(self, email: str) -> SyftVerifyKey: # we are bypassing permissions here, so dont use to return a result directly to the user - credentials = self.admin_verify_key() + credentials = self.root_verify_key user = self.stash.get_by_email(credentials=credentials, email=email).unwrap() if user.verify_key is None: raise UserError(f"User {email} has no verify key") @@ -666,7 +675,7 @@ def user_verify_key(self, email: str) -> SyftVerifyKey: @as_result(StashException) def get_by_verify_key(self, verify_key: SyftVerifyKey) -> UserView: # we are bypassing permissions here, so dont use to return a result directly to the user - credentials = self.admin_verify_key() + credentials = self.root_verify_key user = self.stash.get_by_verify_key( credentials=credentials, verify_key=verify_key ).unwrap() diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index 6b51bf73afb..cdacee66f2f 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -1,3 +1,5 @@ +# third party + # relative from ...serde.serializable import serializable from ...server.credentials import SyftSigningKey @@ -31,13 +33,10 @@ def init_root_user(self) -> None: ), ) - def admin_verify_key(self) -> SyftVerifyKey: - return self.root_verify_key - @as_result(StashException, NotFoundException) def admin_user(self) -> User: # TODO: This returns only one user, the first user with the role ADMIN - admin_credentials = self.admin_verify_key() + admin_credentials = self.root_verify_key return self.get_by_role( credentials=admin_credentials, role=ServiceRole.ADMIN ).unwrap() @@ -59,7 +58,7 @@ def get_by_email(self, credentials: SyftVerifyKey, email: str) -> User: @as_result(StashException) def email_exists(self, email: str) -> bool: try: - self.get_by_email(credentials=self.admin_verify_key(), email=email).unwrap() + self.get_by_email(credentials=self.root_verify_key, email=email).unwrap() return True except NotFoundException: return False @@ -68,7 +67,7 @@ def email_exists(self, email: str) -> bool: def verify_key_exists(self, verify_key: SyftVerifyKey) -> bool: try: self.get_by_verify_key( - credentials=self.admin_verify_key(), verify_key=verify_key + credentials=self.root_verify_key, verify_key=verify_key ).unwrap() return True except NotFoundException: diff --git a/packages/syft/tests/syft/users/user_service_test.py b/packages/syft/tests/syft/users/user_service_test.py index 0452bb21ede..efa69cd1c2b 100644 --- a/packages/syft/tests/syft/users/user_service_test.py +++ b/packages/syft/tests/syft/users/user_service_test.py @@ -526,27 +526,10 @@ def mock_get_by_email(credentials: SyftVerifyKey, email: str) -> NoReturn: assert exc.value.public_message == expected_output -def test_userservice_admin_verify_key_error( - monkeypatch: MonkeyPatch, user_service: UserService -) -> None: - expected_output = "failed to get admin verify_key" - - def mock_admin_verify_key() -> UID: - raise SyftException(public_message=expected_output) - - monkeypatch.setattr(user_service.stash, "admin_verify_key", mock_admin_verify_key) - - with pytest.raises(SyftException) as exc: - user_service.admin_verify_key() - - assert exc.type == SyftException - assert exc.value.public_message == expected_output - - def test_userservice_admin_verify_key_success( monkeypatch: MonkeyPatch, user_service: UserService, worker ) -> None: - response = user_service.admin_verify_key() + response = user_service.root_verify_key assert isinstance(response, SyftVerifyKey) assert response == worker.root_client.credentials.verify_key From 5a96ae95c9876faa43e8732e4e4bdc5fdf97d193 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 12 Sep 2024 10:21:17 +0200 Subject: [PATCH 144/197] fix --- packages/syft/src/syft/service/user/user_stash.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index cdacee66f2f..90364a60565 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -1,5 +1,3 @@ -# third party - # relative from ...serde.serializable import serializable from ...server.credentials import SyftSigningKey From d631b90fc8dc7ec52eb21fde64707aa8585936e0 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 12 Sep 2024 10:39:59 +0200 Subject: [PATCH 145/197] fix tests that delete root user --- packages/syft/tests/syft/action_test.py | 6 ------ packages/syft/tests/syft/users/user_code_test.py | 6 ------ 2 files changed, 12 deletions(-) diff --git a/packages/syft/tests/syft/action_test.py b/packages/syft/tests/syft/action_test.py index bb9ffb0bad2..851a83cb7c2 100644 --- a/packages/syft/tests/syft/action_test.py +++ b/packages/syft/tests/syft/action_test.py @@ -32,11 +32,9 @@ def test_actionobject_method(worker): assert res[0] == "A" -@pytest.mark.parametrize("delete_original_admin", [False, True]) def test_new_admin_has_action_object_permission( worker: Worker, faker: Faker, - delete_original_admin: bool, ) -> None: root_client = worker.root_client @@ -60,10 +58,6 @@ def test_new_admin_has_action_object_permission( root_client.api.services.user.update(uid=admin.account.id, role=ServiceRole.ADMIN) - if delete_original_admin: - res = root_client.api.services.user.delete(root_client.account.id) - assert not isinstance(res, SyftError) - assert admin.api.services.action.get(obj.id) == obj diff --git a/packages/syft/tests/syft/users/user_code_test.py b/packages/syft/tests/syft/users/user_code_test.py index ddf0a467d5d..5de444197ba 100644 --- a/packages/syft/tests/syft/users/user_code_test.py +++ b/packages/syft/tests/syft/users/user_code_test.py @@ -47,12 +47,10 @@ def test_repr_markdown_not_throwing_error(guest_client: DatasiteClient) -> None: assert result[0]._repr_markdown_() -@pytest.mark.parametrize("delete_original_admin", [False, True]) def test_new_admin_can_list_user_code( worker: Worker, ds_client: DatasiteClient, faker: Faker, - delete_original_admin: bool, ) -> None: root_client = worker.root_client @@ -72,10 +70,6 @@ def test_new_admin_can_list_user_code( ) assert result.role == ServiceRole.ADMIN - if delete_original_admin: - res = root_client.api.services.user.delete(root_client.account.id) - assert not isinstance(res, SyftError) - user_code_stash = worker.services.user_code.stash user_codes = user_code_stash._data From ffc2b52f0aac1efd44aaf5461fc611ecd266bfd0 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 12 Sep 2024 11:06:47 +0200 Subject: [PATCH 146/197] implement nested queries --- .../notification/notification_stash.py | 3 +-- packages/syft/src/syft/store/db/db.py | 4 ++++ packages/syft/src/syft/store/db/query.py | 4 ++++ packages/syft/src/syft/store/db/stash.py | 2 +- .../syft/tests/syft/stores/base_stash_test.py | 24 +++++++++++++++++++ 5 files changed, 34 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/service/notification/notification_stash.py b/packages/syft/src/syft/service/notification/notification_stash.py index eba855856be..029cf1b325a 100644 --- a/packages/syft/src/syft/service/notification/notification_stash.py +++ b/packages/syft/src/syft/service/notification/notification_stash.py @@ -69,11 +69,10 @@ def get_notification_for_linked_obj( credentials: SyftVerifyKey, linked_obj: LinkedObject, ) -> Notification: - # TODO does this work? return self.get_one( credentials, filters={ - "linked_obj": linked_obj, + "linked_obj.id": linked_obj.id, }, ).unwrap() diff --git a/packages/syft/src/syft/store/db/db.py b/packages/syft/src/syft/store/db/db.py index b845ccc1754..71250660d32 100644 --- a/packages/syft/src/syft/store/db/db.py +++ b/packages/syft/src/syft/store/db/db.py @@ -107,6 +107,8 @@ def random( server_uid: UID | None = None, root_verify_key: SyftVerifyKey | None = None, ) -> "SQLiteDBManager": + """Get a SQLiteDBManager with random values for the config, server_uid, and root_verify_key. + Intended for testing purposes only.""" root_verify_key = root_verify_key or SyftSigningKey.generate().verify_key server_uid = server_uid or UID() config = config or SQLiteDBConfig() @@ -126,6 +128,8 @@ def random( server_uid: UID | None = None, root_verify_key: SyftVerifyKey | None = None, ) -> "PostgresDBManager": + """Get a PostgresDBManager with random values for the config, server_uid, and root_verify_key. + Intended for testing purposes only.""" if config is None: raise ValueError("Cannot create a postgres db without a config") diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py index 8dc13f1ccb5..d232a1c6572 100644 --- a/packages/syft/src/syft/store/db/query.py +++ b/packages/syft/src/syft/store/db/query.py @@ -271,6 +271,10 @@ def _eq_filter( if field == "id": return table.c.id == UID(value) + if "." in field: + # magic! + field = field.split(".") # type: ignore + json_value = serialize_json(value) return table.c.fields[field] == func.json_quote(json_value) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 6ba36551741..9a618fad7ac 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -545,7 +545,7 @@ def get_one( for field_name, operator, field_value in parse_filters(filters): query = query.filter(field_name, operator, field_value) - query = query.order_by(order_by, sort_order).offset(offset) + query = query.order_by(order_by, sort_order).offset(offset).limit(1) result = query.execute(session).first() if result is None: raise NotFoundException(f"{self.object_type.__name__}: not found") diff --git a/packages/syft/tests/syft/stores/base_stash_test.py b/packages/syft/tests/syft/stores/base_stash_test.py index 3a629999388..9999848fe5e 100644 --- a/packages/syft/tests/syft/stores/base_stash_test.py +++ b/packages/syft/tests/syft/stores/base_stash_test.py @@ -12,12 +12,14 @@ # syft absolute from syft.serde.serializable import serializable +from syft.service.request.request_service import RequestService from syft.store.db.db import SQLiteDBConfig from syft.store.db.db import SQLiteDBManager from syft.store.db.stash import ObjectStash from syft.store.document_store import PartitionKey from syft.store.document_store_errors import NotFoundException from syft.store.document_store_errors import StashException +from syft.store.linked_obj import LinkedObject from syft.types.errors import SyftException from syft.types.syft_object import SyftObject from syft.types.uid import UID @@ -32,6 +34,7 @@ class MockObject(SyftObject): desc: str importance: int value: int + linked_obj: LinkedObject | None = None __attr_searchable__ = ["id", "name", "desc", "importance"] __attr_unique__ = ["id", "name"] @@ -292,6 +295,27 @@ def test_basestash_query_one( ).unwrap() +def test_basestash_query_linked_obj( + root_verify_key, base_stash: MockStash, mock_object: MockObject +) -> None: + mock_object.linked_obj = LinkedObject( + object_type=MockObject, + object_uid=UID(), + id=UID(), + tags=["tag1", "tag2"], + server_uid=UID(), + service_type=RequestService, + ) + base_stash.set(root_verify_key, mock_object).unwrap() + + result = base_stash.get_one( + root_verify_key, + filters={"linked_obj.id": mock_object.linked_obj.id}, + ).unwrap() + + assert result == mock_object + + def test_basestash_query_all( root_verify_key, base_stash: MockStash, mock_objects: list[MockObject], faker: Faker ) -> None: From cadba5c9e90b2dc3045fbe961068dd79ff004901 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 12 Sep 2024 14:38:42 +0530 Subject: [PATCH 147/197] remove duplicate postgres reference in base.yml move db manager and db config to base.py define postgres db manager fix json_quote based on postgres or sqlite fix lint --- packages/grid/backend/grid/core/server.py | 2 +- packages/grid/default.env | 2 +- packages/grid/helm/examples/dev/base.yaml | 7 - packages/grid/helm/syft/values.yaml | 32 -- packages/syft/src/syft/server/server.py | 33 +- packages/syft/src/syft/server/uvicorn.py | 5 +- .../syft/src/syft/server/worker_settings.py | 2 +- .../src/syft/service/action/action_service.py | 2 +- .../syft/src/syft/service/api/api_service.py | 2 +- .../attestation/attestation_service.py | 2 +- .../src/syft/service/blob_storage/service.py | 2 +- .../src/syft/service/code/status_service.py | 2 +- .../syft/service/code/user_code_service.py | 2 +- .../code_history/code_history_service.py | 2 +- .../data_subject_member_service.py | 2 +- .../data_subject/data_subject_service.py | 2 +- .../syft/service/dataset/dataset_service.py | 2 +- .../syft/service/enclave/enclave_service.py | 2 +- .../syft/src/syft/service/job/job_service.py | 2 +- .../syft/src/syft/service/log/log_service.py | 2 +- .../syft/service/metadata/metadata_service.py | 2 +- .../service/migration/migration_service.py | 2 +- .../syft/service/network/network_service.py | 2 +- .../notification/notification_service.py | 2 +- .../syft/service/notifier/notifier_service.py | 2 +- .../src/syft/service/output/output_service.py | 2 +- .../src/syft/service/policy/policy_service.py | 2 +- .../syft/service/project/project_service.py | 2 +- .../src/syft/service/queue/queue_service.py | 2 +- .../syft/service/request/request_service.py | 2 +- .../syft/service/settings/settings_service.py | 2 +- .../src/syft/service/sync/sync_service.py | 2 +- .../syft/src/syft/service/sync/sync_stash.py | 2 +- .../src/syft/service/user/user_service.py | 2 +- .../service/worker/image_registry_service.py | 2 +- .../service/worker/worker_image_service.py | 2 +- .../service/worker/worker_pool_service.py | 2 +- .../src/syft/service/worker/worker_service.py | 2 +- packages/syft/src/syft/store/db/base.py | 53 +++ .../syft/src/syft/store/db/postgres_db.py | 61 ++++ packages/syft/src/syft/store/db/query.py | 14 +- packages/syft/src/syft/store/db/sqlite_db.py | 57 +--- packages/syft/src/syft/store/db/stash.py | 2 +- .../syft/store/postgres_pool_connection.py | 116 ------- .../src/syft/store/postgres_query_executor.py | 82 ----- .../syft/store/postgresql_document_store.py | 313 ------------------ .../migrations/protocol_communication_test.py | 2 +- .../tests/syft/stores/action_store_test.py | 2 +- 48 files changed, 189 insertions(+), 658 deletions(-) create mode 100644 packages/syft/src/syft/store/db/base.py create mode 100644 packages/syft/src/syft/store/db/postgres_db.py delete mode 100644 packages/syft/src/syft/store/postgres_pool_connection.py delete mode 100644 packages/syft/src/syft/store/postgres_query_executor.py delete mode 100644 packages/syft/src/syft/store/postgresql_document_store.py diff --git a/packages/grid/backend/grid/core/server.py b/packages/grid/backend/grid/core/server.py index 15147a9d778..94f209613b7 100644 --- a/packages/grid/backend/grid/core/server.py +++ b/packages/grid/backend/grid/core/server.py @@ -16,7 +16,7 @@ from syft.service.queue.zmq_client import ZMQQueueConfig from syft.store.blob_storage.seaweedfs import SeaweedFSClientConfig from syft.store.blob_storage.seaweedfs import SeaweedFSConfig -from syft.store.db.sqlite_db import PostgresDBConfig +from syft.store.db.postgres_db import PostgresDBConfig from syft.store.db.sqlite_db import SQLiteDBConfig from syft.types.uid import UID diff --git a/packages/grid/default.env b/packages/grid/default.env index 21179646f28..49697538cd8 100644 --- a/packages/grid/default.env +++ b/packages/grid/default.env @@ -108,7 +108,7 @@ RATHOLE_PORT=2333 # POSTGRESQL_IMAGE=postgres # export POSTGRESQL_VERSION="15" POSTGRESQL_DBNAME=syftdb_postgres -POSTGRESQL_HOST=localhost +POSTGRESQL_HOST=postgres POSTGRESQL_PORT=5432 POSTGRESQL_USERNAME=syft_postgres POSTGRESQL_PASSWORD=example diff --git a/packages/grid/helm/examples/dev/base.yaml b/packages/grid/helm/examples/dev/base.yaml index 44ef6395758..3fc1ad5c4da 100644 --- a/packages/grid/helm/examples/dev/base.yaml +++ b/packages/grid/helm/examples/dev/base.yaml @@ -28,13 +28,6 @@ postgres: secret: rootPassword: example -postgres: - resourcesPreset: null - resources: null - - secret: - rootPassword: example - seaweedfs: resourcesPreset: null resources: null diff --git a/packages/grid/helm/syft/values.yaml b/packages/grid/helm/syft/values.yaml index 5a65b80a6a2..93c7ac826d4 100644 --- a/packages/grid/helm/syft/values.yaml +++ b/packages/grid/helm/syft/values.yaml @@ -44,38 +44,6 @@ postgres: rootPassword: null # ================================================================================= -postgres: -# Postgres config - port: 5432 - username: syft_postgres - dbname: syftdb_postgres - host: postgres - - # Extra environment vars - env: null - - # Pod labels & annotations - podLabels: null - podAnnotations: null - - # Node selector for pods - nodeSelector: null - - # Pod Resource Limits - resourcesPreset: large - resources: null - - # PVC storage size - storageSize: 5Gi - - # Mongo secret name. Override this if you want to use a self-managed secret. - secretKeyName: postgres-secret - - # default/custom secret raw values - secret: - rootPassword: null -# ================================================================================= - frontend: # Extra environment vars env: null diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index a376a85c6fd..5462c33f729 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -87,7 +87,9 @@ from ..store.blob_storage.on_disk import OnDiskBlobStorageClientConfig from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig from ..store.blob_storage.seaweedfs import SeaweedFSBlobDeposit -from ..store.db.sqlite_db import DBConfig +from ..store.db.base import DBConfig +from ..store.db.postgres_db import PostgresDBConfig +from ..store.db.postgres_db import PostgresDBManager from ..store.db.sqlite_db import SQLiteDBConfig from ..store.db.sqlite_db import SQLiteDBManager from ..store.db.stash import ObjectStash @@ -336,7 +338,6 @@ def __init__( smtp_host: str | None = None, association_request_auto_approval: bool = False, background_tasks: bool = False, - store_client_config: dict | None = None, consumer_type: ConsumerType | None = None, ): # 🟡 TODO 22: change our ENV variable format and default init args to make this @@ -411,9 +412,12 @@ def __init__( filename=f"{self.id}_json.db", path=self.get_temp_dir("db"), ) - # db_config = PostgresDBConfig(reset=False) + + if reset: + db_config.reset = True self.db_config = db_config + self.db: PostgresDBManager | SQLiteDBManager | None = None self.init_stores(db_config=self.db_config) @@ -739,7 +743,6 @@ def named( in_memory_workers: bool = True, association_request_auto_approval: bool = False, background_tasks: bool = False, - store_client_config: dict | None = None, consumer_type: ConsumerType | None = None, ) -> Server: uid = get_named_server_uid(name) @@ -771,7 +774,6 @@ def named( reset=reset, association_request_auto_approval=association_request_auto_approval, background_tasks=background_tasks, - store_client_config=store_client_config, consumer_type=consumer_type, ) @@ -895,11 +897,20 @@ def reload_user_code() -> None: CODE_RELOADER[ti] = reload_user_code def init_stores(self, db_config: DBConfig) -> None: - self.db = SQLiteDBManager( - config=db_config, - server_uid=self.id, - root_verify_key=self.verify_key, - ) + if isinstance(db_config, SQLiteDBConfig): + self.db = SQLiteDBManager( + config=db_config, + server_uid=self.id, + root_verify_key=self.verify_key, + ) + elif isinstance(db_config, PostgresDBConfig): + self.db = PostgresDBManager( + config=db_config, + server_uid=self.id, + root_verify_key=self.verify_key, + ) + else: + raise SyftException(public_message=f"Unsupported DB config: {db_config}") self.queue_stash = QueueStash(store=self.db) @@ -975,7 +986,7 @@ def update_self(self, settings: ServerSettings) -> None: # settings and services are resolved. def get_settings(self) -> ServerSettings | None: if self._settings: - return self._settings + return self._settings # type: ignore if self.signing_key is None: raise ValueError(f"{self} has no signing key") diff --git a/packages/syft/src/syft/server/uvicorn.py b/packages/syft/src/syft/server/uvicorn.py index 6f2617cc7b7..13fff13d973 100644 --- a/packages/syft/src/syft/server/uvicorn.py +++ b/packages/syft/src/syft/server/uvicorn.py @@ -27,6 +27,7 @@ from ..abstract_server import ServerSideType from ..client.client import API_PATH from ..deployment_type import DeploymentType +from ..store.db.base import DBConfig from ..util.autoreload import enable_autoreload from ..util.constants import DEFAULT_TIMEOUT from ..util.telemetry import TRACING_ENABLED @@ -66,7 +67,7 @@ class AppSettings(BaseSettings): n_consumers: int = 0 association_request_auto_approval: bool = False background_tasks: bool = False - store_client_config: dict | None = None + db_config: DBConfig | None = None model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None") @@ -232,7 +233,6 @@ def serve_server( association_request_auto_approval: bool = False, background_tasks: bool = False, debug: bool = False, - store_client_config: dict | None = None, ) -> tuple[Callable, Callable]: starting_uvicorn_event = multiprocessing.Event() @@ -262,7 +262,6 @@ def serve_server( "debug": debug, "starting_uvicorn_event": starting_uvicorn_event, "deployment_type": deployment_type, - "store_client_config": store_client_config, }, ) diff --git a/packages/syft/src/syft/server/worker_settings.py b/packages/syft/src/syft/server/worker_settings.py index c76e5617d0d..e2a3c64a0e8 100644 --- a/packages/syft/src/syft/server/worker_settings.py +++ b/packages/syft/src/syft/server/worker_settings.py @@ -13,7 +13,7 @@ from ..server.credentials import SyftSigningKey from ..service.queue.base_queue import QueueConfig from ..store.blob_storage import BlobStorageConfig -from ..store.db.sqlite_db import DBConfig +from ..store.db.base import DBConfig from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.syft_object import SyftObject from ..types.uid import UID diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 38b53cf977d..6fb49804d1e 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -9,7 +9,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.datetime import DateTime diff --git a/packages/syft/src/syft/service/api/api_service.py b/packages/syft/src/syft/service/api/api_service.py index cce0d53698a..0f63e50f01b 100644 --- a/packages/syft/src/syft/service/api/api_service.py +++ b/packages/syft/src/syft/service/api/api_service.py @@ -10,7 +10,7 @@ from ...serde.serializable import serializable from ...service.action.action_endpoint import CustomEndpointActionObject from ...service.action.action_object import ActionObject -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException diff --git a/packages/syft/src/syft/service/attestation/attestation_service.py b/packages/syft/src/syft/service/attestation/attestation_service.py index 1794ef6b03a..4fcc434900e 100644 --- a/packages/syft/src/syft/service/attestation/attestation_service.py +++ b/packages/syft/src/syft/service/attestation/attestation_service.py @@ -6,7 +6,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...types.errors import SyftException from ...types.result import as_result from ...util.util import str_to_bool diff --git a/packages/syft/src/syft/service/blob_storage/service.py b/packages/syft/src/syft/service/blob_storage/service.py index f8069566268..3067810e099 100644 --- a/packages/syft/src/syft/service/blob_storage/service.py +++ b/packages/syft/src/syft/service/blob_storage/service.py @@ -11,7 +11,7 @@ from ...store.blob_storage import BlobRetrieval from ...store.blob_storage.on_disk import OnDiskBlobDeposit from ...store.blob_storage.seaweedfs import SeaweedFSBlobDeposit -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...types.blob_storage import AzureSecureFilePathLocation from ...types.blob_storage import BlobFileType from ...types.blob_storage import BlobStorageEntry diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index f6709ebab3a..dca5f93b04a 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -4,7 +4,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.db.stash import ObjectStash from ...store.document_store import PartitionSettings from ...types.uid import UID diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index 29fa9093f14..d6ee23a9923 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -5,7 +5,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index fa7da4976fa..b6f820cbff1 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -3,7 +3,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.document_store_errors import NotFoundException from ...types.uid import UID from ..code.user_code import SubmitUserCode diff --git a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py index ec0241b54e0..01bbfd3b20c 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py @@ -3,7 +3,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.db.stash import ObjectStash from ...store.document_store_errors import StashException from ...types.result import as_result diff --git a/packages/syft/src/syft/service/data_subject/data_subject_service.py b/packages/syft/src/syft/service/data_subject/data_subject_service.py index 8c67cd84849..a6817a36382 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_service.py @@ -5,7 +5,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.db.stash import ObjectStash from ...store.document_store_errors import StashException from ...types.result import as_result diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py index 93ab6f8ad5b..0bfbf535954 100644 --- a/packages/syft/src/syft/service/dataset/dataset_service.py +++ b/packages/syft/src/syft/service/dataset/dataset_service.py @@ -5,7 +5,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...types.dicttuple import DictTuple from ...types.uid import UID from ..action.action_permissions import ActionObjectPermission diff --git a/packages/syft/src/syft/service/enclave/enclave_service.py b/packages/syft/src/syft/service/enclave/enclave_service.py index 76d8c455e52..f0f449578ea 100644 --- a/packages/syft/src/syft/service/enclave/enclave_service.py +++ b/packages/syft/src/syft/service/enclave/enclave_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ..service import AbstractService diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index 3b31f168f17..c1e1c68e18f 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -6,7 +6,7 @@ # relative from ...serde.serializable import serializable from ...server.worker_settings import WorkerSettings -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...types.errors import SyftException from ...types.uid import UID from ..action.action_object import ActionObject diff --git a/packages/syft/src/syft/service/log/log_service.py b/packages/syft/src/syft/service/log/log_service.py index 9d25679b895..ec6d9b57ad6 100644 --- a/packages/syft/src/syft/service/log/log_service.py +++ b/packages/syft/src/syft/service/log/log_service.py @@ -1,6 +1,6 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...types.uid import UID from ..action.action_permissions import StoragePermission from ..context import AuthedServiceContext diff --git a/packages/syft/src/syft/service/metadata/metadata_service.py b/packages/syft/src/syft/service/metadata/metadata_service.py index 603091ca4ee..82043f80461 100644 --- a/packages/syft/src/syft/service/metadata/metadata_service.py +++ b/packages/syft/src/syft/service/metadata/metadata_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ..context import AuthedServiceContext from ..service import AbstractService from ..service import service_method diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 581c3561c92..e5566123c9e 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -6,7 +6,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...types.blob_storage import BlobStorageEntry diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index d14fadb25ed..03eae13f3bd 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -15,7 +15,7 @@ from ...server.credentials import SyftVerifyKey from ...server.worker_settings import WorkerSettings from ...service.settings.settings import ServerSettings -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException diff --git a/packages/syft/src/syft/service/notification/notification_service.py b/packages/syft/src/syft/service/notification/notification_service.py index 9da8ae4a935..afe3d97c662 100644 --- a/packages/syft/src/syft/service/notification/notification_service.py +++ b/packages/syft/src/syft/service/notification/notification_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.document_store_errors import StashException from ...types.errors import SyftException from ...types.result import as_result diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index 34a964b4a03..ce922aa26f9 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -8,7 +8,7 @@ # relative from ...abstract_server import AbstractServer from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index 7f59cf4e1e8..3d8c10ef699 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -7,7 +7,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey from ...store.document_store_errors import StashException diff --git a/packages/syft/src/syft/service/policy/policy_service.py b/packages/syft/src/syft/service/policy/policy_service.py index da8759a3540..e900cf0850b 100644 --- a/packages/syft/src/syft/service/policy/policy_service.py +++ b/packages/syft/src/syft/service/policy/policy_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...types.uid import UID from ..context import AuthedServiceContext from ..response import SyftSuccess diff --git a/packages/syft/src/syft/service/project/project_service.py b/packages/syft/src/syft/service/project/project_service.py index b04ec27bf8b..7b513f0c367 100644 --- a/packages/syft/src/syft/service/project/project_service.py +++ b/packages/syft/src/syft/service/project/project_service.py @@ -4,7 +4,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject diff --git a/packages/syft/src/syft/service/queue/queue_service.py b/packages/syft/src/syft/service/queue/queue_service.py index d47f21052a8..aab81efa5a7 100644 --- a/packages/syft/src/syft/service/queue/queue_service.py +++ b/packages/syft/src/syft/service/queue/queue_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ..service import AbstractService from .queue_stash import QueueStash diff --git a/packages/syft/src/syft/service/request/request_service.py b/packages/syft/src/syft/service/request/request_service.py index 37ceeeae02a..7ada7772875 100644 --- a/packages/syft/src/syft/service/request/request_service.py +++ b/packages/syft/src/syft/service/request/request_service.py @@ -4,7 +4,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.linked_obj import LinkedObject from ...types.errors import SyftException from ...types.result import as_result diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py index a0adbbf8575..5e5b8a2dbfb 100644 --- a/packages/syft/src/syft/service/settings/settings_service.py +++ b/packages/syft/src/syft/service/settings/settings_service.py @@ -8,7 +8,7 @@ # relative from ...abstract_server import ServerSideType from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.sqlite_document_store import SQLiteStoreConfig diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index d044edf3b96..4368b83ba55 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -6,7 +6,7 @@ # relative from ...client.api import ServerIdentity from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.db.stash import ObjectStash from ...store.document_store import NewBaseStash from ...store.document_store_errors import NotFoundException diff --git a/packages/syft/src/syft/service/sync/sync_stash.py b/packages/syft/src/syft/service/sync/sync_stash.py index 4090e62dec1..008ac8970c9 100644 --- a/packages/syft/src/syft/service/sync/sync_stash.py +++ b/packages/syft/src/syft/service/sync/sync_stash.py @@ -1,7 +1,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.db.stash import ObjectStash from ...store.document_store import PartitionSettings from ...store.document_store_errors import StashException diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index 34d5071fa7a..9072c67b192 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -11,7 +11,7 @@ from ...serde.serializable import serializable from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject diff --git a/packages/syft/src/syft/service/worker/image_registry_service.py b/packages/syft/src/syft/service/worker/image_registry_service.py index e9e80f21892..c4ea071356d 100644 --- a/packages/syft/src/syft/service/worker/image_registry_service.py +++ b/packages/syft/src/syft/service/worker/image_registry_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...types.errors import SyftException from ...types.uid import UID from ..context import AuthedServiceContext diff --git a/packages/syft/src/syft/service/worker/worker_image_service.py b/packages/syft/src/syft/service/worker/worker_image_service.py index ab9cf5a250f..e8ac5d10d47 100644 --- a/packages/syft/src/syft/service/worker/worker_image_service.py +++ b/packages/syft/src/syft/service/worker/worker_image_service.py @@ -10,7 +10,7 @@ from ...custom_worker.config import WorkerConfig from ...custom_worker.k8s import IN_KUBERNETES from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...types.datetime import DateTime from ...types.dicttuple import DictTuple from ...types.errors import SyftException diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index ae46958997d..5160c0cd966 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -12,7 +12,7 @@ from ...custom_worker.k8s import IN_KUBERNETES from ...custom_worker.runner_k8s import KubernetesRunner from ...serde.serializable import serializable -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 28f443cb5fa..09b7e147526 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -13,7 +13,7 @@ from ...custom_worker.runner_k8s import KubernetesRunner from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.sqlite_db import DBManager +from ...store.db.base import DBManager from ...store.document_store import SyftSuccess from ...store.document_store_errors import StashException from ...types.errors import SyftException diff --git a/packages/syft/src/syft/store/db/base.py b/packages/syft/src/syft/store/db/base.py new file mode 100644 index 00000000000..8a987861070 --- /dev/null +++ b/packages/syft/src/syft/store/db/base.py @@ -0,0 +1,53 @@ +# stdlib +import logging + +# third party +from pydantic import BaseModel +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +# relative +from ...serde.serializable import serializable +from ...server.credentials import SyftVerifyKey +from ...types.uid import UID + +logger = logging.getLogger(__name__) + + +@serializable(canonical_name="DBConfig", version=1) +class DBConfig(BaseModel): + reset: bool = False + + @property + def connection_string(self) -> str: + raise NotImplementedError("Subclasses must implement this method.") + + +class DBManager: + def __init__( + self, + config: DBConfig, + server_uid: UID, + root_verify_key: SyftVerifyKey, + ) -> None: + self.config = config + self.root_verify_key = root_verify_key + self.server_uid = server_uid + self.engine = create_engine( + config.connection_string, + # json_serializer=dumps, + # json_deserializer=loads, + ) + logger.info(f"Connecting to {config.connection_string}") + self.sessionmaker = sessionmaker(bind=self.engine) + logger.info(f"Successfully connected to {config.connection_string}") + self.update_settings() + + def update_settings(self) -> None: + pass + + def init_tables(self) -> None: + pass + + def reset(self) -> None: + pass diff --git a/packages/syft/src/syft/store/db/postgres_db.py b/packages/syft/src/syft/store/db/postgres_db.py new file mode 100644 index 00000000000..6166a8eb665 --- /dev/null +++ b/packages/syft/src/syft/store/db/postgres_db.py @@ -0,0 +1,61 @@ +# third party +from sqlalchemy import URL + +# relative +from ...serde.serializable import serializable +from ...server.credentials import SyftVerifyKey +from ...types.uid import UID +from .base import DBManager +from .schema import Base +from .sqlite_db import DBConfig + + +@serializable(canonical_name="PostgresDBConfig", version=1) +class PostgresDBConfig(DBConfig): + host: str + port: int + user: str + password: str + database: str + + @property + def connection_string(self) -> str: + return URL.create( + "postgresql", + username=self.user, + password=self.password, + host=self.host, + port=self.port, + database=self.database, + ).render_as_string(hide_password=False) + + +class PostgresDBManager(DBManager): + def update_settings(self) -> None: + return super().update_settings() + + def init_tables(self) -> None: + if self.config.reset: + # drop all tables that we know about + Base.metadata.drop_all(bind=self.engine) + self.config.reset = False + + Base.metadata.create_all(self.engine) + + def reset(self) -> None: + Base.metadata.drop_all(bind=self.engine) + Base.metadata.create_all(self.engine) + + @classmethod + def random( + cls: type, + *, + config: PostgresDBConfig, + server_uid: UID | None = None, + root_verify_key: SyftVerifyKey | None = None, + ) -> "PostgresDBManager": + root_verify_key = root_verify_key or SyftVerifyKey.generate() + server_uid = server_uid or UID() + return PostgresDBManager( + config=config, server_uid=server_uid, root_verify_key=root_verify_key + ) diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py index 8dc13f1ccb5..608b3ea5dc9 100644 --- a/packages/syft/src/syft/store/db/query.py +++ b/packages/syft/src/syft/store/db/query.py @@ -1,6 +1,7 @@ # stdlib from abc import ABC from abc import abstractmethod +from collections.abc import Callable import enum from typing import Any from typing import Literal @@ -33,6 +34,8 @@ class FilterOperator(enum.Enum): class Query(ABC): + json_quote: Callable | None = None + def __init__(self, object_type: type[SyftObject]) -> None: self.object_type: type = object_type self.table: Table = self._get_table(object_type) @@ -272,7 +275,10 @@ def _eq_filter( return table.c.id == UID(value) json_value = serialize_json(value) - return table.c.fields[field] == func.json_quote(json_value) + if self.json_quote: + return table.c.fields[field] == self.json_quote(json_value) + else: + return table.c.fields[field].astext == json_value @abstractmethod def _contains_filter( @@ -297,6 +303,8 @@ def _get_column(self, column: str) -> Column: class SQLiteQuery(Query): + json_quote = func.json_quote + def _make_permissions_clause( self, permission: ActionObjectPermission, @@ -315,7 +323,9 @@ def _contains_filter( value: Any, ) -> sa.sql.elements.BinaryExpression: field_value = serialize_json(value) - return table.c.fields[field].contains(func.json_quote(field_value)) + return table.c.fields[field].contains( + self.json_quote(field_value) if self.json_quote else field_value + ) class PostgresQuery(Query): diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index 28043c89001..c4cfea48189 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -4,29 +4,19 @@ import uuid # third party -from pydantic import BaseModel from pydantic import Field import sqlalchemy as sa -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker # relative from ...serde.serializable import serializable from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey from ...types.uid import UID +from .base import DBConfig +from .base import DBManager from .schema import Base -@serializable(canonical_name="DBConfig", version=1) -class DBConfig(BaseModel): - reset: bool = False - - @property - def connection_string(self) -> str: - raise NotImplementedError("Subclasses must implement this method.") - - @serializable(canonical_name="SQLiteDBConfig", version=1) class SQLiteDBConfig(DBConfig): filename: str = Field(default_factory=lambda: f"{uuid.uuid4()}.db") @@ -38,49 +28,6 @@ def connection_string(self) -> str: return f"sqlite:///{filepath.resolve()}" -@serializable(canonical_name="PostgresDBConfig", version=1) -class PostgresDBConfig(DBConfig): - host: str = "postgres" - port: int = 5432 - user: str = "syft_postgres" - password: str = "example" - database: str = "syftdb_postgres" - - @property - def connection_string(self) -> str: - return f"postgresql://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" - - -class DBManager: - def __init__( - self, - config: SQLiteDBConfig, - server_uid: UID, - root_verify_key: SyftVerifyKey, - ) -> None: - self.config = config - self.root_verify_key = root_verify_key - self.server_uid = server_uid - self.engine = create_engine( - config.connection_string, - # json_serializer=dumps, - # json_deserializer=loads, - ) - print(f"Connecting to {config.connection_string}") - self.sessionmaker = sessionmaker(bind=self.engine) - - self.update_settings() - - def update_settings(self) -> None: - pass - - def init_tables(self) -> None: - pass - - def reset(self) -> None: - pass - - class SQLiteDBManager(DBManager): def update_settings(self) -> None: # TODO split SQLite / PostgresDBManager diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 7f5510d92c2..98f4669ee69 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -38,10 +38,10 @@ from ...types.uid import UID from ..document_store_errors import NotFoundException from ..document_store_errors import StashException +from .base import DBManager from .query import Query from .schema import Base from .schema import create_table -from .sqlite_db import DBManager from .sqlite_db import SQLiteDBManager StashT = TypeVar("StashT", bound=SyftObject) diff --git a/packages/syft/src/syft/store/postgres_pool_connection.py b/packages/syft/src/syft/store/postgres_pool_connection.py deleted file mode 100644 index 41b462b262b..00000000000 --- a/packages/syft/src/syft/store/postgres_pool_connection.py +++ /dev/null @@ -1,116 +0,0 @@ -# stdlib -from collections.abc import Generator -from contextlib import contextmanager -import logging -import time - -# third party -from psycopg import Connection -from psycopg_pool import ConnectionPool -from psycopg_pool import PoolTimeout - -# relative -from ..types.errors import SyftException - -logger = logging.getLogger(__name__) - - -MIN_DB_POOL_SIZE = 1 -MAX_DB_POOL_SIZE = 10 -DEFAULT_POOL_CONN_TIMEOUT = 30 -CONN_RETRY_INTERVAL = 1 - - -class PostgresPoolConnection: - def __init__( - self, - client_config: dict, - min_size: int = MIN_DB_POOL_SIZE, - max_size: int = MAX_DB_POOL_SIZE, - timeout: int = DEFAULT_POOL_CONN_TIMEOUT, - retry_interval: int = CONN_RETRY_INTERVAL, - pool_kwargs: dict | None = None, - ) -> None: - connect_kwargs = self._connection_kwargs_from_config(client_config) - - # https://www.psycopg.org/psycopg3/docs/advanced/prepare.html#using-prepared-statements-with-pgbouncer - # This should default to None to allow the connection pool to manage the prepare threshold - connect_kwargs["prepare_threshold"] = None - - self.pool = ConnectionPool( - kwargs=connect_kwargs, - open=False, - check=ConnectionPool.check_connection, - min_size=min_size, - max_size=max_size, - **pool_kwargs, - ) - logger.info( - f"Connection pool created with min_size={self.min_size} and max_size={self.max_size}" - ) - logger.info(f"Connected to {self.store_config.client_config.dbname}") - logger.info(f"PostgreSQL Pool connection: {self.pool.get_stats()}") - self.timeout = timeout - self.retry_interval = retry_interval - - @contextmanager - def get_connection(self) -> Generator[Connection, None, None]: - """Provide a connection from the pool, waiting if necessary until one is available.""" - conn = None - start_time = time.time() - - try: - while True: - try: - conn = self.pool.getconn(timeout=self.retry_interval) - if conn: - yield conn # Return the connection object to be used in the context - break - except PoolTimeout as e: - elapsed_time = time.time() - start_time - if elapsed_time >= self.timeout: - message = f"Could not get a connection from database pool within {self.timeout} seconds." - raise SyftException.from_exception( - e, - public_message=message, - ) - logger.warning( - f"Connection not available, retrying... ({elapsed_time:.2f} seconds elapsed)" - ) - time.sleep(self.retry_interval) - - except Exception as e: - logger.error(f"Error getting connection from pool: {e}") - yield None - finally: - if conn: - self.pool.putconn(conn) - - def release_connection(self, conn: Connection) -> None: - """Release a connection back to the pool.""" - try: - if conn.closed or conn.broken: - self.pool.putconn(conn, close=True) - logger.info("Broken connection closed and removed from pool.") - else: - self.pool.putconn(conn) - logger.info("Connection released back to pool.") - except Exception as e: - logger.error(f"Error releasing connection: {e}") - - def _connection_kwargs_from_config(self, config: dict) -> dict: - return { - "dbname": config.get("dbname"), - "user": config.get("user"), - "password": config.get("password"), - "host": config.get("host"), - "port": config.get("port"), - } - - def close_all_connections(self) -> None: - """Close all connections in the pool and shut down the pool.""" - try: - self.pool.close() - logger.info("All connections closed and pool shut down.") - except Exception as e: - logger.error(f"Error closing connection pool: {e}") diff --git a/packages/syft/src/syft/store/postgres_query_executor.py b/packages/syft/src/syft/store/postgres_query_executor.py deleted file mode 100644 index dea9bb849b0..00000000000 --- a/packages/syft/src/syft/store/postgres_query_executor.py +++ /dev/null @@ -1,82 +0,0 @@ -# stdlib -import logging -from typing import Any - -# third party -import psycopg -from psycopg import Cursor - -# relative -from .postgres_pool_connection import PostgresPoolConnection - -logger = logging.getLogger(__name__) - - -MAX_QUERY_RETRIES = 3 -QUERY_RETRY_DELAY = 5 - - -class PostgresQueryExecutor: - def __init__( - self, - connection_pool: PostgresPoolConnection, - retries: int = MAX_QUERY_RETRIES, - retry_delay: int = QUERY_RETRY_DELAY, - ) -> None: - self.connection_pool = connection_pool - self.retries = retries - self.retry_delay = retry_delay - - def execute_query(self, query: str, args: list[Any] | None = None) -> Cursor | None: - """ - Execute a query on the database using a context-managed connection. - Handles `InFailedSqlTransaction` errors by rolling back the transaction. - Returns a cursor object after execution for further handling by the caller. - - :param query: SQL query to execute. - :param params: Query parameters (optional). - :return: Cursor object or None if an error occurs. - """ - attempt = 0 - while attempt < self.retries: - try: - # Using the context manager for the connection - with self.connection_pool.get_connection() as conn: - if conn is None: - return None - - cur = conn.cursor() - - # Check if connection is in failed state (i.e., in a failed transaction) - if conn.status == psycopg.extensions.STATUS_IN_FAILED_TRANSACTION: - logger.warning( - "Transaction is in a failed state. Rolling back." - ) - conn.rollback() - - cur.execute(query, args) - - conn.commit() - - return cur # Return the cursor object - - except psycopg.errors.InFailedSqlTransaction as e: - logger.error(f"Transaction failed and is in an invalid state: {e}") - if conn and not conn.closed: - conn.rollback() # Roll back the transaction - attempt += 1 # Retry the query after rollback - - except (psycopg.OperationalError, psycopg.errors.AdminShutdown) as e: - logger.error( - f"Server error or termination: {e}. Retrying ({attempt + 1}/{self.retries})..." - ) - attempt += 1 - - except Exception as e: - logger.error(f"Error executing query: {e}") - if conn and not conn.closed: - conn.rollback() # Roll back on any general error - return None - - logger.error(f"Query failed after {self.retries} attempts.") - return None diff --git a/packages/syft/src/syft/store/postgresql_document_store.py b/packages/syft/src/syft/store/postgresql_document_store.py deleted file mode 100644 index 75dee5776ca..00000000000 --- a/packages/syft/src/syft/store/postgresql_document_store.py +++ /dev/null @@ -1,313 +0,0 @@ -# stdlib -import logging -from typing import Any - -# third party -import psycopg -from psycopg import Connection -from psycopg import Cursor -from psycopg.errors import DuplicateTable -from psycopg.errors import InFailedSqlTransaction -from pydantic import Field -from typing_extensions import Self - -# relative -from ..serde.deserialize import _deserialize -from ..serde.serializable import serializable -from ..serde.serialize import _serialize -from ..types.errors import SyftException -from ..types.result import as_result -from ..types.uid import UID -from .document_store import DocumentStore -from .document_store import PartitionSettings -from .document_store import StoreClientConfig -from .document_store import StoreConfig -from .kv_document_store import KeyValueBackingStore -from .locks import LockingConfig -from .locks import NoLockingConfig -from .locks import SyftLock -from .sqlite_document_store import SQLiteBackingStore -from .sqlite_document_store import SQLiteStorePartition -from .sqlite_document_store import _repr_debug_ -from .sqlite_document_store import cache_key -from .sqlite_document_store import special_exception_public_message - -logger = logging.getLogger(__name__) -_CONNECTION_POOL_DB: dict[str, Connection] = {} - - -# https://www.psycopg.org/docs/module.html#psycopg2.connect -@serializable(canonical_name="PostgreSQLStoreClientConfig", version=1) -class PostgreSQLStoreClientConfig(StoreClientConfig): - dbname: str - username: str - password: str - host: str - port: int - - # makes hashabel - class Config: - frozen = True - - def __hash__(self) -> int: - return hash((self.dbname, self.username, self.password, self.host, self.port)) - - def __str__(self) -> str: - return f"dbname={self.dbname} user={self.username} password={self.password} host={self.host} port={self.port}" - - -@serializable(canonical_name="PostgreSQLStorePartition", version=1) -class PostgreSQLStorePartition(SQLiteStorePartition): - pass - - -@serializable(canonical_name="PostgreSQLDocumentStore", version=1) -class PostgreSQLDocumentStore(DocumentStore): - partition_type = PostgreSQLStorePartition - - -@serializable( - attrs=["index_name", "settings", "store_config"], - canonical_name="PostgreSQLBackingStore", - version=1, -) -class PostgreSQLBackingStore(SQLiteBackingStore): - def __init__( - self, - index_name: str, - settings: PartitionSettings, - store_config: StoreConfig, - ddtype: type | None = None, - ) -> None: - self.index_name = index_name - self.settings = settings - self.store_config = store_config - self.store_config_hash = hash(store_config.client_config) - self._ddtype = ddtype - if self.store_config.client_config: - self.dbname = self.store_config.client_config.dbname - - self.lock = SyftLock(NoLockingConfig()) - self.create_table() - self.subs_char = r"%s" # thanks postgresql - - def _connect(self) -> None: - if self.store_config.client_config: - connection = psycopg.connect( - dbname=self.store_config.client_config.dbname, - user=self.store_config.client_config.username, - password=self.store_config.client_config.password, - host=self.store_config.client_config.host, - port=self.store_config.client_config.port, - # This should default to None, - # https://www.psycopg.org/psycopg3/docs/advanced/prepare.html#using-prepared-statements-with-pgbouncer - prepare_threshold=None, - ) - _CONNECTION_POOL_DB[cache_key(self.dbname)] = connection - logger.info(f"Connected to {self.store_config.client_config.dbname}") - logger.info(f"PostgreSQL database connection: {connection.info.dsn}") - - def create_table(self) -> None: - db = self.db - try: - with self.lock: - with db.cursor() as cur: - cur.execute( - f"CREATE TABLE IF NOT EXISTS {self.table_name} (uid VARCHAR(32) NOT NULL PRIMARY KEY, " # nosec - + "repr TEXT NOT NULL, value BYTEA NOT NULL, " # nosec - + "sqltime TIMESTAMP NOT NULL DEFAULT NOW())" # nosec - ) - cur.connection.commit() - except DuplicateTable: - pass - except InFailedSqlTransaction: - db.rollback() - except Exception as e: - public_message = special_exception_public_message(self.table_name, e) - raise SyftException.from_exception(e, public_message=public_message) - - @property - def db(self) -> Connection: - if cache_key(self.dbname) not in _CONNECTION_POOL_DB: - self._connect() - return _CONNECTION_POOL_DB[cache_key(self.dbname)] - - @property - def cur(self) -> Cursor: - return self.db.cursor() - - @staticmethod - @as_result(SyftException) - def _execute( - lock: SyftLock, - cursor: Cursor, - db: Connection, - table_name: str, - sql: str, - args: list[Any] | None, - ) -> Cursor: - with lock: - try: - cursor.execute(sql, args) # Execute the SQL with arguments - db.commit() # Commit if everything went ok - except InFailedSqlTransaction as ie: - db.rollback() # Rollback if something went wrong - raise SyftException( - public_message=f"Transaction `{sql}` failed and was rolled back. \n" - f"Error: {ie}." - ) - except Exception as e: - logger.debug(f"Rolling back SQL: {sql} with args: {args}") - db.rollback() # Rollback on any other exception to maintain clean state - public_message = special_exception_public_message(table_name, e) - logger.error(public_message) - raise SyftException.from_exception(e, public_message=public_message) - return cursor - - def _set(self, key: UID, value: Any) -> None: - if self._exists(key): - self._update(key, value) - else: - insert_sql = ( - f"insert into {self.table_name} (uid, repr, value) VALUES " # nosec - f"({self.subs_char}, {self.subs_char}, {self.subs_char})" # nosec - ) - data = _serialize(value, to_bytes=True) - with self.cur as cur: - self._execute( - self.lock, - cur, - cur.connection, - self.table_name, - insert_sql, - [str(key), _repr_debug_(value), data], - ).unwrap() - - def _update(self, key: UID, value: Any) -> None: - insert_sql = ( - f"update {self.table_name} set uid = {self.subs_char}, " # nosec - f"repr = {self.subs_char}, value = {self.subs_char} " # nosec - f"where uid = {self.subs_char}" # nosec - ) - data = _serialize(value, to_bytes=True) - with self.cur as cur: - self._execute( - self.lock, - cur, - cur.connection, - self.table_name, - insert_sql, - [str(key), _repr_debug_(value), data, str(key)], - ).unwrap() - - def _get(self, key: UID) -> Any: - select_sql = ( - f"select * from {self.table_name} where uid = {self.subs_char} " # nosec - "order by sqltime" - ) - with self.cur as cur: - cursor = self._execute( - self.lock, cur, cur.connection, self.table_name, select_sql, [str(key)] - ).unwrap(public_message=f"Query {select_sql} failed") - row = cursor.fetchone() - if row is None or len(row) == 0: - raise KeyError(f"{key} not in {type(self)}") - data = row[2] - return _deserialize(data, from_bytes=True) - - def _exists(self, key: UID) -> bool: - select_sql = f"select uid from {self.table_name} where uid = {self.subs_char}" # nosec - row = None - with self.cur as cur: - cursor = self._execute( - self.lock, cur, cur.connection, self.table_name, select_sql, [str(key)] - ).unwrap() - row = cursor.fetchone() # type: ignore - if row is None: - return False - return bool(row) - - def _get_all(self) -> Any: - select_sql = f"select * from {self.table_name} order by sqltime" # nosec - keys = [] - data = [] - with self.cur as cur: - cursor = self._execute( - self.lock, cur, cur.connection, self.table_name, select_sql, [] - ).unwrap() - rows = cursor.fetchall() # type: ignore - if not rows: - return {} - - for row in rows: - keys.append(UID(row[0])) - data.append(_deserialize(row[2], from_bytes=True)) - - return dict(zip(keys, data)) - - def _get_all_keys(self) -> Any: - select_sql = f"select uid from {self.table_name} order by sqltime" # nosec - with self.cur as cur: - cursor = self._execute( - self.lock, cur, cur.connection, self.table_name, select_sql, [] - ).unwrap() - rows = cursor.fetchall() # type: ignore - if not rows: - return [] - keys = [UID(row[0]) for row in rows] - return keys - - def _delete(self, key: UID) -> None: - select_sql = f"delete from {self.table_name} where uid = {self.subs_char}" # nosec - with self.cur as cur: - self._execute( - self.lock, cur, cur.connection, self.table_name, select_sql, [str(key)] - ).unwrap() - - def _delete_all(self) -> None: - select_sql = f"delete from {self.table_name}" # nosec - with self.cur as cur: - self._execute( - self.lock, cur, cur.connection, self.table_name, select_sql, [] - ).unwrap() - - def _len(self) -> int: - select_sql = f"select count(uid) from {self.table_name}" # nosec - with self.cur as cur: - cursor = self._execute( - self.lock, cur, cur.connection, self.table_name, select_sql, [] - ).unwrap() - cnt = cursor.fetchone()[0] - return cnt - - def _close(self) -> None: - self._commit() - if cache_key(self.dbname) in _CONNECTION_POOL_DB: - conn = _CONNECTION_POOL_DB[cache_key(self.dbname)] - conn.close() - _CONNECTION_POOL_DB.pop(cache_key(self.dbname), None) - - def _commit(self) -> None: - self.db.commit() - - -@serializable() -class PostgreSQLStoreConfig(StoreConfig): - __canonical_name__ = "PostgreSQLStorePartition" - - client_config: PostgreSQLStoreClientConfig - store_type: type[DocumentStore] = PostgreSQLDocumentStore - backing_store: type[KeyValueBackingStore] = PostgreSQLBackingStore - locking_config: LockingConfig = Field(default_factory=NoLockingConfig) - - @classmethod - def from_dict(cls, client_config_dict: dict) -> Self: - postgresql_client_config = PostgreSQLStoreClientConfig( - dbname=client_config_dict["POSTGRESQL_DBNAME"], - host=client_config_dict["POSTGRESQL_HOST"], - port=client_config_dict["POSTGRESQL_PORT"], - username=client_config_dict["POSTGRESQL_USERNAME"], - password=client_config_dict["POSTGRESQL_PASSWORD"], - ) - - return PostgreSQLStoreConfig(client_config=postgresql_client_config) diff --git a/packages/syft/tests/syft/migrations/protocol_communication_test.py b/packages/syft/tests/syft/migrations/protocol_communication_test.py index da8ca1fd37a..f111cbcc598 100644 --- a/packages/syft/tests/syft/migrations/protocol_communication_test.py +++ b/packages/syft/tests/syft/migrations/protocol_communication_test.py @@ -20,7 +20,7 @@ from syft.service.service import ServiceConfigRegistry from syft.service.service import service_method from syft.service.user.user_roles import GUEST_ROLE_LEVEL -from syft.store.db.sqlite_db import DBManager +from syft.store.db.base import DBManager from syft.store.document_store import DocumentStore from syft.store.document_store import NewBaseStash from syft.store.document_store import PartitionSettings diff --git a/packages/syft/tests/syft/stores/action_store_test.py b/packages/syft/tests/syft/stores/action_store_test.py index d40fa15e8ca..f9872edb590 100644 --- a/packages/syft/tests/syft/stores/action_store_test.py +++ b/packages/syft/tests/syft/stores/action_store_test.py @@ -16,7 +16,7 @@ from syft.service.user.user import User from syft.service.user.user_roles import ServiceRole from syft.service.user.user_stash import UserStash -from syft.store.db.sqlite_db import DBManager +from syft.store.db.base import DBManager from syft.types.uid import UID # relative From 8eb48a002e1fc834852f585215e8b140513f65df Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 12 Sep 2024 11:11:37 +0200 Subject: [PATCH 148/197] add a test for todo item --- .../src/syft/service/queue/queue_stash.py | 1 - .../syft/tests/syft/stores/base_stash_test.py | 19 +++++++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index 463addbdfdf..5b2ef9f3318 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -147,7 +147,6 @@ def pop_on_complete(self, credentials: SyftVerifyKey, uid: UID) -> QueueItem: def get_by_status( self, credentials: SyftVerifyKey, status: Status ) -> list[QueueItem]: - # TODO do we need json serialization for Status? return self.get_all( credentials=credentials, filters={"status": status}, diff --git a/packages/syft/tests/syft/stores/base_stash_test.py b/packages/syft/tests/syft/stores/base_stash_test.py index 9999848fe5e..ee331dbd952 100644 --- a/packages/syft/tests/syft/stores/base_stash_test.py +++ b/packages/syft/tests/syft/stores/base_stash_test.py @@ -12,6 +12,7 @@ # syft absolute from syft.serde.serializable import serializable +from syft.service.queue.queue_stash import Status from syft.service.request.request_service import RequestService from syft.store.db.db import SQLiteDBConfig from syft.store.db.db import SQLiteDBManager @@ -35,6 +36,7 @@ class MockObject(SyftObject): importance: int value: int linked_obj: LinkedObject | None = None + status: Status = Status.CREATED __attr_searchable__ = ["id", "name", "desc", "importance"] __attr_unique__ = ["id", "name"] @@ -295,6 +297,23 @@ def test_basestash_query_one( ).unwrap() +def test_basestash_query_enum( + root_verify_key, base_stash: MockStash, mock_object: MockObject +) -> None: + base_stash.set(root_verify_key, mock_object).unwrap() + result = base_stash.get_one( + root_verify_key, + filters={"status": Status.CREATED}, + ).unwrap() + + assert result == mock_object + with pytest.raises(NotFoundException): + result = base_stash.get_one( + root_verify_key, + filters={"status": Status.PROCESSING}, + ).unwrap() + + def test_basestash_query_linked_obj( root_verify_key, base_stash: MockStash, mock_object: MockObject ) -> None: From 3451ead0fc6b2760cf1197e780f3c6d20424516d Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 12 Sep 2024 11:37:09 +0200 Subject: [PATCH 149/197] clean up todo notes --- packages/syft/src/syft/store/db/stash.py | 4 ++-- packages/syft/tests/syft/settings/settings_service_test.py | 1 - packages/syft/tests/syft/settings/settings_stash_test.py | 1 - packages/syft/tests/syft/worker_test.py | 1 - 4 files changed, 2 insertions(+), 5 deletions(-) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 9a618fad7ac..097d616cf22 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -386,7 +386,8 @@ def set( fields = serialize_json(obj) try: # check if the fields are deserializable - # PR NOTE: Is this too much extra work? + # TODO: Ideally, we want to make sure we don't serialize what we cannot deserialize + # and remove this check. deserialize_json(fields) except Exception as e: raise StashException( @@ -689,7 +690,6 @@ def has_permissions( self, permissions: list[ActionObjectPermission], session: Session = None ) -> bool: # TODO: we should use a permissions table to check all permissions at once - # TODO: should check for compound permissions permission_filters = [ sa.and_( diff --git a/packages/syft/tests/syft/settings/settings_service_test.py b/packages/syft/tests/syft/settings/settings_service_test.py index 0c9306aaff0..b856f6040c5 100644 --- a/packages/syft/tests/syft/settings/settings_service_test.py +++ b/packages/syft/tests/syft/settings/settings_service_test.py @@ -99,7 +99,6 @@ def test_settingsservice_set_success( ) -> None: response = settings_service.set(authed_context, settings) assert isinstance(response, ServerSettings) - # PR NOTE do we write syft_client_verify_key and syft_server_location to the stash or not? response.syft_client_verify_key = None response.syft_server_location = None response.pwd_token_config.syft_client_verify_key = None diff --git a/packages/syft/tests/syft/settings/settings_stash_test.py b/packages/syft/tests/syft/settings/settings_stash_test.py index bb003757c11..2d976b52108 100644 --- a/packages/syft/tests/syft/settings/settings_stash_test.py +++ b/packages/syft/tests/syft/settings/settings_stash_test.py @@ -4,7 +4,6 @@ from syft.service.settings.settings_stash import SettingsStash -# NOTE: Is this test necessary? It is just testing set and update methods def test_settingsstash_set( root_verify_key, settings_stash: SettingsStash, diff --git a/packages/syft/tests/syft/worker_test.py b/packages/syft/tests/syft/worker_test.py index 92faf0d36b7..7b53c47ad2f 100644 --- a/packages/syft/tests/syft/worker_test.py +++ b/packages/syft/tests/syft/worker_test.py @@ -97,7 +97,6 @@ def test_action_store(action_object_stash: ActionObjectStash) -> None: test_verify_key = test_signing_key.verify_key raw_data = np.array([1, 2, 3]) test_object = ActionObject.from_obj(raw_data) - # PR NOTE: Why `uid` was not `uid = test_object.id`? uid = test_object.id action_object_stash.set_or_update( From d065151e6f25fb7efc64d97b39ef54a3d561e130 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 12 Sep 2024 15:19:33 +0530 Subject: [PATCH 150/197] rename postgres_db to postgres.py --- packages/grid/backend/grid/core/server.py | 2 +- packages/syft/src/syft/server/server.py | 4 +-- packages/syft/src/syft/store/db/base.py | 19 +++++++++++--- .../store/db/{postgres_db.py => postgres.py} | 15 +---------- packages/syft/src/syft/store/db/sqlite_db.py | 26 ++++--------------- 5 files changed, 24 insertions(+), 42 deletions(-) rename packages/syft/src/syft/store/db/{postgres_db.py => postgres.py} (74%) diff --git a/packages/grid/backend/grid/core/server.py b/packages/grid/backend/grid/core/server.py index 94f209613b7..0208bdadba3 100644 --- a/packages/grid/backend/grid/core/server.py +++ b/packages/grid/backend/grid/core/server.py @@ -16,7 +16,7 @@ from syft.service.queue.zmq_client import ZMQQueueConfig from syft.store.blob_storage.seaweedfs import SeaweedFSClientConfig from syft.store.blob_storage.seaweedfs import SeaweedFSConfig -from syft.store.db.postgres_db import PostgresDBConfig +from syft.store.db.postgres import PostgresDBConfig from syft.store.db.sqlite_db import SQLiteDBConfig from syft.types.uid import UID diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 5462c33f729..53bad805c49 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -88,8 +88,8 @@ from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig from ..store.blob_storage.seaweedfs import SeaweedFSBlobDeposit from ..store.db.base import DBConfig -from ..store.db.postgres_db import PostgresDBConfig -from ..store.db.postgres_db import PostgresDBManager +from ..store.db.postgres import PostgresDBConfig +from ..store.db.postgres import PostgresDBManager from ..store.db.sqlite_db import SQLiteDBConfig from ..store.db.sqlite_db import SQLiteDBManager from ..store.db.stash import ObjectStash diff --git a/packages/syft/src/syft/store/db/base.py b/packages/syft/src/syft/store/db/base.py index 8a987861070..640e04a0b6a 100644 --- a/packages/syft/src/syft/store/db/base.py +++ b/packages/syft/src/syft/store/db/base.py @@ -1,5 +1,7 @@ # stdlib import logging +from typing import Generic +from typing import TypeVar # third party from pydantic import BaseModel @@ -10,6 +12,7 @@ from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...types.uid import UID +from .schema import Base logger = logging.getLogger(__name__) @@ -23,7 +26,10 @@ def connection_string(self) -> str: raise NotImplementedError("Subclasses must implement this method.") -class DBManager: +ConfigT = TypeVar("ConfigT", bound=DBConfig) + + +class DBManager(Generic[ConfigT]): def __init__( self, config: DBConfig, @@ -40,14 +46,19 @@ def __init__( ) logger.info(f"Connecting to {config.connection_string}") self.sessionmaker = sessionmaker(bind=self.engine) - logger.info(f"Successfully connected to {config.connection_string}") self.update_settings() + logger.info(f"Successfully connected to {config.connection_string}") def update_settings(self) -> None: pass def init_tables(self) -> None: - pass + if self.config.reset: + # drop all tables that we know about + Base.metadata.drop_all(bind=self.engine) + self.config.reset = False + Base.metadata.create_all(self.engine) def reset(self) -> None: - pass + Base.metadata.drop_all(bind=self.engine) + Base.metadata.create_all(self.engine) diff --git a/packages/syft/src/syft/store/db/postgres_db.py b/packages/syft/src/syft/store/db/postgres.py similarity index 74% rename from packages/syft/src/syft/store/db/postgres_db.py rename to packages/syft/src/syft/store/db/postgres.py index 6166a8eb665..3f14fee89c1 100644 --- a/packages/syft/src/syft/store/db/postgres_db.py +++ b/packages/syft/src/syft/store/db/postgres.py @@ -6,7 +6,6 @@ from ...server.credentials import SyftVerifyKey from ...types.uid import UID from .base import DBManager -from .schema import Base from .sqlite_db import DBConfig @@ -30,22 +29,10 @@ def connection_string(self) -> str: ).render_as_string(hide_password=False) -class PostgresDBManager(DBManager): +class PostgresDBManager(DBManager[PostgresDBConfig]): def update_settings(self) -> None: return super().update_settings() - def init_tables(self) -> None: - if self.config.reset: - # drop all tables that we know about - Base.metadata.drop_all(bind=self.engine) - self.config.reset = False - - Base.metadata.create_all(self.engine) - - def reset(self) -> None: - Base.metadata.drop_all(bind=self.engine) - Base.metadata.create_all(self.engine) - @classmethod def random( cls: type, diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index c4cfea48189..323cc411801 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -14,7 +14,6 @@ from ...types.uid import UID from .base import DBConfig from .base import DBManager -from .schema import Base @serializable(canonical_name="SQLiteDBConfig", version=1) @@ -28,28 +27,13 @@ def connection_string(self) -> str: return f"sqlite:///{filepath.resolve()}" -class SQLiteDBManager(DBManager): +class SQLiteDBManager(DBManager[SQLiteDBConfig]): def update_settings(self) -> None: - # TODO split SQLite / PostgresDBManager connection = self.engine.connect() - - if self.engine.dialect.name == "sqlite": - connection.execute(sa.text("PRAGMA journal_mode = WAL")) - connection.execute(sa.text("PRAGMA busy_timeout = 5000")) - # TODO check - connection.execute(sa.text("PRAGMA temp_store = 2")) - connection.execute(sa.text("PRAGMA synchronous = 1")) - - def init_tables(self) -> None: - if self.config.reset: - # drop all tables that we know about - Base.metadata.drop_all(bind=self.engine) - self.config.reset = False - Base.metadata.create_all(self.engine) - - def reset(self) -> None: - Base.metadata.drop_all(bind=self.engine) - Base.metadata.create_all(self.engine) + connection.execute(sa.text("PRAGMA journal_mode = WAL")) + connection.execute(sa.text("PRAGMA busy_timeout = 5000")) + connection.execute(sa.text("PRAGMA temp_store = 2")) + connection.execute(sa.text("PRAGMA synchronous = 1")) @classmethod def random( From f9caf281ed78cb6bb42827846cb39db371b2fd97 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 12 Sep 2024 15:25:38 +0530 Subject: [PATCH 151/197] rename base to sqlite_db.py rename sqilte_db.py to sqlite.py --- packages/grid/backend/grid/core/server.py | 2 +- packages/syft/src/syft/server/server.py | 6 +- packages/syft/src/syft/server/uvicorn.py | 2 +- .../syft/src/syft/server/worker_settings.py | 2 +- .../src/syft/service/action/action_service.py | 2 +- .../syft/src/syft/service/api/api_service.py | 2 +- .../attestation/attestation_service.py | 2 +- .../src/syft/service/blob_storage/service.py | 2 +- .../src/syft/service/code/status_service.py | 2 +- .../syft/service/code/user_code_service.py | 2 +- .../code_history/code_history_service.py | 2 +- .../data_subject_member_service.py | 2 +- .../data_subject/data_subject_service.py | 2 +- .../syft/service/dataset/dataset_service.py | 2 +- .../syft/service/enclave/enclave_service.py | 2 +- .../syft/src/syft/service/job/job_service.py | 2 +- .../syft/src/syft/service/log/log_service.py | 2 +- .../syft/service/metadata/metadata_service.py | 2 +- .../service/migration/migration_service.py | 2 +- .../syft/service/network/network_service.py | 2 +- .../notification/notification_service.py | 2 +- .../syft/service/notifier/notifier_service.py | 2 +- .../src/syft/service/output/output_service.py | 2 +- .../src/syft/service/policy/policy_service.py | 2 +- .../syft/service/project/project_service.py | 2 +- .../src/syft/service/queue/queue_service.py | 2 +- .../syft/service/request/request_service.py | 2 +- .../syft/service/settings/settings_service.py | 2 +- .../src/syft/service/sync/sync_service.py | 2 +- .../syft/src/syft/service/sync/sync_stash.py | 2 +- .../src/syft/service/user/user_service.py | 2 +- .../service/worker/image_registry_service.py | 2 +- .../service/worker/worker_image_service.py | 2 +- .../service/worker/worker_pool_service.py | 2 +- .../src/syft/service/worker/worker_service.py | 2 +- packages/syft/src/syft/store/db/base.py | 64 -------------- packages/syft/src/syft/store/db/postgres.py | 4 +- packages/syft/src/syft/store/db/sqlite.py | 53 ++++++++++++ packages/syft/src/syft/store/db/sqlite_db.py | 85 +++++++++++-------- packages/syft/src/syft/store/db/stash.py | 4 +- .../migrations/protocol_communication_test.py | 2 +- .../tests/syft/stores/action_store_test.py | 2 +- .../syft/tests/syft/stores/base_stash_test.py | 4 +- .../tests/syft/stores/store_fixtures_test.py | 4 +- packages/syft/tests/syft/worker_test.py | 2 +- 45 files changed, 149 insertions(+), 149 deletions(-) delete mode 100644 packages/syft/src/syft/store/db/base.py create mode 100644 packages/syft/src/syft/store/db/sqlite.py diff --git a/packages/grid/backend/grid/core/server.py b/packages/grid/backend/grid/core/server.py index 0208bdadba3..d337b0357dc 100644 --- a/packages/grid/backend/grid/core/server.py +++ b/packages/grid/backend/grid/core/server.py @@ -17,7 +17,7 @@ from syft.store.blob_storage.seaweedfs import SeaweedFSClientConfig from syft.store.blob_storage.seaweedfs import SeaweedFSConfig from syft.store.db.postgres import PostgresDBConfig -from syft.store.db.sqlite_db import SQLiteDBConfig +from syft.store.db.sqlite import SQLiteDBConfig from syft.types.uid import UID # server absolute diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 53bad805c49..6ad7f9e6e4a 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -87,11 +87,11 @@ from ..store.blob_storage.on_disk import OnDiskBlobStorageClientConfig from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig from ..store.blob_storage.seaweedfs import SeaweedFSBlobDeposit -from ..store.db.base import DBConfig from ..store.db.postgres import PostgresDBConfig from ..store.db.postgres import PostgresDBManager -from ..store.db.sqlite_db import SQLiteDBConfig -from ..store.db.sqlite_db import SQLiteDBManager +from ..store.db.sqlite import SQLiteDBConfig +from ..store.db.sqlite import SQLiteDBManager +from ..store.db.sqlite_db import DBConfig from ..store.db.stash import ObjectStash from ..store.document_store import StoreConfig from ..store.document_store_errors import NotFoundException diff --git a/packages/syft/src/syft/server/uvicorn.py b/packages/syft/src/syft/server/uvicorn.py index 13fff13d973..20e8ac4c43a 100644 --- a/packages/syft/src/syft/server/uvicorn.py +++ b/packages/syft/src/syft/server/uvicorn.py @@ -27,7 +27,7 @@ from ..abstract_server import ServerSideType from ..client.client import API_PATH from ..deployment_type import DeploymentType -from ..store.db.base import DBConfig +from ..store.db.sqlite_db import DBConfig from ..util.autoreload import enable_autoreload from ..util.constants import DEFAULT_TIMEOUT from ..util.telemetry import TRACING_ENABLED diff --git a/packages/syft/src/syft/server/worker_settings.py b/packages/syft/src/syft/server/worker_settings.py index e2a3c64a0e8..c76e5617d0d 100644 --- a/packages/syft/src/syft/server/worker_settings.py +++ b/packages/syft/src/syft/server/worker_settings.py @@ -13,7 +13,7 @@ from ..server.credentials import SyftSigningKey from ..service.queue.base_queue import QueueConfig from ..store.blob_storage import BlobStorageConfig -from ..store.db.base import DBConfig +from ..store.db.sqlite_db import DBConfig from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.syft_object import SyftObject from ..types.uid import UID diff --git a/packages/syft/src/syft/service/action/action_service.py b/packages/syft/src/syft/service/action/action_service.py index 6fb49804d1e..38b53cf977d 100644 --- a/packages/syft/src/syft/service/action/action_service.py +++ b/packages/syft/src/syft/service/action/action_service.py @@ -9,7 +9,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.datetime import DateTime diff --git a/packages/syft/src/syft/service/api/api_service.py b/packages/syft/src/syft/service/api/api_service.py index 0f63e50f01b..cce0d53698a 100644 --- a/packages/syft/src/syft/service/api/api_service.py +++ b/packages/syft/src/syft/service/api/api_service.py @@ -10,7 +10,7 @@ from ...serde.serializable import serializable from ...service.action.action_endpoint import CustomEndpointActionObject from ...service.action.action_object import ActionObject -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException diff --git a/packages/syft/src/syft/service/attestation/attestation_service.py b/packages/syft/src/syft/service/attestation/attestation_service.py index 4fcc434900e..1794ef6b03a 100644 --- a/packages/syft/src/syft/service/attestation/attestation_service.py +++ b/packages/syft/src/syft/service/attestation/attestation_service.py @@ -6,7 +6,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...types.errors import SyftException from ...types.result import as_result from ...util.util import str_to_bool diff --git a/packages/syft/src/syft/service/blob_storage/service.py b/packages/syft/src/syft/service/blob_storage/service.py index 3067810e099..f8069566268 100644 --- a/packages/syft/src/syft/service/blob_storage/service.py +++ b/packages/syft/src/syft/service/blob_storage/service.py @@ -11,7 +11,7 @@ from ...store.blob_storage import BlobRetrieval from ...store.blob_storage.on_disk import OnDiskBlobDeposit from ...store.blob_storage.seaweedfs import SeaweedFSBlobDeposit -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...types.blob_storage import AzureSecureFilePathLocation from ...types.blob_storage import BlobFileType from ...types.blob_storage import BlobStorageEntry diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index dca5f93b04a..f6709ebab3a 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -4,7 +4,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store import PartitionSettings from ...types.uid import UID diff --git a/packages/syft/src/syft/service/code/user_code_service.py b/packages/syft/src/syft/service/code/user_code_service.py index d6ee23a9923..29fa9093f14 100644 --- a/packages/syft/src/syft/service/code/user_code_service.py +++ b/packages/syft/src/syft/service/code/user_code_service.py @@ -5,7 +5,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index b6f820cbff1..fa7da4976fa 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -3,7 +3,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...types.uid import UID from ..code.user_code import SubmitUserCode diff --git a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py index 01bbfd3b20c..ec0241b54e0 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_member_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_member_service.py @@ -3,7 +3,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store_errors import StashException from ...types.result import as_result diff --git a/packages/syft/src/syft/service/data_subject/data_subject_service.py b/packages/syft/src/syft/service/data_subject/data_subject_service.py index a6817a36382..8c67cd84849 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_service.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_service.py @@ -5,7 +5,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store_errors import StashException from ...types.result import as_result diff --git a/packages/syft/src/syft/service/dataset/dataset_service.py b/packages/syft/src/syft/service/dataset/dataset_service.py index 0bfbf535954..93ab6f8ad5b 100644 --- a/packages/syft/src/syft/service/dataset/dataset_service.py +++ b/packages/syft/src/syft/service/dataset/dataset_service.py @@ -5,7 +5,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...types.dicttuple import DictTuple from ...types.uid import UID from ..action.action_permissions import ActionObjectPermission diff --git a/packages/syft/src/syft/service/enclave/enclave_service.py b/packages/syft/src/syft/service/enclave/enclave_service.py index f0f449578ea..76d8c455e52 100644 --- a/packages/syft/src/syft/service/enclave/enclave_service.py +++ b/packages/syft/src/syft/service/enclave/enclave_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ..service import AbstractService diff --git a/packages/syft/src/syft/service/job/job_service.py b/packages/syft/src/syft/service/job/job_service.py index c1e1c68e18f..3b31f168f17 100644 --- a/packages/syft/src/syft/service/job/job_service.py +++ b/packages/syft/src/syft/service/job/job_service.py @@ -6,7 +6,7 @@ # relative from ...serde.serializable import serializable from ...server.worker_settings import WorkerSettings -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...types.errors import SyftException from ...types.uid import UID from ..action.action_object import ActionObject diff --git a/packages/syft/src/syft/service/log/log_service.py b/packages/syft/src/syft/service/log/log_service.py index ec6d9b57ad6..9d25679b895 100644 --- a/packages/syft/src/syft/service/log/log_service.py +++ b/packages/syft/src/syft/service/log/log_service.py @@ -1,6 +1,6 @@ # relative from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...types.uid import UID from ..action.action_permissions import StoragePermission from ..context import AuthedServiceContext diff --git a/packages/syft/src/syft/service/metadata/metadata_service.py b/packages/syft/src/syft/service/metadata/metadata_service.py index 82043f80461..603091ca4ee 100644 --- a/packages/syft/src/syft/service/metadata/metadata_service.py +++ b/packages/syft/src/syft/service/metadata/metadata_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ..context import AuthedServiceContext from ..service import AbstractService from ..service import service_method diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index e5566123c9e..581c3561c92 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -6,7 +6,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...types.blob_storage import BlobStorageEntry diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 03eae13f3bd..d14fadb25ed 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -15,7 +15,7 @@ from ...server.credentials import SyftVerifyKey from ...server.worker_settings import WorkerSettings from ...service.settings.settings import ServerSettings -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException diff --git a/packages/syft/src/syft/service/notification/notification_service.py b/packages/syft/src/syft/service/notification/notification_service.py index afe3d97c662..9da8ae4a935 100644 --- a/packages/syft/src/syft/service/notification/notification_service.py +++ b/packages/syft/src/syft/service/notification/notification_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import StashException from ...types.errors import SyftException from ...types.result import as_result diff --git a/packages/syft/src/syft/service/notifier/notifier_service.py b/packages/syft/src/syft/service/notifier/notifier_service.py index ce922aa26f9..34a964b4a03 100644 --- a/packages/syft/src/syft/service/notifier/notifier_service.py +++ b/packages/syft/src/syft/service/notifier/notifier_service.py @@ -8,7 +8,7 @@ # relative from ...abstract_server import AbstractServer from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.errors import SyftException diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index 3d8c10ef699..7f59cf4e1e8 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -7,7 +7,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey from ...store.document_store_errors import StashException diff --git a/packages/syft/src/syft/service/policy/policy_service.py b/packages/syft/src/syft/service/policy/policy_service.py index e900cf0850b..da8759a3540 100644 --- a/packages/syft/src/syft/service/policy/policy_service.py +++ b/packages/syft/src/syft/service/policy/policy_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...types.uid import UID from ..context import AuthedServiceContext from ..response import SyftSuccess diff --git a/packages/syft/src/syft/service/project/project_service.py b/packages/syft/src/syft/service/project/project_service.py index 7b513f0c367..b04ec27bf8b 100644 --- a/packages/syft/src/syft/service/project/project_service.py +++ b/packages/syft/src/syft/service/project/project_service.py @@ -4,7 +4,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject diff --git a/packages/syft/src/syft/service/queue/queue_service.py b/packages/syft/src/syft/service/queue/queue_service.py index aab81efa5a7..d47f21052a8 100644 --- a/packages/syft/src/syft/service/queue/queue_service.py +++ b/packages/syft/src/syft/service/queue/queue_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ..service import AbstractService from .queue_stash import QueueStash diff --git a/packages/syft/src/syft/service/request/request_service.py b/packages/syft/src/syft/service/request/request_service.py index 7ada7772875..37ceeeae02a 100644 --- a/packages/syft/src/syft/service/request/request_service.py +++ b/packages/syft/src/syft/service/request/request_service.py @@ -4,7 +4,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.linked_obj import LinkedObject from ...types.errors import SyftException from ...types.result import as_result diff --git a/packages/syft/src/syft/service/settings/settings_service.py b/packages/syft/src/syft/service/settings/settings_service.py index 5e5b8a2dbfb..a0adbbf8575 100644 --- a/packages/syft/src/syft/service/settings/settings_service.py +++ b/packages/syft/src/syft/service/settings/settings_service.py @@ -8,7 +8,7 @@ # relative from ...abstract_server import ServerSideType from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.sqlite_document_store import SQLiteStoreConfig diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index 4368b83ba55..d044edf3b96 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -6,7 +6,7 @@ # relative from ...client.api import ServerIdentity from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store import NewBaseStash from ...store.document_store_errors import NotFoundException diff --git a/packages/syft/src/syft/service/sync/sync_stash.py b/packages/syft/src/syft/service/sync/sync_stash.py index 008ac8970c9..4090e62dec1 100644 --- a/packages/syft/src/syft/service/sync/sync_stash.py +++ b/packages/syft/src/syft/service/sync/sync_stash.py @@ -1,7 +1,7 @@ # relative from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store import PartitionSettings from ...store.document_store_errors import StashException diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index 9072c67b192..34d5071fa7a 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -11,7 +11,7 @@ from ...serde.serializable import serializable from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject diff --git a/packages/syft/src/syft/service/worker/image_registry_service.py b/packages/syft/src/syft/service/worker/image_registry_service.py index c4ea071356d..e9e80f21892 100644 --- a/packages/syft/src/syft/service/worker/image_registry_service.py +++ b/packages/syft/src/syft/service/worker/image_registry_service.py @@ -2,7 +2,7 @@ # relative from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...types.errors import SyftException from ...types.uid import UID from ..context import AuthedServiceContext diff --git a/packages/syft/src/syft/service/worker/worker_image_service.py b/packages/syft/src/syft/service/worker/worker_image_service.py index e8ac5d10d47..ab9cf5a250f 100644 --- a/packages/syft/src/syft/service/worker/worker_image_service.py +++ b/packages/syft/src/syft/service/worker/worker_image_service.py @@ -10,7 +10,7 @@ from ...custom_worker.config import WorkerConfig from ...custom_worker.k8s import IN_KUBERNETES from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...types.datetime import DateTime from ...types.dicttuple import DictTuple from ...types.errors import SyftException diff --git a/packages/syft/src/syft/service/worker/worker_pool_service.py b/packages/syft/src/syft/service/worker/worker_pool_service.py index 5160c0cd966..ae46958997d 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_service.py +++ b/packages/syft/src/syft/service/worker/worker_pool_service.py @@ -12,7 +12,7 @@ from ...custom_worker.k8s import IN_KUBERNETES from ...custom_worker.runner_k8s import KubernetesRunner from ...serde.serializable import serializable -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 09b7e147526..28f443cb5fa 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -13,7 +13,7 @@ from ...custom_worker.runner_k8s import KubernetesRunner from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.db.base import DBManager +from ...store.db.sqlite_db import DBManager from ...store.document_store import SyftSuccess from ...store.document_store_errors import StashException from ...types.errors import SyftException diff --git a/packages/syft/src/syft/store/db/base.py b/packages/syft/src/syft/store/db/base.py deleted file mode 100644 index 640e04a0b6a..00000000000 --- a/packages/syft/src/syft/store/db/base.py +++ /dev/null @@ -1,64 +0,0 @@ -# stdlib -import logging -from typing import Generic -from typing import TypeVar - -# third party -from pydantic import BaseModel -from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker - -# relative -from ...serde.serializable import serializable -from ...server.credentials import SyftVerifyKey -from ...types.uid import UID -from .schema import Base - -logger = logging.getLogger(__name__) - - -@serializable(canonical_name="DBConfig", version=1) -class DBConfig(BaseModel): - reset: bool = False - - @property - def connection_string(self) -> str: - raise NotImplementedError("Subclasses must implement this method.") - - -ConfigT = TypeVar("ConfigT", bound=DBConfig) - - -class DBManager(Generic[ConfigT]): - def __init__( - self, - config: DBConfig, - server_uid: UID, - root_verify_key: SyftVerifyKey, - ) -> None: - self.config = config - self.root_verify_key = root_verify_key - self.server_uid = server_uid - self.engine = create_engine( - config.connection_string, - # json_serializer=dumps, - # json_deserializer=loads, - ) - logger.info(f"Connecting to {config.connection_string}") - self.sessionmaker = sessionmaker(bind=self.engine) - self.update_settings() - logger.info(f"Successfully connected to {config.connection_string}") - - def update_settings(self) -> None: - pass - - def init_tables(self) -> None: - if self.config.reset: - # drop all tables that we know about - Base.metadata.drop_all(bind=self.engine) - self.config.reset = False - Base.metadata.create_all(self.engine) - - def reset(self) -> None: - Base.metadata.drop_all(bind=self.engine) - Base.metadata.create_all(self.engine) diff --git a/packages/syft/src/syft/store/db/postgres.py b/packages/syft/src/syft/store/db/postgres.py index 3f14fee89c1..c0a13a7dfb8 100644 --- a/packages/syft/src/syft/store/db/postgres.py +++ b/packages/syft/src/syft/store/db/postgres.py @@ -5,8 +5,8 @@ from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...types.uid import UID -from .base import DBManager -from .sqlite_db import DBConfig +from .sqlite import DBConfig +from .sqlite_db import DBManager @serializable(canonical_name="PostgresDBConfig", version=1) diff --git a/packages/syft/src/syft/store/db/sqlite.py b/packages/syft/src/syft/store/db/sqlite.py new file mode 100644 index 00000000000..705a86c3f53 --- /dev/null +++ b/packages/syft/src/syft/store/db/sqlite.py @@ -0,0 +1,53 @@ +# stdlib +from pathlib import Path +import tempfile +import uuid + +# third party +from pydantic import Field +import sqlalchemy as sa + +# relative +from ...serde.serializable import serializable +from ...server.credentials import SyftSigningKey +from ...server.credentials import SyftVerifyKey +from ...types.uid import UID +from .sqlite_db import DBConfig +from .sqlite_db import DBManager + + +@serializable(canonical_name="SQLiteDBConfig", version=1) +class SQLiteDBConfig(DBConfig): + filename: str = Field(default_factory=lambda: f"{uuid.uuid4()}.db") + path: Path = Field(default_factory=lambda: Path(tempfile.gettempdir())) + + @property + def connection_string(self) -> str: + filepath = self.path / self.filename + return f"sqlite:///{filepath.resolve()}" + + +class SQLiteDBManager(DBManager[SQLiteDBConfig]): + def update_settings(self) -> None: + connection = self.engine.connect() + connection.execute(sa.text("PRAGMA journal_mode = WAL")) + connection.execute(sa.text("PRAGMA busy_timeout = 5000")) + connection.execute(sa.text("PRAGMA temp_store = 2")) + connection.execute(sa.text("PRAGMA synchronous = 1")) + + @classmethod + def random( + cls, + *, + config: SQLiteDBConfig | None = None, + server_uid: UID | None = None, + root_verify_key: SyftVerifyKey | None = None, + ) -> "SQLiteDBManager": + root_verify_key = root_verify_key or SyftSigningKey.generate().verify_key + server_uid = server_uid or UID() + config = config or SQLiteDBConfig() + return SQLiteDBManager( + config=config, + server_uid=server_uid, + root_verify_key=root_verify_key, + ) diff --git a/packages/syft/src/syft/store/db/sqlite_db.py b/packages/syft/src/syft/store/db/sqlite_db.py index 323cc411801..640e04a0b6a 100644 --- a/packages/syft/src/syft/store/db/sqlite_db.py +++ b/packages/syft/src/syft/store/db/sqlite_db.py @@ -1,53 +1,64 @@ # stdlib -from pathlib import Path -import tempfile -import uuid +import logging +from typing import Generic +from typing import TypeVar # third party -from pydantic import Field -import sqlalchemy as sa +from pydantic import BaseModel +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker # relative from ...serde.serializable import serializable -from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey from ...types.uid import UID -from .base import DBConfig -from .base import DBManager +from .schema import Base +logger = logging.getLogger(__name__) -@serializable(canonical_name="SQLiteDBConfig", version=1) -class SQLiteDBConfig(DBConfig): - filename: str = Field(default_factory=lambda: f"{uuid.uuid4()}.db") - path: Path = Field(default_factory=lambda: Path(tempfile.gettempdir())) + +@serializable(canonical_name="DBConfig", version=1) +class DBConfig(BaseModel): + reset: bool = False @property def connection_string(self) -> str: - filepath = self.path / self.filename - return f"sqlite:///{filepath.resolve()}" + raise NotImplementedError("Subclasses must implement this method.") -class SQLiteDBManager(DBManager[SQLiteDBConfig]): - def update_settings(self) -> None: - connection = self.engine.connect() - connection.execute(sa.text("PRAGMA journal_mode = WAL")) - connection.execute(sa.text("PRAGMA busy_timeout = 5000")) - connection.execute(sa.text("PRAGMA temp_store = 2")) - connection.execute(sa.text("PRAGMA synchronous = 1")) - - @classmethod - def random( - cls, - *, - config: SQLiteDBConfig | None = None, - server_uid: UID | None = None, - root_verify_key: SyftVerifyKey | None = None, - ) -> "SQLiteDBManager": - root_verify_key = root_verify_key or SyftSigningKey.generate().verify_key - server_uid = server_uid or UID() - config = config or SQLiteDBConfig() - return SQLiteDBManager( - config=config, - server_uid=server_uid, - root_verify_key=root_verify_key, +ConfigT = TypeVar("ConfigT", bound=DBConfig) + + +class DBManager(Generic[ConfigT]): + def __init__( + self, + config: DBConfig, + server_uid: UID, + root_verify_key: SyftVerifyKey, + ) -> None: + self.config = config + self.root_verify_key = root_verify_key + self.server_uid = server_uid + self.engine = create_engine( + config.connection_string, + # json_serializer=dumps, + # json_deserializer=loads, ) + logger.info(f"Connecting to {config.connection_string}") + self.sessionmaker = sessionmaker(bind=self.engine) + self.update_settings() + logger.info(f"Successfully connected to {config.connection_string}") + + def update_settings(self) -> None: + pass + + def init_tables(self) -> None: + if self.config.reset: + # drop all tables that we know about + Base.metadata.drop_all(bind=self.engine) + self.config.reset = False + Base.metadata.create_all(self.engine) + + def reset(self) -> None: + Base.metadata.drop_all(bind=self.engine) + Base.metadata.create_all(self.engine) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 98f4669ee69..900f15e96fa 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -38,11 +38,11 @@ from ...types.uid import UID from ..document_store_errors import NotFoundException from ..document_store_errors import StashException -from .base import DBManager from .query import Query from .schema import Base from .schema import create_table -from .sqlite_db import SQLiteDBManager +from .sqlite import SQLiteDBManager +from .sqlite_db import DBManager StashT = TypeVar("StashT", bound=SyftObject) T = TypeVar("T") diff --git a/packages/syft/tests/syft/migrations/protocol_communication_test.py b/packages/syft/tests/syft/migrations/protocol_communication_test.py index f111cbcc598..da8ca1fd37a 100644 --- a/packages/syft/tests/syft/migrations/protocol_communication_test.py +++ b/packages/syft/tests/syft/migrations/protocol_communication_test.py @@ -20,7 +20,7 @@ from syft.service.service import ServiceConfigRegistry from syft.service.service import service_method from syft.service.user.user_roles import GUEST_ROLE_LEVEL -from syft.store.db.base import DBManager +from syft.store.db.sqlite_db import DBManager from syft.store.document_store import DocumentStore from syft.store.document_store import NewBaseStash from syft.store.document_store import PartitionSettings diff --git a/packages/syft/tests/syft/stores/action_store_test.py b/packages/syft/tests/syft/stores/action_store_test.py index f9872edb590..d40fa15e8ca 100644 --- a/packages/syft/tests/syft/stores/action_store_test.py +++ b/packages/syft/tests/syft/stores/action_store_test.py @@ -16,7 +16,7 @@ from syft.service.user.user import User from syft.service.user.user_roles import ServiceRole from syft.service.user.user_stash import UserStash -from syft.store.db.base import DBManager +from syft.store.db.sqlite_db import DBManager from syft.types.uid import UID # relative diff --git a/packages/syft/tests/syft/stores/base_stash_test.py b/packages/syft/tests/syft/stores/base_stash_test.py index 451db6fbd74..3487bfb5dca 100644 --- a/packages/syft/tests/syft/stores/base_stash_test.py +++ b/packages/syft/tests/syft/stores/base_stash_test.py @@ -12,8 +12,8 @@ # syft absolute from syft.serde.serializable import serializable -from syft.store.db.sqlite_db import SQLiteDBConfig -from syft.store.db.sqlite_db import SQLiteDBManager +from syft.store.db.sqlite import SQLiteDBConfig +from syft.store.db.sqlite import SQLiteDBManager from syft.store.db.stash import ObjectStash from syft.store.document_store import PartitionKey from syft.store.document_store_errors import NotFoundException diff --git a/packages/syft/tests/syft/stores/store_fixtures_test.py b/packages/syft/tests/syft/stores/store_fixtures_test.py index 7225be989b7..47452e9740e 100644 --- a/packages/syft/tests/syft/stores/store_fixtures_test.py +++ b/packages/syft/tests/syft/stores/store_fixtures_test.py @@ -9,8 +9,8 @@ from syft.service.user.user import UserCreate from syft.service.user.user_roles import ServiceRole from syft.service.user.user_stash import UserStash -from syft.store.db.sqlite_db import SQLiteDBConfig -from syft.store.db.sqlite_db import SQLiteDBManager +from syft.store.db.sqlite import SQLiteDBConfig +from syft.store.db.sqlite import SQLiteDBManager from syft.store.document_store import DocumentStore from syft.types.uid import UID diff --git a/packages/syft/tests/syft/worker_test.py b/packages/syft/tests/syft/worker_test.py index 816ce02b909..a25acc622aa 100644 --- a/packages/syft/tests/syft/worker_test.py +++ b/packages/syft/tests/syft/worker_test.py @@ -24,7 +24,7 @@ from syft.service.user.user import UserCreate from syft.service.user.user import UserView from syft.service.user.user_stash import UserStash -from syft.store.db.sqlite_db import SQLiteDBManager +from syft.store.db.sqlite import SQLiteDBManager from syft.types.errors import SyftException from syft.types.result import Ok From 12c0f177405cde48cbb1b89466d8f369cb2ae8ed Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 12 Sep 2024 15:37:12 +0530 Subject: [PATCH 152/197] fix imports Co-authored-by: khoaguin --- packages/syft/src/syft/server/uvicorn.py | 2 +- packages/syft/src/syft/store/db/postgres.py | 2 +- packages/syft/src/syft/store/db/sqlite.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/server/uvicorn.py b/packages/syft/src/syft/server/uvicorn.py index 20e8ac4c43a..676e6222dbc 100644 --- a/packages/syft/src/syft/server/uvicorn.py +++ b/packages/syft/src/syft/server/uvicorn.py @@ -27,7 +27,7 @@ from ..abstract_server import ServerSideType from ..client.client import API_PATH from ..deployment_type import DeploymentType -from ..store.db.sqlite_db import DBConfig +from ..store.db.db import DBConfig from ..util.autoreload import enable_autoreload from ..util.constants import DEFAULT_TIMEOUT from ..util.telemetry import TRACING_ENABLED diff --git a/packages/syft/src/syft/store/db/postgres.py b/packages/syft/src/syft/store/db/postgres.py index c0a13a7dfb8..630155e29de 100644 --- a/packages/syft/src/syft/store/db/postgres.py +++ b/packages/syft/src/syft/store/db/postgres.py @@ -5,8 +5,8 @@ from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...types.uid import UID +from .db import DBManager from .sqlite import DBConfig -from .sqlite_db import DBManager @serializable(canonical_name="PostgresDBConfig", version=1) diff --git a/packages/syft/src/syft/store/db/sqlite.py b/packages/syft/src/syft/store/db/sqlite.py index 705a86c3f53..485a5f7d64b 100644 --- a/packages/syft/src/syft/store/db/sqlite.py +++ b/packages/syft/src/syft/store/db/sqlite.py @@ -12,8 +12,8 @@ from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey from ...types.uid import UID -from .sqlite_db import DBConfig -from .sqlite_db import DBManager +from .db import DBConfig +from .db import DBManager @serializable(canonical_name="SQLiteDBConfig", version=1) From 000092d78a33e14ce7a6c08edbfa8493359d3eea Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Thu, 12 Sep 2024 16:05:22 +0530 Subject: [PATCH 153/197] update postgres version to 16.1 update init_stores to return db manager --- .../templates/postgres/postgres-statefuleset.yaml | 2 +- packages/syft/src/syft/server/server.py | 14 ++++++++------ 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml b/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml index 425f9e88770..db9416bbdc3 100644 --- a/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml +++ b/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml @@ -32,7 +32,7 @@ spec: {{- end }} containers: - name: postgres-container - image: postgres:13 + image: postgres:16.1 imagePullPolicy: Always resources: {{ include "common.resources.set" (dict "resources" .Values.postgres.resources "preset" .Values.postgres.resourcesPreset) | nindent 12 }} env: diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index feae7c893e8..eb535600307 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -88,6 +88,7 @@ from ..store.blob_storage.on_disk import OnDiskBlobStorageConfig from ..store.blob_storage.seaweedfs import SeaweedFSBlobDeposit from ..store.db.db import DBConfig +from ..store.db.db import DBManager from ..store.db.postgres import PostgresDBConfig from ..store.db.postgres import PostgresDBManager from ..store.db.sqlite import SQLiteDBConfig @@ -417,9 +418,8 @@ def __init__( db_config.reset = True self.db_config = db_config - self.db: PostgresDBManager | SQLiteDBManager | None = None - self.init_stores(db_config=self.db_config) + self.db = self.init_stores(db_config=self.db_config) # construct services only after init stores self.services: ServiceRegistry = ServiceRegistry.for_server(self) @@ -896,15 +896,15 @@ def reload_user_code() -> None: if ti is not None: CODE_RELOADER[ti] = reload_user_code - def init_stores(self, db_config: DBConfig) -> None: + def init_stores(self, db_config: DBConfig) -> DBManager: if isinstance(db_config, SQLiteDBConfig): - self.db = SQLiteDBManager( + db = SQLiteDBManager( config=db_config, server_uid=self.id, root_verify_key=self.verify_key, ) elif isinstance(db_config, PostgresDBConfig): - self.db = PostgresDBManager( + db = PostgresDBManager( config=db_config, server_uid=self.id, root_verify_key=self.verify_key, @@ -912,7 +912,9 @@ def init_stores(self, db_config: DBConfig) -> None: else: raise SyftException(public_message=f"Unsupported DB config: {db_config}") - self.queue_stash = QueueStash(store=self.db) + self.queue_stash = QueueStash(store=db) + + return db @property def job_stash(self) -> JobStash: From f1cd69a21163d2ba1b7123a10715339810922b56 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Thu, 12 Sep 2024 17:41:00 +0700 Subject: [PATCH 154/197] [deps] scanning postgres:16.1 security --- .github/workflows/container-scan.yml | 6 +++--- packages/syftcli/manifest.yml | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/container-scan.yml b/.github/workflows/container-scan.yml index 0e12d35d357..211a8022f62 100644 --- a/.github/workflows/container-scan.yml +++ b/.github/workflows/container-scan.yml @@ -238,7 +238,7 @@ jobs: continue-on-error: true uses: aquasecurity/trivy-action@master with: - image-ref: "postgres:13" + image-ref: "postgres:16.1" format: "cyclonedx" output: "postgres-trivy-results.sbom.json" timeout: "10m0s" @@ -255,7 +255,7 @@ jobs: continue-on-error: true uses: aquasecurity/trivy-action@master with: - image-ref: "postgres:13" + image-ref: "postgres:16.1" format: "sarif" output: "trivy-results.sarif" timeout: "10m0s" @@ -281,7 +281,7 @@ jobs: # This is where you will need to introduce the Snyk API token created with your Snyk account SNYK_TOKEN: ${{ secrets.SNYK_TOKEN }} with: - image: postgres:13 + image: postgres:16.1 args: --sarif-file-output=snyk-code.sarif # Replace any "undefined" security severity values with 0. The undefined value is used in the case diff --git a/packages/syftcli/manifest.yml b/packages/syftcli/manifest.yml index bfc5224ce48..496e65edf66 100644 --- a/packages/syftcli/manifest.yml +++ b/packages/syftcli/manifest.yml @@ -6,7 +6,7 @@ dockerTag: 0.9.2-beta.2 images: - docker.io/openmined/syft-frontend:0.9.2-beta.2 - docker.io/openmined/syft-backend:0.9.2-beta.2 - - docker.io/library/postgres:13 + - docker.io/library/postgres:16.1 - docker.io/traefik:v2.11.0 configFiles: From 011a1bb69534b2f410a6227b9b9f35f1f93ff95c Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 12 Sep 2024 13:43:06 +0200 Subject: [PATCH 155/197] fixes --- packages/syft/src/syft/serde/json_serde.py | 1 - .../src/syft/service/action/action_store.py | 1 + .../syft/src/syft/service/api/api_stash.py | 6 ++--- .../code_history/code_history_stash.py | 2 +- .../syft/service/policy/user_policy_stash.py | 2 +- .../syft/service/worker/worker_pool_stash.py | 2 +- packages/syft/tests/syft/types/errors_test.py | 22 +++++++++++++++++++ 7 files changed, 29 insertions(+), 7 deletions(-) diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index 000ff04085a..a150d3c186e 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -45,7 +45,6 @@ class JSONSerdeError(SyftException): @dataclass class JSONSerde(Generic[T]): - # TODO add json schema klass: type[T] serialize_fn: Callable[[T], Json] | None = None deserialize_fn: Callable[[Json], T] | None = None diff --git a/packages/syft/src/syft/service/action/action_store.py b/packages/syft/src/syft/service/action/action_store.py index bd5fcf45e10..e6597d19d25 100644 --- a/packages/syft/src/syft/service/action/action_store.py +++ b/packages/syft/src/syft/service/action/action_store.py @@ -23,6 +23,7 @@ @serializable(canonical_name="ActionObjectSQLStore", version=1) class ActionObjectStash(ObjectStash[ActionObject]): + # We are storing ActionObject, Action, TwinObject allow_any_type = True @as_result(NotFoundException, SyftException) diff --git a/packages/syft/src/syft/service/api/api_stash.py b/packages/syft/src/syft/service/api/api_stash.py index d88b1c37504..0c0c6f73020 100644 --- a/packages/syft/src/syft/service/api/api_stash.py +++ b/packages/syft/src/syft/service/api/api_stash.py @@ -18,13 +18,13 @@ def get_by_path(self, credentials: SyftVerifyKey, path: str) -> TwinAPIEndpoint: res = self.get_one( credentials=credentials, filters={"path": path}, - ).unwrap() + ) - if res is None: + if res.is_err(): raise NotFoundException( public_message=MISSING_PATH_STRING.format(path=path) ) - return res + return res.unwrap() @as_result(StashException) def path_exists(self, credentials: SyftVerifyKey, path: str) -> bool: diff --git a/packages/syft/src/syft/service/code_history/code_history_stash.py b/packages/syft/src/syft/service/code_history/code_history_stash.py index 14ca4e89720..69dfd272717 100644 --- a/packages/syft/src/syft/service/code_history/code_history_stash.py +++ b/packages/syft/src/syft/service/code_history/code_history_stash.py @@ -19,7 +19,7 @@ def get_by_service_func_name_and_verify_key( return self.get_one( credentials=credentials, filters={ - "user_verify_key": str(user_verify_key), + "user_verify_key": user_verify_key, "service_func_name": service_func_name, }, ).unwrap() diff --git a/packages/syft/src/syft/service/policy/user_policy_stash.py b/packages/syft/src/syft/service/policy/user_policy_stash.py index 860472273b8..9e3a103280b 100644 --- a/packages/syft/src/syft/service/policy/user_policy_stash.py +++ b/packages/syft/src/syft/service/policy/user_policy_stash.py @@ -18,5 +18,5 @@ def get_all_by_user_verify_key( ) -> list[UserPolicy]: return self.get_all( credentials=credentials, - filters={"user_verify_key": str(user_verify_key)}, + filters={"user_verify_key": user_verify_key}, ).unwrap() diff --git a/packages/syft/src/syft/service/worker/worker_pool_stash.py b/packages/syft/src/syft/service/worker/worker_pool_stash.py index 699ee19a47a..81a4f4741d2 100644 --- a/packages/syft/src/syft/service/worker/worker_pool_stash.py +++ b/packages/syft/src/syft/service/worker/worker_pool_stash.py @@ -67,5 +67,5 @@ def get_by_image_uid( ) -> list[WorkerPool]: return self.get_all( credentials=credentials, - filters={"image_id": image_uid.no_dash}, + filters={"image_id": image_uid}, ).unwrap() diff --git a/packages/syft/tests/syft/types/errors_test.py b/packages/syft/tests/syft/types/errors_test.py index 4ac185ba421..ca8e557ef11 100644 --- a/packages/syft/tests/syft/types/errors_test.py +++ b/packages/syft/tests/syft/types/errors_test.py @@ -5,6 +5,7 @@ import pytest # syft absolute +import syft from syft.service.context import AuthedServiceContext from syft.service.user.user_roles import ServiceRole from syft.types.errors import SyftException @@ -52,3 +53,24 @@ def test_get_message(role, private_msg, public_msg, expected_message): mock_context.dev_mode = False exception = SyftException(private_msg, public_message=public_msg) assert exception.get_message(mock_context) == expected_message + + +def test_syfterror_raise_works_in_pytest(): + """ + SyftError has own exception handler that wasnt working in notebook testing environments, + this is just a sanity check to make sure it works in pytest. + """ + with pytest.raises(SyftException): + raise SyftException(public_message="-") + + with syft.raises(SyftException(public_message="-")): + raise SyftException(public_message="-") + + # syft.raises works with wildcard + with syft.raises(SyftException(public_message="*test message*")): + raise SyftException(public_message="longer test message") + + # syft.raises with different public message should raise + with pytest.raises(AssertionError): + with syft.raises(SyftException(public_message="*different message*")): + raise SyftException(public_message="longer test message") From 8d1d482955ea621cf56e0c90be851dfd10525001 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 12 Sep 2024 13:49:44 +0200 Subject: [PATCH 156/197] fix network service --- packages/syft/src/syft/protocol/protocol_version.json | 2 +- packages/syft/src/syft/service/network/network_service.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index a8d210436e3..227c908e9b2 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -486,7 +486,7 @@ "WorkerSettings": { "2": { "version": 2, - "hash": "13c6e022b939778ab37b594dbc5094aba9f54564c90d3cb0c21115382b155bfe", + "hash": "91c375dd40d06c81fc6403751ee48cbc94b9877f91e65a7e302303218dfe71fa", "action": "add" } }, diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 1f4d656201c..aa2264dbe10 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -129,7 +129,7 @@ def get_by_verify_key( def get_by_server_type( self, credentials: SyftVerifyKey, server_type: ServerType ) -> list[ServerPeer]: - return self.get( + return self.get_all( credentials=credentials, filters={"server_type": server_type}, ).unwrap() From 5d8ecd00a69fb810c556f7f0a52bc66abf170823 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 12 Sep 2024 14:09:39 +0200 Subject: [PATCH 157/197] update protocol file --- .../src/syft/protocol/protocol_version.json | 168 ++++++++---------- 1 file changed, 77 insertions(+), 91 deletions(-) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 227c908e9b2..72976eb9c66 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -90,27 +90,6 @@ "action": "add" } }, - "StoreConfig": { - "1": { - "version": 1, - "hash": "a9997fce6a8a0ed2884c58b8eb9382f8554bdd18fff61f8bf0451945bcff12c7", - "action": "add" - } - }, - "MongoDict": { - "1": { - "version": 1, - "hash": "57e36f57eed75e62b29e2bac1295035a9bf2c0e3c56719dac24cb6cc685be00b", - "action": "add" - } - }, - "MongoStoreConfig": { - "1": { - "version": 1, - "hash": "53342b27d34165b7e2699f8e7ad70d13d125875e6a75e8fa18f5796428f41036", - "action": "add" - } - }, "LinkedObject": { "1": { "version": 1, @@ -172,6 +151,13 @@ "action": "add" } }, + "StoreConfig": { + "1": { + "version": 1, + "hash": "a9997fce6a8a0ed2884c58b8eb9382f8554bdd18fff61f8bf0451945bcff12c7", + "action": "add" + } + }, "BaseConfig": { "1": { "version": 1, @@ -544,13 +530,6 @@ "action": "add" } }, - "DataSubjectMemberRelationship": { - "1": { - "version": 1, - "hash": "0810483ea76ea10c8f286c6035dc0b2085291f345183be50c179f3a05a577110", - "action": "add" - } - }, "Contributor": { "1": { "version": 1, @@ -818,69 +797,6 @@ "action": "add" } }, - "OnDiskBlobDeposit": { - "1": { - "version": 1, - "hash": "817bf1bee4a35bfa1cd25d6779a10d8d180b1b3f1e837952f81f48b9411d1970", - "action": "add" - } - }, - "RemoteConfig": { - "1": { - "version": 1, - "hash": "179d067099a178d748c6d9a0477e8de7c3b55577439669eca7150258f2409567", - "action": "add" - } - }, - "AzureRemoteConfig": { - "1": { - "version": 1, - "hash": "a143811fec0da5fd881e927643ef667c91c78a2c90519cf88da7da20738bd187", - "action": "add" - } - }, - "SeaweedFSBlobDeposit": { - "1": { - "version": 1, - "hash": "febeb2a2ce81aa2c512e4c6b611b582984042aafa0541403d4584662273a166c", - "action": "add" - } - }, - "NumpyArrayObject": { - "1": { - "version": 1, - "hash": "05dd2917b7692b3daf4e7ad083a46fa7ec7a2be8faac8d4a654809189c986443", - "action": "add" - } - }, - "NumpyScalarObject": { - "1": { - "version": 1, - "hash": "8753e5c78270a5cacbf0439447724772f4765351a4a8b58b0a5c416a6b2c8b6e", - "action": "add" - } - }, - "NumpyBoolObject": { - "1": { - "version": 1, - "hash": "331c44f8fa3d0a077f1aaad7313bae2c43b386d04def7b8bedae9fdf7690134d", - "action": "add" - } - }, - "PandasDataframeObject": { - "1": { - "version": 1, - "hash": "5e8018364cea31d5f185a901da4ab89846b02153ee7d041ee8a6d305ece31f90", - "action": "add" - } - }, - "PandasSeriesObject": { - "1": { - "version": 1, - "hash": "b8bd482bf16fc7177e9778292cd42f8835b6ced2ce8dc88908b4b8e6d7c7528f", - "action": "add" - } - }, "Change": { "1": { "version": 1, @@ -1084,6 +1000,34 @@ "action": "add" } }, + "OnDiskBlobDeposit": { + "1": { + "version": 1, + "hash": "817bf1bee4a35bfa1cd25d6779a10d8d180b1b3f1e837952f81f48b9411d1970", + "action": "add" + } + }, + "RemoteConfig": { + "1": { + "version": 1, + "hash": "179d067099a178d748c6d9a0477e8de7c3b55577439669eca7150258f2409567", + "action": "add" + } + }, + "AzureRemoteConfig": { + "1": { + "version": 1, + "hash": "a143811fec0da5fd881e927643ef667c91c78a2c90519cf88da7da20738bd187", + "action": "add" + } + }, + "SeaweedFSBlobDeposit": { + "1": { + "version": 1, + "hash": "febeb2a2ce81aa2c512e4c6b611b582984042aafa0541403d4584662273a166c", + "action": "add" + } + }, "SQLiteStoreConfig": { "1": { "version": 1, @@ -1091,6 +1035,48 @@ "action": "add" } }, + "NumpyArrayObject": { + "1": { + "version": 1, + "hash": "05dd2917b7692b3daf4e7ad083a46fa7ec7a2be8faac8d4a654809189c986443", + "action": "add" + } + }, + "NumpyScalarObject": { + "1": { + "version": 1, + "hash": "8753e5c78270a5cacbf0439447724772f4765351a4a8b58b0a5c416a6b2c8b6e", + "action": "add" + } + }, + "NumpyBoolObject": { + "1": { + "version": 1, + "hash": "331c44f8fa3d0a077f1aaad7313bae2c43b386d04def7b8bedae9fdf7690134d", + "action": "add" + } + }, + "PandasDataframeObject": { + "1": { + "version": 1, + "hash": "5e8018364cea31d5f185a901da4ab89846b02153ee7d041ee8a6d305ece31f90", + "action": "add" + } + }, + "PandasSeriesObject": { + "1": { + "version": 1, + "hash": "b8bd482bf16fc7177e9778292cd42f8835b6ced2ce8dc88908b4b8e6d7c7528f", + "action": "add" + } + }, + "DataSubjectMemberRelationship": { + "1": { + "version": 1, + "hash": "0810483ea76ea10c8f286c6035dc0b2085291f345183be50c179f3a05a577110", + "action": "add" + } + }, "ProjectEvent": { "1": { "version": 1, From 7deb51758fa15f27c24e0b5974580998592f6302 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Thu, 12 Sep 2024 15:03:36 +0200 Subject: [PATCH 158/197] add db config --- packages/grid/devspace.yaml | 1 + packages/syft/src/syft/server/server.py | 4 +++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/packages/grid/devspace.yaml b/packages/grid/devspace.yaml index 400239d3f21..0d3cff05232 100644 --- a/packages/grid/devspace.yaml +++ b/packages/grid/devspace.yaml @@ -94,6 +94,7 @@ dev: - port: "8888" # filer - port: "8333" # S3 - port: "4001" # mount azure + - port: "5432" # mount postgres backend: labelSelector: app.kubernetes.io/name: syft diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index eb535600307..f4d1a1cae3d 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -744,6 +744,7 @@ def named( association_request_auto_approval: bool = False, background_tasks: bool = False, consumer_type: ConsumerType | None = None, + db_config: DBConfig | None = None, ) -> Server: uid = get_named_server_uid(name) name_hash = hashlib.sha256(name.encode("utf8")).digest() @@ -775,6 +776,7 @@ def named( association_request_auto_approval=association_request_auto_approval, background_tasks=background_tasks, consumer_type=consumer_type, + db_config=db_config, ) def is_root(self, credentials: SyftVerifyKey) -> bool: @@ -904,7 +906,7 @@ def init_stores(self, db_config: DBConfig) -> DBManager: root_verify_key=self.verify_key, ) elif isinstance(db_config, PostgresDBConfig): - db = PostgresDBManager( + db = PostgresDBManager( # type: ignore config=db_config, server_uid=self.id, root_verify_key=self.verify_key, From b35a5039f75e99c9d89687409e6c18d65362954f Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Thu, 12 Sep 2024 15:55:34 +0200 Subject: [PATCH 159/197] fix migration tests --- .../src/syft/protocol/protocol_version.json | 28 +++++- packages/syft/src/syft/server/server.py | 1 - .../syft/src/syft/server/service_registry.py | 2 + .../syft/src/syft/server/worker_settings.py | 35 ++++++++ .../service/blob_storage/remote_profile.py | 10 +++ .../service/migration/migration_service.py | 13 ++- .../migration/object_migration_state.py | 8 ++ .../src/syft/service/queue/queue_stash.py | 87 ++++++++++++++++++- .../syft/src/syft/service/user/user_stash.py | 17 ---- packages/syft/src/syft/store/db/stash.py | 22 ++++- 10 files changed, 191 insertions(+), 32 deletions(-) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index a8d210436e3..3ed423c5727 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -484,9 +484,14 @@ } }, "WorkerSettings": { + "1": { + "version": 1, + "hash": "dca33003904a71688e5b07db65f8833eb4de8135aade7154076b8eafbb94d26b", + "action": "add" + }, "2": { "version": 2, - "hash": "13c6e022b939778ab37b594dbc5094aba9f54564c90d3cb0c21115382b155bfe", + "hash": "91c375dd40d06c81fc6403751ee48cbc94b9877f91e65a7e302303218dfe71fa", "action": "add" } }, @@ -1059,21 +1064,36 @@ "QueueItem": { "1": { "version": 1, - "hash": "1db212c46b6c56ccc5579cfe2141b693f0cd9286e2ede71210393e8455379bf1", + "hash": "6ba7a6e0413a59cf1997dc94c67615d6acab89bceee989f70239eea556789c5a", + "action": "add" + }, + "2": { + "version": 2, + "hash": "1d8615f6daabcd2a285b2f36fd7bef1df76cdd119dd49c02069c50fd1b9c3ff4", "action": "add" } }, "ActionQueueItem": { "1": { "version": 1, - "hash": "396d579dfc2e2b36b9fbed2f204bffcca1bea7ee2db7175045dd3328ebf08718", + "hash": "a06effcbba3b76435daf6ca518611433bb603d62f913d703685a65fc49d2b0e9", + "action": "add" + }, + "2": { + "version": 2, + "hash": "bfda6ef87e4045d663324bb91a215ea06e1f173aec1fb4d9ddd337cdc1f0787f", "action": "add" } }, "APIEndpointQueueItem": { "1": { "version": 1, - "hash": "f04b3990a8d29c116d301e70df54d58f188895307a411dc13a666ff764ffd8dd", + "hash": "626341cefd3543be351e6060a5ff273a5f23bc257467e4a1ee1b6d63951cfb33", + "action": "add" + }, + "2": { + "version": 2, + "hash": "3a46370205152fa23a7d2bfa47130dbf2e2bc7ef31f6d3fe4c92fd8d683770b5", "action": "add" } }, diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 2743fce3fc7..42b3a7d0ee0 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -419,7 +419,6 @@ def __init__( # construct services only after init stores self.services: ServiceRegistry = ServiceRegistry.for_server(self) self.db.init_tables() - # self.services.user.stash.init_root_user() self.action_store = self.services.action.stash create_root_admin( diff --git a/packages/syft/src/syft/server/service_registry.py b/packages/syft/src/syft/server/service_registry.py index 23052804385..dfb7f331972 100644 --- a/packages/syft/src/syft/server/service_registry.py +++ b/packages/syft/src/syft/server/service_registry.py @@ -11,6 +11,7 @@ from ..service.action.action_service import ActionService from ..service.api.api_service import APIService from ..service.attestation.attestation_service import AttestationService +from ..service.blob_storage.remote_profile import RemoteProfileService from ..service.blob_storage.service import BlobStorageService from ..service.code.status_service import UserCodeStatusService from ..service.code.user_code_service import UserCodeService @@ -83,6 +84,7 @@ class ServiceRegistry: sync: SyncService output: OutputService user_code_status: UserCodeStatusService + remote_profile: RemoteProfileService services: list[AbstractService] = field(default_factory=list, init=False) service_path_map: dict[str, AbstractService] = field( diff --git a/packages/syft/src/syft/server/worker_settings.py b/packages/syft/src/syft/server/worker_settings.py index f06359ec95d..79241d9d8dc 100644 --- a/packages/syft/src/syft/server/worker_settings.py +++ b/packages/syft/src/syft/server/worker_settings.py @@ -1,6 +1,9 @@ # future from __future__ import annotations +# stdlib +from collections.abc import Callable + # third party from typing_extensions import Self @@ -14,8 +17,13 @@ from ..service.queue.base_queue import QueueConfig from ..store.blob_storage import BlobStorageConfig from ..store.db.db import DBConfig +from ..store.document_store import StoreConfig +from ..types.syft_migration import migrate +from ..types.syft_object import SYFT_OBJECT_VERSION_1 from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.syft_object import SyftObject +from ..types.transforms import drop +from ..types.transforms import make_set_default from ..types.uid import UID @@ -50,3 +58,30 @@ def from_server(cls, server: AbstractServer) -> Self: log_level=server.log_level, deployment_type=server.deployment_type, ) + + +@serializable() +class WorkerSettingsV1(SyftObject): + __canonical_name__ = "WorkerSettings" + __version__ = SYFT_OBJECT_VERSION_1 + + id: UID + name: str + server_type: ServerType + server_side_type: ServerSideType + deployment_type: DeploymentType = DeploymentType.REMOTE + signing_key: SyftSigningKey + document_store_config: StoreConfig + action_store_config: StoreConfig + blob_store_config: BlobStorageConfig | None = None + queue_config: QueueConfig | None = None + log_level: int | None = None + + +@migrate(WorkerSettingsV1, WorkerSettings) +def migrate_workersettings_v1_to_v2() -> list[Callable]: + return [ + drop("document_store_config"), + drop("action_store_config"), + make_set_default("db_config", DBConfig()), + ] diff --git a/packages/syft/src/syft/service/blob_storage/remote_profile.py b/packages/syft/src/syft/service/blob_storage/remote_profile.py index e64ec8529a2..76abe869ae6 100644 --- a/packages/syft/src/syft/service/blob_storage/remote_profile.py +++ b/packages/syft/src/syft/service/blob_storage/remote_profile.py @@ -1,8 +1,10 @@ # relative from ...serde.serializable import serializable +from ...store.db.db import DBManager from ...store.db.stash import ObjectStash from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SyftObject +from ..service import AbstractService @serializable() @@ -25,3 +27,11 @@ class AzureRemoteProfile(RemoteProfile): @serializable(canonical_name="RemoteProfileSQLStash", version=1) class RemoteProfileStash(ObjectStash[RemoteProfile]): pass + + +@serializable(canonical_name="RemoteProfileService", version=1) +class RemoteProfileService(AbstractService): + stash: RemoteProfileStash + + def __init__(self, store: DBManager) -> None: + self.stash = RemoteProfileStash(store=store) diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 68ac808473f..b62d2923ff5 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -159,8 +159,8 @@ def _update_store_metadata_for_klass( for server_uid in server_uids ] - stash.add_permissions(permissions) - stash.add_storage_permissions(storage_permissions) + stash.add_permissions(permissions, ignore_missing=True).unwrap() + stash.add_storage_permissions(storage_permissions, ignore_missing=True).unwrap() @as_result(SyftException) def _update_store_metadata( @@ -210,6 +210,9 @@ def _get_migration_objects( def _search_stash_for_klass( self, context: AuthedServiceContext, klass: type[SyftObject] ) -> ObjectStash: + if issubclass(klass, ActionObject | TwinObject | Action): + return context.server.services.action.stash + stashes: dict[str, ObjectStash] = { t.__canonical_name__: stash for t, stash in context.server.services.stashes.items() @@ -219,7 +222,11 @@ def _search_stash_for_klass( class_index = 0 object_stash = None while len(mro) > class_index: - canonical_name = mro[class_index].__canonical_name__ + try: + canonical_name = mro[class_index].__canonical_name__ + except AttributeError: + # Classes without cname dont have a stash + break object_stash = stashes.get(canonical_name) if object_stash is not None: break diff --git a/packages/syft/src/syft/service/migration/object_migration_state.py b/packages/syft/src/syft/service/migration/object_migration_state.py index dbb9cd4df91..22363d867f2 100644 --- a/packages/syft/src/syft/service/migration/object_migration_state.py +++ b/packages/syft/src/syft/service/migration/object_migration_state.py @@ -244,6 +244,14 @@ def get_items_by_canonical_name(self, canonical_name: str) -> list[SyftObject]: return v return [] + def get_metadata_by_canonical_name(self, canonical_name: str) -> StoreMetadata: + for k, v in self.metadata.items(): + if k.__canonical_name__ == canonical_name: + return v + return StoreMetadata( + object_type=SyftObject, permissions={}, storage_permissions={} + ) + def copy_without_workerpools(self) -> "MigrationData": items_to_exclude = [ WorkerPool.__canonical_name__, diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index 5b2ef9f3318..eda247b6139 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -1,4 +1,5 @@ # stdlib +from collections.abc import Callable from enum import Enum from typing import Any @@ -6,6 +7,7 @@ from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...server.worker_settings import WorkerSettings +from ...server.worker_settings import WorkerSettingsV1 from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey from ...store.document_store_errors import NotFoundException @@ -13,8 +15,11 @@ from ...store.linked_obj import LinkedObject from ...types.errors import SyftException from ...types.result import as_result +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject +from ...types.transforms import TransformContext from ...types.uid import UID from ..action.action_permissions import ActionObjectPermission @@ -35,7 +40,7 @@ class Status(str, Enum): @serializable() class QueueItem(SyftObject): __canonical_name__ = "QueueItem" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 __attr_searchable__ = ["status", "worker_pool_id"] @@ -78,7 +83,7 @@ def action(self) -> Any: @serializable() class ActionQueueItem(QueueItem): __canonical_name__ = "ActionQueueItem" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 method: str = "execute" service: str = "actionservice" @@ -87,7 +92,7 @@ class ActionQueueItem(QueueItem): @serializable() class APIEndpointQueueItem(QueueItem): __canonical_name__ = "APIEndpointQueueItem" - __version__ = SYFT_OBJECT_VERSION_1 + __version__ = SYFT_OBJECT_VERSION_2 method: str service: str = "apiservice" @@ -162,3 +167,79 @@ def _get_by_worker_pool( credentials=credentials, filters={"worker_pool_id": worker_pool_id}, ).unwrap() + + +@serializable() +class QueueItemV1(SyftObject): + __canonical_name__ = "QueueItem" + __version__ = SYFT_OBJECT_VERSION_1 + + __attr_searchable__ = ["status", "worker_pool_id"] + + id: UID + server_uid: UID + result: Any | None = None + resolved: bool = False + status: Status = Status.CREATED + + method: str + service: str + args: list + kwargs: dict[str, Any] + job_id: UID | None = None + worker_settings: WorkerSettingsV1 | None = None + has_execute_permissions: bool = False + worker_pool: LinkedObject + + +@serializable() +class ActionQueueItemV1(QueueItemV1): + __canonical_name__ = "ActionQueueItem" + __version__ = SYFT_OBJECT_VERSION_1 + + method: str = "execute" + service: str = "actionservice" + + +@serializable() +class APIEndpointQueueItemV1(QueueItemV1): + __canonical_name__ = "APIEndpointQueueItem" + __version__ = SYFT_OBJECT_VERSION_1 + + method: str + service: str = "apiservice" + + +def migrate_worker_settings_v1_to_v2(context: TransformContext) -> TransformContext: + if context.output is None: + return context + + worker_settings_old: WorkerSettingsV1 | None = context.output.get( + "worker_settings", None + ) + if worker_settings_old is None: + return context + + if not isinstance(worker_settings_old, WorkerSettingsV1): + raise ValueError( + f"Expected WorkerSettingsV1, but got {type(worker_settings_old)}" + ) + worker_settings = worker_settings_old.migrate_to(WorkerSettings.__version__) + context.output["worker_settings"] = worker_settings + + return context + + +@migrate(QueueItemV1, QueueItem) +def migrate_queue_item_v1_to_v2() -> list[Callable]: + return [migrate_worker_settings_v1_to_v2] + + +@migrate(ActionQueueItemV1, ActionQueueItem) +def migrate_action_queue_item_v1_to_v2() -> list[Callable]: + return migrate_queue_item_v1_to_v2() + + +@migrate(APIEndpointQueueItemV1, APIEndpointQueueItem) +def migrate_api_endpoint_queue_item_v1_to_v2() -> list[Callable]: + return migrate_queue_item_v1_to_v2() diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index 90364a60565..3cf8acace75 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -6,7 +6,6 @@ from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result -from ...types.uid import UID from ...util.trace_decorator import instrument from .user import User from .user_roles import ServiceRole @@ -15,22 +14,6 @@ @instrument @serializable(canonical_name="UserStashSQL", version=1) class UserStash(ObjectStash[User]): - @as_result(StashException) - def init_root_user(self) -> None: - # start a transaction - users = self.get_all(self.root_verify_key, has_permission=True).unwrap() - if not users: - # NOTE this is not thread safe, should use a session and transaction - super().set( - self.root_verify_key, - User( - id=UID(), - email="_internal@root.com", - role=ServiceRole.ADMIN, - verify_key=self.root_verify_key, - ), - ) - @as_result(StashException, NotFoundException) def admin_user(self) -> User: # TODO: This returns only one user, the first user with the role ADMIN diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 097d616cf22..409b91ad2d8 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -616,12 +616,19 @@ def get_ownership_permissions( @as_result(NotFoundException) @with_session def add_permissions( - self, permissions: list[ActionObjectPermission], session: Session = None + self, + permissions: list[ActionObjectPermission], + session: Session = None, + ignore_missing: bool = False, ) -> None: # TODO: should do this in a single transaction # TODO add error handling for permission in permissions: - self.add_permission(permission, session=session).unwrap() + try: + self.add_permission(permission, session=session).unwrap() + except NotFoundException: + if not ignore_missing: + raise return None @as_result(NotFoundException) @@ -833,7 +840,14 @@ def _get_storage_permissions_for_uid( @as_result(NotFoundException) @with_session def add_storage_permissions( - self, permissions: list[StoragePermission], session: Session = None + self, + permissions: list[StoragePermission], + session: Session = None, + ignore_missing: bool = False, ) -> None: for permission in permissions: - self.add_storage_permission(permission, session=session).unwrap() + try: + self.add_storage_permission(permission, session=session).unwrap() + except NotFoundException: + if not ignore_missing: + raise From 070805401096ebb75706ae9a50ae3d2e32de2e88 Mon Sep 17 00:00:00 2001 From: dk Date: Fri, 13 Sep 2024 10:50:08 +0700 Subject: [PATCH 160/197] - [db] fix failed stack.notebook test since by changing `SQLITE_PATH` to /tmp/syft instead of `HOME/data/db/` - [dockerfile] fix `WARN: FromAsCasing: 'as' and 'FROM' keywords' casing do not match` - [deps] update `pyarrow` version --- packages/grid/backend/backend.dockerfile | 4 ++-- packages/grid/backend/grid/core/config.py | 2 +- packages/syft/setup.cfg | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/packages/grid/backend/backend.dockerfile b/packages/grid/backend/backend.dockerfile index aafee9a43bb..984d2f174f4 100644 --- a/packages/grid/backend/backend.dockerfile +++ b/packages/grid/backend/backend.dockerfile @@ -9,7 +9,7 @@ ARG TORCH_VERSION="2.2.2" # ==================== [BUILD STEP] Python Dev Base ==================== # -FROM cgr.dev/chainguard/wolfi-base as syft_deps +FROM cgr.dev/chainguard/wolfi-base AS syft_deps ARG PYTHON_VERSION ARG UV_VERSION @@ -45,7 +45,7 @@ RUN --mount=type=cache,target=/root/.cache,sharing=locked \ # ==================== [Final] Setup Syft Server ==================== # -FROM cgr.dev/chainguard/wolfi-base as backend +FROM cgr.dev/chainguard/wolfi-base AS backend ARG PYTHON_VERSION ARG UV_VERSION diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index 55a81aee24a..91e18b46c47 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -138,7 +138,7 @@ def get_emails_enabled(self) -> Self: True if os.getenv("CREATE_PRODUCER", "false").lower() == "true" else False ) N_CONSUMERS: int = int(os.getenv("N_CONSUMERS", 1)) - SQLITE_PATH: str = os.path.expandvars("$HOME/data/db/") + SQLITE_PATH: str = os.path.expandvars("/tmp/syft/") SINGLE_CONTAINER_MODE: bool = str_to_bool(os.getenv("SINGLE_CONTAINER_MODE", False)) CONSUMER_SERVICE_NAME: str | None = os.getenv("CONSUMER_SERVICE_NAME") INMEMORY_WORKERS: bool = str_to_bool(os.getenv("INMEMORY_WORKERS", True)) diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index 090934d74a0..f11303e258d 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -31,7 +31,7 @@ syft = boto3==1.34.56 forbiddenfruit==0.1.4 packaging>=23.0 - pyarrow==15.0.0 + pyarrow==17.0.0 pycapnp==2.0.0 pydantic[email]==2.6.0 pydantic-settings==2.2.1 From 0ddbc534351037b1a2daedae7d694dcdfcab1dc0 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 13 Sep 2024 09:54:24 +0200 Subject: [PATCH 161/197] add different base to handle two dbs --- packages/syft/src/syft/orchestra.py | 4 + packages/syft/src/syft/server/server.py | 16 ++-- .../src/syft/service/sync/sync_service.py | 8 +- packages/syft/src/syft/store/db/db.py | 26 +++++- packages/syft/src/syft/store/db/query.py | 83 +++++++++++-------- packages/syft/src/syft/store/db/schema.py | 12 ++- packages/syft/src/syft/store/db/stash.py | 24 +++--- 7 files changed, 114 insertions(+), 59 deletions(-) diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 4dd7e7fe713..3497717c372 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -185,6 +185,7 @@ def deploy_to_python( migrate: bool = False, store_client_config: dict | None = None, consumer_type: ConsumerType | None = None, + db_url: str | None = None, ) -> ServerHandle: worker_classes = { ServerType.DATASITE: Datasite, @@ -218,6 +219,7 @@ def deploy_to_python( "deployment_type": deployment_type_enum, "store_client_config": store_client_config, "consumer_type": consumer_type, + "db_url": db_url, } if port: @@ -332,6 +334,7 @@ def launch( store_client_config: dict | None = None, from_state_folder: str | Path | None = None, consumer_type: ConsumerType | None = None, + db_url: str | None = None, ) -> ServerHandle: if from_state_folder is not None: with open(f"{from_state_folder}/config.json") as f: @@ -382,6 +385,7 @@ def launch( migrate=migrate, store_client_config=store_client_config, consumer_type=consumer_type, + db_url=db_url, ) display( SyftInfo( diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index dd41050901a..db059c06379 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -300,7 +300,6 @@ class Server(AbstractServer): signing_key: SyftSigningKey | None required_signed_calls: bool = True packages: str - db_config: DBConfig def __init__( self, @@ -340,6 +339,7 @@ def __init__( association_request_auto_approval: bool = False, background_tasks: bool = False, consumer_type: ConsumerType | None = None, + db_url: str | None = None, ): # 🟡 TODO 22: change our ENV variable format and default init args to make this # less horrible or add some convenience functions @@ -407,6 +407,7 @@ def __init__( action_store_config = action_store_config or self.get_default_store( store_type="Action Store", ) + db_config = DBConfig.from_connection_string(db_url) if db_url else db_config if db_config is None: db_config = SQLiteDBConfig( @@ -513,10 +514,10 @@ def runs_in_docker(self) -> bool: def get_default_store(self, store_type: str) -> StoreConfig: path = self.get_temp_dir("db") file_name: str = f"{self.id}.sqlite" - if self.dev_mode: - # leave this until the logger shows this in the notebook - print(f"{store_type}'s SQLite DB path: {path/file_name}") - logger.debug(f"{store_type}'s SQLite DB path: {path/file_name}") + # if self.dev_mode: + # leave this until the logger shows this in the notebook + # print(f"{store_type}'s SQLite DB path: {path/file_name}") + # logger.debug(f"{store_type}'s SQLite DB path: {path/file_name}") return SQLiteStoreConfig( client_config=SQLiteStoreClientConfig( filename=file_name, @@ -743,6 +744,7 @@ def named( association_request_auto_approval: bool = False, background_tasks: bool = False, consumer_type: ConsumerType | None = None, + db_url: str | None = None, db_config: DBConfig | None = None, ) -> Server: uid = get_named_server_uid(name) @@ -775,7 +777,7 @@ def named( association_request_auto_approval=association_request_auto_approval, background_tasks=background_tasks, consumer_type=consumer_type, - db_config=db_config, + db_url=db_url, ) def is_root(self, credentials: SyftVerifyKey) -> bool: @@ -915,6 +917,8 @@ def init_stores(self, db_config: DBConfig) -> DBManager: self.queue_stash = QueueStash(store=db) + print(f"Using {db_config.__class__.__name__} and {db_config.connection_string}") + return db @property diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index bc7326bd974..ddafd86b1d3 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -59,14 +59,14 @@ def add_actionobject_read_permissions( action_object: ActionObject, new_permissions: list[ActionObjectPermission], ) -> None: - store_to = context.server.services.action.stash # type: ignore + action_stash = context.server.services.action.stash for permission in new_permissions: if permission.permission == ActionPermission.READ: - store_to.add_permission(permission) + action_stash.add_permission(permission) blob_id = action_object.syft_blob_storage_entry_id if blob_id: - store_to_blob = context.server.services.blob_sotrage.stash.partition # type: ignore + blob_stash = context.server.services.blob_storage.stash for permission in new_permissions: if permission.permission == ActionPermission.READ: permission_blob = ActionObjectPermission( @@ -74,7 +74,7 @@ def add_actionobject_read_permissions( permission=permission.permission, credentials=permission.credentials, ) - store_to_blob.add_permission(permission_blob) + blob_stash.add_permission(permission_blob) def set_obj_ids(self, context: AuthedServiceContext, x: Any) -> None: if hasattr(x, "__dict__") and isinstance(x, SyftObject): diff --git a/packages/syft/src/syft/store/db/db.py b/packages/syft/src/syft/store/db/db.py index 713e190d48d..fe0d7c9b099 100644 --- a/packages/syft/src/syft/store/db/db.py +++ b/packages/syft/src/syft/store/db/db.py @@ -2,6 +2,7 @@ import logging from typing import Generic from typing import TypeVar +from urllib.parse import urlparse # third party from pydantic import BaseModel @@ -12,7 +13,8 @@ from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...types.uid import UID -from .schema import Base +from .schema import PostgresBase +from .schema import SQLiteBase logger = logging.getLogger(__name__) @@ -25,6 +27,26 @@ class DBConfig(BaseModel): def connection_string(self) -> str: raise NotImplementedError("Subclasses must implement this method.") + @classmethod + def from_connection_string(cls, conn_str: str) -> "DBConfig": + # relative + from .postgres import PostgresDBConfig + from .sqlite import SQLiteDBConfig + + parsed = urlparse(conn_str) + if parsed.scheme == "postgresql": + return PostgresDBConfig( + host=parsed.hostname, + port=parsed.port, + user=parsed.username, + password=parsed.password, + database=parsed.path.lstrip("/"), + ) + elif parsed.scheme == "sqlite": + return SQLiteDBConfig(path=parsed.path) + else: + raise ValueError(f"Unsupported database scheme {parsed.scheme}") + ConfigT = TypeVar("ConfigT", bound=DBConfig) @@ -53,6 +75,7 @@ def update_settings(self) -> None: pass def init_tables(self) -> None: + Base = SQLiteBase if self.engine.dialect.name == "sqlite" else PostgresBase if self.config.reset: # drop all tables that we know about Base.metadata.drop_all(bind=self.engine) @@ -60,5 +83,6 @@ def init_tables(self) -> None: Base.metadata.create_all(self.engine) def reset(self) -> None: + Base = SQLiteBase if self.engine.dialect.name == "sqlite" else PostgresBase Base.metadata.drop_all(bind=self.engine) Base.metadata.create_all(self.engine) diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py index b97b0602d6f..372f5ab96c4 100644 --- a/packages/syft/src/syft/store/db/query.py +++ b/packages/syft/src/syft/store/db/query.py @@ -1,7 +1,6 @@ # stdlib from abc import ABC from abc import abstractmethod -from collections.abc import Callable import enum from typing import Any from typing import Literal @@ -25,7 +24,8 @@ from ...service.user.user_roles import ServiceRole from ...types.syft_object import SyftObject from ...types.uid import UID -from .schema import Base +from .schema import PostgresBase +from .schema import SQLiteBase class FilterOperator(enum.Enum): @@ -34,19 +34,11 @@ class FilterOperator(enum.Enum): class Query(ABC): - json_quote: Callable | None = None - def __init__(self, object_type: type[SyftObject]) -> None: self.object_type: type = object_type self.table: Table = self._get_table(object_type) self.stmt: Select = self.table.select() - def _get_table(self, object_type: type[SyftObject]) -> Table: - cname = object_type.__canonical_name__ - if cname not in Base.metadata.tables: - raise ValueError(f"Table for {cname} not found") - return Base.metadata.tables[cname] - @staticmethod def get_query_class(dialect: str | Dialect) -> "type[Query]": if isinstance(dialect, Dialect): @@ -265,25 +257,6 @@ def _make_permissions_clause( ) -> sa.sql.elements.BinaryExpression: pass - def _eq_filter( - self, - table: Table, - field: str, - value: Any, - ) -> sa.sql.elements.BinaryExpression: - if field == "id": - return table.c.id == UID(value) - - if "." in field: - # magic! - field = field.split(".") # type: ignore - - json_value = serialize_json(value) - if self.json_quote: - return table.c.fields[field] == self.json_quote(json_value) - else: - return table.c.fields[field].astext == json_value - @abstractmethod def _contains_filter( self, @@ -307,8 +280,6 @@ def _get_column(self, column: str) -> Column: class SQLiteQuery(Query): - json_quote = func.json_quote - def _make_permissions_clause( self, permission: ActionObjectPermission, @@ -320,6 +291,12 @@ def _make_permissions_clause( self.table.c.permissions.contains(compound_permission_string), ) + def _get_table(self, object_type: type[SyftObject]) -> Table: + cname = object_type.__canonical_name__ + if cname not in SQLiteBase.metadata.tables: + raise ValueError(f"Table for {cname} not found") + return SQLiteBase.metadata.tables[cname] + def _contains_filter( self, table: Table, @@ -327,9 +304,23 @@ def _contains_filter( value: Any, ) -> sa.sql.elements.BinaryExpression: field_value = serialize_json(value) - return table.c.fields[field].contains( - self.json_quote(field_value) if self.json_quote else field_value - ) + return table.c.fields[field].contains(func.json_quote(field_value)) + + def _eq_filter( + self, + table: Table, + field: str, + value: Any, + ) -> sa.sql.elements.BinaryExpression: + if field == "id": + return table.c.id == UID(value) + + if "." in field: + # magic! + field = field.split(".") # type: ignore + + json_value = serialize_json(value) + return table.c.fields[field] == func.json_quote(json_value) class PostgresQuery(Query): @@ -349,5 +340,27 @@ def _contains_filter( field: str, value: Any, ) -> sa.sql.elements.BinaryExpression: - field_value = [serialize_json(value)] + field_value = serialize_json(value) return table.c.fields[field].contains(field_value) + + def _get_table(self, object_type: type[SyftObject]) -> Table: + cname = object_type.__canonical_name__ + if cname not in PostgresBase.metadata.tables: + raise ValueError(f"Table for {cname} not found") + return PostgresBase.metadata.tables[cname] + + def _eq_filter( + self, + table: Table, + field: str, + value: Any, + ) -> sa.sql.elements.BinaryExpression: + if field == "id": + return table.c.id == UID(value) + + if "." in field: + # magic! + field = field.split(".") # type: ignore + + json_value = serialize_json(value) + return table.c.fields[field].astext == json_value diff --git a/packages/syft/src/syft/store/db/schema.py b/packages/syft/src/syft/store/db/schema.py index 6f687800752..86626c144c3 100644 --- a/packages/syft/src/syft/store/db/schema.py +++ b/packages/syft/src/syft/store/db/schema.py @@ -18,7 +18,11 @@ from ...types.uid import UID -class Base(DeclarativeBase): +class SQLiteBase(DeclarativeBase): + pass + + +class PostgresBase(DeclarativeBase): pass @@ -56,18 +60,20 @@ def create_table( dialect_name = dialect.name fields_type = JSON if dialect_name == "sqlite" else postgresql.JSONB - permissons_type = JSON if dialect_name == "sqlite" else postgresql.ARRAY(sa.String) + permissions_type = JSON if dialect_name == "sqlite" else postgresql.ARRAY(sa.String) storage_permissions_type = ( JSON if dialect_name == "sqlite" else postgresql.ARRAY(sa.String) ) + Base = SQLiteBase if dialect_name == "sqlite" else PostgresBase + if table_name not in Base.metadata.tables: Table( object_type.__canonical_name__, Base.metadata, Column("id", UIDTypeDecorator, primary_key=True, default=uuid.uuid4), Column("fields", fields_type, default={}), - Column("permissions", permissons_type, default=[]), + Column("permissions", permissions_type, default=[]), Column( "storage_permissions", storage_permissions_type, diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 204715aca76..e107694b2bb 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -42,7 +42,8 @@ from ..document_store_errors import StashException from .db import DBManager from .query import Query -from .schema import Base +from .schema import PostgresBase +from .schema import SQLiteBase from .schema import create_table from .sqlite import SQLiteDBManager @@ -91,8 +92,6 @@ def wrapper(self: "ObjectStash[StashT]", *args: Any, **kwargs: Any) -> Any: class ObjectStash(Generic[StashT]): - table: Table - object_type: type[SyftObject] allow_any_type: bool = False def __init__(self, store: DBManager) -> None: @@ -126,6 +125,7 @@ def __len__(self, session: Session = None) -> int: @classmethod def random(cls, **kwargs: dict) -> Self: + """Create a random stash with a random server_uid and root_verify_key. Useful for development.""" db_manager = SQLiteDBManager.random(**kwargs) stash = cls(store=db_manager) stash.db.init_tables() @@ -163,6 +163,7 @@ def session(self) -> Session: def _drop_table(self) -> None: table_name = self.object_type.__canonical_name__ + Base = SQLiteBase if self._is_sqlite() else PostgresBase if table_name in Base.metadata.tables: Base.metadata.tables[table_name].drop(self.db.engine) else: @@ -288,6 +289,8 @@ def get_role( # relative from ...service.user.user import User + Base = SQLiteBase if self._is_sqlite() else PostgresBase + # TODO error handling if Base.metadata.tables.get("User") is None: # if User table does not exist, we assume the user is a guest @@ -624,17 +627,18 @@ def add_permissions( # TODO: should do this in a single transaction # TODO add error handling for permission in permissions: - try: - self.add_permission(permission, session=session).unwrap() - except NotFoundException: - if not ignore_missing: - raise + self.add_permission( + permission, session=session, ignore_missing=ignore_missing + ).unwrap() return None @as_result(NotFoundException) @with_session def add_permission( - self, permission: ActionObjectPermission, session: Session = None + self, + permission: ActionObjectPermission, + session: Session = None, + ignore_missing: bool = False, ) -> None: # TODO add error handling stmt = self.table.update().where(self.table.c.id == permission.uid) @@ -655,7 +659,7 @@ def add_permission( result = session.execute(stmt) session.commit() - if result.rowcount == 0: + if result.rowcount == 0 and not ignore_missing: raise NotFoundException( f"{self.object_type.__name__}: {permission.uid} not found or no permission to change." ) From 43ca0d3a4e1092f35eeef38ca7ca6581516fb027 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 13 Sep 2024 10:19:15 +0200 Subject: [PATCH 162/197] add get_table to basecls --- packages/syft/src/syft/store/db/query.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py index 372f5ab96c4..9222e8069c6 100644 --- a/packages/syft/src/syft/store/db/query.py +++ b/packages/syft/src/syft/store/db/query.py @@ -39,6 +39,10 @@ def __init__(self, object_type: type[SyftObject]) -> None: self.table: Table = self._get_table(object_type) self.stmt: Select = self.table.select() + @abstractmethod + def _get_table(self, object_type: type[SyftObject]) -> Table: + raise NotImplementedError + @staticmethod def get_query_class(dialect: str | Dialect) -> "type[Query]": if isinstance(dialect, Dialect): From 1ef65067ea056514c301c67a52b4fbb5a7d9d7d1 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 13 Sep 2024 10:36:18 +0200 Subject: [PATCH 163/197] fix postgres error --- packages/syft/src/syft/server/server.py | 5 ++--- packages/syft/src/syft/store/db/stash.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index db059c06379..bbda5754ccb 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -415,15 +415,14 @@ def __init__( path=self.get_temp_dir("db"), ) - if reset: - db_config.reset = True - self.db_config = db_config self.db = self.init_stores(db_config=self.db_config) # construct services only after init stores self.services: ServiceRegistry = ServiceRegistry.for_server(self) + if reset: + self.db.reset() self.db.init_tables() self.action_store = self.services.action.stash diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index e107694b2bb..8c677d0383a 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -770,7 +770,7 @@ def add_storage_permission( ) else: stmt = stmt.values( - permissions=func.array_append( + storage_permissions=func.array_append( self.table.c.storage_permissions, permission.server_uid.no_dash ) ) From 8a6921d3d49c16db1843132fd0086f484f53afc3 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 13 Sep 2024 11:01:11 +0200 Subject: [PATCH 164/197] fix queries for postgres --- packages/syft/src/syft/store/db/query.py | 5 +++-- packages/syft/src/syft/store/db/schema.py | 8 +++----- packages/syft/src/syft/store/db/stash.py | 2 +- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py index 372f5ab96c4..939ea199687 100644 --- a/packages/syft/src/syft/store/db/query.py +++ b/packages/syft/src/syft/store/db/query.py @@ -320,7 +320,7 @@ def _eq_filter( field = field.split(".") # type: ignore json_value = serialize_json(value) - return table.c.fields[field] == func.json_quote(json_value) + return table.c.fields[field] == func.to_jsonb(json_value) class PostgresQuery(Query): @@ -363,4 +363,5 @@ def _eq_filter( field = field.split(".") # type: ignore json_value = serialize_json(value) - return table.c.fields[field].astext == json_value + # NOTE: there might be a bug with casting everything to text + return table.c.fields[field].astext == sa.cast(json_value, sa.Text) diff --git a/packages/syft/src/syft/store/db/schema.py b/packages/syft/src/syft/store/db/schema.py index 86626c144c3..7f81e39802e 100644 --- a/packages/syft/src/syft/store/db/schema.py +++ b/packages/syft/src/syft/store/db/schema.py @@ -59,11 +59,9 @@ def create_table( table_name = object_type.__canonical_name__ dialect_name = dialect.name - fields_type = JSON if dialect_name == "sqlite" else postgresql.JSONB - permissions_type = JSON if dialect_name == "sqlite" else postgresql.ARRAY(sa.String) - storage_permissions_type = ( - JSON if dialect_name == "sqlite" else postgresql.ARRAY(sa.String) - ) + fields_type = JSON if dialect_name == "sqlite" else postgresql.JSON + permissions_type = JSON if dialect_name == "sqlite" else postgresql.JSONB + storage_permissions_type = JSON if dialect_name == "sqlite" else postgresql.JSONB Base = SQLiteBase if dialect_name == "sqlite" else PostgresBase diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 8c677d0383a..a20ad44e654 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -254,7 +254,7 @@ def _get_field_filter( if self.db.engine.dialect.name == "sqlite": return table.c.fields[field_name] == func.json_quote(json_value) elif self.db.engine.dialect.name == "postgresql": - return table.c.fields[field_name].astext == json_value + return table.c.fields[field_name].astext == cast(json_value, sa.String) @as_result(SyftException, StashException, NotFoundException) def get_index( From f9334699781d44c1e5723d31c1ef05db5d2aa8a8 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 13 Sep 2024 11:10:48 +0200 Subject: [PATCH 165/197] fix sqlite --- packages/syft/src/syft/service/migration/migration_service.py | 2 +- packages/syft/src/syft/store/db/query.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index b62d2923ff5..b1d461e4e80 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -213,7 +213,7 @@ def _search_stash_for_klass( if issubclass(klass, ActionObject | TwinObject | Action): return context.server.services.action.stash - stashes: dict[str, ObjectStash] = { + stashes: dict[str, ObjectStash] = { # type: ignore t.__canonical_name__: stash for t, stash in context.server.services.stashes.items() } diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py index 60ea83c30ea..3813f17d4df 100644 --- a/packages/syft/src/syft/store/db/query.py +++ b/packages/syft/src/syft/store/db/query.py @@ -324,7 +324,7 @@ def _eq_filter( field = field.split(".") # type: ignore json_value = serialize_json(value) - return table.c.fields[field] == func.to_jsonb(json_value) + return table.c.fields[field] == func.json_quote(json_value) class PostgresQuery(Query): From 501861b4d919cf922afc7a584d4dad09313d7984 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 13 Sep 2024 11:48:12 +0200 Subject: [PATCH 166/197] fix ordering in postgres --- packages/syft/src/syft/service/user/user_service.py | 2 +- packages/syft/src/syft/store/db/query.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/service/user/user_service.py b/packages/syft/src/syft/service/user/user_service.py index 6ffee6756c9..a0d27a00786 100644 --- a/packages/syft/src/syft/service/user/user_service.py +++ b/packages/syft/src/syft/service/user/user_service.py @@ -365,7 +365,7 @@ def get_role_for_credentials( return role elif isinstance(credentials, SyftSigningKey): user = self.stash.get_by_signing_key( - credentials=credentials, + credentials=credentials.verify_key, signing_key=credentials, # type: ignore ).unwrap() else: diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py index 3813f17d4df..c15c645583d 100644 --- a/packages/syft/src/syft/store/db/query.py +++ b/packages/syft/src/syft/store/db/query.py @@ -226,10 +226,12 @@ def order_by( order = order or default_order column = self._get_column(field) + if order.lower() == "asc": - self.stmt = self.stmt.order_by(column) + self.stmt = self.stmt.order_by(sa.cast(column, sa.String).asc()) + elif order.lower() == "desc": - self.stmt = self.stmt.order_by(column.desc()) + self.stmt = self.stmt.order_by(sa.cast(column, sa.String).desc()) else: raise ValueError(f"Invalid sort order {order}") From 0dc509e2be6d15526b3e77f088cf4353aaa79bc5 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 13 Sep 2024 12:40:39 +0200 Subject: [PATCH 167/197] fix string casting in postgres --- packages/syft/src/syft/server/server.py | 1 + packages/syft/src/syft/server/uvicorn.py | 3 +++ packages/syft/src/syft/store/db/query.py | 4 +++- .../syft/util/notebook_ui/components/tabulator_template.py | 2 +- packages/syft/tests/syft/users/user_test.py | 4 ++++ 5 files changed, 12 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index bbda5754ccb..438997b1ad2 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -777,6 +777,7 @@ def named( background_tasks=background_tasks, consumer_type=consumer_type, db_url=db_url, + db_config=db_config, ) def is_root(self, credentials: SyftVerifyKey) -> bool: diff --git a/packages/syft/src/syft/server/uvicorn.py b/packages/syft/src/syft/server/uvicorn.py index 676e6222dbc..e1982953a32 100644 --- a/packages/syft/src/syft/server/uvicorn.py +++ b/packages/syft/src/syft/server/uvicorn.py @@ -68,6 +68,7 @@ class AppSettings(BaseSettings): association_request_auto_approval: bool = False background_tasks: bool = False db_config: DBConfig | None = None + db_url: str | None = None model_config = SettingsConfigDict(env_prefix="SYFT_", env_parse_none_str="None") @@ -233,6 +234,7 @@ def serve_server( association_request_auto_approval: bool = False, background_tasks: bool = False, debug: bool = False, + db_url: str | None = None, ) -> tuple[Callable, Callable]: starting_uvicorn_event = multiprocessing.Event() @@ -262,6 +264,7 @@ def serve_server( "debug": debug, "starting_uvicorn_event": starting_uvicorn_event, "deployment_type": deployment_type, + "db_url": db_url, }, ) diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py index c15c645583d..5c9457bb29c 100644 --- a/packages/syft/src/syft/store/db/query.py +++ b/packages/syft/src/syft/store/db/query.py @@ -347,7 +347,9 @@ def _contains_filter( value: Any, ) -> sa.sql.elements.BinaryExpression: field_value = serialize_json(value) - return table.c.fields[field].contains(field_value) + col = sa.cast(table.c.fields[field], sa.Text) + val = sa.cast(field_value, sa.Text) + return col.contains(val) def _get_table(self, object_type: type[SyftObject]) -> Table: cname = object_type.__canonical_name__ diff --git a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py index 4e93e82c45a..538614b4cb8 100644 --- a/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py +++ b/packages/syft/src/syft/util/notebook_ui/components/tabulator_template.py @@ -24,7 +24,7 @@ def make_links(text: str) -> str: file_pattern = re.compile(r"([\w/.-]+\.py)\", line (\d+)") - return file_pattern.sub(r'\1, line \2', text) + return file_pattern.sub(r'\1, line \2', text) DEFAULT_ID_WIDTH = 110 diff --git a/packages/syft/tests/syft/users/user_test.py b/packages/syft/tests/syft/users/user_test.py index cdc348c9f4f..4c727b4f1fe 100644 --- a/packages/syft/tests/syft/users/user_test.py +++ b/packages/syft/tests/syft/users/user_test.py @@ -1,5 +1,6 @@ # stdlib from secrets import token_hex +import time # third party from faker import Faker @@ -388,6 +389,9 @@ def test_user_view_set_role(worker: Worker, guest_client: DatasiteClient) -> Non admin_client = get_mock_client(worker.root_client, ServiceRole.ADMIN) assert admin_client.account.role == ServiceRole.ADMIN + # wait for the user to be created for sorting purposes + time.sleep(0.01) + admin_client.register( name="Sheldon Cooper", email="sheldon@caltech.edu", From 66851f48490ce12051084f10e4d75811f22684fa Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 13 Sep 2024 12:45:00 +0200 Subject: [PATCH 168/197] deduplicate permissions --- packages/syft/src/syft/store/db/stash.py | 179 +++++++++++------------ 1 file changed, 85 insertions(+), 94 deletions(-) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index e107694b2bb..ae51759d6da 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -618,51 +618,43 @@ def get_ownership_permissions( @as_result(NotFoundException) @with_session - def add_permissions( + def add_permission( self, - permissions: list[ActionObjectPermission], + permission: ActionObjectPermission, session: Session = None, ignore_missing: bool = False, ) -> None: - # TODO: should do this in a single transaction - # TODO add error handling - for permission in permissions: - self.add_permission( - permission, session=session, ignore_missing=ignore_missing + try: + existing_permissions = self._get_permissions_for_uid( + permission.uid, session=session ).unwrap() + except NotFoundException: + if ignore_missing: + return None + raise + + existing_permissions.add(permission.permission_string) + + stmt = self.table.update().where(self.table.c.id == permission.uid) + stmt = stmt.values(permissions=list(existing_permissions)) + session.execute(stmt) + session.commit() + return None @as_result(NotFoundException) @with_session - def add_permission( + def add_permissions( self, - permission: ActionObjectPermission, - session: Session = None, + permissions: list[ActionObjectPermission], ignore_missing: bool = False, + session: Session = None, ) -> None: - # TODO add error handling - stmt = self.table.update().where(self.table.c.id == permission.uid) - if self._is_sqlite(): - stmt = stmt.values( - permissions=func.json_insert( - self.table.c.permissions, - "$[#]", - permission.permission_string, - ) - ) - else: - stmt = stmt.values( - permissions=func.array_append( - self.table.c.permissions, permission.permission_string - ) - ) - - result = session.execute(stmt) - session.commit() - if result.rowcount == 0 and not ignore_missing: - raise NotFoundException( - f"{self.object_type.__name__}: {permission.uid} not found or no permission to change." - ) + for permission in permissions: + self.add_permission( + permission, session=session, ignore_missing=ignore_missing + ).unwrap() + return None @with_session def remove_permission( @@ -741,48 +733,6 @@ def has_storage_permission( ) -> bool: return self.has_storage_permissions([permission], session=session) - @as_result(StashException) - @with_session - def get_all_storage_permissions( - self, session: Session = None - ) -> dict[UID, Set[UID]]: # noqa: UP006 - stmt = select(self.table.c.id, self.table.c.storage_permissions) - results = session.execute(stmt).all() - - return { - UID(row.id): {UID(uid) for uid in row.storage_permissions} - for row in results - } - - @as_result(NotFoundException) - @with_session - def add_storage_permission( - self, permission: StoragePermission, session: Session = None - ) -> None: - stmt = self.table.update().where(self.table.c.id == permission.uid) - if self._is_sqlite(): - stmt = stmt.values( - storage_permissions=func.json_insert( - self.table.c.storage_permissions, - "$[#]", - permission.permission_string, - ) - ) - else: - stmt = stmt.values( - permissions=func.array_append( - self.table.c.storage_permissions, permission.server_uid.no_dash - ) - ) - - result = session.execute(stmt) - session.commit() - if result.rowcount == 0: - raise NotFoundException( - f"{self.object_type.__name__}: {permission.uid} not found or no permission to change." - ) - return None - @with_session def has_storage_permissions( self, permissions: list[StoragePermission], session: Session = None @@ -807,15 +757,71 @@ def has_storage_permissions( result = session.execute(stmt).first() return result is not None + @as_result(StashException) + @with_session + def get_all_storage_permissions( + self, session: Session = None + ) -> dict[UID, Set[UID]]: # noqa: UP006 + stmt = select(self.table.c.id, self.table.c.storage_permissions) + results = session.execute(stmt).all() + + return { + UID(row.id): {UID(uid) for uid in row.storage_permissions} + for row in results + } + + @as_result(NotFoundException) + @with_session + def add_storage_permissions( + self, + permissions: list[StoragePermission], + session: Session = None, + ignore_missing: bool = False, + ) -> None: + for permission in permissions: + self.add_storage_permission( + permission, session=session, ignore_missing=ignore_missing + ).unwrap() + + return None + + @as_result(NotFoundException) + @with_session + def add_storage_permission( + self, + permission: StoragePermission, + session: Session = None, + ignore_missing: bool = False, + ) -> None: + try: + existing_permissions = self._get_storage_permissions_for_uid( + permission.uid, session=session + ).unwrap() + except NotFoundException: + if ignore_missing: + return None + raise + + existing_permissions.add(permission.server_uid) + + stmt = ( + self.table.update() + .where(self.table.c.id == permission.uid) + .values(storage_permissions=[str(uid) for uid in existing_permissions]) + ) + + session.execute(stmt) + @with_session def remove_storage_permission( self, permission: StoragePermission, session: Session = None ) -> None: - # TODO not threadsafe try: - permissions = self._get_storage_permissions_for_uid(permission.uid).unwrap() - permissions.remove(permission.server_uid) - except (NotFoundException, KeyError): + permissions = self._get_storage_permissions_for_uid( + permission.uid, session=session + ).unwrap() + permissions.discard(permission.server_uid) + except NotFoundException: # TODO add error handling to permissions return None @@ -840,18 +846,3 @@ def _get_storage_permissions_for_uid( if result is None: raise NotFoundException(f"No storage permissions found for uid: {uid}") return {UID(uid) for uid in result.storage_permissions} - - @as_result(NotFoundException) - @with_session - def add_storage_permissions( - self, - permissions: list[StoragePermission], - session: Session = None, - ignore_missing: bool = False, - ) -> None: - for permission in permissions: - try: - self.add_storage_permission(permission, session=session).unwrap() - except NotFoundException: - if not ignore_missing: - raise From 638132c30aef6fcd1429aceffbb02440a1a5e7a5 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 13 Sep 2024 13:15:11 +0200 Subject: [PATCH 169/197] fix getlatest --- packages/syft/src/syft/service/sync/sync_stash.py | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/syft/src/syft/service/sync/sync_stash.py b/packages/syft/src/syft/service/sync/sync_stash.py index d6824a3ac89..114ee209af1 100644 --- a/packages/syft/src/syft/service/sync/sync_stash.py +++ b/packages/syft/src/syft/service/sync/sync_stash.py @@ -27,6 +27,7 @@ def get_latest(self, credentials: SyftVerifyKey) -> SyncState | None: states = self.get_all( credentials=credentials, + order_by="created_at", sort_order="desc", limit=1, ).unwrap() From 72256f73dbd79dccdb89fda55a2237a0fc336301 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 13 Sep 2024 13:27:16 +0200 Subject: [PATCH 170/197] fix orderby non-json --- packages/syft/src/syft/store/db/query.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py index 5c9457bb29c..fe96c3fba00 100644 --- a/packages/syft/src/syft/store/db/query.py +++ b/packages/syft/src/syft/store/db/query.py @@ -227,11 +227,14 @@ def order_by( column = self._get_column(field) + if isinstance(column.type, sa.JSON): + column = sa.cast(column, sa.String) + if order.lower() == "asc": - self.stmt = self.stmt.order_by(sa.cast(column, sa.String).asc()) + self.stmt = self.stmt.order_by(column.asc()) elif order.lower() == "desc": - self.stmt = self.stmt.order_by(sa.cast(column, sa.String).desc()) + self.stmt = self.stmt.order_by(column.desc()) else: raise ValueError(f"Invalid sort order {order}") From 86153aecf5e053eb40c73019287d5bb948d39176 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 13 Sep 2024 13:37:37 +0200 Subject: [PATCH 171/197] fix lint CI --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 6e37f408ded..0fc268805f5 100644 --- a/tox.ini +++ b/tox.ini @@ -231,7 +231,7 @@ commands = [testenv:syft.protocol.check] description = Syft Protocol Check deps = - {[testenv:syft-minimal]deps} + {[testenv:syft]deps} changedir = {toxinidir}/packages/syft allowlist_externals = bash From 7bf645211f31e770fb256dfac145fbd9cc3ed9fa Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 13 Sep 2024 13:52:39 +0200 Subject: [PATCH 172/197] fix lint --- tox.ini | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tox.ini b/tox.ini index 0fc268805f5..dd99c917301 100644 --- a/tox.ini +++ b/tox.ini @@ -247,7 +247,7 @@ commands = [testenv:syft.api.snapshot] description = Syft API Snapshot Check deps = - {[testenv:syft-minimal]deps} + {[testenv:syft]deps} changedir = {toxinidir}/packages/syft allowlist_externals = bash From 4865ec6496efb32b3c84f83157e65496ac393312 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 13 Sep 2024 14:12:28 +0200 Subject: [PATCH 173/197] remove uses of find_all --- packages/syft/src/syft/service/blob_storage/service.py | 4 ++-- .../src/syft/service/code_history/code_history_service.py | 6 ++++-- packages/syft/src/syft/store/db/db.py | 1 + 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/service/blob_storage/service.py b/packages/syft/src/syft/service/blob_storage/service.py index 78da77d7eab..055e4d946e4 100644 --- a/packages/syft/src/syft/service/blob_storage/service.py +++ b/packages/syft/src/syft/service/blob_storage/service.py @@ -136,8 +136,8 @@ def mount_azure( def get_files_from_bucket( self, context: AuthedServiceContext, bucket_name: str ) -> list: - bse_list = self.stash.find_all( - context.credentials, bucket_name=bucket_name + bse_list = self.stash.get_all( + context.credentials, filters={"bucket_name": bucket_name} ).unwrap() blob_files = [] diff --git a/packages/syft/src/syft/service/code_history/code_history_service.py b/packages/syft/src/syft/service/code_history/code_history_service.py index e6cea5a4a21..ff2967f169b 100644 --- a/packages/syft/src/syft/service/code_history/code_history_service.py +++ b/packages/syft/src/syft/service/code_history/code_history_service.py @@ -187,11 +187,13 @@ def get_by_func_name_and_user_email( ) -> list[CodeHistory]: user_verify_key = context.server.services.user.user_verify_key(user_email) - kwargs = { + filters = { "id": user_id, "email": user_email, "verify_key": user_verify_key, "service_func_name": service_func_name, } - return self.stash.find_all(credentials=context.credentials, **kwargs).unwrap() + return self.stash.get_all( + credentials=context.credentials, filters=filters + ).unwrap() diff --git a/packages/syft/src/syft/store/db/db.py b/packages/syft/src/syft/store/db/db.py index fe0d7c9b099..d2adca212c8 100644 --- a/packages/syft/src/syft/store/db/db.py +++ b/packages/syft/src/syft/store/db/db.py @@ -76,6 +76,7 @@ def update_settings(self) -> None: def init_tables(self) -> None: Base = SQLiteBase if self.engine.dialect.name == "sqlite" else PostgresBase + if self.config.reset: # drop all tables that we know about Base.metadata.drop_all(bind=self.engine) From 5e1882a55053a44692b88ebc1407016e5fa6d67b Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Fri, 13 Sep 2024 14:19:44 +0200 Subject: [PATCH 174/197] fix instrument imports --- packages/syft/src/syft/service/code/status_service.py | 2 -- packages/syft/src/syft/service/code/user_code_stash.py | 2 -- packages/syft/src/syft/service/dataset/dataset_stash.py | 2 -- packages/syft/src/syft/service/job/job_stash.py | 2 -- packages/syft/src/syft/service/output/output_service.py | 2 -- packages/syft/src/syft/service/project/project_stash.py | 2 -- packages/syft/src/syft/service/request/request_stash.py | 2 -- packages/syft/src/syft/service/user/user_stash.py | 2 -- packages/syft/src/syft/store/db/stash.py | 2 ++ tox.ini | 4 ++-- 10 files changed, 4 insertions(+), 18 deletions(-) diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index a1fa9300bdf..d6c1a56e801 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -8,7 +8,6 @@ from ...store.db.stash import ObjectStash from ...store.document_store import PartitionSettings from ...types.uid import UID -from ...util.trace_decorator import instrument from ..context import AuthedServiceContext from ..response import SyftSuccess from ..service import AbstractService @@ -19,7 +18,6 @@ from .user_code import UserCodeStatusCollection -@instrument @serializable(canonical_name="StatusSQLStash", version=1) class StatusStash(ObjectStash[UserCodeStatusCollection]): settings: PartitionSettings = PartitionSettings( diff --git a/packages/syft/src/syft/service/code/user_code_stash.py b/packages/syft/src/syft/service/code/user_code_stash.py index ba86950d61f..232342bd8d5 100644 --- a/packages/syft/src/syft/service/code/user_code_stash.py +++ b/packages/syft/src/syft/service/code/user_code_stash.py @@ -6,11 +6,9 @@ from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result -from ...util.trace_decorator import instrument from .user_code import UserCode -@instrument @serializable(canonical_name="UserCodeSQLStash", version=1) class UserCodeStash(ObjectStash[UserCode]): settings: PartitionSettings = PartitionSettings( diff --git a/packages/syft/src/syft/service/dataset/dataset_stash.py b/packages/syft/src/syft/service/dataset/dataset_stash.py index bc7af00afdc..aee2a280372 100644 --- a/packages/syft/src/syft/service/dataset/dataset_stash.py +++ b/packages/syft/src/syft/service/dataset/dataset_stash.py @@ -6,11 +6,9 @@ from ...store.document_store_errors import StashException from ...types.result import as_result from ...types.uid import UID -from ...util.trace_decorator import instrument from .dataset import Dataset -@instrument @serializable(canonical_name="DatasetStashSQL", version=1) class DatasetStash(ObjectStash[Dataset]): @as_result(StashException, NotFoundException) diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index fb85cd810d7..358834470fb 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -34,7 +34,6 @@ from ...types.syncable_object import SyncableSyftObject from ...types.uid import UID from ...util.markdown import as_markdown_code -from ...util.trace_decorator import instrument from ...util.util import prompt_warning_message from ..action.action_object import Action from ..action.action_object import ActionObject @@ -735,7 +734,6 @@ def from_job( return info -@instrument @serializable(canonical_name="JobStashSQL", version=1) class JobStash(ObjectStash[Job]): settings: PartitionSettings = PartitionSettings( diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index 3d80a8e4148..5d26ff2cb3e 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -17,7 +17,6 @@ from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syncable_object import SyncableSyftObject from ...types.uid import UID -from ...util.trace_decorator import instrument from ..action.action_object import ActionObject from ..action.action_permissions import ActionObjectREAD from ..context import AuthedServiceContext @@ -180,7 +179,6 @@ def get_sync_dependencies(self, context: AuthedServiceContext) -> list[UID]: return res -@instrument @serializable(canonical_name="OutputStashSQL", version=1) class OutputStash(ObjectStash[ExecutionOutput]): @as_result(StashException) diff --git a/packages/syft/src/syft/service/project/project_stash.py b/packages/syft/src/syft/service/project/project_stash.py index 051c293362d..13dab37bdea 100644 --- a/packages/syft/src/syft/service/project/project_stash.py +++ b/packages/syft/src/syft/service/project/project_stash.py @@ -9,11 +9,9 @@ from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result -from ...util.trace_decorator import instrument from .project import Project -@instrument # noqa: F821 @serializable(canonical_name="ProjectSQLStash", version=1) class ProjectStash(ObjectStash[Project]): @as_result(StashException) diff --git a/packages/syft/src/syft/service/request/request_stash.py b/packages/syft/src/syft/service/request/request_stash.py index 7a2e27c603e..a28fd5842e1 100644 --- a/packages/syft/src/syft/service/request/request_stash.py +++ b/packages/syft/src/syft/service/request/request_stash.py @@ -5,11 +5,9 @@ from ...types.errors import SyftException from ...types.result import as_result from ...types.uid import UID -from ...util.trace_decorator import instrument from .request import Request -@instrument @serializable(canonical_name="RequestStashSQL", version=1) class RequestStash(ObjectStash[Request]): @as_result(SyftException) diff --git a/packages/syft/src/syft/service/user/user_stash.py b/packages/syft/src/syft/service/user/user_stash.py index 3cf8acace75..92fb87d37b3 100644 --- a/packages/syft/src/syft/service/user/user_stash.py +++ b/packages/syft/src/syft/service/user/user_stash.py @@ -6,12 +6,10 @@ from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result -from ...util.trace_decorator import instrument from .user import User from .user_roles import ServiceRole -@instrument @serializable(canonical_name="UserStashSQL", version=1) class UserStash(ObjectStash[User]): @as_result(StashException, NotFoundException) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 693d7456acb..90dfcdcf254 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -38,6 +38,7 @@ from ...types.syft_metaclass import Empty from ...types.syft_object import SyftObject from ...types.uid import UID +from ...util.telemetry import instrument from ..document_store_errors import NotFoundException from ..document_store_errors import StashException from .db import DBManager @@ -91,6 +92,7 @@ def wrapper(self: "ObjectStash[StashT]", *args: Any, **kwargs: Any) -> Any: return wrapper # type: ignore +@instrument class ObjectStash(Generic[StashT]): allow_any_type: bool = False diff --git a/tox.ini b/tox.ini index dd99c917301..6e37f408ded 100644 --- a/tox.ini +++ b/tox.ini @@ -231,7 +231,7 @@ commands = [testenv:syft.protocol.check] description = Syft Protocol Check deps = - {[testenv:syft]deps} + {[testenv:syft-minimal]deps} changedir = {toxinidir}/packages/syft allowlist_externals = bash @@ -247,7 +247,7 @@ commands = [testenv:syft.api.snapshot] description = Syft API Snapshot Check deps = - {[testenv:syft]deps} + {[testenv:syft-minimal]deps} changedir = {toxinidir}/packages/syft allowlist_externals = bash From 3a1957709e1f22f2e472b1a5f703fb3e9fc844df Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Fri, 13 Sep 2024 14:56:26 +0200 Subject: [PATCH 175/197] trying to fix flaky test --- packages/syft/tests/conftest.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/tests/conftest.py b/packages/syft/tests/conftest.py index c9d118b1b34..eacf21eb616 100644 --- a/packages/syft/tests/conftest.py +++ b/packages/syft/tests/conftest.py @@ -115,7 +115,7 @@ def faker(): @pytest.fixture(scope="function") def worker() -> Worker: - worker = sy.Worker.named(name=token_hex(8)) + worker = sy.Worker.named(name=token_hex(16)) yield worker worker.cleanup() del worker @@ -124,7 +124,7 @@ def worker() -> Worker: @pytest.fixture(scope="function") def second_worker() -> Worker: # Used in server syncing tests - worker = sy.Worker.named(name=token_hex(8)) + worker = sy.Worker.named(name=token_hex(16)) yield worker worker.cleanup() del worker From 39a06b752becd4a1c3b740ab1de61d177046d6b4 Mon Sep 17 00:00:00 2001 From: dk Date: Mon, 16 Sep 2024 14:18:27 +0700 Subject: [PATCH 176/197] [protocol_version] define old version for `SyftWorkerImage` and `Notification` - revert protocol json files back to dev's state Co-authored-by: Shubham Gupta --- .../src/syft/protocol/protocol_version.json | 1171 +--------------- .../src/syft/protocol/releases/0.9.1.json | 1178 +++++++++++++++++ .../service/notification/notifications.py | 28 +- .../src/syft/service/worker/worker_image.py | 28 +- 4 files changed, 1234 insertions(+), 1171 deletions(-) create mode 100644 packages/syft/src/syft/protocol/releases/0.9.1.json diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 400df20c628..5f9f6a8fab1 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -1,1172 +1,5 @@ { - "dev": { - "object_versions": { - "SyftObjectVersioned": { - "1": { - "version": 1, - "hash": "7c842dcdbb57e2528ffa690ea18c19fff3c8a591811d40cad2b19be3100e2ff4", - "action": "add" - } - }, - "BaseDateTime": { - "1": { - "version": 1, - "hash": "614db484b1950be729902b1861bd3a7b33899176507c61cef11dc0d44611cfd3", - "action": "add" - } - }, - "SyftObject": { - "1": { - "version": 1, - "hash": "bb70d874355988908d3a92a3941d6613a6995a4850be3b6a0147f4d387724406", - "action": "add" - } - }, - "PartialSyftObject": { - "1": { - "version": 1, - "hash": "19a995fcc2833f4fab24584fd99b71a80c2ef1f13c06f83af79e4482846b1656", - "action": "add" - } - }, - "ServerMetadata": { - "1": { - "version": 1, - "hash": "1691c7667eca86b20c4189e90ce4e643dd41fd3682cdb69c6308878f2a6f135c", - "action": "add" - } - }, - "User": { - "1": { - "version": 1, - "hash": "2df4b68182c558dba5485a8a6867acf2a5c341b249ad67373a504098aa8c4343", - "action": "add" - }, - "2": { - "version": 2, - "hash": "af6fb5b2e1606e97838f4a60f0536ad95db606d455e94acbd1977df866608a2c", - "action": "add" - } - }, - "UserUpdate": { - "1": { - "version": 1, - "hash": "1bf6707c69b809c804fb939c7c37d787c2f6889508a4bec37d24221af2eb777a", - "action": "add" - } - }, - "UserCreate": { - "1": { - "version": 1, - "hash": "49d6087e2309ba59987f3126e286e74b3a66492a08ad82fa507ea17d52ce78e3", - "action": "add" - } - }, - "UserSearch": { - "1": { - "version": 1, - "hash": "9ac946338cca68d00d1696a57943442f062628ec3daf53077d0bdd3f72cd9fa0", - "action": "add" - } - }, - "UserView": { - "1": { - "version": 1, - "hash": "0b52d758e31d5889c9cd88afb467aae4a74e34a5276924e07012243c34d300fe", - "action": "add" - } - }, - "UserViewPage": { - "1": { - "version": 1, - "hash": "1cd6528d02ec180f080d5c35f0da760d8a59af9da7baaa9c17c1c7cedcc858fa", - "action": "add" - } - }, - "UserPrivateKey": { - "1": { - "version": 1, - "hash": "4817d8147aba94373f320dcd90e65f097cf6e5a2ef353aa8520e23128d522b5d", - "action": "add" - } - }, - "LinkedObject": { - "1": { - "version": 1, - "hash": "d80f5ac7f51a9383be1a3cb334d56ae50e49733ed3199f3b6b5d6febd9de410b", - "action": "add" - } - }, - "DateTime": { - "1": { - "version": 1, - "hash": "394abb554114ead4d63c36e3fe83ac018dead4b21a8465174009577c46d54c58", - "action": "add" - } - }, - "ReplyNotification": { - "1": { - "version": 1, - "hash": "84102dfc59d711b03c2f3d3a6ecaca000b6835f1bbdd9af801057f7aacb5f1d0", - "action": "add" - } - }, - "Notification": { - "1": { - "version": 1, - "hash": "cbd07fe9c549c5f05d0880d1417f18a5bfaa71eba65c63ff9cb761e274fffc54", - "action": "add" - } - }, - "CreateNotification": { - "1": { - "version": 1, - "hash": "7e426c946b7d5db6f9427960ec16042f3018091d835ca5966f3568c324a2ab53", - "action": "add" - } - }, - "UserNotificationActivity": { - "1": { - "version": 1, - "hash": "422fd01c6d9af38688a9982abd34e80794a1f6ddd444cca225d77f49189847a9", - "action": "add" - } - }, - "NotificationPreferences": { - "1": { - "version": 1, - "hash": "a42f06b367e7c6cbabcbf3cfcc84d1ca0873e457d972ebd060e87c9d6185f62b", - "action": "add" - } - }, - "NotifierSettings": { - "1": { - "version": 1, - "hash": "65c8ab814d35fac32f68d3000756692592cc59940f30e3af3dcdfa2328755b9d", - "action": "add" - }, - "2": { - "version": 2, - "hash": "be8b52597fc628d1b7cd22b776ee81416e1adbb04a45188778eb0e32ed1416b4", - "action": "add" - } - }, - "StoreConfig": { - "1": { - "version": 1, - "hash": "a9997fce6a8a0ed2884c58b8eb9382f8554bdd18fff61f8bf0451945bcff12c7", - "action": "add" - } - }, - "BaseConfig": { - "1": { - "version": 1, - "hash": "10bd7566041d0f0a3aa295367785fdcc2c5bbf0ded984ac9230754f37496a6a7", - "action": "add" - }, - "2": { - "version": 2, - "hash": "890d2879ac44611db9b88ba9334a721130d0ac3aa18a303fa9e4081f14b9b8c7", - "action": "add" - } - }, - "ServiceConfig": { - "1": { - "version": 1, - "hash": "28af8a296f5ff63de50438277eaa1f4380682e6aca9f2ca28320d7a444825e88", - "action": "add" - }, - "2": { - "version": 2, - "hash": "93dfab144e0b0884c602358b3a9ce889bb29ab96e3b4adcfe3cef47a31694a9a", - "action": "add" - } - }, - "LibConfig": { - "1": { - "version": 1, - "hash": "ee8f0e3f6aae81948d72e30226645e8eb5d312a6770411a1edca748168c467c0", - "action": "add" - }, - "2": { - "version": 2, - "hash": "a8a78a8d726ee9e79f95614f3d0fa5b85edc6fce7be7651715208669be93e0e3", - "action": "add" - } - }, - "SyftImageRegistry": { - "1": { - "version": 1, - "hash": "67e18903e41cba1afe136adf29d404b63ec04fea6e928abb2533ec4fa52b246b", - "action": "add" - } - }, - "SyftWorkerImage": { - "1": { - "version": 1, - "hash": "3771910757bf7534369ceb3e6cf1fc2e3a5438d8d094a67b0b548d501a0ec63f", - "action": "add" - } - }, - "SyftWorker": { - "1": { - "version": 1, - "hash": "9d897f6039eabe48dfa8e8d5c5cdcb283b0375b4c64571b457777eaaf3fb1920", - "action": "add" - } - }, - "WorkerPool": { - "1": { - "version": 1, - "hash": "16efc5dd2596ae744fd611c8f46af9eaec1bd5729eb20e85e9fd2f31df402564", - "action": "add" - } - }, - "MarkdownDescription": { - "1": { - "version": 1, - "hash": "31a73f8824cad1636a55d14b6a1074cdb071d0d4e16e86baaa3d4f63a7e80134", - "action": "add" - } - }, - "HTMLObject": { - "1": { - "version": 1, - "hash": "97f2e93f5ceaa88015047186f66a17ff13df2a6b7925b41331f9e19d5a515a9f", - "action": "add" - } - }, - "PwdTokenResetConfig": { - "1": { - "version": 1, - "hash": "0415a272428f22add4896c64aa9f29c8c1d35619e2433da6564eb5f1faff39ac", - "action": "add" - } - }, - "ServerSettingsUpdate": { - "1": { - "version": 1, - "hash": "1e4260ad879ae80728c3ffae2cd1d48759abd51f9d0960d4b25855cdbb4c506b", - "action": "add" - }, - "2": { - "version": 2, - "hash": "23b2716e9dceca667e228408e2416c82f11821e322e5bccf1f83406f3d09abdc", - "action": "add" - }, - "3": { - "version": 3, - "hash": "335c7946f2e52d09c7b26f511120cd340717c74c5cca9107e84f839da993c55c", - "action": "add" - }, - "4": { - "version": 4, - "hash": "8d7a41992c39c287fcb46383bed429ce75d3c9524ced8c86b88c26dd0232e2fe", - "action": "add" - } - }, - "ServerSettings": { - "1": { - "version": 1, - "hash": "5a1e7470cbeaaae5b80ac9beecb743734f7e4e42d429a09ea8defa569a5ddff1", - "action": "add" - }, - "2": { - "version": 2, - "hash": "7727ea54e494dc9deaa0d1bd38ac8a6180bc192b74eec5659adbc338a19e21f5", - "action": "add" - }, - "3": { - "version": 3, - "hash": "997667e1cba22d151857aacc2caba6b1ca73c1648adbd03461dc74a0c0c372b3", - "action": "add" - }, - "4": { - "version": 4, - "hash": "b8067777967a0e06733433e179e549caaf501419d62f7e8474ee33b839e3890d", - "action": "add" - } - }, - "APIEndpoint": { - "1": { - "version": 1, - "hash": "faa1cf9336a0d1233868c8c57745ff38c0be60399dc1acd0c0e8dd440e405dbd", - "action": "add" - } - }, - "LibEndpoint": { - "1": { - "version": 1, - "hash": "a585c83a33a019d363ae5a0c6d4197193654307c19a4829dfbf8a8cfd2c1842a", - "action": "add" - } - }, - "SignedSyftAPICall": { - "1": { - "version": 1, - "hash": "2f959455f7130f4e59360b8aa58f19785b76eaa0f8a5a9188a6cbf32b31311ca", - "action": "add" - } - }, - "SyftAPICall": { - "1": { - "version": 1, - "hash": "59e89e7b9ea30deaed64d1ffd9bc0769b999d3082b305428432c1f5be36c6343", - "action": "add" - } - }, - "SyftAPIData": { - "1": { - "version": 1, - "hash": "820b279c581cafd9bb5009702d4e3db22ec3a3156676426304b9038dad260a24", - "action": "add" - } - }, - "SyftAPI": { - "1": { - "version": 1, - "hash": "cc13ab058ee36748c14b0d4bd9b9e894c7566fff09cfa4170b3eece520169f15", - "action": "add" - } - }, - "HTTPConnection": { - "1": { - "version": 1, - "hash": "bf10f81646c71069c76292b1237b4a3de1e507264392c5c591d067636ce6fb46", - "action": "add" - } - }, - "PythonConnection": { - "1": { - "version": 1, - "hash": "28010778b5e3463ff6960a0e2224818de00bc7b5e6f892192e02e399ccbe18b5", - "action": "add" - } - }, - "ActionDataEmpty": { - "1": { - "version": 1, - "hash": "e0e4a5cf18d05b6b747addc048515c6f2a5f35f0766ebaee96d898cb971e1c5b", - "action": "add" - } - }, - "ObjectNotReady": { - "1": { - "version": 1, - "hash": "8cf471e205cd0893d6aae5f0227d14db7df1c9698da08a3ab991f59132d17fe9", - "action": "add" - } - }, - "ActionDataLink": { - "1": { - "version": 1, - "hash": "3469478343439e411b761c270eec63eb3d533e459ad72d0965158c3a6cdf3b9a", - "action": "add" - } - }, - "Action": { - "1": { - "version": 1, - "hash": "021826d7c6f69bd0283d025d40661f3ffbeba8810ca94de01344f6afbdae62cd", - "action": "add" - } - }, - "ActionObject": { - "1": { - "version": 1, - "hash": "0a5f4bc343cb114a251f06686ecdbb59d74bfb3d29a098b176699deb35a1e683", - "action": "add" - } - }, - "AnyActionObject": { - "1": { - "version": 1, - "hash": "b3c44c7788c59c03fa1baeec656c2ca6e633f4cbd4b23ff7ece6ee94c38449f0", - "action": "add" - } - }, - "BlobFile": { - "1": { - "version": 1, - "hash": "d99239100f1cb0b73c69b2ad7cab01a06909cc3a4976ba2b3b67cf6fe5e2f516", - "action": "add" - } - }, - "BlobFileOBject": { - "1": { - "version": 1, - "hash": "6c40dab2c8d2220d4fff7cc653d76cc026a856db7e2b5713b6341e255adc7ea2", - "action": "add" - } - }, - "SecureFilePathLocation": { - "1": { - "version": 1, - "hash": "ea5978b98d7773d221665b450454c9130c103a5c850669a0acd620607cd614b7", - "action": "add" - } - }, - "SeaweedSecureFilePathLocation": { - "1": { - "version": 1, - "hash": "3fc9bfc8c1b1cf660c9747e8c1fe3eb2220e78d4e3b5d6b5c5f29a07a77ebf3e", - "action": "add" - } - }, - "AzureSecureFilePathLocation": { - "1": { - "version": 1, - "hash": "090a9e962eeb655586ee966c5651d8996363969818a38f9a486fd64d33047e05", - "action": "add" - } - }, - "BlobStorageEntry": { - "1": { - "version": 1, - "hash": "afdc6a1d8a24b1ee1ed9d3e79f5bac64b4f0d9d36800f07f10be0b896470345f", - "action": "add" - } - }, - "BlobStorageMetadata": { - "1": { - "version": 1, - "hash": "9d4b61ac4ea1910c2f7c767a50a6a52544a24663548f069e79bd906f11b538e4", - "action": "add" - } - }, - "CreateBlobStorageEntry": { - "1": { - "version": 1, - "hash": "ffc3cbfeade67d074dc5bf7d655a1eb8c83630076028a72b3cc4548f3b413e14", - "action": "add" - } - }, - "BlobRetrieval": { - "1": { - "version": 1, - "hash": "c422c74b89a9349742acaa848566fe18bfef1a83333458b858c074baed37a859", - "action": "add" - } - }, - "SyftObjectRetrieval": { - "1": { - "version": 1, - "hash": "b2b62447445adc4cd0b77ab59d6fa56624dd316fb50281e570daad07556b6db2", - "action": "add" - } - }, - "BlobRetrievalByURL": { - "1": { - "version": 1, - "hash": "4db0e3b7a6334d3835356d8393866711e243e360af25a95f3cc4066f032404b5", - "action": "add" - } - }, - "BlobDeposit": { - "1": { - "version": 1, - "hash": "6eb5cc57dc763126bfc6ec5a2b79d02e77eadf9d9efb1888a5c366b7799c1c24", - "action": "add" - } - }, - "WorkerSettings": { - "1": { - "version": 1, - "hash": "dca33003904a71688e5b07db65f8833eb4de8135aade7154076b8eafbb94d26b", - "action": "add" - }, - "2": { - "version": 2, - "hash": "91c375dd40d06c81fc6403751ee48cbc94b9877f91e65a7e302303218dfe71fa", - "action": "add" - } - }, - "HTTPServerRoute": { - "1": { - "version": 1, - "hash": "938245604a9c7e50001299afff5b669b2548364e356fed22a22780497831bf81", - "action": "add" - } - }, - "PythonServerRoute": { - "1": { - "version": 1, - "hash": "a068d8f942d55ecb6d45af88a27c6ebf208584275bf589cbc308df3f774ab9a9", - "action": "add" - } - }, - "VeilidServerRoute": { - "1": { - "version": 1, - "hash": "e676bc165601d2ede69707a4b6168ed4674f3f98887026d098a2dd4da4dfd097", - "action": "add" - } - }, - "EnclaveMetadata": { - "1": { - "version": 1, - "hash": "8d2dfafa01ec909c080a790cf15a8fc78e00382d3bfe6207098ceb25a60b9c53", - "action": "add" - } - }, - "CustomEndpointActionObject": { - "1": { - "version": 1, - "hash": "c7addbaf2777707f3e91e5c1e092343476cd22efc4ec8617f39ccf76e61a5a14", - "action": "add" - }, - "2": { - "version": 2, - "hash": "846ba36e8737a1bec16853c9de54c4948450009278e0b76fe7e3355ef9e70089", - "action": "add" - } - }, - "DataSubject": { - "1": { - "version": 1, - "hash": "582cdf9e82b5d6915b7f09f7c0d5f08328b11a2ce9b0198e5083f1672c2e2bf5", - "action": "add" - } - }, - "DataSubjectCreate": { - "1": { - "version": 1, - "hash": "5a8423c2690d55f425bfeecc87cd4a797a75d88ebb5fbda754d4f269b62d2ceb", - "action": "add" - } - }, - "Contributor": { - "1": { - "version": 1, - "hash": "30c32bd44098f00e0b15496be441763b6e50af8b12d3d2bef33aca6287193876", - "action": "add" - } - }, - "Asset": { - "1": { - "version": 1, - "hash": "000abc78719611c106295cf12b1690b7e5411dc1bb9db9d4afd22956da90d1f4", - "action": "add" - } - }, - "CreateAsset": { - "1": { - "version": 1, - "hash": "357d52576cb12b24fb3980342bb49a562b065c0e4419e87d34176340628c7309", - "action": "add" - } - }, - "Dataset": { - "1": { - "version": 1, - "hash": "0ca6b0b4a3aebb2c8f351668075b44951bb20d1e23a779b82109124f334ce3a4", - "action": "add" - } - }, - "DatasetPageView": { - "1": { - "version": 1, - "hash": "aa0dd69637281b80d5523b4409a2c7e89db114c9fe79c858063c6dadff8977d1", - "action": "add" - }, - "2": { - "version": 2, - "hash": "be1ca6dcd0b3aa0481ce5dce737e78432d06a78ad0c701aaf136be407c798352", - "action": "add" - } - }, - "CreateDataset": { - "1": { - "version": 1, - "hash": "7e02dfa89540c3dbebacbb13810d95cdc4e36db31d56cffed7ab54abe25716c9", - "action": "add" - } - }, - "SyftLog": { - "1": { - "version": 1, - "hash": "1bcd71e5bf3f0db3bba0996f33b6b2bde3489b9c71f11e6b30c3495c76a8f53f", - "action": "add" - } - }, - "JobItem": { - "2": { - "version": 2, - "hash": "b087d0c62b7d304c6ca80e4fb0e8a7f2a444be8f8cba57490dc09aeb98033105", - "action": "add" - } - }, - "ExecutionOutput": { - "1": { - "version": 1, - "hash": "e36c71685edf5276a3427cb6749550486d3a177c1dcf73dd337ab2a73c0ce6b5", - "action": "add" - } - }, - "TwinObject": { - "1": { - "version": 1, - "hash": "4f31243fb348dbb083579afd6f638d75af010cb53d19bfba59b74afff41ccbbb", - "action": "add" - } - }, - "PolicyRule": { - "1": { - "version": 1, - "hash": "44d1ca1db97be46f66558aa1a729ff31bf8e113c6a913b11aedf9d6b6ad5b7b5", - "action": "add" - } - }, - "CreatePolicyRule": { - "1": { - "version": 1, - "hash": "342bb723526d445151a0435f57d251f4c1219f8ae7cca3e8e9fce52e2ee1b8b1", - "action": "add" - } - }, - "CreatePolicyRuleConstant": { - "1": { - "version": 1, - "hash": "78b54832cb0468a87013bc36bc11d4759874ca1b5065a1b711f1e5ef5d94c2df", - "action": "add" - } - }, - "Matches": { - "1": { - "version": 1, - "hash": "dd6d91ddb2ec5eaf60be2b0899ecfdb9a15f7904aa39d2f4d9bb2d7b793040e6", - "action": "add" - } - }, - "PreFill": { - "1": { - "version": 1, - "hash": "c7aefb11dc4c4569dcd1e6988371047a32a8be1b32ad46d12adba419a19769ad", - "action": "add" - } - }, - "UserOwned": { - "1": { - "version": 1, - "hash": "c8738dc3d8c2a5ef461b85a0467c3dff53dab16b54a4d12b44b1477906aef51d", - "action": "add" - } - }, - "MixedInputPolicy": { - "1": { - "version": 1, - "hash": "37bb12d950518d9579c8ec7c4cc22ac731ea82caf8c1370dd0b0a82b46462dde", - "action": "add" - } - }, - "ExactMatch": { - "1": { - "version": 1, - "hash": "5eb37edbf5e451d942e599247f3eaed923c1fe9d91eefdba02bf06503f6cc08d", - "action": "add" - } - }, - "OutputHistory": { - "1": { - "version": 1, - "hash": "9366db79d131f8c65e5a4ff12c90e2aa0c11e302debe06e46eeb93b26e2aaf61", - "action": "add" - } - }, - "OutputPolicyExecuteCount": { - "1": { - "version": 1, - "hash": "2a77e5ed5c7b0391147562651ad4061e20b11745c191fbc34cb549da37ba72dd", - "action": "add" - } - }, - "OutputPolicyExecuteOnce": { - "1": { - "version": 1, - "hash": "5589c00d127d9eb1f5ccf3a16def8219737784d57bb3bf9be5cb6d83325ef436", - "action": "add" - } - }, - "EmptyInputPolicy": { - "1": { - "version": 1, - "hash": "7ef81cfd223be0064600e1503f8b04bafc16385e27730e9319466e68a077c68b", - "action": "add" - } - }, - "UserPolicy": { - "1": { - "version": 1, - "hash": "74373bb71a334f4dcf77623ae10ff5b1c7e5b3006f65f2051ffb1e01f422f982", - "action": "add" - } - }, - "SubmitUserPolicy": { - "1": { - "version": 1, - "hash": "ec4e808eb39613bcdbbbf9ffb3267612084a9d99880a2f3bee3ef32d46329c02", - "action": "add" - } - }, - "UserCodeStatusCollection": { - "1": { - "version": 1, - "hash": "735ecf2d4abb1e7d19b2e751d880f32b01ce267ba10e417ef1b440be3d94d8f1", - "action": "add" - } - }, - "UserCode": { - "1": { - "version": 1, - "hash": "3bcd14413b9c4fbde7c5612c2ed713518340280b5cff89cf2aaaf1c77c4037a8", - "action": "add" - } - }, - "SubmitUserCode": { - "1": { - "version": 1, - "hash": "d2bb8cfe12f070b4adafded78ce01900c5409bd83f055f94b1e285745ef65a76", - "action": "add" - } - }, - "UserCodeExecutionResult": { - "1": { - "version": 1, - "hash": "1f4cbc62caac4dd193f427306405dc7a099ae744bea5830cf57149ce71c1e589", - "action": "add" - } - }, - "UserCodeExecutionOutput": { - "1": { - "version": 1, - "hash": "c1d53300a39dbbb437d7d5a1257bd175a067b1065f4099a0938fac7540035258", - "action": "add" - }, - "2": { - "version": 2, - "hash": "3e104e39b4ab53c950e61e4f7e92ce935cf96a5100de301de9bf297eb7e5787e", - "action": "add" - } - }, - "CodeHistory": { - "1": { - "version": 1, - "hash": "e3ef5346f108257828f364d22b12d9311812c9cf843200afef5dc4d9302f9b21", - "action": "add" - } - }, - "CodeHistoryView": { - "1": { - "version": 1, - "hash": "8b8b97d334b51d1ce0a9efab722411ff25caa3f12be319105954497e0a306eb2", - "action": "add" - } - }, - "CodeHistoriesDict": { - "1": { - "version": 1, - "hash": "01d7dcd4b21525a06e4484d8699a4a34a5c84f1f6026ec55e32eb30412742601", - "action": "add" - } - }, - "UsersCodeHistoriesDict": { - "1": { - "version": 1, - "hash": "4ed8b83973258ea19a1f91feb2590ff73b801be86f4296cc3db48f6929ff784c", - "action": "add" - } - }, - "SyftObjectMigrationState": { - "1": { - "version": 1, - "hash": "ee83315828551f18904bab18e0cac48896493620561215b04cc448e6ce5834af", - "action": "add" - } - }, - "StoreMetadata": { - "1": { - "version": 1, - "hash": "8de9a22a2765ef976bc161cb0704347d30350c085da8c8ffa876065cfca3e5fd", - "action": "add" - } - }, - "MigrationData": { - "1": { - "version": 1, - "hash": "cb96b8c8413609e1224341d1b0dd1efb08387c0ff7b0ff65eba36c0b104c9ed1", - "action": "add" - }, - "2": { - "version": 2, - "hash": "1d1b14c196221ecf6d644d7dcaa32ac9e90361b2687fa83161ff399ebc6df1bd", - "action": "add" - } - }, - "Change": { - "1": { - "version": 1, - "hash": "75fb9a5cd4e76b189ebe130a421d3921a0c251947a48bbb92a2ef1c315dc3c16", - "action": "add" - } - }, - "ChangeStatus": { - "1": { - "version": 1, - "hash": "c914a6f7637b555a51b71e8e197e591f7a2e28121e29b5dd586f87e0383d179d", - "action": "add" - } - }, - "ActionStoreChange": { - "1": { - "version": 1, - "hash": "1a803bb08924b49f3114fd46e0e132f819d4d56be5e03a27e9fe90947ca26e85", - "action": "add" - } - }, - "CreateCustomImageChange": { - "1": { - "version": 1, - "hash": "c3dbea3f49979fdcc517c0d13cd02739ca2fe86b370c42496a224f142ae31562", - "action": "add" - } - }, - "CreateCustomWorkerPoolChange": { - "1": { - "version": 1, - "hash": "0355793dd58b364dcb84fff29714b6a26446bead3ba95c6d75e3200008e580f4", - "action": "add" - } - }, - "Request": { - "1": { - "version": 1, - "hash": "1d69f5f0074114f99aa29c5ee77cb20b9151e5b50e77b026f11c3632a12efadf", - "action": "add" - } - }, - "RequestInfo": { - "1": { - "version": 1, - "hash": "779562547744ebed64548f8021647292604fdf4256bf79685dfa14a1e56cc27b", - "action": "add" - } - }, - "RequestInfoFilter": { - "1": { - "version": 1, - "hash": "bb881a003032f4676321218d7cd09580f4d64fccaa1cf9e118fdcd5c73c3d3a8", - "action": "add" - } - }, - "SubmitRequest": { - "1": { - "version": 1, - "hash": "6c38b6ffd0a6f7442746e68b9ace7b21cb1dca7d2031929db5f9a302a280403f", - "action": "add" - } - }, - "ObjectMutation": { - "1": { - "version": 1, - "hash": "ce88096760ce9334599c8194ec97b0a1470651ad680d9d21b8826a0df0af2a36", - "action": "add" - } - }, - "EnumMutation": { - "1": { - "version": 1, - "hash": "5173fda73df17a344eb663b7692cca48bd46bf1773455439836b852cd165448c", - "action": "add" - } - }, - "UserCodeStatusChange": { - "1": { - "version": 1, - "hash": "89aaf7f1368c782e3a1b9e79988877f6eaa05ab84365f7d321b757fde7fe86e7", - "action": "add" - } - }, - "SyncedUserCodeStatusChange": { - "1": { - "version": 1, - "hash": "d9ad2d341eb645bd50d06330cd30fd4c266f93e37b9f5391d58b78365fc440e6", - "action": "add" - } - }, - "TwinAPIContextView": { - "1": { - "version": 1, - "hash": "e099eef32cb3a8a806cbdc54cc7fca96bed3d60344bd571163ec049db407938b", - "action": "add" - } - }, - "CustomAPIView": { - "1": { - "version": 1, - "hash": "769e96bebd05736ab860591670fb6da19406239b0104ddc71bd092a134335146", - "action": "add" - } - }, - "CustomApiEndpoint": { - "1": { - "version": 1, - "hash": "ec4a217585336d1b59c93c18570443a63f4fbb24d2c088fbacf80bcf389d23e8", - "action": "add" - } - }, - "PrivateAPIEndpoint": { - "1": { - "version": 1, - "hash": "6d7d143432c2811c520ab6dade005ba40173b590e5c676be04f5921b970ef938", - "action": "add" - } - }, - "PublicAPIEndpoint": { - "1": { - "version": 1, - "hash": "3bf51fc33aa8feb1abc9d0ef792e8889da31a57050430e0bd8e17f2065ff8734", - "action": "add" - } - }, - "UpdateTwinAPIEndpoint": { - "1": { - "version": 1, - "hash": "851e59412716e73c7f70a696619e0b375ce136b43f6fe2ea784747091caba5d8", - "action": "add" - } - }, - "CreateTwinAPIEndpoint": { - "1": { - "version": 1, - "hash": "3d0b84dae95ebcc6647b5aabe54e65b3c6bf957665fde57d8037806a4aac13be", - "action": "add" - } - }, - "TwinAPIEndpoint": { - "1": { - "version": 1, - "hash": "d1947b8f9c80d6c9b443e5a9f0758afa8849a5f12b9a511feefd7e4f82c374f4", - "action": "add" - } - }, - "SyncState": { - "1": { - "version": 1, - "hash": "9a3f0bb973858b55bc766c9770c4d9abcc817898f797d94a89938650c0c67868", - "action": "add" - } - }, - "ServerPeer": { - "1": { - "version": 1, - "hash": "0d5f252018e324ea0d2dcb5c2ad8bd15707220565fce4f14de7f63a8f9e4391b", - "action": "add" - } - }, - "ServerPeerUpdate": { - "1": { - "version": 1, - "hash": "0b854b57db7a18118c1fd8f31495b2ba4eeb9fbe4f24c631ff112418a94570d3", - "action": "add" - } - }, - "AssociationRequestChange": { - "1": { - "version": 1, - "hash": "0134ac0002879c85fc9ddb06bed6306a8905c8434b0a40d3a96ce24a7bd4da90", - "action": "add" - } - }, - "QueueItem": { - "1": { - "version": 1, - "hash": "6ba7a6e0413a59cf1997dc94c67615d6acab89bceee989f70239eea556789c5a", - "action": "add" - }, - "2": { - "version": 2, - "hash": "1d8615f6daabcd2a285b2f36fd7bef1df76cdd119dd49c02069c50fd1b9c3ff4", - "action": "add" - } - }, - "ActionQueueItem": { - "1": { - "version": 1, - "hash": "a06effcbba3b76435daf6ca518611433bb603d62f913d703685a65fc49d2b0e9", - "action": "add" - }, - "2": { - "version": 2, - "hash": "bfda6ef87e4045d663324bb91a215ea06e1f173aec1fb4d9ddd337cdc1f0787f", - "action": "add" - } - }, - "APIEndpointQueueItem": { - "1": { - "version": 1, - "hash": "626341cefd3543be351e6060a5ff273a5f23bc257467e4a1ee1b6d63951cfb33", - "action": "add" - }, - "2": { - "version": 2, - "hash": "3a46370205152fa23a7d2bfa47130dbf2e2bc7ef31f6d3fe4c92fd8d683770b5", - "action": "add" - } - }, - "ZMQClientConfig": { - "1": { - "version": 1, - "hash": "36ee8f75067d5144f0ed062cdc79466caae16b7a128231d89b6b430174843bde", - "action": "add" - } - }, - "OnDiskBlobDeposit": { - "1": { - "version": 1, - "hash": "817bf1bee4a35bfa1cd25d6779a10d8d180b1b3f1e837952f81f48b9411d1970", - "action": "add" - } - }, - "RemoteConfig": { - "1": { - "version": 1, - "hash": "179d067099a178d748c6d9a0477e8de7c3b55577439669eca7150258f2409567", - "action": "add" - } - }, - "AzureRemoteConfig": { - "1": { - "version": 1, - "hash": "a143811fec0da5fd881e927643ef667c91c78a2c90519cf88da7da20738bd187", - "action": "add" - } - }, - "SeaweedFSBlobDeposit": { - "1": { - "version": 1, - "hash": "febeb2a2ce81aa2c512e4c6b611b582984042aafa0541403d4584662273a166c", - "action": "add" - } - }, - "SQLiteStoreConfig": { - "1": { - "version": 1, - "hash": "ad062a5f863ae84683867d2a6a5e1d4420c010a64b88bc7b392106e33d71ac03", - "action": "add" - } - }, - "NumpyArrayObject": { - "1": { - "version": 1, - "hash": "05dd2917b7692b3daf4e7ad083a46fa7ec7a2be8faac8d4a654809189c986443", - "action": "add" - } - }, - "NumpyScalarObject": { - "1": { - "version": 1, - "hash": "8753e5c78270a5cacbf0439447724772f4765351a4a8b58b0a5c416a6b2c8b6e", - "action": "add" - } - }, - "NumpyBoolObject": { - "1": { - "version": 1, - "hash": "331c44f8fa3d0a077f1aaad7313bae2c43b386d04def7b8bedae9fdf7690134d", - "action": "add" - } - }, - "PandasDataframeObject": { - "1": { - "version": 1, - "hash": "5e8018364cea31d5f185a901da4ab89846b02153ee7d041ee8a6d305ece31f90", - "action": "add" - } - }, - "PandasSeriesObject": { - "1": { - "version": 1, - "hash": "b8bd482bf16fc7177e9778292cd42f8835b6ced2ce8dc88908b4b8e6d7c7528f", - "action": "add" - } - }, - "DataSubjectMemberRelationship": { - "1": { - "version": 1, - "hash": "0810483ea76ea10c8f286c6035dc0b2085291f345183be50c179f3a05a577110", - "action": "add" - } - }, - "ProjectEvent": { - "1": { - "version": 1, - "hash": "dc0486c52daebd5e98c2b3b03ffd9a9a14bc3d86d8dc0c23e41ebf6c31fe2ffb", - "action": "add" - } - }, - "ProjectThreadMessage": { - "1": { - "version": 1, - "hash": "99256d7592577d1e37df94a06eabc0a287f2d79e144c51fd719315e278edb46d", - "action": "add" - } - }, - "ProjectMessage": { - "1": { - "version": 1, - "hash": "b5004b6354f71b19c81dd5f4b20bf446e0b959f5608a22707e96b944dd8175b0", - "action": "add" - } - }, - "ProjectRequestResponse": { - "1": { - "version": 1, - "hash": "52162a8a779a4a301d8755691bf4cf994c86b9f650f9e8c8a923b44e635b1bc0", - "action": "add" - } - }, - "ProjectRequest": { - "1": { - "version": 1, - "hash": "dc684135d5a5a48e5fc7988598c1e6e0de76cf1c5995f1c283fcf63d0eb4d24f", - "action": "add" - } - }, - "AnswerProjectPoll": { - "1": { - "version": 1, - "hash": "c83d83a5ba6cc034d5061df200b3f1d029aa770b1e13dbef959bb1790323dc6e", - "action": "add" - } - }, - "ProjectPoll": { - "1": { - "version": 1, - "hash": "ecf69b3b324e0bee9c82295796d44c4e8f796496cdc9db6d4302c2f160566466", - "action": "add" - } - }, - "Project": { - "1": { - "version": 1, - "hash": "de86a1163ddbcd1cc3cc2b1b5dfcb85a8ad9f9d4bbc759c2b1f92a0d0a2ff184", - "action": "add" - } - }, - "ProjectSubmit": { - "1": { - "version": 1, - "hash": "7555ba11ee5a814dcd9c45647300020f7359efc1081559940990cbd745936cac", - "action": "add" - } - }, - "Plan": { - "1": { - "version": 1, - "hash": "ed05cb87aec832098fc464ac36cd6bceaab705463d0d2fa1b2d8e1ccc510018c", - "action": "add" - } - } - } + "1": { + "release_name": "0.9.1.json" } } diff --git a/packages/syft/src/syft/protocol/releases/0.9.1.json b/packages/syft/src/syft/protocol/releases/0.9.1.json new file mode 100644 index 00000000000..9c33a5d3a88 --- /dev/null +++ b/packages/syft/src/syft/protocol/releases/0.9.1.json @@ -0,0 +1,1178 @@ +{ + "1": { + "object_versions": { + "SyftObjectVersioned": { + "1": { + "version": 1, + "hash": "7c842dcdbb57e2528ffa690ea18c19fff3c8a591811d40cad2b19be3100e2ff4", + "action": "add" + } + }, + "BaseDateTime": { + "1": { + "version": 1, + "hash": "614db484b1950be729902b1861bd3a7b33899176507c61cef11dc0d44611cfd3", + "action": "add" + } + }, + "SyftObject": { + "1": { + "version": 1, + "hash": "bb70d874355988908d3a92a3941d6613a6995a4850be3b6a0147f4d387724406", + "action": "add" + } + }, + "PartialSyftObject": { + "1": { + "version": 1, + "hash": "19a995fcc2833f4fab24584fd99b71a80c2ef1f13c06f83af79e4482846b1656", + "action": "add" + } + }, + "ServerMetadata": { + "1": { + "version": 1, + "hash": "1691c7667eca86b20c4189e90ce4e643dd41fd3682cdb69c6308878f2a6f135c", + "action": "add" + } + }, + "StoreConfig": { + "1": { + "version": 1, + "hash": "a9997fce6a8a0ed2884c58b8eb9382f8554bdd18fff61f8bf0451945bcff12c7", + "action": "add" + } + }, + "MongoDict": { + "1": { + "version": 1, + "hash": "57e36f57eed75e62b29e2bac1295035a9bf2c0e3c56719dac24cb6cc685be00b", + "action": "add" + } + }, + "MongoStoreConfig": { + "1": { + "version": 1, + "hash": "53342b27d34165b7e2699f8e7ad70d13d125875e6a75e8fa18f5796428f41036", + "action": "add" + } + }, + "LinkedObject": { + "1": { + "version": 1, + "hash": "d80f5ac7f51a9383be1a3cb334d56ae50e49733ed3199f3b6b5d6febd9de410b", + "action": "add" + } + }, + "BaseConfig": { + "1": { + "version": 1, + "hash": "10bd7566041d0f0a3aa295367785fdcc2c5bbf0ded984ac9230754f37496a6a7", + "action": "add" + }, + "2": { + "version": 2, + "hash": "890d2879ac44611db9b88ba9334a721130d0ac3aa18a303fa9e4081f14b9b8c7", + "action": "add" + } + }, + "ServiceConfig": { + "1": { + "version": 1, + "hash": "28af8a296f5ff63de50438277eaa1f4380682e6aca9f2ca28320d7a444825e88", + "action": "add" + }, + "2": { + "version": 2, + "hash": "93dfab144e0b0884c602358b3a9ce889bb29ab96e3b4adcfe3cef47a31694a9a", + "action": "add" + } + }, + "LibConfig": { + "1": { + "version": 1, + "hash": "ee8f0e3f6aae81948d72e30226645e8eb5d312a6770411a1edca748168c467c0", + "action": "add" + }, + "2": { + "version": 2, + "hash": "a8a78a8d726ee9e79f95614f3d0fa5b85edc6fce7be7651715208669be93e0e3", + "action": "add" + } + }, + "APIEndpoint": { + "1": { + "version": 1, + "hash": "faa1cf9336a0d1233868c8c57745ff38c0be60399dc1acd0c0e8dd440e405dbd", + "action": "add" + } + }, + "LibEndpoint": { + "1": { + "version": 1, + "hash": "a585c83a33a019d363ae5a0c6d4197193654307c19a4829dfbf8a8cfd2c1842a", + "action": "add" + } + }, + "SignedSyftAPICall": { + "1": { + "version": 1, + "hash": "2f959455f7130f4e59360b8aa58f19785b76eaa0f8a5a9188a6cbf32b31311ca", + "action": "add" + } + }, + "SyftAPICall": { + "1": { + "version": 1, + "hash": "59e89e7b9ea30deaed64d1ffd9bc0769b999d3082b305428432c1f5be36c6343", + "action": "add" + } + }, + "SyftAPIData": { + "1": { + "version": 1, + "hash": "820b279c581cafd9bb5009702d4e3db22ec3a3156676426304b9038dad260a24", + "action": "add" + } + }, + "SyftAPI": { + "1": { + "version": 1, + "hash": "cc13ab058ee36748c14b0d4bd9b9e894c7566fff09cfa4170b3eece520169f15", + "action": "add" + } + }, + "User": { + "1": { + "version": 1, + "hash": "2df4b68182c558dba5485a8a6867acf2a5c341b249ad67373a504098aa8c4343", + "action": "add" + }, + "2": { + "version": 2, + "hash": "af6fb5b2e1606e97838f4a60f0536ad95db606d455e94acbd1977df866608a2c", + "action": "add" + } + }, + "UserUpdate": { + "1": { + "version": 1, + "hash": "1bf6707c69b809c804fb939c7c37d787c2f6889508a4bec37d24221af2eb777a", + "action": "add" + } + }, + "UserCreate": { + "1": { + "version": 1, + "hash": "49d6087e2309ba59987f3126e286e74b3a66492a08ad82fa507ea17d52ce78e3", + "action": "add" + } + }, + "UserSearch": { + "1": { + "version": 1, + "hash": "9ac946338cca68d00d1696a57943442f062628ec3daf53077d0bdd3f72cd9fa0", + "action": "add" + } + }, + "UserView": { + "1": { + "version": 1, + "hash": "0b52d758e31d5889c9cd88afb467aae4a74e34a5276924e07012243c34d300fe", + "action": "add" + } + }, + "UserViewPage": { + "1": { + "version": 1, + "hash": "1cd6528d02ec180f080d5c35f0da760d8a59af9da7baaa9c17c1c7cedcc858fa", + "action": "add" + } + }, + "UserPrivateKey": { + "1": { + "version": 1, + "hash": "4817d8147aba94373f320dcd90e65f097cf6e5a2ef353aa8520e23128d522b5d", + "action": "add" + } + }, + "DateTime": { + "1": { + "version": 1, + "hash": "394abb554114ead4d63c36e3fe83ac018dead4b21a8465174009577c46d54c58", + "action": "add" + } + }, + "ReplyNotification": { + "1": { + "version": 1, + "hash": "84102dfc59d711b03c2f3d3a6ecaca000b6835f1bbdd9af801057f7aacb5f1d0", + "action": "add" + } + }, + "Notification": { + "1": { + "version": 1, + "hash": "af4cb232bff390c431e399975f048b34da7e940ace8b23b940a3b398c91c5326", + "action": "add" + } + }, + "CreateNotification": { + "1": { + "version": 1, + "hash": "7e426c946b7d5db6f9427960ec16042f3018091d835ca5966f3568c324a2ab53", + "action": "add" + } + }, + "UserNotificationActivity": { + "1": { + "version": 1, + "hash": "422fd01c6d9af38688a9982abd34e80794a1f6ddd444cca225d77f49189847a9", + "action": "add" + } + }, + "NotificationPreferences": { + "1": { + "version": 1, + "hash": "a42f06b367e7c6cbabcbf3cfcc84d1ca0873e457d972ebd060e87c9d6185f62b", + "action": "add" + } + }, + "NotifierSettings": { + "1": { + "version": 1, + "hash": "65c8ab814d35fac32f68d3000756692592cc59940f30e3af3dcdfa2328755b9d", + "action": "add" + }, + "2": { + "version": 2, + "hash": "be8b52597fc628d1b7cd22b776ee81416e1adbb04a45188778eb0e32ed1416b4", + "action": "add" + } + }, + "SyftImageRegistry": { + "1": { + "version": 1, + "hash": "67e18903e41cba1afe136adf29d404b63ec04fea6e928abb2533ec4fa52b246b", + "action": "add" + } + }, + "SyftWorkerImage": { + "1": { + "version": 1, + "hash": "44da7badfbe573d5403d3ab78c077f17dbefc560b81fdf927b671815be047441", + "action": "add" + } + }, + "SyftWorker": { + "1": { + "version": 1, + "hash": "9d897f6039eabe48dfa8e8d5c5cdcb283b0375b4c64571b457777eaaf3fb1920", + "action": "add" + } + }, + "WorkerPool": { + "1": { + "version": 1, + "hash": "16efc5dd2596ae744fd611c8f46af9eaec1bd5729eb20e85e9fd2f31df402564", + "action": "add" + } + }, + "MarkdownDescription": { + "1": { + "version": 1, + "hash": "31a73f8824cad1636a55d14b6a1074cdb071d0d4e16e86baaa3d4f63a7e80134", + "action": "add" + } + }, + "HTMLObject": { + "1": { + "version": 1, + "hash": "97f2e93f5ceaa88015047186f66a17ff13df2a6b7925b41331f9e19d5a515a9f", + "action": "add" + } + }, + "PwdTokenResetConfig": { + "1": { + "version": 1, + "hash": "0415a272428f22add4896c64aa9f29c8c1d35619e2433da6564eb5f1faff39ac", + "action": "add" + } + }, + "ServerSettingsUpdate": { + "1": { + "version": 1, + "hash": "1e4260ad879ae80728c3ffae2cd1d48759abd51f9d0960d4b25855cdbb4c506b", + "action": "add" + }, + "2": { + "version": 2, + "hash": "23b2716e9dceca667e228408e2416c82f11821e322e5bccf1f83406f3d09abdc", + "action": "add" + }, + "3": { + "version": 3, + "hash": "335c7946f2e52d09c7b26f511120cd340717c74c5cca9107e84f839da993c55c", + "action": "add" + }, + "4": { + "version": 4, + "hash": "8d7a41992c39c287fcb46383bed429ce75d3c9524ced8c86b88c26dd0232e2fe", + "action": "add" + } + }, + "ServerSettings": { + "1": { + "version": 1, + "hash": "5a1e7470cbeaaae5b80ac9beecb743734f7e4e42d429a09ea8defa569a5ddff1", + "action": "add" + }, + "2": { + "version": 2, + "hash": "7727ea54e494dc9deaa0d1bd38ac8a6180bc192b74eec5659adbc338a19e21f5", + "action": "add" + }, + "3": { + "version": 3, + "hash": "997667e1cba22d151857aacc2caba6b1ca73c1648adbd03461dc74a0c0c372b3", + "action": "add" + }, + "4": { + "version": 4, + "hash": "b8067777967a0e06733433e179e549caaf501419d62f7e8474ee33b839e3890d", + "action": "add" + } + }, + "HTTPConnection": { + "1": { + "version": 1, + "hash": "bf10f81646c71069c76292b1237b4a3de1e507264392c5c591d067636ce6fb46", + "action": "add" + } + }, + "PythonConnection": { + "1": { + "version": 1, + "hash": "28010778b5e3463ff6960a0e2224818de00bc7b5e6f892192e02e399ccbe18b5", + "action": "add" + } + }, + "ActionDataEmpty": { + "1": { + "version": 1, + "hash": "e0e4a5cf18d05b6b747addc048515c6f2a5f35f0766ebaee96d898cb971e1c5b", + "action": "add" + } + }, + "ObjectNotReady": { + "1": { + "version": 1, + "hash": "8cf471e205cd0893d6aae5f0227d14db7df1c9698da08a3ab991f59132d17fe9", + "action": "add" + } + }, + "ActionDataLink": { + "1": { + "version": 1, + "hash": "3469478343439e411b761c270eec63eb3d533e459ad72d0965158c3a6cdf3b9a", + "action": "add" + } + }, + "Action": { + "1": { + "version": 1, + "hash": "021826d7c6f69bd0283d025d40661f3ffbeba8810ca94de01344f6afbdae62cd", + "action": "add" + } + }, + "ActionObject": { + "1": { + "version": 1, + "hash": "0a5f4bc343cb114a251f06686ecdbb59d74bfb3d29a098b176699deb35a1e683", + "action": "add" + } + }, + "AnyActionObject": { + "1": { + "version": 1, + "hash": "b3c44c7788c59c03fa1baeec656c2ca6e633f4cbd4b23ff7ece6ee94c38449f0", + "action": "add" + } + }, + "CustomEndpointActionObject": { + "1": { + "version": 1, + "hash": "c7addbaf2777707f3e91e5c1e092343476cd22efc4ec8617f39ccf76e61a5a14", + "action": "add" + }, + "2": { + "version": 2, + "hash": "846ba36e8737a1bec16853c9de54c4948450009278e0b76fe7e3355ef9e70089", + "action": "add" + } + }, + "DataSubject": { + "1": { + "version": 1, + "hash": "582cdf9e82b5d6915b7f09f7c0d5f08328b11a2ce9b0198e5083f1672c2e2bf5", + "action": "add" + } + }, + "DataSubjectCreate": { + "1": { + "version": 1, + "hash": "5a8423c2690d55f425bfeecc87cd4a797a75d88ebb5fbda754d4f269b62d2ceb", + "action": "add" + } + }, + "DataSubjectMemberRelationship": { + "1": { + "version": 1, + "hash": "0810483ea76ea10c8f286c6035dc0b2085291f345183be50c179f3a05a577110", + "action": "add" + } + }, + "Contributor": { + "1": { + "version": 1, + "hash": "30c32bd44098f00e0b15496be441763b6e50af8b12d3d2bef33aca6287193876", + "action": "add" + } + }, + "Asset": { + "1": { + "version": 1, + "hash": "000abc78719611c106295cf12b1690b7e5411dc1bb9db9d4afd22956da90d1f4", + "action": "add" + } + }, + "CreateAsset": { + "1": { + "version": 1, + "hash": "357d52576cb12b24fb3980342bb49a562b065c0e4419e87d34176340628c7309", + "action": "add" + } + }, + "Dataset": { + "1": { + "version": 1, + "hash": "0ca6b0b4a3aebb2c8f351668075b44951bb20d1e23a779b82109124f334ce3a4", + "action": "add" + } + }, + "DatasetPageView": { + "1": { + "version": 1, + "hash": "aa0dd69637281b80d5523b4409a2c7e89db114c9fe79c858063c6dadff8977d1", + "action": "add" + }, + "2": { + "version": 2, + "hash": "be1ca6dcd0b3aa0481ce5dce737e78432d06a78ad0c701aaf136be407c798352", + "action": "add" + } + }, + "CreateDataset": { + "1": { + "version": 1, + "hash": "7e02dfa89540c3dbebacbb13810d95cdc4e36db31d56cffed7ab54abe25716c9", + "action": "add" + } + }, + "SyftLog": { + "1": { + "version": 1, + "hash": "1bcd71e5bf3f0db3bba0996f33b6b2bde3489b9c71f11e6b30c3495c76a8f53f", + "action": "add" + } + }, + "JobItem": { + "1": { + "version": 1, + "hash": "0b32277b7d3b9bdc14a2a51cc9005f8254e7f7b6ec059ddcccbcd681a807afb6", + "action": "add" + }, + "2": { + "version": 2, + "hash": "b087d0c62b7d304c6ca80e4fb0e8a7f2a444be8f8cba57490dc09aeb98033105", + "action": "add" + } + }, + "ExecutionOutput": { + "1": { + "version": 1, + "hash": "e36c71685edf5276a3427cb6749550486d3a177c1dcf73dd337ab2a73c0ce6b5", + "action": "add" + } + }, + "TwinObject": { + "1": { + "version": 1, + "hash": "4f31243fb348dbb083579afd6f638d75af010cb53d19bfba59b74afff41ccbbb", + "action": "add" + } + }, + "PolicyRule": { + "1": { + "version": 1, + "hash": "44d1ca1db97be46f66558aa1a729ff31bf8e113c6a913b11aedf9d6b6ad5b7b5", + "action": "add" + } + }, + "CreatePolicyRule": { + "1": { + "version": 1, + "hash": "342bb723526d445151a0435f57d251f4c1219f8ae7cca3e8e9fce52e2ee1b8b1", + "action": "add" + } + }, + "CreatePolicyRuleConstant": { + "1": { + "version": 1, + "hash": "78b54832cb0468a87013bc36bc11d4759874ca1b5065a1b711f1e5ef5d94c2df", + "action": "add" + } + }, + "Matches": { + "1": { + "version": 1, + "hash": "dd6d91ddb2ec5eaf60be2b0899ecfdb9a15f7904aa39d2f4d9bb2d7b793040e6", + "action": "add" + } + }, + "PreFill": { + "1": { + "version": 1, + "hash": "c7aefb11dc4c4569dcd1e6988371047a32a8be1b32ad46d12adba419a19769ad", + "action": "add" + } + }, + "UserOwned": { + "1": { + "version": 1, + "hash": "c8738dc3d8c2a5ef461b85a0467c3dff53dab16b54a4d12b44b1477906aef51d", + "action": "add" + } + }, + "MixedInputPolicy": { + "1": { + "version": 1, + "hash": "37bb12d950518d9579c8ec7c4cc22ac731ea82caf8c1370dd0b0a82b46462dde", + "action": "add" + } + }, + "ExactMatch": { + "1": { + "version": 1, + "hash": "5eb37edbf5e451d942e599247f3eaed923c1fe9d91eefdba02bf06503f6cc08d", + "action": "add" + } + }, + "OutputHistory": { + "1": { + "version": 1, + "hash": "9366db79d131f8c65e5a4ff12c90e2aa0c11e302debe06e46eeb93b26e2aaf61", + "action": "add" + } + }, + "OutputPolicyExecuteCount": { + "1": { + "version": 1, + "hash": "2a77e5ed5c7b0391147562651ad4061e20b11745c191fbc34cb549da37ba72dd", + "action": "add" + } + }, + "OutputPolicyExecuteOnce": { + "1": { + "version": 1, + "hash": "5589c00d127d9eb1f5ccf3a16def8219737784d57bb3bf9be5cb6d83325ef436", + "action": "add" + } + }, + "EmptyInputPolicy": { + "1": { + "version": 1, + "hash": "7ef81cfd223be0064600e1503f8b04bafc16385e27730e9319466e68a077c68b", + "action": "add" + } + }, + "UserPolicy": { + "1": { + "version": 1, + "hash": "74373bb71a334f4dcf77623ae10ff5b1c7e5b3006f65f2051ffb1e01f422f982", + "action": "add" + } + }, + "SubmitUserPolicy": { + "1": { + "version": 1, + "hash": "ec4e808eb39613bcdbbbf9ffb3267612084a9d99880a2f3bee3ef32d46329c02", + "action": "add" + } + }, + "UserCodeStatusCollection": { + "1": { + "version": 1, + "hash": "735ecf2d4abb1e7d19b2e751d880f32b01ce267ba10e417ef1b440be3d94d8f1", + "action": "add" + } + }, + "UserCode": { + "1": { + "version": 1, + "hash": "3bcd14413b9c4fbde7c5612c2ed713518340280b5cff89cf2aaaf1c77c4037a8", + "action": "add" + } + }, + "SubmitUserCode": { + "1": { + "version": 1, + "hash": "d2bb8cfe12f070b4adafded78ce01900c5409bd83f055f94b1e285745ef65a76", + "action": "add" + } + }, + "UserCodeExecutionResult": { + "1": { + "version": 1, + "hash": "1f4cbc62caac4dd193f427306405dc7a099ae744bea5830cf57149ce71c1e589", + "action": "add" + } + }, + "UserCodeExecutionOutput": { + "1": { + "version": 1, + "hash": "c1d53300a39dbbb437d7d5a1257bd175a067b1065f4099a0938fac7540035258", + "action": "add" + }, + "2": { + "version": 2, + "hash": "3e104e39b4ab53c950e61e4f7e92ce935cf96a5100de301de9bf297eb7e5787e", + "action": "add" + } + }, + "CodeHistory": { + "1": { + "version": 1, + "hash": "e3ef5346f108257828f364d22b12d9311812c9cf843200afef5dc4d9302f9b21", + "action": "add" + } + }, + "CodeHistoryView": { + "1": { + "version": 1, + "hash": "8b8b97d334b51d1ce0a9efab722411ff25caa3f12be319105954497e0a306eb2", + "action": "add" + } + }, + "CodeHistoriesDict": { + "1": { + "version": 1, + "hash": "01d7dcd4b21525a06e4484d8699a4a34a5c84f1f6026ec55e32eb30412742601", + "action": "add" + } + }, + "UsersCodeHistoriesDict": { + "1": { + "version": 1, + "hash": "4ed8b83973258ea19a1f91feb2590ff73b801be86f4296cc3db48f6929ff784c", + "action": "add" + } + }, + "BlobFile": { + "1": { + "version": 1, + "hash": "d99239100f1cb0b73c69b2ad7cab01a06909cc3a4976ba2b3b67cf6fe5e2f516", + "action": "add" + } + }, + "BlobFileOBject": { + "1": { + "version": 1, + "hash": "6c40dab2c8d2220d4fff7cc653d76cc026a856db7e2b5713b6341e255adc7ea2", + "action": "add" + } + }, + "SecureFilePathLocation": { + "1": { + "version": 1, + "hash": "ea5978b98d7773d221665b450454c9130c103a5c850669a0acd620607cd614b7", + "action": "add" + } + }, + "SeaweedSecureFilePathLocation": { + "1": { + "version": 1, + "hash": "3fc9bfc8c1b1cf660c9747e8c1fe3eb2220e78d4e3b5d6b5c5f29a07a77ebf3e", + "action": "add" + } + }, + "AzureSecureFilePathLocation": { + "1": { + "version": 1, + "hash": "090a9e962eeb655586ee966c5651d8996363969818a38f9a486fd64d33047e05", + "action": "add" + } + }, + "BlobStorageEntry": { + "1": { + "version": 1, + "hash": "afdc6a1d8a24b1ee1ed9d3e79f5bac64b4f0d9d36800f07f10be0b896470345f", + "action": "add" + } + }, + "BlobStorageMetadata": { + "1": { + "version": 1, + "hash": "9d4b61ac4ea1910c2f7c767a50a6a52544a24663548f069e79bd906f11b538e4", + "action": "add" + } + }, + "CreateBlobStorageEntry": { + "1": { + "version": 1, + "hash": "ffc3cbfeade67d074dc5bf7d655a1eb8c83630076028a72b3cc4548f3b413e14", + "action": "add" + } + }, + "SyftObjectMigrationState": { + "1": { + "version": 1, + "hash": "ee83315828551f18904bab18e0cac48896493620561215b04cc448e6ce5834af", + "action": "add" + } + }, + "StoreMetadata": { + "1": { + "version": 1, + "hash": "8de9a22a2765ef976bc161cb0704347d30350c085da8c8ffa876065cfca3e5fd", + "action": "add" + } + }, + "MigrationData": { + "1": { + "version": 1, + "hash": "cb96b8c8413609e1224341d1b0dd1efb08387c0ff7b0ff65eba36c0b104c9ed1", + "action": "add" + }, + "2": { + "version": 2, + "hash": "1d1b14c196221ecf6d644d7dcaa32ac9e90361b2687fa83161ff399ebc6df1bd", + "action": "add" + } + }, + "BlobRetrieval": { + "1": { + "version": 1, + "hash": "c422c74b89a9349742acaa848566fe18bfef1a83333458b858c074baed37a859", + "action": "add" + } + }, + "SyftObjectRetrieval": { + "1": { + "version": 1, + "hash": "b2b62447445adc4cd0b77ab59d6fa56624dd316fb50281e570daad07556b6db2", + "action": "add" + } + }, + "BlobRetrievalByURL": { + "1": { + "version": 1, + "hash": "4db0e3b7a6334d3835356d8393866711e243e360af25a95f3cc4066f032404b5", + "action": "add" + } + }, + "BlobDeposit": { + "1": { + "version": 1, + "hash": "6eb5cc57dc763126bfc6ec5a2b79d02e77eadf9d9efb1888a5c366b7799c1c24", + "action": "add" + } + }, + "OnDiskBlobDeposit": { + "1": { + "version": 1, + "hash": "817bf1bee4a35bfa1cd25d6779a10d8d180b1b3f1e837952f81f48b9411d1970", + "action": "add" + } + }, + "RemoteConfig": { + "1": { + "version": 1, + "hash": "179d067099a178d748c6d9a0477e8de7c3b55577439669eca7150258f2409567", + "action": "add" + } + }, + "AzureRemoteConfig": { + "1": { + "version": 1, + "hash": "a143811fec0da5fd881e927643ef667c91c78a2c90519cf88da7da20738bd187", + "action": "add" + } + }, + "SeaweedFSBlobDeposit": { + "1": { + "version": 1, + "hash": "febeb2a2ce81aa2c512e4c6b611b582984042aafa0541403d4584662273a166c", + "action": "add" + } + }, + "DictStoreConfig": { + "1": { + "version": 1, + "hash": "2e1365c5535fa51c22eef79f67dd6444789bc829c27881367e3050e06e2ffbfe", + "action": "add" + } + }, + "NumpyArrayObject": { + "1": { + "version": 1, + "hash": "05dd2917b7692b3daf4e7ad083a46fa7ec7a2be8faac8d4a654809189c986443", + "action": "add" + } + }, + "NumpyScalarObject": { + "1": { + "version": 1, + "hash": "8753e5c78270a5cacbf0439447724772f4765351a4a8b58b0a5c416a6b2c8b6e", + "action": "add" + } + }, + "NumpyBoolObject": { + "1": { + "version": 1, + "hash": "331c44f8fa3d0a077f1aaad7313bae2c43b386d04def7b8bedae9fdf7690134d", + "action": "add" + } + }, + "PandasDataframeObject": { + "1": { + "version": 1, + "hash": "5e8018364cea31d5f185a901da4ab89846b02153ee7d041ee8a6d305ece31f90", + "action": "add" + } + }, + "PandasSeriesObject": { + "1": { + "version": 1, + "hash": "b8bd482bf16fc7177e9778292cd42f8835b6ced2ce8dc88908b4b8e6d7c7528f", + "action": "add" + } + }, + "Change": { + "1": { + "version": 1, + "hash": "75fb9a5cd4e76b189ebe130a421d3921a0c251947a48bbb92a2ef1c315dc3c16", + "action": "add" + } + }, + "ChangeStatus": { + "1": { + "version": 1, + "hash": "c914a6f7637b555a51b71e8e197e591f7a2e28121e29b5dd586f87e0383d179d", + "action": "add" + } + }, + "ActionStoreChange": { + "1": { + "version": 1, + "hash": "1a803bb08924b49f3114fd46e0e132f819d4d56be5e03a27e9fe90947ca26e85", + "action": "add" + } + }, + "CreateCustomImageChange": { + "1": { + "version": 1, + "hash": "c3dbea3f49979fdcc517c0d13cd02739ca2fe86b370c42496a224f142ae31562", + "action": "add" + } + }, + "CreateCustomWorkerPoolChange": { + "1": { + "version": 1, + "hash": "0355793dd58b364dcb84fff29714b6a26446bead3ba95c6d75e3200008e580f4", + "action": "add" + } + }, + "Request": { + "1": { + "version": 1, + "hash": "1d69f5f0074114f99aa29c5ee77cb20b9151e5b50e77b026f11c3632a12efadf", + "action": "add" + } + }, + "RequestInfo": { + "1": { + "version": 1, + "hash": "779562547744ebed64548f8021647292604fdf4256bf79685dfa14a1e56cc27b", + "action": "add" + } + }, + "RequestInfoFilter": { + "1": { + "version": 1, + "hash": "bb881a003032f4676321218d7cd09580f4d64fccaa1cf9e118fdcd5c73c3d3a8", + "action": "add" + } + }, + "SubmitRequest": { + "1": { + "version": 1, + "hash": "6c38b6ffd0a6f7442746e68b9ace7b21cb1dca7d2031929db5f9a302a280403f", + "action": "add" + } + }, + "ObjectMutation": { + "1": { + "version": 1, + "hash": "ce88096760ce9334599c8194ec97b0a1470651ad680d9d21b8826a0df0af2a36", + "action": "add" + } + }, + "EnumMutation": { + "1": { + "version": 1, + "hash": "5173fda73df17a344eb663b7692cca48bd46bf1773455439836b852cd165448c", + "action": "add" + } + }, + "UserCodeStatusChange": { + "1": { + "version": 1, + "hash": "89aaf7f1368c782e3a1b9e79988877f6eaa05ab84365f7d321b757fde7fe86e7", + "action": "add" + } + }, + "SyncedUserCodeStatusChange": { + "1": { + "version": 1, + "hash": "d9ad2d341eb645bd50d06330cd30fd4c266f93e37b9f5391d58b78365fc440e6", + "action": "add" + } + }, + "TwinAPIContextView": { + "1": { + "version": 1, + "hash": "e099eef32cb3a8a806cbdc54cc7fca96bed3d60344bd571163ec049db407938b", + "action": "add" + } + }, + "CustomAPIView": { + "1": { + "version": 1, + "hash": "769e96bebd05736ab860591670fb6da19406239b0104ddc71bd092a134335146", + "action": "add" + } + }, + "CustomApiEndpoint": { + "1": { + "version": 1, + "hash": "ec4a217585336d1b59c93c18570443a63f4fbb24d2c088fbacf80bcf389d23e8", + "action": "add" + } + }, + "PrivateAPIEndpoint": { + "1": { + "version": 1, + "hash": "6d7d143432c2811c520ab6dade005ba40173b590e5c676be04f5921b970ef938", + "action": "add" + } + }, + "PublicAPIEndpoint": { + "1": { + "version": 1, + "hash": "3bf51fc33aa8feb1abc9d0ef792e8889da31a57050430e0bd8e17f2065ff8734", + "action": "add" + } + }, + "UpdateTwinAPIEndpoint": { + "1": { + "version": 1, + "hash": "851e59412716e73c7f70a696619e0b375ce136b43f6fe2ea784747091caba5d8", + "action": "add" + } + }, + "CreateTwinAPIEndpoint": { + "1": { + "version": 1, + "hash": "3d0b84dae95ebcc6647b5aabe54e65b3c6bf957665fde57d8037806a4aac13be", + "action": "add" + } + }, + "TwinAPIEndpoint": { + "1": { + "version": 1, + "hash": "d1947b8f9c80d6c9b443e5a9f0758afa8849a5f12b9a511feefd7e4f82c374f4", + "action": "add" + } + }, + "SyncState": { + "1": { + "version": 1, + "hash": "9a3f0bb973858b55bc766c9770c4d9abcc817898f797d94a89938650c0c67868", + "action": "add" + } + }, + "WorkerSettings": { + "1": { + "version": 1, + "hash": "dca33003904a71688e5b07db65f8833eb4de8135aade7154076b8eafbb94d26b", + "action": "add" + } + }, + "HTTPServerRoute": { + "1": { + "version": 1, + "hash": "938245604a9c7e50001299afff5b669b2548364e356fed22a22780497831bf81", + "action": "add" + } + }, + "PythonServerRoute": { + "1": { + "version": 1, + "hash": "a068d8f942d55ecb6d45af88a27c6ebf208584275bf589cbc308df3f774ab9a9", + "action": "add" + } + }, + "VeilidServerRoute": { + "1": { + "version": 1, + "hash": "e676bc165601d2ede69707a4b6168ed4674f3f98887026d098a2dd4da4dfd097", + "action": "add" + } + }, + "ServerPeer": { + "1": { + "version": 1, + "hash": "0d5f252018e324ea0d2dcb5c2ad8bd15707220565fce4f14de7f63a8f9e4391b", + "action": "add" + } + }, + "ServerPeerUpdate": { + "1": { + "version": 1, + "hash": "0b854b57db7a18118c1fd8f31495b2ba4eeb9fbe4f24c631ff112418a94570d3", + "action": "add" + } + }, + "AssociationRequestChange": { + "1": { + "version": 1, + "hash": "0134ac0002879c85fc9ddb06bed6306a8905c8434b0a40d3a96ce24a7bd4da90", + "action": "add" + } + }, + "QueueItem": { + "1": { + "version": 1, + "hash": "1db212c46b6c56ccc5579cfe2141b693f0cd9286e2ede71210393e8455379bf1", + "action": "add" + } + }, + "ActionQueueItem": { + "1": { + "version": 1, + "hash": "396d579dfc2e2b36b9fbed2f204bffcca1bea7ee2db7175045dd3328ebf08718", + "action": "add" + } + }, + "APIEndpointQueueItem": { + "1": { + "version": 1, + "hash": "f04b3990a8d29c116d301e70df54d58f188895307a411dc13a666ff764ffd8dd", + "action": "add" + } + }, + "ZMQClientConfig": { + "1": { + "version": 1, + "hash": "36ee8f75067d5144f0ed062cdc79466caae16b7a128231d89b6b430174843bde", + "action": "add" + } + }, + "SQLiteStoreConfig": { + "1": { + "version": 1, + "hash": "ad062a5f863ae84683867d2a6a5e1d4420c010a64b88bc7b392106e33d71ac03", + "action": "add" + } + }, + "ProjectEvent": { + "1": { + "version": 1, + "hash": "dc0486c52daebd5e98c2b3b03ffd9a9a14bc3d86d8dc0c23e41ebf6c31fe2ffb", + "action": "add" + } + }, + "ProjectThreadMessage": { + "1": { + "version": 1, + "hash": "99256d7592577d1e37df94a06eabc0a287f2d79e144c51fd719315e278edb46d", + "action": "add" + } + }, + "ProjectMessage": { + "1": { + "version": 1, + "hash": "b5004b6354f71b19c81dd5f4b20bf446e0b959f5608a22707e96b944dd8175b0", + "action": "add" + } + }, + "ProjectRequestResponse": { + "1": { + "version": 1, + "hash": "52162a8a779a4a301d8755691bf4cf994c86b9f650f9e8c8a923b44e635b1bc0", + "action": "add" + } + }, + "ProjectRequest": { + "1": { + "version": 1, + "hash": "dc684135d5a5a48e5fc7988598c1e6e0de76cf1c5995f1c283fcf63d0eb4d24f", + "action": "add" + } + }, + "AnswerProjectPoll": { + "1": { + "version": 1, + "hash": "c83d83a5ba6cc034d5061df200b3f1d029aa770b1e13dbef959bb1790323dc6e", + "action": "add" + } + }, + "ProjectPoll": { + "1": { + "version": 1, + "hash": "ecf69b3b324e0bee9c82295796d44c4e8f796496cdc9db6d4302c2f160566466", + "action": "add" + } + }, + "Project": { + "1": { + "version": 1, + "hash": "de86a1163ddbcd1cc3cc2b1b5dfcb85a8ad9f9d4bbc759c2b1f92a0d0a2ff184", + "action": "add" + } + }, + "ProjectSubmit": { + "1": { + "version": 1, + "hash": "7555ba11ee5a814dcd9c45647300020f7359efc1081559940990cbd745936cac", + "action": "add" + } + }, + "Plan": { + "1": { + "version": 1, + "hash": "ed05cb87aec832098fc464ac36cd6bceaab705463d0d2fa1b2d8e1ccc510018c", + "action": "add" + } + }, + "EnclaveMetadata": { + "1": { + "version": 1, + "hash": "8d2dfafa01ec909c080a790cf15a8fc78e00382d3bfe6207098ceb25a60b9c53", + "action": "add" + } + } + } + } +} diff --git a/packages/syft/src/syft/service/notification/notifications.py b/packages/syft/src/syft/service/notification/notifications.py index 3a552df424b..e168d3083c2 100644 --- a/packages/syft/src/syft/service/notification/notifications.py +++ b/packages/syft/src/syft/service/notification/notifications.py @@ -9,6 +9,7 @@ from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.transforms import TransformContext from ...types.transforms import add_credentials_for_key @@ -48,10 +49,35 @@ class ReplyNotification(SyftObject): @serializable() -class Notification(SyftObject): +class NotificationV1(SyftObject): __canonical_name__ = "Notification" __version__ = SYFT_OBJECT_VERSION_1 + subject: str + server_uid: UID + from_user_verify_key: SyftVerifyKey + to_user_verify_key: SyftVerifyKey + created_at: DateTime + status: NotificationStatus = NotificationStatus.UNREAD + linked_obj: LinkedObject | None = None + notifier_types: list[NOTIFIERS] = [] + email_template: type[EmailTemplate] | None = None + replies: list[ReplyNotification] | None = [] + + __attr_searchable__ = [ + "from_user_verify_key", + "to_user_verify_key", + "status", + ] + __repr_attrs__ = ["subject", "status", "created_at", "linked_obj"] + __table_sort_attr__ = "Created at" + + +@serializable() +class Notification(SyftObject): + __canonical_name__ = "Notification" + __version__ = SYFT_OBJECT_VERSION_2 + subject: str server_uid: UID from_user_verify_key: SyftVerifyKey diff --git a/packages/syft/src/syft/service/worker/worker_image.py b/packages/syft/src/syft/service/worker/worker_image.py index 85eaf38d00f..99ca3ad6040 100644 --- a/packages/syft/src/syft/service/worker/worker_image.py +++ b/packages/syft/src/syft/service/worker/worker_image.py @@ -7,16 +7,42 @@ from ...server.credentials import SyftVerifyKey from ...types.datetime import DateTime from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject from ...types.uid import UID from .image_identifier import SyftWorkerImageIdentifier @serializable() -class SyftWorkerImage(SyftObject): +class SyftWorkerImageV1(SyftObject): __canonical_name__ = "SyftWorkerImage" __version__ = SYFT_OBJECT_VERSION_1 + __attr_unique__ = ["config"] + __attr_searchable__ = ["config", "image_hash", "created_by"] + + __repr_attrs__ = [ + "image_identifier", + "image_hash", + "created_at", + "built_at", + "config", + ] + + id: UID + config: WorkerConfig + created_by: SyftVerifyKey + created_at: DateTime = DateTime.now() + image_identifier: SyftWorkerImageIdentifier | None = None + image_hash: str | None = None + built_at: DateTime | None = None + + +@serializable() +class SyftWorkerImage(SyftObject): + __canonical_name__ = "SyftWorkerImage" + __version__ = SYFT_OBJECT_VERSION_2 + __attr_unique__ = ["config_hash"] __attr_searchable__ = [ "config", From 1742cf08e36857da6514e12a60b957e6f5753a2d Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 16 Sep 2024 13:09:49 +0530 Subject: [PATCH 177/197] reset queue version to 1 and remove migrations for queue and subclasses - extract db config from server - stage protocol version --- .../src/syft/protocol/protocol_version.json | 53 +++++++++++ .../syft/src/syft/server/worker_settings.py | 10 ++- .../src/syft/service/queue/queue_stash.py | 87 +------------------ 3 files changed, 64 insertions(+), 86 deletions(-) diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index 5f9f6a8fab1..e3c913dee39 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -1,5 +1,58 @@ { "1": { "release_name": "0.9.1.json" + }, + "dev": { + "object_versions": { + "Notification": { + "2": { + "version": 2, + "hash": "812d3a612422fb1cf53caa13ec34a7bdfcf033a7c24b7518f527af144cb45f3c", + "action": "add" + } + }, + "SyftWorkerImage": { + "2": { + "version": 2, + "hash": "afd3a69719cd6d08b1121676ca8d80ca37be96ee5ed5893dc73733fbf47fd035", + "action": "add" + } + }, + "WorkerSettings": { + "2": { + "version": 2, + "hash": "91c375dd40d06c81fc6403751ee48cbc94b9877f91e65a7e302303218dfe71fa", + "action": "add" + } + }, + "MongoDict": { + "1": { + "version": 1, + "hash": "57e36f57eed75e62b29e2bac1295035a9bf2c0e3c56719dac24cb6cc685be00b", + "action": "remove" + } + }, + "MongoStoreConfig": { + "1": { + "version": 1, + "hash": "53342b27d34165b7e2699f8e7ad70d13d125875e6a75e8fa18f5796428f41036", + "action": "remove" + } + }, + "JobItem": { + "1": { + "version": 1, + "hash": "0b32277b7d3b9bdc14a2a51cc9005f8254e7f7b6ec059ddcccbcd681a807afb6", + "action": "remove" + } + }, + "DictStoreConfig": { + "1": { + "version": 1, + "hash": "2e1365c5535fa51c22eef79f67dd6444789bc829c27881367e3050e06e2ffbfe", + "action": "remove" + } + } + } } } diff --git a/packages/syft/src/syft/server/worker_settings.py b/packages/syft/src/syft/server/worker_settings.py index 79241d9d8dc..cd75c9f1fdb 100644 --- a/packages/syft/src/syft/server/worker_settings.py +++ b/packages/syft/src/syft/server/worker_settings.py @@ -22,8 +22,8 @@ from ..types.syft_object import SYFT_OBJECT_VERSION_1 from ..types.syft_object import SYFT_OBJECT_VERSION_2 from ..types.syft_object import SyftObject +from ..types.transforms import TransformContext from ..types.transforms import drop -from ..types.transforms import make_set_default from ..types.uid import UID @@ -78,10 +78,16 @@ class WorkerSettingsV1(SyftObject): log_level: int | None = None +def set_db_config(context: TransformContext) -> TransformContext: + if context.output and context.output["db_config"] is None: + context.output["db_config"] = context.server.db_config + return context + + @migrate(WorkerSettingsV1, WorkerSettings) def migrate_workersettings_v1_to_v2() -> list[Callable]: return [ drop("document_store_config"), drop("action_store_config"), - make_set_default("db_config", DBConfig()), + set_db_config, ] diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index eda247b6139..5b2ef9f3318 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -1,5 +1,4 @@ # stdlib -from collections.abc import Callable from enum import Enum from typing import Any @@ -7,7 +6,6 @@ from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...server.worker_settings import WorkerSettings -from ...server.worker_settings import WorkerSettingsV1 from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey from ...store.document_store_errors import NotFoundException @@ -15,11 +13,8 @@ from ...store.linked_obj import LinkedObject from ...types.errors import SyftException from ...types.result import as_result -from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_1 -from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject -from ...types.transforms import TransformContext from ...types.uid import UID from ..action.action_permissions import ActionObjectPermission @@ -40,7 +35,7 @@ class Status(str, Enum): @serializable() class QueueItem(SyftObject): __canonical_name__ = "QueueItem" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_1 __attr_searchable__ = ["status", "worker_pool_id"] @@ -83,7 +78,7 @@ def action(self) -> Any: @serializable() class ActionQueueItem(QueueItem): __canonical_name__ = "ActionQueueItem" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_1 method: str = "execute" service: str = "actionservice" @@ -92,7 +87,7 @@ class ActionQueueItem(QueueItem): @serializable() class APIEndpointQueueItem(QueueItem): __canonical_name__ = "APIEndpointQueueItem" - __version__ = SYFT_OBJECT_VERSION_2 + __version__ = SYFT_OBJECT_VERSION_1 method: str service: str = "apiservice" @@ -167,79 +162,3 @@ def _get_by_worker_pool( credentials=credentials, filters={"worker_pool_id": worker_pool_id}, ).unwrap() - - -@serializable() -class QueueItemV1(SyftObject): - __canonical_name__ = "QueueItem" - __version__ = SYFT_OBJECT_VERSION_1 - - __attr_searchable__ = ["status", "worker_pool_id"] - - id: UID - server_uid: UID - result: Any | None = None - resolved: bool = False - status: Status = Status.CREATED - - method: str - service: str - args: list - kwargs: dict[str, Any] - job_id: UID | None = None - worker_settings: WorkerSettingsV1 | None = None - has_execute_permissions: bool = False - worker_pool: LinkedObject - - -@serializable() -class ActionQueueItemV1(QueueItemV1): - __canonical_name__ = "ActionQueueItem" - __version__ = SYFT_OBJECT_VERSION_1 - - method: str = "execute" - service: str = "actionservice" - - -@serializable() -class APIEndpointQueueItemV1(QueueItemV1): - __canonical_name__ = "APIEndpointQueueItem" - __version__ = SYFT_OBJECT_VERSION_1 - - method: str - service: str = "apiservice" - - -def migrate_worker_settings_v1_to_v2(context: TransformContext) -> TransformContext: - if context.output is None: - return context - - worker_settings_old: WorkerSettingsV1 | None = context.output.get( - "worker_settings", None - ) - if worker_settings_old is None: - return context - - if not isinstance(worker_settings_old, WorkerSettingsV1): - raise ValueError( - f"Expected WorkerSettingsV1, but got {type(worker_settings_old)}" - ) - worker_settings = worker_settings_old.migrate_to(WorkerSettings.__version__) - context.output["worker_settings"] = worker_settings - - return context - - -@migrate(QueueItemV1, QueueItem) -def migrate_queue_item_v1_to_v2() -> list[Callable]: - return [migrate_worker_settings_v1_to_v2] - - -@migrate(ActionQueueItemV1, ActionQueueItem) -def migrate_action_queue_item_v1_to_v2() -> list[Callable]: - return migrate_queue_item_v1_to_v2() - - -@migrate(APIEndpointQueueItemV1, APIEndpointQueueItem) -def migrate_api_endpoint_queue_item_v1_to_v2() -> list[Callable]: - return migrate_queue_item_v1_to_v2() From cbd320d8e7ffbdd8eba9dddb7284f460f9683e32 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 16 Sep 2024 13:25:03 +0530 Subject: [PATCH 178/197] fix lint --- packages/syft/src/syft/abstract_server.py | 2 ++ packages/syft/src/syft/server/worker_settings.py | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/abstract_server.py b/packages/syft/src/syft/abstract_server.py index c222cf4ea5a..3b7885f0a0e 100644 --- a/packages/syft/src/syft/abstract_server.py +++ b/packages/syft/src/syft/abstract_server.py @@ -5,6 +5,7 @@ # relative from .serde.serializable import serializable +from .store.db.db import DBConfig from .types.uid import UID if TYPE_CHECKING: @@ -41,6 +42,7 @@ class AbstractServer: server_side_type: ServerSideType | None in_memory_workers: bool services: "ServiceRegistry" + db_config: DBConfig def get_service(self, path_or_func: str | Callable) -> "AbstractService": raise NotImplementedError diff --git a/packages/syft/src/syft/server/worker_settings.py b/packages/syft/src/syft/server/worker_settings.py index cd75c9f1fdb..b7f28aa878b 100644 --- a/packages/syft/src/syft/server/worker_settings.py +++ b/packages/syft/src/syft/server/worker_settings.py @@ -80,7 +80,9 @@ class WorkerSettingsV1(SyftObject): def set_db_config(context: TransformContext) -> TransformContext: if context.output and context.output["db_config"] is None: - context.output["db_config"] = context.server.db_config + context.output["db_config"] = ( + context.server.db_config if context.server is not None else DBConfig() + ) return context From 74f0edb777223afe42a9d6f97aae0c61a3aaa449 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 16 Sep 2024 13:58:45 +0530 Subject: [PATCH 179/197] fix arg to refer to canonical name if exists - redefine migrations for queue items and subclasses --- .../syft/src/syft/protocol/data_protocol.py | 8 +- .../src/syft/protocol/protocol_version.json | 21 +++++ .../src/syft/service/queue/queue_stash.py | 81 ++++++++++++++++++- 3 files changed, 107 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/protocol/data_protocol.py b/packages/syft/src/syft/protocol/data_protocol.py index 1ea9d1ae203..0c848585119 100644 --- a/packages/syft/src/syft/protocol/data_protocol.py +++ b/packages/syft/src/syft/protocol/data_protocol.py @@ -81,8 +81,14 @@ def handle_annotation_repr_(annotation: type) -> str: """Handle typing representation.""" origin = typing.get_origin(annotation) args = typing.get_args(annotation) + + def get_annotation_repr_for_arg(arg: type) -> str: + if hasattr(arg, "__canonical_name__"): + return arg.__canonical_name__ + return getattr(arg, "__name__", str(arg)) + if origin and args: - args_repr = ", ".join(getattr(arg, "__name__", str(arg)) for arg in args) + args_repr = ", ".join(get_annotation_repr_for_arg(arg) for arg in args) origin_repr = getattr(origin, "__name__", str(origin)) # Handle typing.Union and types.UnionType diff --git a/packages/syft/src/syft/protocol/protocol_version.json b/packages/syft/src/syft/protocol/protocol_version.json index e3c913dee39..326aa7e2c82 100644 --- a/packages/syft/src/syft/protocol/protocol_version.json +++ b/packages/syft/src/syft/protocol/protocol_version.json @@ -52,6 +52,27 @@ "hash": "2e1365c5535fa51c22eef79f67dd6444789bc829c27881367e3050e06e2ffbfe", "action": "remove" } + }, + "QueueItem": { + "2": { + "version": 2, + "hash": "1d8615f6daabcd2a285b2f36fd7bef1df76cdd119dd49c02069c50fd1b9c3ff4", + "action": "add" + } + }, + "ActionQueueItem": { + "2": { + "version": 2, + "hash": "bfda6ef87e4045d663324bb91a215ea06e1f173aec1fb4d9ddd337cdc1f0787f", + "action": "add" + } + }, + "APIEndpointQueueItem": { + "2": { + "version": 2, + "hash": "3a46370205152fa23a7d2bfa47130dbf2e2bc7ef31f6d3fe4c92fd8d683770b5", + "action": "add" + } } } } diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index 5b2ef9f3318..d79cd3a35e1 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -1,4 +1,5 @@ # stdlib +from collections.abc import Callable from enum import Enum from typing import Any @@ -6,6 +7,7 @@ from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...server.worker_settings import WorkerSettings +from ...server.worker_settings import WorkerSettingsV1 from ...store.db.stash import ObjectStash from ...store.document_store import PartitionKey from ...store.document_store_errors import NotFoundException @@ -13,11 +15,16 @@ from ...store.linked_obj import LinkedObject from ...types.errors import SyftException from ...types.result import as_result +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_1 +from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject +from ...types.transforms import TransformContext from ...types.uid import UID from ..action.action_permissions import ActionObjectPermission +__all__ = ["QueueItem"] + @serializable(canonical_name="Status", version=1) class Status(str, Enum): @@ -33,10 +40,33 @@ class Status(str, Enum): @serializable() -class QueueItem(SyftObject): +class QueueItemV1(SyftObject): __canonical_name__ = "QueueItem" __version__ = SYFT_OBJECT_VERSION_1 + __attr_searchable__ = ["status", "worker_pool"] + + id: UID + server_uid: UID + result: Any | None = None + resolved: bool = False + status: Status = Status.CREATED + + method: str + service: str + args: list + kwargs: dict[str, Any] + job_id: UID | None = None + worker_settings: WorkerSettingsV1 | None = None + has_execute_permissions: bool = False + worker_pool: LinkedObject + + +@serializable() +class QueueItem(SyftObject): + __canonical_name__ = "QueueItem" + __version__ = SYFT_OBJECT_VERSION_2 + __attr_searchable__ = ["status", "worker_pool_id"] id: UID @@ -77,6 +107,24 @@ def action(self) -> Any: @serializable() class ActionQueueItem(QueueItem): + __canonical_name__ = "ActionQueueItem" + __version__ = SYFT_OBJECT_VERSION_2 + + method: str = "execute" + service: str = "actionservice" + + +@serializable() +class APIEndpointQueueItemV1(QueueItem): + __canonical_name__ = "APIEndpointQueueItem" + __version__ = SYFT_OBJECT_VERSION_2 + + method: str + service: str = "apiservice" + + +@serializable() +class ActionQueueItemV1(QueueItemV1): __canonical_name__ = "ActionQueueItem" __version__ = SYFT_OBJECT_VERSION_1 @@ -85,7 +133,7 @@ class ActionQueueItem(QueueItem): @serializable() -class APIEndpointQueueItem(QueueItem): +class APIEndpointQueueItem(QueueItemV1): __canonical_name__ = "APIEndpointQueueItem" __version__ = SYFT_OBJECT_VERSION_1 @@ -162,3 +210,32 @@ def _get_by_worker_pool( credentials=credentials, filters={"worker_pool_id": worker_pool_id}, ).unwrap() + + +def upgrade_worker_settings_for_queue(context: TransformContext) -> TransformContext: + if context.output and context.output["worker_settings"] is None: + worker_settings_old: WorkerSettingsV1 | None = context.output["worker_settings"] + if worker_settings_old is None: + return context + + worker_settings = worker_settings_old.migrate_to( + WorkerSettings.__version__, context=context.to_server_context() + ) + context.output["worker_settings"] = worker_settings + + return context + + +@migrate(QueueItemV1, QueueItem) +def migrate_queue_item_from_v1_to_v2() -> list[Callable]: + return [upgrade_worker_settings_for_queue] + + +@migrate(ActionQueueItemV1, ActionQueueItem) +def migrate_action_queue_item_v1_to_v2() -> list[Callable]: + return [upgrade_worker_settings_for_queue] + + +@migrate(APIEndpointQueueItemV1, APIEndpointQueueItem) +def migrate_api_endpoint_queue_item_v1_to_v2() -> list[Callable]: + return [upgrade_worker_settings_for_queue] From 3b3023d3aa3f3a9914b276a0663f28c4feebe0a0 Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 16 Sep 2024 15:31:03 +0700 Subject: [PATCH 180/197] [clean] remove `store_client_config` arg using for debug --- packages/syft/src/syft/orchestra.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/packages/syft/src/syft/orchestra.py b/packages/syft/src/syft/orchestra.py index 3497717c372..efed6023ab8 100644 --- a/packages/syft/src/syft/orchestra.py +++ b/packages/syft/src/syft/orchestra.py @@ -183,7 +183,6 @@ def deploy_to_python( log_level: str | int | None = None, debug: bool = False, migrate: bool = False, - store_client_config: dict | None = None, consumer_type: ConsumerType | None = None, db_url: str | None = None, ) -> ServerHandle: @@ -217,7 +216,6 @@ def deploy_to_python( "debug": debug, "migrate": migrate, "deployment_type": deployment_type_enum, - "store_client_config": store_client_config, "consumer_type": consumer_type, "db_url": db_url, } @@ -331,7 +329,6 @@ def launch( background_tasks: bool = False, debug: bool = False, migrate: bool = False, - store_client_config: dict | None = None, from_state_folder: str | Path | None = None, consumer_type: ConsumerType | None = None, db_url: str | None = None, @@ -383,7 +380,6 @@ def launch( background_tasks=background_tasks, debug=debug, migrate=migrate, - store_client_config=store_client_config, consumer_type=consumer_type, db_url=db_url, ) From b93f7d7783543ac86049f6ef343d79a5d2f2bfc6 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 16 Sep 2024 14:10:36 +0530 Subject: [PATCH 181/197] fix migration logic for worker settings - add migration method for notification class --- packages/syft/src/syft/server/worker_settings.py | 2 +- .../syft/src/syft/service/notification/notifications.py | 6 ++++++ packages/syft/src/syft/service/queue/queue_stash.py | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/server/worker_settings.py b/packages/syft/src/syft/server/worker_settings.py index b7f28aa878b..3e10cc7d5fa 100644 --- a/packages/syft/src/syft/server/worker_settings.py +++ b/packages/syft/src/syft/server/worker_settings.py @@ -79,7 +79,7 @@ class WorkerSettingsV1(SyftObject): def set_db_config(context: TransformContext) -> TransformContext: - if context.output and context.output["db_config"] is None: + if context.output: context.output["db_config"] = ( context.server.db_config if context.server is not None else DBConfig() ) diff --git a/packages/syft/src/syft/service/notification/notifications.py b/packages/syft/src/syft/service/notification/notifications.py index e168d3083c2..3fbddf6eb98 100644 --- a/packages/syft/src/syft/service/notification/notifications.py +++ b/packages/syft/src/syft/service/notification/notifications.py @@ -8,6 +8,7 @@ from ...server.credentials import SyftVerifyKey from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime +from ...types.syft_migration import migrate from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SYFT_OBJECT_VERSION_2 from ...types.syft_object import SyftObject @@ -172,3 +173,8 @@ def createnotification_to_notification() -> list[Callable]: add_credentials_for_key("from_user_verify_key"), add_server_uid_for_key("server_uid"), ] + + +@migrate(NotificationV1, Notification) +def migrate_nofitication_v1_to_v2() -> list[Callable]: + return [] # skip migration, no changes in the class diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index d79cd3a35e1..3c33f8a8a78 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -213,7 +213,7 @@ def _get_by_worker_pool( def upgrade_worker_settings_for_queue(context: TransformContext) -> TransformContext: - if context.output and context.output["worker_settings"] is None: + if context.output and context.output["worker_settings"] is not None: worker_settings_old: WorkerSettingsV1 | None = context.output["worker_settings"] if worker_settings_old is None: return context From 34a9ca063dc02a56d409335e5b25d5cb1652a74d Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 16 Sep 2024 10:43:58 +0200 Subject: [PATCH 182/197] fix partial updates --- packages/syft/src/syft/store/db/stash.py | 3 ++- packages/syft/tests/conftest.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 90dfcdcf254..ca5b319ae29 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -36,6 +36,7 @@ from ...types.errors import SyftException from ...types.result import as_result from ...types.syft_metaclass import Empty +from ...types.syft_object import PartialSyftObject from ...types.syft_object import SyftObject from ...types.uid import UID from ...util.telemetry import instrument @@ -446,7 +447,7 @@ def update( - To fix, we either need db-supported computed fields, or know in our ORM which fields should be re-computed. """ - if not isinstance(obj, self.object_type): + if issubclass(type(obj), PartialSyftObject): original_obj = self.get_by_uid( credentials, obj.id, session=session ).unwrap() diff --git a/packages/syft/tests/conftest.py b/packages/syft/tests/conftest.py index eacf21eb616..bf555ebce80 100644 --- a/packages/syft/tests/conftest.py +++ b/packages/syft/tests/conftest.py @@ -271,7 +271,7 @@ def big_dataset() -> Dataset: scope="function", params=[ "tODOsqlite_address", - "TODOpostgres_address", + # "TODOpostgres_address", # will be used when we have a postgres CI tests ], ) def queue_stash(request): From 48668e4d7fe3754318562f107719f8cfe9366d38 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 16 Sep 2024 14:18:11 +0530 Subject: [PATCH 183/197] fix return statement in queue stash migration --- packages/syft/src/syft/service/queue/queue_stash.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index 3c33f8a8a78..d0c24400ee1 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -223,7 +223,7 @@ def upgrade_worker_settings_for_queue(context: TransformContext) -> TransformCon ) context.output["worker_settings"] = worker_settings - return context + return context @migrate(QueueItemV1, QueueItem) From ef4d05f33f0ec5ab46a8b1774d636bb3aa1bc437 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 16 Sep 2024 10:53:31 +0200 Subject: [PATCH 184/197] rename create_root_admin --- packages/syft/src/syft/server/server.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 438997b1ad2..f90cd0cab4d 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -426,7 +426,7 @@ def __init__( self.db.init_tables() self.action_store = self.services.action.stash - create_root_admin( + create_root_admin_if_not_exists( name=root_username, email=root_email, password=root_password, # nosec @@ -1722,7 +1722,7 @@ def create_initial_settings(self, admin_email: str) -> ServerSettings: ).unwrap() -def create_root_admin( +def create_root_admin_if_not_exists( name: str, email: str, password: str, From 8c1c5fc35298ca5c98e9646d157ffe938e9d269b Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 16 Sep 2024 10:56:00 +0200 Subject: [PATCH 185/197] rename searchable attrs serde --- packages/syft/src/syft/serde/json_serde.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index a150d3c186e..2386d72b138 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -180,7 +180,7 @@ def _serialize_pydantic_to_json(obj: pydantic.BaseModel) -> dict[str, Json]: continue result[key] = serialize_json(getattr(obj, key), type_.annotation) - result = _serialize_searchable_attrs(obj, result, raise_errors=False) + result = _add_searchable_and_unique_attrs(obj, result, raise_errors=False) return result @@ -198,11 +198,11 @@ def get_property_return_type(obj: Any, attr_name: str) -> Any: return None -def _serialize_searchable_attrs( +def _add_searchable_and_unique_attrs( obj: pydantic.BaseModel, obj_dict: dict[str, Json], raise_errors: bool = True ) -> dict[str, Json]: """ - Add searchable attrs and unique attrs to the serialized object dict, if they are not already present. + Add searchable attrs and unique attrs to the serialized object dict, if they are not already present. Needed for adding non-field attributes (like @property) Args: From 1040fa59913237aab92f6f91af091ef848885839 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Mon, 16 Sep 2024 11:03:26 +0200 Subject: [PATCH 186/197] jsonserde default noop --- packages/syft/src/syft/serde/json_serde.py | 31 +++++++++------------- 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/packages/syft/src/syft/serde/json_serde.py b/packages/syft/src/syft/serde/json_serde.py index 2386d72b138..ee86241716c 100644 --- a/packages/syft/src/syft/serde/json_serde.py +++ b/packages/syft/src/syft/serde/json_serde.py @@ -22,7 +22,6 @@ from ..server.credentials import SyftSigningKey from ..server.credentials import SyftVerifyKey from ..types.datetime import DateTime -from ..types.errors import SyftException from ..types.syft_object import BaseDateTime from ..types.syft_object_registry import SyftObjectRegistry from ..types.uid import LineageID @@ -39,31 +38,21 @@ Json = JsonPrimitive | list["Json"] | dict[str, "Json"] -class JSONSerdeError(SyftException): - pass +def _noop_fn(obj: Any) -> Any: + return obj @dataclass class JSONSerde(Generic[T]): klass: type[T] - serialize_fn: Callable[[T], Json] | None = None - deserialize_fn: Callable[[Json], T] | None = None - - def _check_type(self, obj: Any) -> None: - if not isinstance(obj, self.klass): - raise JSONSerdeError(f"Expected {self.klass}, got {type(obj)}") + serialize_fn: Callable[[T], Json] + deserialize_fn: Callable[[Json], T] def serialize(self, obj: T) -> Json: - if self.serialize_fn is None: - return obj # type: ignore - else: - return self.serialize_fn(obj) + return self.serialize_fn(obj) def deserialize(self, obj: Json) -> T: - if self.deserialize_fn is None: - return obj # type: ignore - else: - return self.deserialize_fn(obj) # type: ignore + return self.deserialize_fn(obj) JSON_SERDE_REGISTRY: dict[type[T], JSONSerde[T]] = {} @@ -77,7 +66,13 @@ def register_json_serde( if type_ in JSON_SERDE_REGISTRY: raise ValueError(f"Type {type_} is already registered") - JSON_SERDE_REGISTRY[(type_)] = JSONSerde( + if serialize is None: + serialize = _noop_fn + + if deserialize is None: + deserialize = _noop_fn + + JSON_SERDE_REGISTRY[type_] = JSONSerde( klass=type_, serialize_fn=serialize, deserialize_fn=deserialize, From cc8e9762d75e300ce522daf007b40db65c4eacee Mon Sep 17 00:00:00 2001 From: khoaguin Date: Mon, 16 Sep 2024 16:08:46 +0700 Subject: [PATCH 187/197] change sqlite db path in container mode and postgresql mount path in k8s to `tmp/data/db` Co-authored-by: Shubham Gupta --- packages/grid/backend/grid/core/config.py | 2 +- packages/grid/backend/grid/core/server.py | 6 ++++++ .../helm/syft/templates/postgres/postgres-statefuleset.yaml | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/packages/grid/backend/grid/core/config.py b/packages/grid/backend/grid/core/config.py index 91e18b46c47..e92d6783ae7 100644 --- a/packages/grid/backend/grid/core/config.py +++ b/packages/grid/backend/grid/core/config.py @@ -138,7 +138,7 @@ def get_emails_enabled(self) -> Self: True if os.getenv("CREATE_PRODUCER", "false").lower() == "true" else False ) N_CONSUMERS: int = int(os.getenv("N_CONSUMERS", 1)) - SQLITE_PATH: str = os.path.expandvars("/tmp/syft/") + SQLITE_PATH: str = os.path.expandvars("/tmp/data/db") SINGLE_CONTAINER_MODE: bool = str_to_bool(os.getenv("SINGLE_CONTAINER_MODE", False)) CONSUMER_SERVICE_NAME: str | None = os.getenv("CONSUMER_SERVICE_NAME") INMEMORY_WORKERS: bool = str_to_bool(os.getenv("INMEMORY_WORKERS", True)) diff --git a/packages/grid/backend/grid/core/server.py b/packages/grid/backend/grid/core/server.py index d337b0357dc..7d8d011de5d 100644 --- a/packages/grid/backend/grid/core/server.py +++ b/packages/grid/backend/grid/core/server.py @@ -1,4 +1,5 @@ # stdlib +from pathlib import Path # syft absolute from syft.abstract_server import ServerType @@ -37,6 +38,11 @@ def queue_config() -> ZMQQueueConfig: def sql_store_config() -> SQLiteDBConfig: + # Check if the directory exists, and create it if it doesn't + sqlite_path = Path(settings.SQLITE_PATH) + if not sqlite_path.exists(): + sqlite_path.mkdir(parents=True, exist_ok=True) + return SQLiteDBConfig( filename=f"{UID.from_string(get_server_uid_env())}.sqlite", path=settings.SQLITE_PATH, diff --git a/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml b/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml index db9416bbdc3..986031b17e9 100644 --- a/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml +++ b/packages/grid/helm/syft/templates/postgres/postgres-statefuleset.yaml @@ -49,7 +49,7 @@ spec: {{- toYaml .Values.postgres.env | nindent 12 }} {{- end }} volumeMounts: - - mountPath: /data/db + - mountPath: tmp/data/db name: postgres-data readOnly: false subPath: '' From a6ea3d54f87dc5ea8819ecb9f515805687b59586 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 16 Sep 2024 15:48:12 +0530 Subject: [PATCH 188/197] fix class name for APIEndpointQueueItem version 1 --- packages/syft/src/syft/service/queue/queue_stash.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index d0c24400ee1..aa5b872b226 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -115,7 +115,7 @@ class ActionQueueItem(QueueItem): @serializable() -class APIEndpointQueueItemV1(QueueItem): +class APIEndpointQueueItem(QueueItem): __canonical_name__ = "APIEndpointQueueItem" __version__ = SYFT_OBJECT_VERSION_2 @@ -133,7 +133,7 @@ class ActionQueueItemV1(QueueItemV1): @serializable() -class APIEndpointQueueItem(QueueItemV1): +class APIEndpointQueueItemV1(QueueItemV1): __canonical_name__ = "APIEndpointQueueItem" __version__ = SYFT_OBJECT_VERSION_1 From 7a6458883145d1efd971513921cb83a4e4ca8114 Mon Sep 17 00:00:00 2001 From: Shubham Gupta Date: Mon, 16 Sep 2024 16:21:22 +0530 Subject: [PATCH 189/197] add sqlalchemy opentelemetry intrumentation package to pypi --- packages/syft/setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/packages/syft/setup.cfg b/packages/syft/setup.cfg index d0aa8276924..d4ee2cff521 100644 --- a/packages/syft/setup.cfg +++ b/packages/syft/setup.cfg @@ -116,6 +116,7 @@ telemetry = opentelemetry-instrumentation-fastapi==0.48b0 opentelemetry-instrumentation-botocore==0.48b0 opentelemetry-instrumentation-logging==0.48b0 + opentelemetry-instrumentation-sqlalchemy==0.48b0 ; opentelemetry-instrumentation-asyncio==0.48b0 ; opentelemetry-instrumentation-sqlite3==0.48b0 ; opentelemetry-instrumentation-threading==0.48b0 From 5b4c63a29912fa63af38cf86ae18d071e8cb2b69 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Mon, 16 Sep 2024 13:56:14 +0200 Subject: [PATCH 190/197] refactor db reset logic --- packages/syft/src/syft/server/server.py | 4 +--- packages/syft/src/syft/store/db/db.py | 18 +++++------------- packages/syft/src/syft/store/db/stash.py | 8 -------- packages/syft/tests/syft/worker_test.py | 2 +- 4 files changed, 7 insertions(+), 25 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 438997b1ad2..684fc35073b 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -421,9 +421,7 @@ def __init__( # construct services only after init stores self.services: ServiceRegistry = ServiceRegistry.for_server(self) - if reset: - self.db.reset() - self.db.init_tables() + self.db.init_tables(reset=reset) self.action_store = self.services.action.stash create_root_admin( diff --git a/packages/syft/src/syft/store/db/db.py b/packages/syft/src/syft/store/db/db.py index d2adca212c8..cc82e5a3f4e 100644 --- a/packages/syft/src/syft/store/db/db.py +++ b/packages/syft/src/syft/store/db/db.py @@ -21,8 +21,6 @@ @serializable(canonical_name="DBConfig", version=1) class DBConfig(BaseModel): - reset: bool = False - @property def connection_string(self) -> str: raise NotImplementedError("Subclasses must implement this method.") @@ -74,16 +72,10 @@ def __init__( def update_settings(self) -> None: pass - def init_tables(self) -> None: + def init_tables(self, reset: bool = False) -> None: Base = SQLiteBase if self.engine.dialect.name == "sqlite" else PostgresBase - if self.config.reset: - # drop all tables that we know about - Base.metadata.drop_all(bind=self.engine) - self.config.reset = False - Base.metadata.create_all(self.engine) - - def reset(self) -> None: - Base = SQLiteBase if self.engine.dialect.name == "sqlite" else PostgresBase - Base.metadata.drop_all(bind=self.engine) - Base.metadata.create_all(self.engine) + with self.sessionmaker().begin() as _: + if reset: + Base.metadata.drop_all(bind=self.engine) + Base.metadata.create_all(self.engine) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index ca5b319ae29..22259f2e65a 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -164,14 +164,6 @@ def check_type(self, obj: T, type_: type) -> T: def session(self) -> Session: return self.db.session - def _drop_table(self) -> None: - table_name = self.object_type.__canonical_name__ - Base = SQLiteBase if self._is_sqlite() else PostgresBase - if table_name in Base.metadata.tables: - Base.metadata.tables[table_name].drop(self.db.engine) - else: - raise StashException(f"Table {table_name} does not exist") - def _print_query(self, stmt: sa.sql.select) -> None: print( stmt.compile( diff --git a/packages/syft/tests/syft/worker_test.py b/packages/syft/tests/syft/worker_test.py index 1c92da156d5..f52772038cf 100644 --- a/packages/syft/tests/syft/worker_test.py +++ b/packages/syft/tests/syft/worker_test.py @@ -80,7 +80,7 @@ def test_signing_key() -> None: scope="function", params=[ "tODOsqlite_address", - "TODOpostgres_address", + # "TODOpostgres_address", # will be used when we have a postgres CI tests ], ) def action_object_stash() -> ActionObjectStash: From 8a47d997a4f64569871235c5ef9e00f832104349 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 17 Sep 2024 08:32:12 +0200 Subject: [PATCH 191/197] use in memory sqlite --- packages/syft/src/syft/store/db/sqlite.py | 3 +++ packages/syft/tests/conftest.py | 12 ++++++++---- .../tests/syft/blob_storage/blob_storage_test.py | 16 +++------------- .../tests/syft/migrations/data_migration_test.py | 2 +- .../tests/syft/settings/settings_service_test.py | 6 +++--- .../syft/tests/syft/users/user_service_test.py | 8 ++++---- 6 files changed, 22 insertions(+), 25 deletions(-) diff --git a/packages/syft/src/syft/store/db/sqlite.py b/packages/syft/src/syft/store/db/sqlite.py index 485a5f7d64b..e9b3468efa3 100644 --- a/packages/syft/src/syft/store/db/sqlite.py +++ b/packages/syft/src/syft/store/db/sqlite.py @@ -23,6 +23,9 @@ class SQLiteDBConfig(DBConfig): @property def connection_string(self) -> str: + if self.path == Path("."): + # Use in-memory database + return "sqlite://" filepath = self.path / self.filename return f"sqlite:///{filepath.resolve()}" diff --git a/packages/syft/tests/conftest.py b/packages/syft/tests/conftest.py index bf555ebce80..2b45201d490 100644 --- a/packages/syft/tests/conftest.py +++ b/packages/syft/tests/conftest.py @@ -7,6 +7,7 @@ import sys from tempfile import gettempdir from unittest import mock +from uuid import uuid4 # third party from faker import Faker @@ -115,7 +116,7 @@ def faker(): @pytest.fixture(scope="function") def worker() -> Worker: - worker = sy.Worker.named(name=token_hex(16)) + worker = sy.Worker.named(name=token_hex(16), db_url="sqlite://") yield worker worker.cleanup() del worker @@ -124,7 +125,7 @@ def worker() -> Worker: @pytest.fixture(scope="function") def second_worker() -> Worker: # Used in server syncing tests - worker = sy.Worker.named(name=token_hex(16)) + worker = sy.Worker.named(name=uuid4().hex, db_url="sqlite://") yield worker worker.cleanup() del worker @@ -133,7 +134,7 @@ def second_worker() -> Worker: @pytest.fixture(scope="function") def high_worker() -> Worker: worker = sy.Worker.named( - name=token_hex(8), server_side_type=ServerSideType.HIGH_SIDE + name=token_hex(8), server_side_type=ServerSideType.HIGH_SIDE, db_url="sqlite://" ) yield worker worker.cleanup() @@ -143,7 +144,10 @@ def high_worker() -> Worker: @pytest.fixture(scope="function") def low_worker() -> Worker: worker = sy.Worker.named( - name=token_hex(8), server_side_type=ServerSideType.LOW_SIDE, dev_mode=True + name=token_hex(8), + server_side_type=ServerSideType.LOW_SIDE, + dev_mode=True, + db_url="sqlite://", ) yield worker worker.cleanup() diff --git a/packages/syft/tests/syft/blob_storage/blob_storage_test.py b/packages/syft/tests/syft/blob_storage/blob_storage_test.py index 8b8613498fb..47e33f7926d 100644 --- a/packages/syft/tests/syft/blob_storage/blob_storage_test.py +++ b/packages/syft/tests/syft/blob_storage/blob_storage_test.py @@ -1,6 +1,5 @@ # stdlib import io -import random # third party import numpy as np @@ -42,10 +41,7 @@ def test_blob_storage_allocate(authed_context, blob_storage): assert isinstance(blob_deposit, BlobDeposit) -def test_blob_storage_write(): - random.seed() - name = "".join(str(random.randint(0, 9)) for i in range(8)) - worker = sy.Worker.named(name=name) +def test_blob_storage_write(worker): blob_storage = worker.services.blob_storage authed_context = AuthedServiceContext( server=worker, credentials=worker.signing_key.verify_key @@ -60,10 +56,7 @@ def test_blob_storage_write(): worker.cleanup() -def test_blob_storage_write_syft_object(): - random.seed() - name = "".join(str(random.randint(0, 9)) for i in range(8)) - worker = sy.Worker.named(name=name) +def test_blob_storage_write_syft_object(worker): blob_storage = worker.services.blob_storage authed_context = AuthedServiceContext( server=worker, credentials=worker.signing_key.verify_key @@ -78,10 +71,7 @@ def test_blob_storage_write_syft_object(): worker.cleanup() -def test_blob_storage_read(): - random.seed() - name = "".join(str(random.randint(0, 9)) for i in range(8)) - worker = sy.Worker.named(name=name) +def test_blob_storage_read(worker): blob_storage = worker.services.blob_storage authed_context = AuthedServiceContext( server=worker, credentials=worker.signing_key.verify_key diff --git a/packages/syft/tests/syft/migrations/data_migration_test.py b/packages/syft/tests/syft/migrations/data_migration_test.py index 708c56ac75a..a5203e2a0f8 100644 --- a/packages/syft/tests/syft/migrations/data_migration_test.py +++ b/packages/syft/tests/syft/migrations/data_migration_test.py @@ -115,7 +115,7 @@ def test_get_migration_data(worker, tmp_path): @contextmanager def named_worker_context(name): # required to launch worker with same name twice within the same test + ensure cleanup - worker = sy.Worker.named(name=name) + worker = sy.Worker.named(name=name, db_url="sqlite://") try: yield worker finally: diff --git a/packages/syft/tests/syft/settings/settings_service_test.py b/packages/syft/tests/syft/settings/settings_service_test.py index b856f6040c5..7555aadd91e 100644 --- a/packages/syft/tests/syft/settings/settings_service_test.py +++ b/packages/syft/tests/syft/settings/settings_service_test.py @@ -255,7 +255,7 @@ def test_settings_allow_guest_registration( new_callable=mock.PropertyMock, return_value=mock_server_settings, ): - worker = syft.Worker.named(name=faker.name(), reset=True) + worker = syft.Worker.named(name=faker.name(), reset=True, db_url="sqlite://") guest_datasite_client = worker.guest_client root_datasite_client = worker.root_client @@ -289,7 +289,7 @@ def test_settings_allow_guest_registration( new_callable=mock.PropertyMock, return_value=mock_server_settings, ): - worker = syft.Worker.named(name=faker.name(), reset=True) + worker = syft.Worker.named(name=faker.name(), reset=True, db_url="sqlite://") guest_datasite_client = worker.guest_client root_datasite_client = worker.root_client @@ -348,7 +348,7 @@ def get_mock_client(faker, root_client, role): new_callable=mock.PropertyMock, return_value=mock_server_settings, ): - worker = syft.Worker.named(name=faker.name(), reset=True) + worker = syft.Worker.named(name=faker.name(), reset=True, db_url="sqlite://") root_client = worker.root_client emails_added = [] diff --git a/packages/syft/tests/syft/users/user_service_test.py b/packages/syft/tests/syft/users/user_service_test.py index efa69cd1c2b..59e905ee657 100644 --- a/packages/syft/tests/syft/users/user_service_test.py +++ b/packages/syft/tests/syft/users/user_service_test.py @@ -554,7 +554,7 @@ def mock_get_by_email(credentials: SyftVerifyKey, email) -> User: new_callable=mock.PropertyMock, return_value=settings_with_signup_enabled(worker), ): - mock_worker = Worker.named(name="mock-server") + mock_worker = Worker.named(name="mock-server", db_url="sqlite://") server_context = ServerServiceContext(server=mock_worker) with pytest.raises(SyftException) as exc: @@ -584,7 +584,7 @@ def mock_get_by_email(credentials: SyftVerifyKey, email) -> NoReturn: new_callable=mock.PropertyMock, return_value=settings_with_signup_enabled(worker), ): - mock_worker = Worker.named(name="mock-server") + mock_worker = Worker.named(name="mock-server", db_url="sqlite://") server_context = ServerServiceContext(server=mock_worker) with pytest.raises(StashException) as exc: @@ -613,7 +613,7 @@ def mock_set(*args, **kwargs) -> User: new_callable=mock.PropertyMock, return_value=settings_with_signup_enabled(worker), ): - mock_worker = Worker.named(name="mock-server") + mock_worker = Worker.named(name="mock-server", db_url="sqlite://") server_context = ServerServiceContext(server=mock_worker) monkeypatch.setattr(user_service.stash, "get_by_email", mock_get_by_email) @@ -652,7 +652,7 @@ def mock_set( new_callable=mock.PropertyMock, return_value=settings_with_signup_enabled(worker), ): - mock_worker = Worker.named(name="mock-server") + mock_worker = Worker.named(name="mock-server", db_url="sqlite://") server_context = ServerServiceContext(server=mock_worker) monkeypatch.setattr(user_service.stash, "get_by_email", mock_get_by_email) From 8417653217fb14883f0baf13d40599acd291b4a5 Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 17 Sep 2024 10:31:14 +0200 Subject: [PATCH 192/197] manage db path externally --- packages/syft/src/syft/server/server.py | 6 ++++ packages/syft/tests/conftest.py | 30 ++++++++++++++----- .../syft/migrations/data_migration_test.py | 2 +- 3 files changed, 29 insertions(+), 9 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index 3f0161057bf..a013fa2f5db 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -719,6 +719,12 @@ def remove_consumer_with_id(self, syft_worker_id: UID) -> None: if consumer_to_pop is not None: consumers.pop(consumer_to_pop) + def remove_all_consumers(self) -> None: + for consumers in self.queue_manager.consumers.values(): + for consumer in consumers: + consumer.close() + consumers.clear() + @classmethod def named( cls: type[Server], diff --git a/packages/syft/tests/conftest.py b/packages/syft/tests/conftest.py index 2b45201d490..6382a319e45 100644 --- a/packages/syft/tests/conftest.py +++ b/packages/syft/tests/conftest.py @@ -6,6 +6,7 @@ import shutil import sys from tempfile import gettempdir +from tempfile import mkstemp from unittest import mock from uuid import uuid4 @@ -25,6 +26,7 @@ from syft.server.worker import Worker from syft.service.queue.queue_stash import QueueStash from syft.service.user import user +from syft.store.db.sqlite import SQLiteDBConfig # our version of mongomock that has a fix for CodecOptions and custom TypeRegistry Support @@ -115,26 +117,38 @@ def faker(): @pytest.fixture(scope="function") -def worker() -> Worker: - worker = sy.Worker.named(name=token_hex(16), db_url="sqlite://") +def db_config() -> SQLiteDBConfig: + file_handle, path = mkstemp(suffix=".db") + # close the file handle + os.close(file_handle) + temp_path = Path(path) + temp_path.parent.mkdir(parents=True, exist_ok=True) + yield SQLiteDBConfig(filename=temp_path.name, path=temp_path.parent) + temp_path.unlink() + + +@pytest.fixture(scope="function") +def worker(db_config: SQLiteDBConfig) -> Worker: + worker = sy.Worker.named(name=token_hex(16), db_config=db_config) yield worker worker.cleanup() del worker @pytest.fixture(scope="function") -def second_worker() -> Worker: +def second_worker(db_config: SQLiteDBConfig) -> Worker: # Used in server syncing tests - worker = sy.Worker.named(name=uuid4().hex, db_url="sqlite://") + worker = sy.Worker.named(name=uuid4().hex, db_config=db_config) yield worker worker.cleanup() del worker @pytest.fixture(scope="function") -def high_worker() -> Worker: +def high_worker(db_config: SQLiteDBConfig) -> Worker: worker = sy.Worker.named( - name=token_hex(8), server_side_type=ServerSideType.HIGH_SIDE, db_url="sqlite://" + name=token_hex(8), + server_side_type=ServerSideType.HIGH_SIDE, ) yield worker worker.cleanup() @@ -142,12 +156,12 @@ def high_worker() -> Worker: @pytest.fixture(scope="function") -def low_worker() -> Worker: +def low_worker(db_config: SQLiteDBConfig) -> Worker: worker = sy.Worker.named( name=token_hex(8), server_side_type=ServerSideType.LOW_SIDE, dev_mode=True, - db_url="sqlite://", + db_config=db_config, ) yield worker worker.cleanup() diff --git a/packages/syft/tests/syft/migrations/data_migration_test.py b/packages/syft/tests/syft/migrations/data_migration_test.py index a5203e2a0f8..708c56ac75a 100644 --- a/packages/syft/tests/syft/migrations/data_migration_test.py +++ b/packages/syft/tests/syft/migrations/data_migration_test.py @@ -115,7 +115,7 @@ def test_get_migration_data(worker, tmp_path): @contextmanager def named_worker_context(name): # required to launch worker with same name twice within the same test + ensure cleanup - worker = sy.Worker.named(name=name, db_url="sqlite://") + worker = sy.Worker.named(name=name) try: yield worker finally: From e5ff154cc0e4c78019e9149071d814a04aa1c7de Mon Sep 17 00:00:00 2001 From: Aziz Berkay Yesilyurt Date: Tue, 17 Sep 2024 10:40:40 +0200 Subject: [PATCH 193/197] Revert "manage db path externally" This reverts commit 8417653217fb14883f0baf13d40599acd291b4a5. --- packages/syft/src/syft/server/server.py | 6 ---- packages/syft/tests/conftest.py | 30 +++++-------------- .../syft/migrations/data_migration_test.py | 2 +- 3 files changed, 9 insertions(+), 29 deletions(-) diff --git a/packages/syft/src/syft/server/server.py b/packages/syft/src/syft/server/server.py index a013fa2f5db..3f0161057bf 100644 --- a/packages/syft/src/syft/server/server.py +++ b/packages/syft/src/syft/server/server.py @@ -719,12 +719,6 @@ def remove_consumer_with_id(self, syft_worker_id: UID) -> None: if consumer_to_pop is not None: consumers.pop(consumer_to_pop) - def remove_all_consumers(self) -> None: - for consumers in self.queue_manager.consumers.values(): - for consumer in consumers: - consumer.close() - consumers.clear() - @classmethod def named( cls: type[Server], diff --git a/packages/syft/tests/conftest.py b/packages/syft/tests/conftest.py index 6382a319e45..2b45201d490 100644 --- a/packages/syft/tests/conftest.py +++ b/packages/syft/tests/conftest.py @@ -6,7 +6,6 @@ import shutil import sys from tempfile import gettempdir -from tempfile import mkstemp from unittest import mock from uuid import uuid4 @@ -26,7 +25,6 @@ from syft.server.worker import Worker from syft.service.queue.queue_stash import QueueStash from syft.service.user import user -from syft.store.db.sqlite import SQLiteDBConfig # our version of mongomock that has a fix for CodecOptions and custom TypeRegistry Support @@ -117,38 +115,26 @@ def faker(): @pytest.fixture(scope="function") -def db_config() -> SQLiteDBConfig: - file_handle, path = mkstemp(suffix=".db") - # close the file handle - os.close(file_handle) - temp_path = Path(path) - temp_path.parent.mkdir(parents=True, exist_ok=True) - yield SQLiteDBConfig(filename=temp_path.name, path=temp_path.parent) - temp_path.unlink() - - -@pytest.fixture(scope="function") -def worker(db_config: SQLiteDBConfig) -> Worker: - worker = sy.Worker.named(name=token_hex(16), db_config=db_config) +def worker() -> Worker: + worker = sy.Worker.named(name=token_hex(16), db_url="sqlite://") yield worker worker.cleanup() del worker @pytest.fixture(scope="function") -def second_worker(db_config: SQLiteDBConfig) -> Worker: +def second_worker() -> Worker: # Used in server syncing tests - worker = sy.Worker.named(name=uuid4().hex, db_config=db_config) + worker = sy.Worker.named(name=uuid4().hex, db_url="sqlite://") yield worker worker.cleanup() del worker @pytest.fixture(scope="function") -def high_worker(db_config: SQLiteDBConfig) -> Worker: +def high_worker() -> Worker: worker = sy.Worker.named( - name=token_hex(8), - server_side_type=ServerSideType.HIGH_SIDE, + name=token_hex(8), server_side_type=ServerSideType.HIGH_SIDE, db_url="sqlite://" ) yield worker worker.cleanup() @@ -156,12 +142,12 @@ def high_worker(db_config: SQLiteDBConfig) -> Worker: @pytest.fixture(scope="function") -def low_worker(db_config: SQLiteDBConfig) -> Worker: +def low_worker() -> Worker: worker = sy.Worker.named( name=token_hex(8), server_side_type=ServerSideType.LOW_SIDE, dev_mode=True, - db_config=db_config, + db_url="sqlite://", ) yield worker worker.cleanup() diff --git a/packages/syft/tests/syft/migrations/data_migration_test.py b/packages/syft/tests/syft/migrations/data_migration_test.py index 708c56ac75a..a5203e2a0f8 100644 --- a/packages/syft/tests/syft/migrations/data_migration_test.py +++ b/packages/syft/tests/syft/migrations/data_migration_test.py @@ -115,7 +115,7 @@ def test_get_migration_data(worker, tmp_path): @contextmanager def named_worker_context(name): # required to launch worker with same name twice within the same test + ensure cleanup - worker = sy.Worker.named(name=name) + worker = sy.Worker.named(name=name, db_url="sqlite://") try: yield worker finally: From f489ee45ed662df330a7f3f31d76f842b58a96e3 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 17 Sep 2024 12:04:47 +0200 Subject: [PATCH 194/197] more logs in custom endpoint notebook --- .../api/0.8/12-custom-api-endpoint.ipynb | 26 ++++++++++++++----- 1 file changed, 20 insertions(+), 6 deletions(-) diff --git a/notebooks/api/0.8/12-custom-api-endpoint.ipynb b/notebooks/api/0.8/12-custom-api-endpoint.ipynb index aa60e30dd87..c58c78c1795 100644 --- a/notebooks/api/0.8/12-custom-api-endpoint.ipynb +++ b/notebooks/api/0.8/12-custom-api-endpoint.ipynb @@ -657,6 +657,9 @@ "# stdlib\n", "import time\n", "\n", + "# syft absolute\n", + "from syft.service.job.job_stash import JobStatus\n", + "\n", "# Iterate over the Jobs waiting them to finish their pipelines.\n", "job_pool = [\n", " (log_call_job, \"Logging Private Function Call\"),\n", @@ -665,13 +668,19 @@ "]\n", "for job, expected_log in job_pool:\n", " updated_job = datasite_client.api.services.job.get(job.id)\n", - " while updated_job.status.value != \"completed\":\n", + " while updated_job.status in {JobStatus.CREATED, JobStatus.PROCESSING}:\n", " updated_job = datasite_client.api.services.job.get(job.id)\n", " time.sleep(1)\n", - " # If they're completed. Then, check if the TwinAPI print appears in the job logs.\n", - " assert expected_log in datasite_client.api.services.job.get(job.id).logs(\n", - " _print=False\n", - " )" + "\n", + " assert (\n", + " updated_job.status == JobStatus.COMPLETED\n", + " ), f\"Job {updated_job.id} exited with status {updated_job.status} and result {updated_job.result}\"\n", + " if updated_job.status == JobStatus.COMPLETED:\n", + " print(f\"Job {updated_job.id} completed\")\n", + " # If they're completed. Then, check if the TwinAPI print appears in the job logs.\n", + " assert expected_log in datasite_client.api.services.job.get(job.id).logs(\n", + " _print=False\n", + " )" ] }, { @@ -683,6 +692,11 @@ } ], "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, "language_info": { "codemirror_mode": { "name": "ipython", @@ -693,7 +707,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.5" + "version": "3.10.13" } }, "nbformat": 4, From f1e8d41c936c83b362533ae0a990dbb67ee7e01f Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 17 Sep 2024 12:32:41 +0200 Subject: [PATCH 195/197] debug partition_by_server --- packages/syft/src/syft/service/policy/policy.py | 6 ++++-- packages/syft/src/syft/store/db/stash.py | 7 ++++++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index db57609af93..1e33755418e 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -174,7 +174,6 @@ def partition_by_server(kwargs: dict[str, Any]) -> dict[ServerIdentity, dict[str from ..action.action_object import ActionObject # fetches the all the current api's connected - api_list = APIRegistry.get_all_api() output_kwargs = {} for k, v in kwargs.items(): uid = v @@ -190,7 +189,7 @@ def partition_by_server(kwargs: dict[str, Any]) -> dict[ServerIdentity, dict[str raise Exception(f"Input {k} must have a UID not {type(v)}") _obj_exists = False - for api in api_list: + for identity, api in APIRegistry.__api_registry__.items(): try: if api.services.action.exists(uid): server_identity = ServerIdentity.from_api(api) @@ -205,6 +204,9 @@ def partition_by_server(kwargs: dict[str, Any]) -> dict[ServerIdentity, dict[str # To handle the cases , where there an old api objects in # in APIRegistry continue + except Exception as e: + print(f"Error in partition_by_server with identity {identity}", e) + raise e if not _obj_exists: raise Exception(f"Input data {k}:{uid} does not belong to any Datasite") diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 22259f2e65a..aec2a2ed9c5 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -292,7 +292,12 @@ def get_role( # this happens when we create stashes in tests return ServiceRole.GUEST - query = self.query(User).filter("verify_key", "eq", credentials) + try: + query = self.query(User).filter("verify_key", "eq", credentials) + except Exception as e: + print("Error getting role", e) + raise e + user = query.execute(session).first() if user is None: return ServiceRole.GUEST From c6d0e0b4a3149ec2417e368f9a97adb8e4c42742 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 17 Sep 2024 13:15:48 +0200 Subject: [PATCH 196/197] wrap DatabaseError --- packages/syft/src/syft/store/db/errors.py | 31 +++++++++++++++++++++++ packages/syft/src/syft/store/db/query.py | 7 ++++- 2 files changed, 37 insertions(+), 1 deletion(-) create mode 100644 packages/syft/src/syft/store/db/errors.py diff --git a/packages/syft/src/syft/store/db/errors.py b/packages/syft/src/syft/store/db/errors.py new file mode 100644 index 00000000000..8f9a4ca048a --- /dev/null +++ b/packages/syft/src/syft/store/db/errors.py @@ -0,0 +1,31 @@ +# stdlib +import logging + +# third party +from sqlalchemy.exc import DatabaseError +from typing_extensions import Self + +# relative +from ..document_store_errors import StashException + +logger = logging.getLogger(__name__) + + +class StashDBException(StashException): + """ + See https://docs.sqlalchemy.org/en/20/errors.html#databaseerror + + StashDBException converts a SQLAlchemy DatabaseError into a StashException, + DatabaseErrors are errors thrown by the database itself, for example when a + query fails because a table is missing. + """ + + public_message = "There was an error retrieving data. Contact your admin." + + @classmethod + def from_sqlalchemy_error(cls, e: DatabaseError) -> Self: + logger.exception(e) + + error_type = e.__class__.__name__ + private_message = f"{error_type}: {str(e)}" + return cls(private_message=private_message) diff --git a/packages/syft/src/syft/store/db/query.py b/packages/syft/src/syft/store/db/query.py index fe96c3fba00..04864a4a74a 100644 --- a/packages/syft/src/syft/store/db/query.py +++ b/packages/syft/src/syft/store/db/query.py @@ -13,6 +13,7 @@ from sqlalchemy import Select from sqlalchemy import Table from sqlalchemy import func +from sqlalchemy.exc import DatabaseError from sqlalchemy.orm import Session from typing_extensions import Self @@ -24,6 +25,7 @@ from ...service.user.user_roles import ServiceRole from ...types.syft_object import SyftObject from ...types.uid import UID +from .errors import StashDBException from .schema import PostgresBase from .schema import SQLiteBase @@ -63,7 +65,10 @@ def create(cls, object_type: type[SyftObject], dialect: str | Dialect) -> "Query def execute(self, session: Session) -> Result: """Execute the query using the given session.""" - return session.execute(self.stmt) + try: + return session.execute(self.stmt) + except DatabaseError as e: + raise StashDBException.from_sqlalchemy_error(e) from e def with_permissions( self, From b59fede6eb69c460b8f3efec04afaa2c3171bf41 Mon Sep 17 00:00:00 2001 From: eelcovdw Date: Tue, 17 Sep 2024 13:23:47 +0200 Subject: [PATCH 197/197] error handling for Query db errors --- packages/syft/src/syft/store/db/sqlite.py | 7 ++++++- packages/syft/tests/conftest.py | 5 +++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/packages/syft/src/syft/store/db/sqlite.py b/packages/syft/src/syft/store/db/sqlite.py index e9b3468efa3..fbcf87ce47b 100644 --- a/packages/syft/src/syft/store/db/sqlite.py +++ b/packages/syft/src/syft/store/db/sqlite.py @@ -23,8 +23,13 @@ class SQLiteDBConfig(DBConfig): @property def connection_string(self) -> str: + """ + NOTE in-memory sqlite is not shared between connections, so: + - using 2 workers (high/low) will not share a db + - re-using a connection (e.g. for a Job worker) will not share a db + """ if self.path == Path("."): - # Use in-memory database + # Use in-memory database, only for unittests return "sqlite://" filepath = self.path / self.filename return f"sqlite:///{filepath.resolve()}" diff --git a/packages/syft/tests/conftest.py b/packages/syft/tests/conftest.py index 2b45201d490..56506b43fad 100644 --- a/packages/syft/tests/conftest.py +++ b/packages/syft/tests/conftest.py @@ -116,6 +116,11 @@ def faker(): @pytest.fixture(scope="function") def worker() -> Worker: + """ + NOTE in-memory sqlite is not shared between connections, so: + - using 2 workers (high/low) will not share a db + - re-using a connection (e.g. for a Job worker) will not share a db + """ worker = sy.Worker.named(name=token_hex(16), db_url="sqlite://") yield worker worker.cleanup()