diff --git a/packages/syft/src/syft/service/migration/migration_service.py b/packages/syft/src/syft/service/migration/migration_service.py index 78bcc30458e..aa5c73b5a92 100644 --- a/packages/syft/src/syft/service/migration/migration_service.py +++ b/packages/syft/src/syft/service/migration/migration_service.py @@ -9,6 +9,7 @@ from ...store.db.db import DBManager from ...store.db.stash import ObjectStash from ...store.document_store_errors import NotFoundException +from ...store.document_store_errors import UniqueConstraintException from ...types.blob_storage import BlobStorageEntry from ...types.errors import SyftException from ...types.result import as_result @@ -24,7 +25,6 @@ from ..response import SyftSuccess from ..service import AbstractService from ..service import service_method -from ..user.user import User from ..user.user_roles import ADMIN_ROLE_LEVEL from ..worker.utils import DEFAULT_WORKER_POOL_NAME from .object_migration_state import MigrationData @@ -289,20 +289,17 @@ def _update_migrated_objects( self, context: AuthedServiceContext, migrated_objects: list[SyftObject] ) -> SyftSuccess: for migrated_object in migrated_objects: - if ( - isinstance(migrated_object, User) - and migrated_object.verify_key == context.server.verify_key - ): - self.stash.update_root_user(context, migrated_object).unwrap() - else: - stash = self._search_stash_for_klass( - context, type(migrated_object) - ).unwrap() + stash = self._search_stash_for_klass( + context, type(migrated_object) + ).unwrap() + try: stash.update( context.credentials, obj=migrated_object, ).unwrap() + except UniqueConstraintException as e: + print(f"Failed to update {migrated_object}: {e}") return SyftSuccess(message="Updated migration objects!") diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 1101b09d1f4..85860821866 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -42,6 +42,7 @@ from ...util.telemetry import instrument from ..document_store_errors import NotFoundException from ..document_store_errors import StashException +from ..document_store_errors import UniqueConstraintException from .db import DBManager from .query import Query from .schema import PostgresBase @@ -205,11 +206,6 @@ def is_unique(self, obj: StashT, session: Session = None) -> bool: elif len(results) == 1: result = results[0] res = result.id == obj.id - if not res: - # third party - import ipdb - - ipdb.set_trace() return res return True @@ -434,7 +430,13 @@ def apply_partial_update( self.object_type.model_validate(original_obj) return original_obj - @as_result(StashException, NotFoundException, AttributeError, ValidationError) + @as_result( + StashException, + NotFoundException, + AttributeError, + ValidationError, + UniqueConstraintException, + ) @with_session def update( self, @@ -461,7 +463,9 @@ def update( # TODO has_permission is not used if not self.is_unique(obj): - raise StashException(f"Some fields are not unique for {type(obj).__name__}") + raise UniqueConstraintException( + f"Some fields are not unique for {type(obj).__name__} and unique fields {self.unique_fields}" + ) stmt = self.table.update().where(self._get_field_filter("id", obj.id)) stmt = self._apply_permission_filter( diff --git a/packages/syft/src/syft/store/document_store_errors.py b/packages/syft/src/syft/store/document_store_errors.py index 69da6b73a8f..04fb6777897 100644 --- a/packages/syft/src/syft/store/document_store_errors.py +++ b/packages/syft/src/syft/store/document_store_errors.py @@ -14,6 +14,10 @@ class StashException(SyftException): public_message = "There was an error retrieving data. Contact your admin." +class UniqueConstraintException(StashException): + public_message = "Another item with the same unique constraint already exists." + + class ObjectCRUDPermissionException(SyftException): public_message = "You do not have permission to perform this action." diff --git a/packages/syft/src/syft/types/transforms.py b/packages/syft/src/syft/types/transforms.py index 7ff980e692c..60e9722a029 100644 --- a/packages/syft/src/syft/types/transforms.py +++ b/packages/syft/src/syft/types/transforms.py @@ -30,8 +30,6 @@ class TransformContext(Context): @classmethod def from_context(cls, obj: Any, context: Context | None = None) -> Self: - if isinstance(context, TransformContext): - return context t_context = cls() t_context.obj = obj try: