diff --git a/packages/syft/src/syft/__init__.py b/packages/syft/src/syft/__init__.py index 0c9cf55440d..3a0fb8a6bb1 100644 --- a/packages/syft/src/syft/__init__.py +++ b/packages/syft/src/syft/__init__.py @@ -77,8 +77,6 @@ from .service.user.roles import Roles as roles from .service.user.user_service import UserService from .stable_version import LATEST_STABLE_SYFT -from .store.mongo_document_store import MongoStoreConfig -from .store.sqlite_document_store import SQLiteStoreConfig from .types.errors import SyftException from .types.errors import raises from .types.result import as_result diff --git a/packages/syft/src/syft/protocol/data_protocol.py b/packages/syft/src/syft/protocol/data_protocol.py index 2316263112a..e4e9804603a 100644 --- a/packages/syft/src/syft/protocol/data_protocol.py +++ b/packages/syft/src/syft/protocol/data_protocol.py @@ -21,6 +21,7 @@ # syft absolute from syft.types.result import Err from syft.types.result import Ok +from syft.util.util import get_dev_mode # relative from .. import __version__ @@ -29,7 +30,6 @@ from ..types.errors import SyftException from ..types.syft_object import SyftBaseObject from ..types.syft_object_registry import SyftObjectRegistry -from ..util.util import get_dev_mode PROTOCOL_STATE_FILENAME = "protocol_version.json" PROTOCOL_TYPE = str | int diff --git a/packages/syft/src/syft/service/code/status_service.py b/packages/syft/src/syft/service/code/status_service.py index 89871e63c19..e5a32f5c721 100644 --- a/packages/syft/src/syft/service/code/status_service.py +++ b/packages/syft/src/syft/service/code/status_service.py @@ -7,7 +7,6 @@ from ...serde.serializable import serializable from ...store.db.db import DBManager from ...store.db.stash import ObjectStash -from ...store.document_store import PartitionSettings from ...types.syft_object import PartialSyftObject from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.uid import UID @@ -24,10 +23,7 @@ @serializable(canonical_name="StatusSQLStash", version=1) class StatusStash(ObjectStash[UserCodeStatusCollection]): - settings: PartitionSettings = PartitionSettings( - name=UserCodeStatusCollection.__canonical_name__, - object_type=UserCodeStatusCollection, - ) + pass class CodeStatusUpdate(PartialSyftObject): diff --git a/packages/syft/src/syft/service/code/user_code_stash.py b/packages/syft/src/syft/service/code/user_code_stash.py index 232342bd8d5..4ba67e3633b 100644 --- a/packages/syft/src/syft/service/code/user_code_stash.py +++ b/packages/syft/src/syft/service/code/user_code_stash.py @@ -2,7 +2,6 @@ from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...store.db.stash import ObjectStash -from ...store.document_store import PartitionSettings from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result @@ -11,10 +10,6 @@ @serializable(canonical_name="UserCodeSQLStash", version=1) class UserCodeStash(ObjectStash[UserCode]): - settings: PartitionSettings = PartitionSettings( - name=UserCode.__canonical_name__, object_type=UserCode - ) - @as_result(StashException, NotFoundException) def get_by_code_hash(self, credentials: SyftVerifyKey, code_hash: str) -> UserCode: return self.get_one( diff --git a/packages/syft/src/syft/service/data_subject/data_subject.py b/packages/syft/src/syft/service/data_subject/data_subject.py index e9d039a7b78..f85f80b0069 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject.py +++ b/packages/syft/src/syft/service/data_subject/data_subject.py @@ -7,7 +7,6 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import PartitionKey from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SyftObject from ...types.transforms import TransformContext @@ -17,8 +16,6 @@ from ...types.uid import UID from ...util.markdown import as_markdown_python_code -NamePartitionKey = PartitionKey(key="name", type_=str) - @serializable() class DataSubject(SyftObject): diff --git a/packages/syft/src/syft/service/data_subject/data_subject_member.py b/packages/syft/src/syft/service/data_subject/data_subject_member.py index 82767e4b631..83704fc95ab 100644 --- a/packages/syft/src/syft/service/data_subject/data_subject_member.py +++ b/packages/syft/src/syft/service/data_subject/data_subject_member.py @@ -3,13 +3,9 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import PartitionKey from ...types.syft_object import SYFT_OBJECT_VERSION_1 from ...types.syft_object import SyftObject -ParentPartitionKey = PartitionKey(key="parent", type_=str) -ChildPartitionKey = PartitionKey(key="child", type_=str) - @serializable() class DataSubjectMemberRelationship(SyftObject): diff --git a/packages/syft/src/syft/service/dataset/dataset.py b/packages/syft/src/syft/service/dataset/dataset.py index fbc5318be57..185774fa86c 100644 --- a/packages/syft/src/syft/service/dataset/dataset.py +++ b/packages/syft/src/syft/service/dataset/dataset.py @@ -17,7 +17,6 @@ # relative from ...serde.serializable import serializable -from ...store.document_store import PartitionKey from ...types.datetime import DateTime from ...types.dicttuple import DictTuple from ...types.errors import SyftException @@ -45,7 +44,6 @@ from ..response import SyftSuccess from ..response import SyftWarning -NamePartitionKey = PartitionKey(key="name", type_=str) logger = logging.getLogger(__name__) diff --git a/packages/syft/src/syft/service/job/job_stash.py b/packages/syft/src/syft/service/job/job_stash.py index 358834470fb..36cb42de1c4 100644 --- a/packages/syft/src/syft/service/job/job_stash.py +++ b/packages/syft/src/syft/service/job/job_stash.py @@ -20,7 +20,6 @@ from ...service.context import AuthedServiceContext from ...service.worker.worker_pool import SyftWorker from ...store.db.stash import ObjectStash -from ...store.document_store import PartitionSettings from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.datetime import DateTime @@ -736,10 +735,6 @@ def from_job( @serializable(canonical_name="JobStashSQL", version=1) class JobStash(ObjectStash[Job]): - settings: PartitionSettings = PartitionSettings( - name=Job.__canonical_name__, object_type=Job - ) - @as_result(StashException) def set_result( self, diff --git a/packages/syft/src/syft/service/migration/object_migration_state.py b/packages/syft/src/syft/service/migration/object_migration_state.py index 22363d867f2..9ded4cd497b 100644 --- a/packages/syft/src/syft/service/migration/object_migration_state.py +++ b/packages/syft/src/syft/service/migration/object_migration_state.py @@ -16,7 +16,6 @@ from ...server.credentials import SyftSigningKey from ...server.credentials import SyftVerifyKey from ...store.db.stash import ObjectStash -from ...store.document_store import PartitionKey from ...store.document_store_errors import NotFoundException from ...types.blob_storage import BlobStorageEntry from ...types.blob_storage import CreateBlobStorageEntry @@ -64,9 +63,6 @@ def supported_versions(self) -> list: return SyftObjectRegistry.get_versions(self.canonical_name) -KlassNamePartitionKey = PartitionKey(key="canonical_name", type_=str) - - @serializable(canonical_name="SyftMigrationStateSQLStash", version=1) class SyftMigrationStateStash(ObjectStash[SyftObjectMigrationState]): @as_result(SyftException, NotFoundException) diff --git a/packages/syft/src/syft/service/notifier/notifier_stash.py b/packages/syft/src/syft/service/notifier/notifier_stash.py index 861f3a89975..b3583173e50 100644 --- a/packages/syft/src/syft/service/notifier/notifier_stash.py +++ b/packages/syft/src/syft/service/notifier/notifier_stash.py @@ -6,7 +6,6 @@ from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...store.db.stash import ObjectStash -from ...store.document_store import PartitionSettings from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result @@ -17,10 +16,6 @@ @instrument @serializable(canonical_name="NotifierSQLStash", version=1) class NotifierStash(ObjectStash[NotifierSettings]): - settings: PartitionSettings = PartitionSettings( - name=NotifierSettings.__canonical_name__, object_type=NotifierSettings - ) - @as_result(StashException, NotFoundException) def get(self, credentials: SyftVerifyKey) -> NotifierSettings: """Get Settings""" diff --git a/packages/syft/src/syft/service/output/output_service.py b/packages/syft/src/syft/service/output/output_service.py index 5d26ff2cb3e..cdf567a5955 100644 --- a/packages/syft/src/syft/service/output/output_service.py +++ b/packages/syft/src/syft/service/output/output_service.py @@ -9,7 +9,6 @@ from ...server.credentials import SyftVerifyKey from ...store.db.db import DBManager from ...store.db.stash import ObjectStash -from ...store.document_store import PartitionKey from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime @@ -26,11 +25,6 @@ from ..user.user_roles import ADMIN_ROLE_LEVEL from ..user.user_roles import GUEST_ROLE_LEVEL -CreatedAtPartitionKey = PartitionKey(key="created_at", type_=DateTime) -UserCodeIdPartitionKey = PartitionKey(key="user_code_id", type_=UID) -JobIdPartitionKey = PartitionKey(key="job_id", type_=UID) -OutputPolicyIdPartitionKey = PartitionKey(key="output_policy_id", type_=UID) - @serializable() class ExecutionOutput(SyncableSyftObject): diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index 4662307c235..f89493c4c0b 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -30,7 +30,6 @@ from ...serde.recursive_primitives import recursive_serde_register_type from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey -from ...store.document_store import PartitionKey from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.datetime import DateTime @@ -84,10 +83,6 @@ class OutputPolicyValidEnum(Enum): DEFAULT_USER_POLICY_VERSION = 1 -PolicyUserVerifyKeyPartitionKey = PartitionKey( - key="user_verify_key", type_=SyftVerifyKey -) - PyCodeObject = Any diff --git a/packages/syft/src/syft/service/queue/base_queue.py b/packages/syft/src/syft/service/queue/base_queue.py index 3d8773bd8c2..edfc1eea461 100644 --- a/packages/syft/src/syft/service/queue/base_queue.py +++ b/packages/syft/src/syft/service/queue/base_queue.py @@ -1,15 +1,19 @@ # stdlib from typing import Any from typing import ClassVar +from typing import TYPE_CHECKING # relative from ...serde.serializable import serializable from ...service.context import AuthedServiceContext -from ...store.document_store import NewBaseStash from ...types.uid import UID from ..response import SyftSuccess from ..worker.worker_stash import WorkerStash +if TYPE_CHECKING: + # relative + from .queue_stash import QueueStash + @serializable(canonical_name="QueueClientConfig", version=1) class QueueClientConfig: @@ -105,7 +109,7 @@ def create_consumer( def create_producer( self, queue_name: str, - queue_stash: type[NewBaseStash], + queue_stash: "QueueStash", context: AuthedServiceContext, worker_stash: WorkerStash, ) -> QueueProducer: diff --git a/packages/syft/src/syft/service/queue/queue.py b/packages/syft/src/syft/service/queue/queue.py index b2807389cf1..2cd86d7ae5f 100644 --- a/packages/syft/src/syft/service/queue/queue.py +++ b/packages/syft/src/syft/service/queue/queue.py @@ -6,6 +6,7 @@ from threading import Thread import time from typing import Any +from typing import TYPE_CHECKING from typing import cast # third party @@ -17,7 +18,6 @@ from ...server.credentials import SyftVerifyKey from ...server.worker_settings import WorkerSettings from ...service.context import AuthedServiceContext -from ...store.document_store import NewBaseStash from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime from ...types.errors import SyftException @@ -39,6 +39,10 @@ from .queue_stash import QueueItem from .queue_stash import Status +if TYPE_CHECKING: + # relative + from .queue_stash import QueueStash + logger = logging.getLogger(__name__) @@ -129,7 +133,7 @@ def create_consumer( def create_producer( self, queue_name: str, - queue_stash: type[NewBaseStash], + queue_stash: "QueueStash", context: AuthedServiceContext, worker_stash: WorkerStash, ) -> QueueProducer: diff --git a/packages/syft/src/syft/service/queue/queue_stash.py b/packages/syft/src/syft/service/queue/queue_stash.py index aa5b872b226..c43caa302be 100644 --- a/packages/syft/src/syft/service/queue/queue_stash.py +++ b/packages/syft/src/syft/service/queue/queue_stash.py @@ -9,7 +9,6 @@ from ...server.worker_settings import WorkerSettings from ...server.worker_settings import WorkerSettingsV1 from ...store.db.stash import ObjectStash -from ...store.document_store import PartitionKey from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...store.linked_obj import LinkedObject @@ -35,10 +34,6 @@ class Status(str, Enum): INTERRUPTED = "interrupted" -StatusPartitionKey = PartitionKey(key="status", type_=Status) -_WorkerPoolPartitionKey = PartitionKey(key="worker_pool", type_=LinkedObject) - - @serializable() class QueueItemV1(SyftObject): __canonical_name__ = "QueueItem" diff --git a/packages/syft/src/syft/service/service.py b/packages/syft/src/syft/service/service.py index 49749711853..d990e297054 100644 --- a/packages/syft/src/syft/service/service.py +++ b/packages/syft/src/syft/service/service.py @@ -38,7 +38,6 @@ from ..serde.signature import signature_remove_self from ..server.credentials import SyftVerifyKey from ..store.db.stash import ObjectStash -from ..store.document_store import DocumentStore from ..store.linked_obj import LinkedObject from ..types.errors import SyftException from ..types.result import as_result @@ -71,7 +70,6 @@ class AbstractService: server: AbstractServer server_uid: UID - store_type: type = DocumentStore stash: ObjectStash @as_result(SyftException) diff --git a/packages/syft/src/syft/service/sync/sync_service.py b/packages/syft/src/syft/service/sync/sync_service.py index 4aadaaa079a..fa926b13966 100644 --- a/packages/syft/src/syft/service/sync/sync_service.py +++ b/packages/syft/src/syft/service/sync/sync_service.py @@ -8,7 +8,6 @@ from ...serde.serializable import serializable from ...store.db.db import DBManager from ...store.db.stash import ObjectStash -from ...store.document_store import NewBaseStash from ...store.document_store_errors import NotFoundException from ...store.linked_obj import LinkedObject from ...types.datetime import DateTime @@ -108,9 +107,10 @@ def transform_item( self.set_obj_ids(context, item) return item + @as_result(ValueError) def get_stash_for_item( self, context: AuthedServiceContext, item: SyftObject - ) -> NewBaseStash: + ) -> ObjectStash: services = list(context.server.service_path_map.values()) # type: ignore all_stashes = {} @@ -119,6 +119,8 @@ def get_stash_for_item( all_stashes[_stash.object_type] = _stash stash = all_stashes.get(type(item), None) + if stash is None: + raise ValueError(f"Could not find stash for {type(item)}") return stash def add_permissions_for_item( @@ -148,7 +150,7 @@ def add_storage_permissions_for_item( def set_object( self, context: AuthedServiceContext, item: SyncableSyftObject ) -> SyftObject: - stash = self.get_stash_for_item(context, item) + stash = self.get_stash_for_item(context, item).unwrap() creds = context.credentials obj = None diff --git a/packages/syft/src/syft/service/sync/sync_stash.py b/packages/syft/src/syft/service/sync/sync_stash.py index 114ee209af1..81bc26db263 100644 --- a/packages/syft/src/syft/service/sync/sync_stash.py +++ b/packages/syft/src/syft/service/sync/sync_stash.py @@ -3,7 +3,6 @@ from ...server.credentials import SyftVerifyKey from ...store.db.db import DBManager from ...store.db.stash import ObjectStash -from ...store.document_store import PartitionSettings from ...store.document_store_errors import StashException from ...types.result import as_result from .sync_state import SyncState @@ -11,11 +10,6 @@ @serializable(canonical_name="SyncStash", version=1) class SyncStash(ObjectStash[SyncState]): - settings: PartitionSettings = PartitionSettings( - name=SyncState.__canonical_name__, - object_type=SyncState, - ) - def __init__(self, store: DBManager) -> None: super().__init__(store) self.last_state: SyncState | None = None diff --git a/packages/syft/src/syft/service/worker/worker_service.py b/packages/syft/src/syft/service/worker/worker_service.py index 300c0b6ed3d..5f806108b26 100644 --- a/packages/syft/src/syft/service/worker/worker_service.py +++ b/packages/syft/src/syft/service/worker/worker_service.py @@ -14,11 +14,11 @@ from ...serde.serializable import serializable from ...server.credentials import SyftVerifyKey from ...store.db.db import DBManager -from ...store.document_store import SyftSuccess from ...store.document_store_errors import StashException from ...types.errors import SyftException from ...types.result import as_result from ...types.uid import UID +from ..response import SyftSuccess from ..service import AbstractService from ..service import AuthedServiceContext from ..service import service_method diff --git a/packages/syft/src/syft/service/worker/worker_stash.py b/packages/syft/src/syft/service/worker/worker_stash.py index d64314a5d81..11d2d66bacc 100644 --- a/packages/syft/src/syft/service/worker/worker_stash.py +++ b/packages/syft/src/syft/service/worker/worker_stash.py @@ -10,7 +10,6 @@ from ...server.credentials import SyftVerifyKey from ...store.db.stash import ObjectStash from ...store.db.stash import with_session -from ...store.document_store import PartitionKey from ...store.document_store_errors import NotFoundException from ...store.document_store_errors import StashException from ...types.result import as_result @@ -20,8 +19,6 @@ from .worker_pool import ConsumerState from .worker_pool import SyftWorker -WorkerContainerNamePartitionKey = PartitionKey(key="container_name", type_=str) - @serializable(canonical_name="WorkerSQLStash", version=1) class WorkerStash(ObjectStash[SyftWorker]): diff --git a/packages/syft/src/syft/store/__init__.py b/packages/syft/src/syft/store/__init__.py index e69de29bb2d..42ff4bbd825 100644 --- a/packages/syft/src/syft/store/__init__.py +++ b/packages/syft/src/syft/store/__init__.py @@ -0,0 +1,3 @@ +# relative +from . import mongo_document_store # noqa: F401 +from . import sqlite_document_store # noqa: F401 diff --git a/packages/syft/src/syft/store/document_store.py b/packages/syft/src/syft/store/document_store.py index 98cccb82568..24dbd9969b4 100644 --- a/packages/syft/src/syft/store/document_store.py +++ b/packages/syft/src/syft/store/document_store.py @@ -1,75 +1,16 @@ # future from __future__ import annotations -# stdlib -from collections.abc import Callable -import types -import typing -from typing import Any -from typing import Literal -from typing import TypeVar - # third party from pydantic import BaseModel from pydantic import Field -from typeguard import check_type # relative from ..serde.serializable import serializable -from ..server.credentials import SyftSigningKey -from ..server.credentials import SyftVerifyKey -from ..service.action.action_permissions import ActionObjectPermission -from ..service.action.action_permissions import StoragePermission -from ..service.context import AuthedServiceContext -from ..service.response import SyftSuccess -from ..types.base import SyftBaseModel -from ..types.errors import SyftException -from ..types.result import Ok -from ..types.result import as_result -from ..types.syft_object import BaseDateTime -from ..types.syft_object import PartialSyftObject from ..types.syft_object import SYFT_OBJECT_VERSION_1 from ..types.syft_object import SyftBaseObject -from ..types.syft_object import SyftObject -from ..types.uid import UID -from ..util.telemetry import instrument -from .document_store_errors import NotFoundException -from .document_store_errors import StashException from .locks import LockingConfig from .locks import NoLockingConfig -from .locks import SyftLock - - -@serializable(canonical_name="BasePartitionSettings", version=1) -class BasePartitionSettings(SyftBaseModel): - """Basic Partition Settings - - Parameters: - name: str - Identifier to be used as prefix by stores and for partitioning - """ - - name: str - - -T = TypeVar("T") - - -def new_first_or_none(result: list[T]) -> T | None: - if hasattr(result, "__len__") and len(result) > 0: - return result[0] - return None - - -# todo: remove -def first_or_none(result: Any) -> Ok: - if hasattr(result, "__len__") and len(result) > 0: - return Ok(result[0]) - return Ok(None) - - -def is_generic_alias(t: type) -> bool: - return isinstance(t, types.GenericAlias | typing._GenericAlias) class StoreClientConfig(BaseModel): @@ -78,210 +19,6 @@ class StoreClientConfig(BaseModel): pass -@serializable(canonical_name="PartitionKey", version=1) -class PartitionKey(BaseModel): - key: str - type_: type | object - - def __eq__(self, other: Any) -> bool: - return ( - type(other) == type(self) - and self.key == other.key - and self.type_ == other.type_ - ) - - def with_obj(self, obj: Any) -> QueryKey: - return QueryKey.from_obj(partition_key=self, obj=obj) - - def extract_list(self, obj: Any) -> list: - # not a list and matches the internal list type of the _GenericAlias - if not isinstance(obj, list): - if not isinstance(obj, typing.get_args(self.type_)): - obj = getattr(obj, self.key) - if isinstance(obj, types.FunctionType | types.MethodType): - obj = obj() - - if not isinstance(obj, list) and isinstance( - obj, typing.get_args(self.type_) - ): - # still not a list but the right type - obj = [obj] - - # is a list type so lets compare directly - check_type(obj, self.type_) - return obj - - @property - def type_list(self) -> bool: - return is_generic_alias(self.type_) and self.type_.__origin__ == list - - -@serializable(canonical_name="PartitionKeys", version=1) -class PartitionKeys(BaseModel): - pks: PartitionKey | tuple[PartitionKey, ...] | list[PartitionKey] - - @property - def all(self) -> tuple[PartitionKey, ...] | list[PartitionKey]: - # make sure we always return a list even if there's a single value - return self.pks if isinstance(self.pks, tuple | list) else [self.pks] - - def with_obj(self, obj: Any) -> QueryKeys: - return QueryKeys.from_obj(partition_keys=self, obj=obj) - - def with_tuple(self, *args: Any) -> QueryKeys: - return QueryKeys.from_tuple(partition_keys=self, args=args) - - def add(self, pk: PartitionKey) -> PartitionKeys: - return PartitionKeys(pks=list(self.all) + [pk]) - - @staticmethod - def from_dict(cks_dict: dict[str, type]) -> PartitionKeys: - pks = [] - for k, t in cks_dict.items(): - pks.append(PartitionKey(key=k, type_=t)) - return PartitionKeys(pks=pks) - - -@serializable(canonical_name="QueryKey", version=1) -class QueryKey(PartitionKey): - value: Any = None - - def __eq__(self, other: Any) -> bool: - return ( - type(other) == type(self) - and self.key == other.key - and self.type_ == other.type_ - and self.value == other.value - ) - - @property - def partition_key(self) -> PartitionKey: - return PartitionKey(key=self.key, type_=self.type_) - - @staticmethod - def from_obj(partition_key: PartitionKey, obj: Any) -> QueryKey: - pk_key = partition_key.key - pk_type = partition_key.type_ - - # 🟡 TODO: support more advanced types than List[type] - if partition_key.type_list: - pk_value = partition_key.extract_list(obj) - else: - if isinstance(obj, pk_type): - pk_value = obj - else: - pk_value = getattr(obj, pk_key) - # object has a method for getting these types - # we can't use properties because we don't seem to be able to get the - # return types - # TODO: fix the mypy issue - if isinstance(pk_value, types.FunctionType | types.MethodType): # type: ignore[unreachable] - pk_value = pk_value() # type: ignore[unreachable] - - if pk_value and not isinstance(pk_value, pk_type): - raise Exception( - f"PartitionKey {pk_value} of type {type(pk_value)} must be {pk_type}." - ) - return QueryKey(key=pk_key, type_=pk_type, value=pk_value) - - @property - def as_dict(self) -> dict[str, Any]: - return {self.key: self.value} - - -@serializable(canonical_name="PartitionKeysWithUID", version=1) -class PartitionKeysWithUID(PartitionKeys): - uid_pk: PartitionKey - - @property - def all(self) -> tuple[PartitionKey, ...] | list[PartitionKey]: - all_keys = list(self.pks) if isinstance(self.pks, tuple | list) else [self.pks] - if self.uid_pk not in all_keys: - all_keys.insert(0, self.uid_pk) - return all_keys - - -@serializable(canonical_name="QueryKeys", version=1) -class QueryKeys(SyftBaseModel): - qks: QueryKey | tuple[QueryKey, ...] | list[QueryKey] - - @property - def all(self) -> tuple[QueryKey, ...] | list[QueryKey]: - # make sure we always return a list even if there's a single value - return self.qks if isinstance(self.qks, tuple | list) else [self.qks] - - @staticmethod - def from_obj(partition_keys: PartitionKeys, obj: SyftObject) -> QueryKeys: - qks = [] - for partition_key in partition_keys.all: - pk_key = partition_key.key # name of the attribute - pk_type = partition_key.type_ - pk_value = getattr(obj, pk_key) - # object has a method for getting these types - # we can't use properties because we don't seem to be able to get the - # return types - if isinstance(pk_value, types.FunctionType | types.MethodType): - pk_value = pk_value() - if partition_key.type_list: - pk_value = partition_key.extract_list(obj) - else: - if pk_value and not isinstance(pk_value, pk_type): - raise Exception( - f"PartitionKey {pk_value} of type {type(pk_value)} must be {pk_type}." - ) - qk = QueryKey(key=pk_key, type_=pk_type, value=pk_value) - qks.append(qk) - return QueryKeys(qks=qks) - - @staticmethod - def from_tuple(partition_keys: PartitionKeys, args: tuple) -> QueryKeys: - qks = [] - for partition_key, pk_value in zip(partition_keys.all, args): - pk_key = partition_key.key - pk_type = partition_key.type_ - if not isinstance(pk_value, pk_type): - raise Exception( - f"PartitionKey {pk_value} of type {type(pk_value)} must be {pk_type}." - ) - qk = QueryKey(key=pk_key, type_=pk_type, value=pk_value) - qks.append(qk) - return QueryKeys(qks=qks) - - @staticmethod - def from_dict(qks_dict: dict[str, Any]) -> QueryKeys: - qks = [] - for k, v in qks_dict.items(): - qks.append(QueryKey(key=k, type_=type(v), value=v)) - return QueryKeys(qks=qks) - - @property - def as_dict(self) -> dict: - qk_dict = {} - for qk in self.all: - qk_key = qk.key - qk_value = qk.value - qk_dict[qk_key] = qk_value - return qk_dict - - -UIDPartitionKey = PartitionKey(key="id", type_=UID) - - -@serializable(canonical_name="PartitionSettings", version=1) -class PartitionSettings(BasePartitionSettings): - object_type: type - store_key: PartitionKey = UIDPartitionKey - - @property - def unique_keys(self) -> PartitionKeys: - unique_keys = PartitionKeys.from_dict(self.object_type._syft_unique_keys_dict()) - return unique_keys.add(self.store_key) - - @property - def searchable_keys(self) -> PartitionKeys: - return PartitionKeys.from_dict(self.object_type._syft_searchable_keys_dict()) - - @serializable( attrs=["settings", "store_config", "unique_cks", "searchable_cks"], canonical_name="StorePartition", @@ -289,7 +26,6 @@ def searchable_keys(self) -> PartitionKeys: ) class StorePartition: """Base StorePartition - Parameters: settings: PartitionSettings PySyft specific settings @@ -297,332 +33,6 @@ class StorePartition: Backend specific configuration """ - def __init__( - self, - server_uid: UID, - root_verify_key: SyftVerifyKey | None, - settings: PartitionSettings, - store_config: StoreConfig, - has_admin_permissions: Callable[[SyftVerifyKey], bool] | None = None, - ) -> None: - if root_verify_key is None: - root_verify_key = SyftSigningKey.generate().verify_key - self.server_uid = server_uid - self.root_verify_key = root_verify_key - self.settings = settings - self.store_config = store_config - self.has_admin_permissions = has_admin_permissions - self.init_store().unwrap( - public_message="Something went wrong initializing the store" - ) - store_config.locking_config.lock_name = f"StorePartition-{settings.name}" - self.lock = SyftLock(store_config.locking_config) - - @as_result(SyftException) - def init_store(self) -> bool: - try: - self.unique_cks = self.settings.unique_keys.all - self.searchable_cks = self.settings.searchable_keys.all - except BaseException as e: - raise SyftException.from_exception(e) - return True - - def matches_unique_cks(self, partition_key: PartitionKey) -> bool: - return partition_key in self.unique_cks - - def matches_searchable_cks(self, partition_key: PartitionKey) -> bool: - return partition_key in self.searchable_cks - - def store_query_key(self, obj: Any) -> QueryKey: - return self.settings.store_key.with_obj(obj) - - def store_query_keys(self, objs: Any) -> QueryKeys: - return QueryKeys(qks=[self.store_query_key(obj) for obj in objs]) - - # Thread-safe methods - @as_result(SyftException) - def _thread_safe_cbk(self, cbk: Callable, *args: Any, **kwargs: Any) -> Any: - locked = self.lock.acquire(blocking=True) - if not locked: - raise SyftException( - public_message=f"Failed to acquire lock for the operation {self.lock.lock_name} ({self.lock._lock})" - ) - - try: - result = cbk(*args, **kwargs).unwrap() - except BaseException as e: - raise SyftException.from_exception(e) - finally: - self.lock.release() - - return result - - @as_result(SyftException) - def set( - self, - credentials: SyftVerifyKey, - obj: SyftObject, - add_permissions: list[ActionObjectPermission] | None = None, - add_storage_permission: bool = True, - ignore_duplicates: bool = False, - ) -> SyftObject: - if obj.created_date is None: - obj.created_date = BaseDateTime.now() - return self._thread_safe_cbk( - self._set, - credentials=credentials, - obj=obj, - add_permissions=add_permissions, - add_storage_permission=add_storage_permission, - ignore_duplicates=ignore_duplicates, - ).unwrap() - - @as_result(SyftException) - def get( - self, - credentials: SyftVerifyKey, - uid: UID, - ) -> SyftObject: - return self._thread_safe_cbk( - self._get, - uid=uid, - credentials=credentials, - ).unwrap() - - @as_result(SyftException) - def find_index_or_search_keys( - self, - credentials: SyftVerifyKey, - index_qks: QueryKeys, - search_qks: QueryKeys, - order_by: PartitionKey | None = None, - ) -> list[SyftObject]: - return self._thread_safe_cbk( - self._find_index_or_search_keys, - credentials, - index_qks=index_qks, - search_qks=search_qks, - order_by=order_by, - ).unwrap() - - @as_result(SyftException) - def remove_keys( - self, - unique_query_keys: QueryKeys, - searchable_query_keys: QueryKeys, - ) -> None: - return self._thread_safe_cbk( - self._remove_keys, - unique_query_keys=unique_query_keys, - searchable_query_keys=searchable_query_keys, - ).unwrap() - - @as_result(SyftException) - def update( - self, - credentials: SyftVerifyKey, - qk: QueryKey, - obj: SyftObject, - has_permission: bool = False, - ) -> SyftObject: - return self._thread_safe_cbk( - self._update, - credentials=credentials, - qk=qk, - obj=obj, - has_permission=has_permission, - ).unwrap() - - @as_result(SyftException) - def get_all_from_store( - self, - credentials: SyftVerifyKey, - qks: QueryKeys, - order_by: PartitionKey | None = None, - ) -> list[SyftObject]: - return self._thread_safe_cbk( - self._get_all_from_store, credentials, qks, order_by - ).unwrap() - - @as_result(SyftException) - def delete( - self, credentials: SyftVerifyKey, qk: QueryKey, has_permission: bool = False - ) -> SyftSuccess: - return self._thread_safe_cbk( - self._delete, credentials, qk, has_permission=has_permission - ).unwrap() - - @as_result(SyftException) - def all( - self, - credentials: SyftVerifyKey, - order_by: PartitionKey | None = None, - has_permission: bool | None = False, - ) -> list[NewBaseStash.object_type]: - return self._thread_safe_cbk( - self._all, credentials, order_by, has_permission - ).unwrap() - - @as_result(SyftException) - def migrate_data( - self, - to_klass: SyftObject, - context: AuthedServiceContext, - has_permission: bool | None = False, - ) -> bool: - return self._thread_safe_cbk( - self._migrate_data, to_klass, context, has_permission - ).unwrap() - - # Potentially thread-unsafe methods. - # CAUTION: - # * Don't use self.lock here. - # * Do not call the public thread-safe methods here(with locking). - # These methods are called from the public thread-safe API, and will hang the process. - @as_result(SyftException) - def _set( - self, - credentials: SyftVerifyKey, - obj: SyftObject, - add_permissions: list[ActionObjectPermission] | None = None, - add_storage_permission: bool = True, - ignore_duplicates: bool = False, - ) -> SyftObject: - raise NotImplementedError - - @as_result(SyftException) - def _update( - self, - credentials: SyftVerifyKey, - qk: QueryKey, - obj: SyftObject, - has_permission: bool = False, - overwrite: bool = False, - allow_missing_keys: bool = False, - ) -> SyftObject: - raise NotImplementedError - - @as_result(SyftException) - def _get_all_from_store( - self, - credentials: SyftVerifyKey, - qks: QueryKeys, - order_by: PartitionKey | None = None, - ) -> list[SyftObject]: - raise NotImplementedError - - @as_result(SyftException) - def _delete( - self, credentials: SyftVerifyKey, qk: QueryKey, has_permission: bool = False - ) -> SyftSuccess: - raise NotImplementedError - - @as_result(SyftException) - def _all( - self, - credentials: SyftVerifyKey, - order_by: PartitionKey | None = None, - has_permission: bool | None = False, - ) -> list[NewBaseStash.object_type]: - raise NotImplementedError - - def add_permission(self, permission: ActionObjectPermission) -> None: - raise NotImplementedError - - def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: - raise NotImplementedError - - def remove_permission(self, permission: ActionObjectPermission) -> None: - raise NotImplementedError - - def has_permission(self, permission: ActionObjectPermission) -> bool: - raise NotImplementedError - - @as_result(SyftException) - def get_all_permissions(self) -> dict[UID, set[str]]: - raise NotImplementedError - - def _get_permissions_for_uid(self, uid: UID) -> set[str]: - raise NotImplementedError - - def add_storage_permission(self, permission: StoragePermission) -> None: - raise NotImplementedError - - def add_storage_permissions(self, permissions: list[StoragePermission]) -> None: - raise NotImplementedError - - def remove_storage_permission(self, permission: StoragePermission) -> None: - raise NotImplementedError - - def has_storage_permission(self, permission: StoragePermission | UID) -> bool: - raise NotImplementedError - - def _get_storage_permissions_for_uid(self, uid: UID) -> set[UID]: - raise NotImplementedError - - @as_result(SyftException) - def get_all_storage_permissions(self) -> dict[UID, set[UID]]: - raise NotImplementedError - - @as_result(SyftException) - def _migrate_data( - self, - to_klass: SyftObject, - context: AuthedServiceContext, - has_permission: bool, - ) -> bool: - raise NotImplementedError - - -@serializable(canonical_name="DocumentStore", version=1) -class DocumentStore: - """Base Document Store - - Parameters: - store_config: StoreConfig - Store specific configuration. - """ - - partitions: dict[str, StorePartition] - partition_type: type[StorePartition] - - def __init__( - self, - server_uid: UID, - root_verify_key: SyftVerifyKey | None, - store_config: StoreConfig, - ) -> None: - if store_config is None: - raise Exception("must have store config") - self.partitions = {} - self.store_config = store_config - self.server_uid = server_uid - self.root_verify_key = root_verify_key - - def __has_admin_permissions( - self, settings: PartitionSettings - ) -> Callable[[SyftVerifyKey], bool]: - def has_admin_permissions(credentials: SyftVerifyKey) -> bool: - return credentials == self.root_verify_key - - return has_admin_permissions - - def partition(self, settings: PartitionSettings) -> StorePartition: - if settings.name not in self.partitions: - self.partitions[settings.name] = self.partition_type( - server_uid=self.server_uid, - root_verify_key=self.root_verify_key, - settings=settings, - store_config=self.store_config, - has_admin_permissions=self.__has_admin_permissions(settings), - ) - return self.partitions[settings.name] - - def get_partition_object_types(self) -> list[type]: - return [ - partition.settings.object_type for partition in self.partitions.values() - ] - @serializable() class StoreConfig(SyftBaseObject): @@ -645,241 +55,9 @@ class StoreConfig(SyftBaseObject): store_type: type[DocumentStore] client_config: StoreClientConfig | None = None - locking_config: LockingConfig = Field(default_factory=NoLockingConfig) - - -@instrument -class NewBaseStash: - object_type: type[SyftObject] - settings: PartitionSettings - partition: StorePartition - - def __init__(self, store: DocumentStore) -> None: - self.partition = store.partition(type(self).settings) - - @as_result(StashException) - def check_type(self, obj: Any, type_: type) -> Any: - if not isinstance(obj, type_): - raise StashException(f"{type(obj)} does not match required type: {type_}") - return obj - - @as_result(StashException) - def get_all( - self, - credentials: SyftVerifyKey, - order_by: PartitionKey | None = None, - has_permission: bool = False, - ) -> list[NewBaseStash.object_type]: - return self.partition.all(credentials, order_by, has_permission).unwrap() - - def add_permissions(self, permissions: list[ActionObjectPermission]) -> None: - self.partition.add_permissions(permissions) - - def add_permission(self, permission: ActionObjectPermission) -> None: - self.partition.add_permission(permission) - - def remove_permission(self, permission: ActionObjectPermission) -> None: - self.partition.remove_permission(permission) - - def has_permission(self, permission: ActionObjectPermission) -> bool: - return self.partition.has_permission(permission=permission) - - def has_storage_permission(self, permission: StoragePermission) -> bool: - return self.partition.has_storage_permission(permission=permission) - - def __len__(self) -> int: - return len(self.partition) - - @as_result(StashException) - def set( - self, - credentials: SyftVerifyKey, - obj: NewBaseStash.object_type, - add_permissions: list[ActionObjectPermission] | None = None, - add_storage_permission: bool = True, - ignore_duplicates: bool = False, - ) -> NewBaseStash.object_type: - return self.partition.set( - credentials=credentials, - obj=obj, - ignore_duplicates=ignore_duplicates, - add_permissions=add_permissions, - add_storage_permission=add_storage_permission, - ).unwrap() - - @as_result(StashException) - def query_all( - self, - credentials: SyftVerifyKey, - qks: QueryKey | QueryKeys, - order_by: PartitionKey | None = None, - ) -> list[NewBaseStash.object_type]: - if isinstance(qks, QueryKey): - qks = QueryKeys(qks=qks) - - unique_keys = [] - searchable_keys = [] - - for qk in qks.all: - pk = qk.partition_key - if self.partition.matches_unique_cks(pk): - unique_keys.append(qk) - elif self.partition.matches_searchable_cks(pk): - searchable_keys.append(qk) - else: - raise StashException( - f"{qk} not in {type(self.partition)} unique or searchable keys" - ) + locking_config: LockingConfig = Field(default_factory=NoLockingConfig) # noqa: F821 - index_qks = QueryKeys(qks=unique_keys) - search_qks = QueryKeys(qks=searchable_keys) - return self.partition.find_index_or_search_keys( - credentials=credentials, - index_qks=index_qks, - search_qks=search_qks, - order_by=order_by, - ).unwrap() - - @as_result(StashException) - def query_all_kwargs( - self, - credentials: SyftVerifyKey, - **kwargs: dict[str, Any], - ) -> list[NewBaseStash.object_type]: - order_by = kwargs.pop("order_by", None) - qks = QueryKeys.from_dict(kwargs) - # TODO: Check order_by type... - return self.query_all( - credentials=credentials, qks=qks, order_by=order_by - ).unwrap() - - @as_result(StashException, NotFoundException) - def query_one( - self, - credentials: SyftVerifyKey, - qks: QueryKey | QueryKeys, - order_by: PartitionKey | None = None, - ) -> NewBaseStash.object_type: - result = self.query_all( - credentials=credentials, qks=qks, order_by=order_by - ).unwrap() - value = new_first_or_none(result) - if value is None: - keys = qks.all if isinstance(qks, QueryKeys) else [qks] - keys_str = ", ".join(f"{x.key}: {x.value}" for x in keys) - raise NotFoundException( - public_message=f"Could not find {self.object_type} with {keys_str}" - ) - return value - - @as_result(StashException, NotFoundException) - def query_one_kwargs( - self, - credentials: SyftVerifyKey, - **kwargs: dict[str, Any], - ) -> NewBaseStash.object_type: - result = self.query_all_kwargs(credentials, **kwargs).unwrap() - value = new_first_or_none(result) - if value is None: - raise NotFoundException - return value - - @as_result(StashException) - def find_all( - self, credentials: SyftVerifyKey, **kwargs: dict[str, Any] - ) -> list[NewBaseStash.object_type]: - return self.query_all_kwargs(credentials=credentials, **kwargs).unwrap() - - @as_result(StashException, NotFoundException) - def find_one( - self, credentials: SyftVerifyKey, **kwargs: dict[str, Any] - ) -> NewBaseStash.object_type: - return self.query_one_kwargs(credentials=credentials, **kwargs).unwrap() - - @as_result(StashException, NotFoundException) - def find_and_delete( - self, credentials: SyftVerifyKey, **kwargs: dict[str, Any] - ) -> Literal[True]: - obj = self.query_one_kwargs(credentials=credentials, **kwargs).unwrap() - qk = self.partition.store_query_key(obj) - return self.delete(credentials=credentials, qk=qk).unwrap() - - @as_result(StashException, SyftException) - def delete( - self, credentials: SyftVerifyKey, qk: QueryKey, has_permission: bool = False - ) -> Literal[True]: - # TODO: (error) Check return response - return self.partition.delete( - credentials=credentials, qk=qk, has_permission=has_permission - ).unwrap() - - @as_result(StashException, SyftException) - def update( - self, - credentials: SyftVerifyKey, - obj: NewBaseStash.object_type, - has_permission: bool = False, - ) -> NewBaseStash.object_type: - # TODO: See what breaks: - # this is for when we pass an somelike like a UserUpdate obj - if isinstance(obj, PartialSyftObject): - current = self.find_one(credentials, id=obj.id).unwrap() - obj.apply(to=current) - obj = current - - obj = self.check_type(obj, self.object_type).unwrap() - qk = self.partition.store_query_key(obj) - return self.partition.update( - credentials=credentials, qk=qk, obj=obj, has_permission=has_permission - ).unwrap() - - -@instrument -class NewBaseUIDStoreStash(NewBaseStash): - @as_result(SyftException, StashException) - def delete_by_uid( - self, credentials: SyftVerifyKey, uid: UID, has_permission: bool = False - ) -> UID: - qk = UIDPartitionKey.with_obj(uid) - super().delete( - credentials=credentials, qk=qk, has_permission=has_permission - ).unwrap() - return uid - - @as_result(SyftException, StashException, NotFoundException) - def get_by_uid( - self, credentials: SyftVerifyKey, uid: UID - ) -> NewBaseUIDStoreStash.object_type: - # TODO: Could change to query_one, no? - result = self.partition.get(credentials=credentials, uid=uid).unwrap() - if result is None: - raise NotFoundException( - public_message=f"{self.object_type} with uid {uid} not found" - ) - - return result - - @as_result(SyftException, StashException) - def set( # type: ignore [override] - self, - credentials: SyftVerifyKey, - obj: NewBaseUIDStoreStash.object_type, - add_permissions: list[ActionObjectPermission] | None = None, - add_storage_permission: bool = True, - ignore_duplicates: bool = False, - ) -> NewBaseUIDStoreStash.object_type: - self.check_type(obj, self.object_type).unwrap() - return ( - super() - .set( - credentials=credentials, - obj=obj, - ignore_duplicates=ignore_duplicates, - add_permissions=add_permissions, - add_storage_permission=add_storage_permission, - ) - .unwrap( - public_message=f"Failed to set {self.object_type} with uid {obj.id} not found" - ) - ) +@serializable(canonical_name="DocumentStore", version=1) +class DocumentStore: + pass diff --git a/packages/syft/src/syft/store/kv_document_store.py b/packages/syft/src/syft/store/kv_document_store.py index 52f712d842b..30104b9a581 100644 --- a/packages/syft/src/syft/store/kv_document_store.py +++ b/packages/syft/src/syft/store/kv_document_store.py @@ -1,12 +1,9 @@ # relative -from .document_store import StorePartition - - class KeyValueBackingStore: pass -class KeyValueStorePartition(StorePartition): +class KeyValueStorePartition: """Key-Value StorePartition Parameters: `settings`: PartitionSettings diff --git a/packages/syft/src/syft/store/locks.py b/packages/syft/src/syft/store/locks.py index 2494dffa895..b16b1b7bd9b 100644 --- a/packages/syft/src/syft/store/locks.py +++ b/packages/syft/src/syft/store/locks.py @@ -1,39 +1,12 @@ -# stdlib -from collections import defaultdict -import logging -import threading -import time -from typing import Any - # third party from pydantic import BaseModel -from sherlock.lock import BaseLock # relative from ..serde.serializable import serializable -logger = logging.getLogger(__name__) -THREAD_FILE_LOCKS: dict[int, dict[str, int]] = defaultdict(dict) - @serializable(canonical_name="LockingConfig", version=1) class LockingConfig(BaseModel): - """ - Locking config - - Args: - lock_name: str - Lock name - namespace: Optional[str] - Namespace to use for setting lock keys in the backend store. - expire: Optional[int] - Lock expiration time in seconds. If explicitly set to `None`, lock will not expire. - timeout: Optional[int] - Timeout to acquire lock(seconds) - retry_interval: float - Retry interval to retry acquiring a lock if previous attempts failed. - """ - lock_name: str = "syft_lock" namespace: str | None = None expire: int | None = 60 @@ -43,195 +16,9 @@ class LockingConfig(BaseModel): @serializable(canonical_name="NoLockingConfig", version=1) class NoLockingConfig(LockingConfig): - """ - No-locking policy - """ - pass @serializable(canonical_name="ThreadingLockingConfig", version=1) class ThreadingLockingConfig(LockingConfig): - """ - Threading-based locking policy - """ - pass - - -class ThreadingLock(BaseLock): - """ - Threading-based Lock. Used to provide the same API as the rest of the locks. - """ - - def __init__(self, expire: int, **kwargs: Any) -> None: - self.expire = expire - self.locked_timestamp: float = 0.0 - self.lock = threading.Lock() - - @property - def _locked(self) -> bool: - """ - Implementation of method to check if lock has been acquired. Must be - :returns: if the lock is acquired or not - :rtype: bool - """ - locked = self.lock.locked() - if ( - locked - and time.time() - self.locked_timestamp >= self.expire - and self.expire != -1 - ): - self._release() - - return self.lock.locked() - - def _acquire(self) -> bool: - """ - Implementation of acquiring a lock in a non-blocking fashion. - :returns: if the lock was successfully acquired or not - :rtype: bool - """ - locked = self.lock.locked() - if ( - locked - and time.time() - self.locked_timestamp > self.expire - and self.expire != -1 - ): - self._release() - - status = self.lock.acquire( - blocking=False - ) # timeout/retries handle in the `acquire` method - if status: - self.locked_timestamp = time.time() - return status - - def _release(self) -> None: - """ - Implementation of releasing an acquired lock. - """ - - try: - return self.lock.release() - except RuntimeError: # already unlocked - pass - - def _renew(self) -> bool: - """ - Implementation of renewing an acquired lock. - """ - return True - - -class SyftLock(BaseLock): - """ - Syft Lock implementations. - - Params: - config: Config specific to a locking strategy. - """ - - def __init__(self, config: LockingConfig): - self.config = config - - self.lock_name = config.lock_name - self.namespace = config.namespace - self.expire = config.expire - self.timeout = config.timeout - self.retry_interval = config.retry_interval - - self.passthrough = False - - self._lock: BaseLock | None = None - - base_params = { - "lock_name": config.lock_name, - "namespace": config.namespace, - "expire": config.expire, - "timeout": config.timeout, - "retry_interval": config.retry_interval, - } - if isinstance(config, NoLockingConfig): - self.passthrough = True - elif isinstance(config, ThreadingLockingConfig): - self._lock = ThreadingLock(**base_params) - else: - raise ValueError("Unsupported config type") - - @property - def _locked(self) -> bool: - """ - Implementation of method to check if lock has been acquired. - - :returns: if the lock is acquired or not - :rtype: bool - """ - if self.passthrough: - return False - return self._lock.locked() if self._lock else False - - def acquire(self, blocking: bool = True) -> bool: - """ - Acquire a lock, blocking or non-blocking. - :param bool blocking: acquire a lock in a blocking or non-blocking - fashion. Defaults to True. - :returns: if the lock was successfully acquired or not - :rtype: bool - """ - - if not blocking: - return self._acquire() - - timeout: float = float(self.timeout) - start_time = time.time() - elapsed: float = 0.0 - while timeout >= elapsed: - if not self._acquire(): - time.sleep(self.retry_interval) - elapsed = time.time() - start_time - else: - return True - logger.debug( - f"Timeout elapsed after {self.timeout} seconds while trying to acquiring lock." - ) - # third party - return False - - def _acquire(self) -> bool: - """ - Implementation of acquiring a lock in a non-blocking fashion. - `acquire` makes use of this implementation to provide blocking and non-blocking implementations. - - :returns: if the lock was successfully acquired or not - :rtype: bool - """ - if self.passthrough: - return True - - try: - return self._lock._acquire() if self._lock else False - except BaseException: - return False - - def _release(self) -> bool | None: - """ - Implementation of releasing an acquired lock. - """ - if self.passthrough: - return None - if not self._lock: - return None - try: - return self._lock._release() - except BaseException: - return None - - def _renew(self) -> bool: - """ - Implementation of renewing an acquired lock. - """ - if self.passthrough: - return True - - return self._lock._renew() if self._lock else False diff --git a/packages/syft/tests/syft/locks_test.py b/packages/syft/tests/syft/locks_test.py deleted file mode 100644 index 52a0e2a4102..00000000000 --- a/packages/syft/tests/syft/locks_test.py +++ /dev/null @@ -1,325 +0,0 @@ -# stdlib -from pathlib import Path -from secrets import token_hex -import tempfile -from threading import Thread -import time - -# third party -import pytest - -# syft absolute -from syft.store.locks import LockingConfig -from syft.store.locks import NoLockingConfig -from syft.store.locks import SyftLock -from syft.store.locks import ThreadingLockingConfig - -def_params = { - "lock_name": "testing_lock", - "expire": 5, # seconds, - "timeout": 1, # seconds, - "retry_interval": 0.1, # seconds, -} - - -@pytest.fixture(scope="function") -def locks_nop_config(request): - def_params["lock_name"] = token_hex(8) - yield NoLockingConfig(**def_params) - - -@pytest.fixture(scope="function") -def locks_threading_config(request): - def_params["lock_name"] = token_hex(8) - yield ThreadingLockingConfig(**def_params) - - -@pytest.mark.parametrize( - "config", - [ - pytest.lazy_fixture("locks_nop_config"), - pytest.lazy_fixture("locks_threading_config"), - ], -) -def test_sanity(config: LockingConfig): - lock = SyftLock(config) - - assert lock is not None - - -@pytest.mark.parametrize( - "config", - [ - pytest.lazy_fixture("locks_nop_config"), - ], -) -def test_acquire_nop(config: LockingConfig): - lock = SyftLock(config) - - assert lock.locked() is False - - acq_ok = lock.acquire() - assert acq_ok - - assert lock.locked() is False - - lock.release() - - assert lock.locked() is False - - -@pytest.mark.parametrize( - "config", - [ - pytest.lazy_fixture("locks_threading_config"), - ], -) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_acquire_release(config: LockingConfig): - lock = SyftLock(config) - - expected_not_locked = lock.locked() - - acq_ok = lock.acquire() - assert acq_ok - - expected_locked = lock.locked() - - lock.release() - - expected_not_locked_again = lock.locked() - - assert not expected_not_locked - assert expected_locked - assert not expected_not_locked_again - - -@pytest.mark.parametrize( - "config", - [ - pytest.lazy_fixture("locks_threading_config"), - ], -) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_acquire_release_with(config: LockingConfig): - was_locked = True - with SyftLock(config) as lock: - was_locked = lock.locked() - - assert was_locked - - -@pytest.mark.parametrize( - "config", - [ - pytest.lazy_fixture("locks_threading_config"), - ], -) -def test_acquire_expire(config: LockingConfig): - config.expire = 1 # second - lock = SyftLock(config) - - expected_not_locked = lock.locked() - - acq_ok = lock.acquire(blocking=True) - assert acq_ok - - expected_locked = lock.locked() - - time.sleep(config.expire + 1.0) - - expected_not_locked_again = lock.locked() - - assert not expected_not_locked - assert expected_locked - assert not expected_not_locked_again - - -@pytest.mark.parametrize( - "config", - [ - pytest.lazy_fixture("locks_threading_config"), - ], -) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_acquire_double_aqcuire_timeout_fail(config: LockingConfig): - config.timeout = 1 - config.expire = 5 - lock = SyftLock(config) - - acq_ok = lock.acquire(blocking=True) - assert acq_ok - - not_acq = lock.acquire(blocking=True) - - lock.release() - - assert not not_acq - - -@pytest.mark.parametrize( - "config", - [ - pytest.lazy_fixture("locks_threading_config"), - ], -) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_acquire_double_aqcuire_timeout_ok(config: LockingConfig): - config.timeout = 2 - config.expire = 1 - lock = SyftLock(config) - - lock.locked() - - acq_ok = lock.acquire(blocking=True) - assert acq_ok - - also_acq = lock.acquire(blocking=True) - - lock.release() - - assert also_acq - - -@pytest.mark.parametrize( - "config", - [ - pytest.lazy_fixture("locks_threading_config"), - ], -) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_acquire_double_aqcuire_nonblocking(config: LockingConfig): - config.timeout = 2 - config.expire = 1 - lock = SyftLock(config) - - lock.locked() - - acq_ok = lock.acquire(blocking=False) - assert acq_ok - - not_acq = lock.acquire(blocking=False) - - lock.release() - - assert not not_acq - - -@pytest.mark.parametrize( - "config", - [ - pytest.lazy_fixture("locks_threading_config"), - ], -) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_acquire_double_aqcuire_retry_interval(config: LockingConfig): - config.timeout = 2 - config.expire = 1 - config.retry_interval = 3 - lock = SyftLock(config) - - lock.locked() - - acq_ok = lock.acquire(blocking=True) - assert acq_ok - - not_acq = lock.acquire(blocking=True) - - lock.release() - - assert not not_acq - - -@pytest.mark.parametrize( - "config", - [ - pytest.lazy_fixture("locks_threading_config"), - ], -) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_acquire_double_release(config: LockingConfig): - lock = SyftLock(config) - - lock.acquire(blocking=True) - - lock.release() - lock.release() - - -@pytest.mark.parametrize( - "config", - [ - pytest.lazy_fixture("locks_threading_config"), - ], -) -@pytest.mark.flaky(reruns=3, reruns_delay=3) -def test_acquire_same_name_diff_namespace(config: LockingConfig): - config.namespace = "ns1" - lock1 = SyftLock(config) - assert lock1.acquire(blocking=True) - - config.namespace = "ns2" - lock2 = SyftLock(config) - assert lock2.acquire(blocking=True) - - lock2.release() - lock1.release() - - -@pytest.mark.skip(reason="The tests are highly flaky, delaying progress on PR's") -@pytest.mark.parametrize( - "config", - [ - pytest.lazy_fixture("locks_threading_config"), - ], -) -def test_locks_parallel_multithreading(config: LockingConfig) -> None: - thread_cnt = 3 - repeats = 5 - - temp_dir = Path(tempfile.TemporaryDirectory().name) - temp_dir.mkdir(parents=True, exist_ok=True) - temp_file = temp_dir / "dbg.txt" - if temp_file.exists(): - temp_file.unlink() - - with open(temp_file, "w") as f: - f.write("0") - - config.timeout = 10 - lock = SyftLock(config) - - def _kv_cbk(tid: int) -> None: - for _idx in range(repeats): - locked = lock.acquire() - if not locked: - continue - - for _retry in range(10): - try: - with open(temp_file) as f: - prev = f.read() - prev = int(prev) - with open(temp_file, "w") as f: - f.write(str(prev + 1)) - f.flush() - break - except BaseException as e: - print("failed ", e) - - lock.release() - - tids = [] - for tid in range(thread_cnt): - thread = Thread(target=_kv_cbk, args=(tid,)) - thread.start() - - tids.append(thread) - - for thread in tids: - thread.join() - - with open(temp_file) as f: - stored = int(f.read()) - - assert stored == thread_cnt * repeats diff --git a/packages/syft/tests/syft/migrations/protocol_communication_test.py b/packages/syft/tests/syft/migrations/protocol_communication_test.py deleted file mode 100644 index 64c670d5ea3..00000000000 --- a/packages/syft/tests/syft/migrations/protocol_communication_test.py +++ /dev/null @@ -1,270 +0,0 @@ -# stdlib -from copy import deepcopy -import os -from pathlib import Path -from unittest import mock - -# third party -import pytest - -# syft absolute -import syft as sy -from syft.protocol.data_protocol import get_data_protocol -from syft.protocol.data_protocol import protocol_release_dir -from syft.protocol.data_protocol import stage_protocol_changes -from syft.serde.recursive import TYPE_BANK -from syft.serde.serializable import serializable -from syft.server.worker import Worker -from syft.service.context import AuthedServiceContext -from syft.service.service import AbstractService -from syft.service.service import ServiceConfigRegistry -from syft.service.service import service_method -from syft.service.user.user_roles import GUEST_ROLE_LEVEL -from syft.store.db.db import DBManager -from syft.store.document_store import DocumentStore -from syft.store.document_store import NewBaseStash -from syft.store.document_store import PartitionSettings -from syft.types.syft_migration import migrate -from syft.types.syft_object import SYFT_OBJECT_VERSION_1 -from syft.types.syft_object import SyftBaseObject -from syft.types.syft_object import SyftObject -from syft.types.transforms import convert_types -from syft.types.transforms import rename -from syft.types.uid import UID -from syft.util.util import index_syft_by_module_name - -MOCK_TYPE_BANK = deepcopy(TYPE_BANK) - - -def get_klass_version_1(): - @serializable() - class SyftMockObjectTestV1(SyftObject): - __canonical_name__ = "SyftMockObjectTest" - __version__ = SYFT_OBJECT_VERSION_1 - - id: UID - name: str - version: int - - return SyftMockObjectTestV1 - - -def get_klass_version_2(): - @serializable() - class SyftMockObjectTestV2(SyftObject): - __canonical_name__ = "SyftMockObjectTest" - __version__ = SYFT_OBJECT_VERSION_1 - - id: UID - full_name: str - version: str - - return SyftMockObjectTestV2 - - -def setup_migration_transforms(mock_klass_v1, mock_klass_v2): - @migrate(mock_klass_v1, mock_klass_v2) - def mock_v1_to_v2(): - return [rename("name", "full_name"), convert_types(["version"], str)] - - @migrate(mock_klass_v2, mock_klass_v1) - def mock_v2_to_v1(): - return [rename("full_name", "name"), convert_types(["version"], int)] - - return mock_v1_to_v2, mock_v2_to_v1 - - -def get_stash_klass(syft_object: type[SyftBaseObject]): - @serializable( - canonical_name="SyftMockObjectStash", - version=1, - ) - class SyftMockObjectStash(NewBaseStash): - object_type = syft_object - settings: PartitionSettings = PartitionSettings( - name=object_type.__canonical_name__, - object_type=syft_object, - ) - - def __init__(self, store: DBManager) -> None: - super().__init__(store=store) - - return SyftMockObjectStash - - -def setup_service_method(syft_object): - stash_klass: NewBaseStash = get_stash_klass(syft_object=syft_object) - - @serializable( - canonical_name="SyftMockObjectService", - version=1, - ) - class SyftMockObjectService(AbstractService): - store: DocumentStore - stash: stash_klass # type: ignore - __module__: str = "syft.test" - - def __init__(self, store: DBManager) -> None: - self.stash = stash_klass(store=store) - - @service_method( - path="dummy.syft_object", - name="get", - roles=GUEST_ROLE_LEVEL, - ) - def get(self, context: AuthedServiceContext) -> list[syft_object]: - return self.stash.get_all(context.credentials, has_permission=True) - - return SyftMockObjectService - - -def setup_version_one(server_name: str): - syft_klass_version_one = get_klass_version_1() - sy.stage_protocol_changes() - sy.bump_protocol_version() - - syft_service_klass = setup_service_method( - syft_object=syft_klass_version_one, - ) - - server = sy.orchestra.launch(server_name, dev_mode=True, reset=True) - - worker: Worker = server.python_server - - worker.services.append(syft_service_klass) - worker.service_path_map[syft_service_klass.__name__.lower()] = syft_service_klass( - store=worker.document_store - ) - - return server, syft_klass_version_one - - -def mock_syft_version(): - return f"{sy.__version__}.dev" - - -def setup_version_second(server_name: str, klass_version_one: type): - syft_klass_version_second = get_klass_version_2() - setup_migration_transforms(klass_version_one, syft_klass_version_second) - - sy.stage_protocol_changes() - sy.bump_protocol_version() - - syft_service_klass = setup_service_method(syft_object=syft_klass_version_second) - - server = sy.orchestra.launch(server_name, dev_mode=True) - - worker: Worker = server.python_server - - worker.services.append(syft_service_klass) - worker.service_path_map[syft_service_klass.__name__.lower()] = syft_service_klass( - store=worker.document_store - ) - - return server, syft_klass_version_second - - -@pytest.fixture -def my_stage_protocol(protocol_file: Path): - with mock.patch( - "syft.protocol.data_protocol.PROTOCOL_STATE_FILENAME", - protocol_file.name, - ): - dp = get_data_protocol() - stage_protocol_changes() - yield dp.protocol_history - dp.revert_latest_protocol() - dp.save_history(dp.protocol_history) - - # Cleanup release dir, remove unused released files - if os.path.exists(protocol_release_dir()): - for _file_path in protocol_release_dir().iterdir(): - for version in dp.read_json(_file_path): - if version not in dp.protocol_history.keys(): - _file_path.unlink() - - -@pytest.mark.skip( - reason="Issues running with other tests. Shared release folder causes issues." -) -def test_client_server_running_different_protocols(my_stage_protocol): - def patched_index_syft_by_module_name(fully_qualified_name: str): - if klass_v1.__name__ in fully_qualified_name: - return klass_v1 - elif klass_v2.__name__ in fully_qualified_name: - return klass_v2 - - return index_syft_by_module_name(fully_qualified_name) - - server_name = UID().to_string() - with mock.patch("syft.serde.recursive.TYPE_BANK", MOCK_TYPE_BANK): - with mock.patch( - "syft.protocol.data_protocol.TYPE_BANK", - MOCK_TYPE_BANK, - ): - with mock.patch( - "syft.client.api.index_syft_by_module_name", - patched_index_syft_by_module_name, - ): - # Setup mock object version one - nh1, klass_v1 = setup_version_one(server_name) - assert klass_v1.__canonical_name__ == "SyftMockObjectTest" - assert klass_v1.__name__ == "SyftMockObjectTestV1" - - nh1_client = nh1.client - assert nh1_client is not None - result_from_client_1 = nh1_client.api.services.dummy.get() - - protocol_version_with_mock_obj_v1 = get_data_protocol().latest_version - - # No data saved - assert len(result_from_client_1) == 0 - - # Setup mock object version second - with mock.patch( - "syft.protocol.data_protocol.__version__", mock_syft_version() - ): - nh2, klass_v2 = setup_version_second( - server_name, klass_version_one=klass_v1 - ) - - # Create a sample data in version second - sample_data = klass_v2(full_name="John", version=str(1), id=UID()) - - assert isinstance(sample_data, klass_v2) - - # Validate migrations - sample_data_v1 = sample_data.migrate_to( - version=klass_v1.__version__, - ) - assert sample_data_v1.name == sample_data.full_name - assert sample_data_v1.version == int(sample_data.version) - - # Set the sample data in version second - service_klass = nh1.python_server.get_service( - "SyftMockObjectService" - ) - service_klass.stash.set( - nh1.python_server.root_client.verify_key, - sample_data, - ) - - nh2_client = nh2.client - assert nh2_client is not None - # Force communication protocol to when version object is defined - nh2_client.communication_protocol = ( - protocol_version_with_mock_obj_v1 - ) - # Reset api - nh2_client._api = None - - # Call the API with an older communication protocol version - result2 = nh2_client.api.services.dummy.get() - assert isinstance(result2, list) - - # Validate the data received - for data in result2: - assert isinstance(data, klass_v1) - assert data.name == sample_data.full_name - assert data.version == int(sample_data.version) - ServiceConfigRegistry.__service_config_registry__.pop("dummy.syft_object", None) diff --git a/packages/syft/tests/syft/notifications/notification_service_test.py b/packages/syft/tests/syft/notifications/notification_service_test.py index a8319d32a80..7885549a35c 100644 --- a/packages/syft/tests/syft/notifications/notification_service_test.py +++ b/packages/syft/tests/syft/notifications/notification_service_test.py @@ -15,7 +15,6 @@ from syft.service.notification.notifications import Notification from syft.service.notification.notifications import NotificationStatus from syft.service.response import SyftSuccess -from syft.store.document_store import DocumentStore from syft.store.document_store_errors import StashException from syft.store.linked_obj import LinkedObject from syft.types.datetime import DateTime @@ -129,7 +128,7 @@ def test_get_all_success( monkeypatch: MonkeyPatch, notification_service: NotificationService, authed_context: AuthedServiceContext, - document_store: DocumentStore, + document_store, ) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key @@ -181,7 +180,7 @@ def mock_get_all_inbox_for_verify_key( def test_get_sent_success( authed_context: AuthedServiceContext, - document_store: DocumentStore, + document_store, ) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key @@ -236,7 +235,7 @@ def test_get_all_for_status_success( monkeypatch: MonkeyPatch, notification_service: NotificationService, authed_context: AuthedServiceContext, - document_store: DocumentStore, + document_store, ) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key @@ -306,7 +305,7 @@ def test_get_all_read_success( monkeypatch: MonkeyPatch, notification_service: NotificationService, authed_context: AuthedServiceContext, - document_store: DocumentStore, + document_store, ) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key @@ -363,7 +362,7 @@ def test_get_all_unread_success( monkeypatch: MonkeyPatch, notification_service: NotificationService, authed_context: AuthedServiceContext, - document_store: DocumentStore, + document_store, ) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key @@ -419,7 +418,7 @@ def test_mark_as_read_success( monkeypatch: MonkeyPatch, notification_service: NotificationService, authed_context: AuthedServiceContext, - document_store: DocumentStore, + document_store, ) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key @@ -458,7 +457,7 @@ def test_mark_as_read_error_on_update_notification_status( monkeypatch: MonkeyPatch, notification_service: NotificationService, authed_context: AuthedServiceContext, - document_store: DocumentStore, + document_store, ) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key @@ -497,7 +496,7 @@ def test_mark_as_unread_success( monkeypatch: MonkeyPatch, notification_service: NotificationService, authed_context: AuthedServiceContext, - document_store: DocumentStore, + document_store, ) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key @@ -537,7 +536,7 @@ def test_mark_as_unread_error_on_update_notification_status( monkeypatch: MonkeyPatch, notification_service: NotificationService, authed_context: AuthedServiceContext, - document_store: DocumentStore, + document_store, ) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key @@ -578,7 +577,7 @@ def test_resolve_object_success( authed_context: AuthedServiceContext, linked_object: LinkedObject, notification_service: NotificationService, - document_store: DocumentStore, + document_store, ) -> None: test_notification_service = NotificationService(document_store) @@ -612,7 +611,7 @@ def test_resolve_object_error_on_resolve_link( monkeypatch: MonkeyPatch, authed_context: AuthedServiceContext, linked_object: LinkedObject, - document_store: DocumentStore, + document_store, notification_service: NotificationService, ) -> None: test_notification_service = NotificationService(document_store) @@ -651,7 +650,7 @@ def test_clear_success( monkeypatch: MonkeyPatch, notification_service: NotificationService, authed_context: AuthedServiceContext, - document_store: DocumentStore, + document_store, ) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key @@ -692,7 +691,7 @@ def test_clear_error_on_delete_all_for_verify_key( monkeypatch: MonkeyPatch, notification_service: NotificationService, authed_context: AuthedServiceContext, - document_store: DocumentStore, + document_store, ) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key diff --git a/packages/syft/tests/syft/notifications/notification_stash_test.py b/packages/syft/tests/syft/notifications/notification_stash_test.py index 7864fb10d19..ceca7d3b37b 100644 --- a/packages/syft/tests/syft/notifications/notification_stash_test.py +++ b/packages/syft/tests/syft/notifications/notification_stash_test.py @@ -12,6 +12,7 @@ from syft.service.notification.notifications import Notification from syft.service.notification.notifications import NotificationExpiryStatus from syft.service.notification.notifications import NotificationStatus +from syft.store.db.db import DBManager from syft.store.document_store_errors import StashException from syft.types.datetime import DateTime from syft.types.errors import SyftException @@ -54,7 +55,9 @@ def add_mock_notification( return mock_notification -def test_get_all_inbox_for_verify_key(root_verify_key, document_store) -> None: +def test_get_all_inbox_for_verify_key( + root_verify_key, document_store: DBManager +) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key test_stash = NotificationStash(store=document_store) @@ -92,7 +95,9 @@ def test_get_all_inbox_for_verify_key(root_verify_key, document_store) -> None: assert result == sorted_notification_list -def test_get_all_sent_for_verify_key(root_verify_key, document_store) -> None: +def test_get_all_sent_for_verify_key( + root_verify_key, document_store: DBManager +) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key test_stash = NotificationStash(store=document_store) @@ -121,7 +126,7 @@ def test_get_all_sent_for_verify_key(root_verify_key, document_store) -> None: test_stash.get_all_sent_for_verify_key(root_verify_key, random_signing_key) -def test_get_all_for_verify_key(root_verify_key, document_store) -> None: +def test_get_all_for_verify_key(root_verify_key, document_store: DBManager) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key test_stash = NotificationStash(store=document_store) @@ -157,7 +162,9 @@ def test_get_all_for_verify_key(root_verify_key, document_store) -> None: assert len(result) == 1 -def test_get_all_by_verify_key_for_status(root_verify_key, document_store) -> None: +def test_get_all_by_verify_key_for_status( + root_verify_key, document_store: DBManager +) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key test_stash = NotificationStash(store=document_store) @@ -184,7 +191,7 @@ def test_get_all_by_verify_key_for_status(root_verify_key, document_store) -> No ) -def test_update_notification_status(root_verify_key, document_store) -> None: +def test_update_notification_status(root_verify_key, document_store: DBManager) -> None: random_uid = UID() random_verify_key = SyftSigningKey.generate().verify_key test_stash = NotificationStash(store=document_store) @@ -254,7 +261,7 @@ def mock_get_by_uid(root_verify_key: SyftVerifyKey, uid: UID) -> NoReturn: assert exc.value.public_message == expected_error_msg -def test_delete_all_for_verify_key(root_verify_key, document_store) -> None: +def test_delete_all_for_verify_key(root_verify_key, document_store: DBManager) -> None: random_signing_key = SyftSigningKey.generate() random_verify_key = random_signing_key.verify_key test_stash = NotificationStash(store=document_store) diff --git a/packages/syft/tests/syft/request/fixtures.py b/packages/syft/tests/syft/request/fixtures.py index 5e7c2226b9e..6294ff7eae7 100644 --- a/packages/syft/tests/syft/request/fixtures.py +++ b/packages/syft/tests/syft/request/fixtures.py @@ -9,11 +9,10 @@ from syft.service.context import AuthedServiceContext from syft.service.request.request_service import RequestService from syft.service.request.request_stash import RequestStash -from syft.store.document_store import DocumentStore @pytest.fixture -def request_stash(document_store: DocumentStore) -> RequestStash: +def request_stash(document_store) -> RequestStash: yield RequestStash(store=document_store) @@ -26,5 +25,5 @@ def authed_context_guest_datasite_client( @pytest.fixture -def request_service(document_store: DocumentStore): +def request_service(document_store): yield RequestService(store=document_store) diff --git a/packages/syft/tests/syft/stores/base_stash_test.py b/packages/syft/tests/syft/stores/base_stash_test.py index 3873f343261..dd5f6aa2a90 100644 --- a/packages/syft/tests/syft/stores/base_stash_test.py +++ b/packages/syft/tests/syft/stores/base_stash_test.py @@ -20,7 +20,6 @@ from syft.store.db.sqlite import SQLiteDBConfig from syft.store.db.sqlite import SQLiteDBManager from syft.store.db.stash import ObjectStash -from syft.store.document_store import PartitionKey from syft.store.document_store_errors import NotFoundException from syft.store.document_store_errors import StashException from syft.store.linked_obj import LinkedObject @@ -45,11 +44,6 @@ class MockObject(SyftObject): __attr_unique__ = ["id", "name"] -NamePartitionKey = PartitionKey(key="name", type_=str) -DescPartitionKey = PartitionKey(key="desc", type_=str) -ImportancePartitionKey = PartitionKey(key="importance", type_=int) - - class MockStash(ObjectStash[MockObject]): pass diff --git a/packages/syft/tests/syft/stores/store_fixtures_test.py b/packages/syft/tests/syft/stores/store_fixtures_test.py deleted file mode 100644 index 47452e9740e..00000000000 --- a/packages/syft/tests/syft/stores/store_fixtures_test.py +++ /dev/null @@ -1,54 +0,0 @@ -# stdlib -import uuid - -# syft absolute -from syft.server.credentials import SyftVerifyKey -from syft.service.action.action_permissions import ActionObjectPermission -from syft.service.action.action_permissions import ActionPermission -from syft.service.user.user import User -from syft.service.user.user import UserCreate -from syft.service.user.user_roles import ServiceRole -from syft.service.user.user_stash import UserStash -from syft.store.db.sqlite import SQLiteDBConfig -from syft.store.db.sqlite import SQLiteDBManager -from syft.store.document_store import DocumentStore -from syft.types.uid import UID - -# relative -from .store_constants_test import TEST_SIGNING_KEY_NEW_ADMIN -from .store_constants_test import TEST_VERIFY_KEY_NEW_ADMIN - - -def document_store_with_admin( - server_uid: UID, verify_key: SyftVerifyKey -) -> DocumentStore: - config = SQLiteDBConfig() - document_store = SQLiteDBManager( - server_uid=server_uid, root_verify_key=verify_key, config=config - ) - - password = uuid.uuid4().hex - - user_stash = UserStash(store=document_store) - admin_user = UserCreate( - email="mail@example.org", - name="Admin", - password=password, - password_verify=password, - role=ServiceRole.ADMIN, - ).to(User) - - admin_user.signing_key = TEST_SIGNING_KEY_NEW_ADMIN - admin_user.verify_key = TEST_VERIFY_KEY_NEW_ADMIN - - user_stash.set( - credentials=verify_key, - obj=admin_user, - add_permissions=[ - ActionObjectPermission( - uid=admin_user.id, permission=ActionPermission.ALL_READ - ), - ], - ) - - return document_store diff --git a/packages/syft/tests/syft/users/fixtures.py b/packages/syft/tests/syft/users/fixtures.py index 8f282b6dbdc..46319b46704 100644 --- a/packages/syft/tests/syft/users/fixtures.py +++ b/packages/syft/tests/syft/users/fixtures.py @@ -16,7 +16,6 @@ from syft.service.user.user_roles import ServiceRole from syft.service.user.user_service import UserService from syft.service.user.user_stash import UserStash -from syft.store.document_store import DocumentStore @pytest.fixture @@ -107,12 +106,12 @@ def guest_user_search(guest_user) -> UserSearch: @pytest.fixture -def user_stash(document_store: DocumentStore) -> UserStash: +def user_stash(document_store) -> UserStash: yield UserStash(store=document_store) @pytest.fixture -def user_service(document_store: DocumentStore): +def user_service(document_store): yield UserService(store=document_store)