Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean up stale document store code #9323

Merged
merged 13 commits into from
Sep 27, 2024
Merged
2 changes: 0 additions & 2 deletions packages/syft/src/syft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,6 @@
from .service.user.roles import Roles as roles
from .service.user.user_service import UserService
from .stable_version import LATEST_STABLE_SYFT
from .store.mongo_document_store import MongoStoreConfig
from .store.sqlite_document_store import SQLiteStoreConfig
from .types.errors import SyftException
from .types.errors import raises
from .types.result import as_result
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/protocol/data_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
# syft absolute
from syft.types.result import Err
from syft.types.result import Ok
from syft.util.util import get_dev_mode

# relative
from .. import __version__
Expand All @@ -29,7 +30,6 @@
from ..types.errors import SyftException
from ..types.syft_object import SyftBaseObject
from ..types.syft_object_registry import SyftObjectRegistry
from ..util.util import get_dev_mode

PROTOCOL_STATE_FILENAME = "protocol_version.json"
PROTOCOL_TYPE = str | int
Expand Down
6 changes: 1 addition & 5 deletions packages/syft/src/syft/service/code/status_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from ...serde.serializable import serializable
from ...store.db.db import DBManager
from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionSettings
from ...types.syft_object import PartialSyftObject
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.uid import UID
Expand All @@ -24,10 +23,7 @@

@serializable(canonical_name="StatusSQLStash", version=1)
class StatusStash(ObjectStash[UserCodeStatusCollection]):
settings: PartitionSettings = PartitionSettings(
name=UserCodeStatusCollection.__canonical_name__,
object_type=UserCodeStatusCollection,
)
pass


class CodeStatusUpdate(PartialSyftObject):
Expand Down
5 changes: 0 additions & 5 deletions packages/syft/src/syft/service/code/user_code_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionSettings
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.result import as_result
Expand All @@ -11,10 +10,6 @@

@serializable(canonical_name="UserCodeSQLStash", version=1)
class UserCodeStash(ObjectStash[UserCode]):
settings: PartitionSettings = PartitionSettings(
name=UserCode.__canonical_name__, object_type=UserCode
)

@as_result(StashException, NotFoundException)
def get_by_code_hash(self, credentials: SyftVerifyKey, code_hash: str) -> UserCode:
return self.get_one(
Expand Down
3 changes: 0 additions & 3 deletions packages/syft/src/syft/service/data_subject/data_subject.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

# relative
from ...serde.serializable import serializable
from ...store.document_store import PartitionKey
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SyftObject
from ...types.transforms import TransformContext
Expand All @@ -17,8 +16,6 @@
from ...types.uid import UID
from ...util.markdown import as_markdown_python_code

NamePartitionKey = PartitionKey(key="name", type_=str)


@serializable()
class DataSubject(SyftObject):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,9 @@

# relative
from ...serde.serializable import serializable
from ...store.document_store import PartitionKey
from ...types.syft_object import SYFT_OBJECT_VERSION_1
from ...types.syft_object import SyftObject

ParentPartitionKey = PartitionKey(key="parent", type_=str)
ChildPartitionKey = PartitionKey(key="child", type_=str)


@serializable()
class DataSubjectMemberRelationship(SyftObject):
Expand Down
2 changes: 0 additions & 2 deletions packages/syft/src/syft/service/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

# relative
from ...serde.serializable import serializable
from ...store.document_store import PartitionKey
from ...types.datetime import DateTime
from ...types.dicttuple import DictTuple
from ...types.errors import SyftException
Expand Down Expand Up @@ -45,7 +44,6 @@
from ..response import SyftSuccess
from ..response import SyftWarning

NamePartitionKey = PartitionKey(key="name", type_=str)
logger = logging.getLogger(__name__)


Expand Down
5 changes: 0 additions & 5 deletions packages/syft/src/syft/service/job/job_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from ...service.context import AuthedServiceContext
from ...service.worker.worker_pool import SyftWorker
from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionSettings
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.datetime import DateTime
Expand Down Expand Up @@ -736,10 +735,6 @@ def from_job(

@serializable(canonical_name="JobStashSQL", version=1)
class JobStash(ObjectStash[Job]):
settings: PartitionSettings = PartitionSettings(
name=Job.__canonical_name__, object_type=Job
)

@as_result(StashException)
def set_result(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from ...server.credentials import SyftSigningKey
from ...server.credentials import SyftVerifyKey
from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionKey
from ...store.document_store_errors import NotFoundException
from ...types.blob_storage import BlobStorageEntry
from ...types.blob_storage import CreateBlobStorageEntry
Expand Down Expand Up @@ -64,9 +63,6 @@ def supported_versions(self) -> list:
return SyftObjectRegistry.get_versions(self.canonical_name)


KlassNamePartitionKey = PartitionKey(key="canonical_name", type_=str)


@serializable(canonical_name="SyftMigrationStateSQLStash", version=1)
class SyftMigrationStateStash(ObjectStash[SyftObjectMigrationState]):
@as_result(SyftException, NotFoundException)
Expand Down
5 changes: 0 additions & 5 deletions packages/syft/src/syft/service/notifier/notifier_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionSettings
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.result import as_result
Expand All @@ -17,10 +16,6 @@
@instrument
@serializable(canonical_name="NotifierSQLStash", version=1)
class NotifierStash(ObjectStash[NotifierSettings]):
settings: PartitionSettings = PartitionSettings(
name=NotifierSettings.__canonical_name__, object_type=NotifierSettings
)

@as_result(StashException, NotFoundException)
def get(self, credentials: SyftVerifyKey) -> NotifierSettings:
"""Get Settings"""
Expand Down
6 changes: 0 additions & 6 deletions packages/syft/src/syft/service/output/output_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ...server.credentials import SyftVerifyKey
from ...store.db.db import DBManager
from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionKey
from ...store.document_store_errors import StashException
from ...store.linked_obj import LinkedObject
from ...types.datetime import DateTime
Expand All @@ -26,11 +25,6 @@
from ..user.user_roles import ADMIN_ROLE_LEVEL
from ..user.user_roles import GUEST_ROLE_LEVEL

CreatedAtPartitionKey = PartitionKey(key="created_at", type_=DateTime)
UserCodeIdPartitionKey = PartitionKey(key="user_code_id", type_=UID)
JobIdPartitionKey = PartitionKey(key="job_id", type_=UID)
OutputPolicyIdPartitionKey = PartitionKey(key="output_policy_id", type_=UID)


@serializable()
class ExecutionOutput(SyncableSyftObject):
Expand Down
5 changes: 0 additions & 5 deletions packages/syft/src/syft/service/policy/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from ...serde.recursive_primitives import recursive_serde_register_type
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
from ...store.document_store import PartitionKey
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.datetime import DateTime
Expand Down Expand Up @@ -84,10 +83,6 @@ class OutputPolicyValidEnum(Enum):

DEFAULT_USER_POLICY_VERSION = 1

PolicyUserVerifyKeyPartitionKey = PartitionKey(
key="user_verify_key", type_=SyftVerifyKey
)

PyCodeObject = Any


Expand Down
8 changes: 6 additions & 2 deletions packages/syft/src/syft/service/queue/base_queue.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
# stdlib
from typing import Any
from typing import ClassVar
from typing import TYPE_CHECKING

# relative
from ...serde.serializable import serializable
from ...service.context import AuthedServiceContext
from ...store.document_store import NewBaseStash
from ...types.uid import UID
from ..response import SyftSuccess
from ..worker.worker_stash import WorkerStash

if TYPE_CHECKING:
# relative
from .queue_stash import QueueStash


@serializable(canonical_name="QueueClientConfig", version=1)
class QueueClientConfig:
Expand Down Expand Up @@ -105,7 +109,7 @@ def create_consumer(
def create_producer(
self,
queue_name: str,
queue_stash: type[NewBaseStash],
queue_stash: "QueueStash",
context: AuthedServiceContext,
worker_stash: WorkerStash,
) -> QueueProducer:
Expand Down
8 changes: 6 additions & 2 deletions packages/syft/src/syft/service/queue/queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from threading import Thread
import time
from typing import Any
from typing import TYPE_CHECKING
from typing import cast

# third party
Expand All @@ -17,7 +18,6 @@
from ...server.credentials import SyftVerifyKey
from ...server.worker_settings import WorkerSettings
from ...service.context import AuthedServiceContext
from ...store.document_store import NewBaseStash
from ...store.linked_obj import LinkedObject
from ...types.datetime import DateTime
from ...types.errors import SyftException
Expand All @@ -39,6 +39,10 @@
from .queue_stash import QueueItem
from .queue_stash import Status

if TYPE_CHECKING:
# relative
from .queue_stash import QueueStash

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -129,7 +133,7 @@ def create_consumer(
def create_producer(
self,
queue_name: str,
queue_stash: type[NewBaseStash],
queue_stash: "QueueStash",
context: AuthedServiceContext,
worker_stash: WorkerStash,
) -> QueueProducer:
Expand Down
5 changes: 0 additions & 5 deletions packages/syft/src/syft/service/queue/queue_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from ...server.worker_settings import WorkerSettings
from ...server.worker_settings import WorkerSettingsV1
from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionKey
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...store.linked_obj import LinkedObject
Expand All @@ -35,10 +34,6 @@ class Status(str, Enum):
INTERRUPTED = "interrupted"


StatusPartitionKey = PartitionKey(key="status", type_=Status)
_WorkerPoolPartitionKey = PartitionKey(key="worker_pool", type_=LinkedObject)


@serializable()
class QueueItemV1(SyftObject):
__canonical_name__ = "QueueItem"
Expand Down
2 changes: 0 additions & 2 deletions packages/syft/src/syft/service/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from ..serde.signature import signature_remove_self
from ..server.credentials import SyftVerifyKey
from ..store.db.stash import ObjectStash
from ..store.document_store import DocumentStore
from ..store.linked_obj import LinkedObject
from ..types.errors import SyftException
from ..types.result import as_result
Expand Down Expand Up @@ -71,7 +70,6 @@
class AbstractService:
server: AbstractServer
server_uid: UID
store_type: type = DocumentStore
stash: ObjectStash

@as_result(SyftException)
Expand Down
8 changes: 5 additions & 3 deletions packages/syft/src/syft/service/sync/sync_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from ...serde.serializable import serializable
from ...store.db.db import DBManager
from ...store.db.stash import ObjectStash
from ...store.document_store import NewBaseStash
from ...store.document_store_errors import NotFoundException
from ...store.linked_obj import LinkedObject
from ...types.datetime import DateTime
Expand Down Expand Up @@ -108,9 +107,10 @@ def transform_item(
self.set_obj_ids(context, item)
return item

@as_result(ValueError)
def get_stash_for_item(
self, context: AuthedServiceContext, item: SyftObject
) -> NewBaseStash:
) -> ObjectStash:
services = list(context.server.service_path_map.values()) # type: ignore

all_stashes = {}
Expand All @@ -119,6 +119,8 @@ def get_stash_for_item(
all_stashes[_stash.object_type] = _stash

stash = all_stashes.get(type(item), None)
if stash is None:
raise ValueError(f"Could not find stash for {type(item)}")
return stash

def add_permissions_for_item(
Expand Down Expand Up @@ -148,7 +150,7 @@ def add_storage_permissions_for_item(
def set_object(
self, context: AuthedServiceContext, item: SyncableSyftObject
) -> SyftObject:
stash = self.get_stash_for_item(context, item)
stash = self.get_stash_for_item(context, item).unwrap()
creds = context.credentials

obj = None
Expand Down
6 changes: 0 additions & 6 deletions packages/syft/src/syft/service/sync/sync_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@
from ...server.credentials import SyftVerifyKey
from ...store.db.db import DBManager
from ...store.db.stash import ObjectStash
from ...store.document_store import PartitionSettings
from ...store.document_store_errors import StashException
from ...types.result import as_result
from .sync_state import SyncState


@serializable(canonical_name="SyncStash", version=1)
class SyncStash(ObjectStash[SyncState]):
settings: PartitionSettings = PartitionSettings(
name=SyncState.__canonical_name__,
object_type=SyncState,
)

def __init__(self, store: DBManager) -> None:
super().__init__(store)
self.last_state: SyncState | None = None
Expand Down
2 changes: 1 addition & 1 deletion packages/syft/src/syft/service/worker/worker_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
from ...serde.serializable import serializable
from ...server.credentials import SyftVerifyKey
from ...store.db.db import DBManager
from ...store.document_store import SyftSuccess
from ...store.document_store_errors import StashException
from ...types.errors import SyftException
from ...types.result import as_result
from ...types.uid import UID
from ..response import SyftSuccess
from ..service import AbstractService
from ..service import AuthedServiceContext
from ..service import service_method
Expand Down
3 changes: 0 additions & 3 deletions packages/syft/src/syft/service/worker/worker_stash.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from ...server.credentials import SyftVerifyKey
from ...store.db.stash import ObjectStash
from ...store.db.stash import with_session
from ...store.document_store import PartitionKey
from ...store.document_store_errors import NotFoundException
from ...store.document_store_errors import StashException
from ...types.result import as_result
Expand All @@ -20,8 +19,6 @@
from .worker_pool import ConsumerState
from .worker_pool import SyftWorker

WorkerContainerNamePartitionKey = PartitionKey(key="container_name", type_=str)


@serializable(canonical_name="WorkerSQLStash", version=1)
class WorkerStash(ObjectStash[SyftWorker]):
Expand Down
3 changes: 3 additions & 0 deletions packages/syft/src/syft/store/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# relative
from . import mongo_document_store # noqa: F401
from . import sqlite_document_store # noqa: F401
Loading
Loading