diff --git a/.github/workflows/pr-tests-stack.yml b/.github/workflows/pr-tests-stack.yml index 5dfb315b565..35c4e40823a 100644 --- a/.github/workflows/pr-tests-stack.yml +++ b/.github/workflows/pr-tests-stack.yml @@ -682,7 +682,6 @@ jobs: tox -e migration.scenarios.test pr-tests-migrations-k8s: - if: false # skipping this job for now strategy: max-parallel: 99 matrix: diff --git a/packages/syft/src/syft/client/datasite_client.py b/packages/syft/src/syft/client/datasite_client.py index 7553344ad5a..0129b4a17ac 100644 --- a/packages/syft/src/syft/client/datasite_client.py +++ b/packages/syft/src/syft/client/datasite_client.py @@ -416,8 +416,14 @@ def get_migration_data(self, include_blobs: bool = True) -> MigrationData: return res - def load_migration_data(self, path: str | Path) -> SyftSuccess: - migration_data = MigrationData.from_file(path) + def load_migration_data( + self, path_or_data: str | Path | MigrationData + ) -> SyftSuccess: + if isinstance(path_or_data, MigrationData): + migration_data = path_or_data + else: + migration_data = MigrationData.from_file(path_or_data) + migration_data._set_obj_location_(self.id, self.verify_key) if self.id != migration_data.server_uid: diff --git a/packages/syft/src/syft/service/api/api_service.py b/packages/syft/src/syft/service/api/api_service.py index 8df26f0ac15..5f6ba0dbe48 100644 --- a/packages/syft/src/syft/service/api/api_service.py +++ b/packages/syft/src/syft/service/api/api_service.py @@ -72,7 +72,7 @@ def set( public_message="An API endpoint already exists at the given path." ) - result = self.stash.upsert(context.credentials, endpoint=new_endpoint).unwrap() + result = self.stash.upsert(context.credentials, obj=new_endpoint).unwrap() action_obj = ActionObject.from_obj( id=new_endpoint.action_object_id, syft_action_data=CustomEndpointActionObject(endpoint_id=result.id), @@ -157,7 +157,7 @@ def update( endpoint.mock_function.view_access = view_access # save changes - self.stash.upsert(context.credentials, endpoint=endpoint).unwrap() + self.stash.upsert(context.credentials, obj=endpoint).unwrap() return SyftSuccess(message="Endpoint successfully updated.") @service_method( @@ -218,7 +218,7 @@ def set_state( if mock and api_endpoint.mock_function: api_endpoint.mock_function.state = state - self.stash.upsert(context.credentials, endpoint=api_endpoint).unwrap() + self.stash.upsert(context.credentials, obj=api_endpoint).unwrap() return SyftSuccess(message=f"APIEndpoint {api_path} state updated.") @service_method( @@ -248,7 +248,7 @@ def set_settings( if mock and api_endpoint.mock_function: api_endpoint.mock_function.settings = settings - self.stash.upsert(context.credentials, endpoint=api_endpoint).unwrap() + self.stash.upsert(context.credentials, obj=api_endpoint).unwrap() return SyftSuccess(message=f"APIEndpoint {api_path} settings updated.") @service_method( diff --git a/packages/syft/src/syft/service/api/api_stash.py b/packages/syft/src/syft/service/api/api_stash.py index 0c0c6f73020..e892d48da61 100644 --- a/packages/syft/src/syft/service/api/api_stash.py +++ b/packages/syft/src/syft/service/api/api_stash.py @@ -33,22 +33,3 @@ def path_exists(self, credentials: SyftVerifyKey, path: str) -> bool: return True except NotFoundException: return False - - @as_result(StashException) - def upsert( - self, - credentials: SyftVerifyKey, - endpoint: TwinAPIEndpoint, - has_permission: bool = False, - ) -> TwinAPIEndpoint: - """Upsert an endpoint.""" - exists = self.path_exists(credentials=credentials, path=endpoint.path).unwrap() - - if exists: - super().delete_by_uid(credentials=credentials, uid=endpoint.id).unwrap() - - return ( - super() - .set(credentials=credentials, obj=endpoint, ignore_duplicates=False) - .unwrap() - ) diff --git a/packages/syft/src/syft/service/policy/policy.py b/packages/syft/src/syft/service/policy/policy.py index 1e33755418e..4662307c235 100644 --- a/packages/syft/src/syft/service/policy/policy.py +++ b/packages/syft/src/syft/service/policy/policy.py @@ -642,15 +642,16 @@ def allowed_ids_only( public_message=f"Invalid server type for code submission: {context.server.server_type}" ) - server_identity = ServerIdentity( - server_name=context.server.name, - server_id=context.server.id, - verify_key=context.server.signing_key.verify_key, - ) - allowed_inputs = allowed_inputs.get(server_identity, {}) + allowed_inputs_for_server = None + for identity, inputs in allowed_inputs.items(): + if identity.server_id == context.server.id: + allowed_inputs_for_server = inputs + break + if allowed_inputs_for_server is None: + allowed_inputs_for_server = {} filtered_kwargs = {} - for key in allowed_inputs.keys(): + for key in allowed_inputs_for_server.keys(): if key in kwargs: value = kwargs[key] uid = value @@ -658,7 +659,7 @@ def allowed_ids_only( if not isinstance(uid, UID): uid = getattr(value, "id", None) - if uid != allowed_inputs[key]: + if uid != allowed_inputs_for_server[key]: raise SyftException( public_message=f"Input with uid: {uid} for `{key}` not in allowed inputs: {allowed_inputs}" ) diff --git a/packages/syft/src/syft/store/db/stash.py b/packages/syft/src/syft/store/db/stash.py index 85860821866..5323f3455c8 100644 --- a/packages/syft/src/syft/store/db/stash.py +++ b/packages/syft/src/syft/store/db/stash.py @@ -75,6 +75,8 @@ def with_session(func: Callable[P, T]) -> Callable[P, T]: # type: ignore """ Decorator to inject a session into the function kwargs if it is not provided. + Make sure to pass session as a keyword argument to the function. + TODO: This decorator is a temporary fix, we want to move to a DI approach instead: move db connection and session to context, and pass context to all stash methods. """ @@ -87,8 +89,9 @@ def with_session(func: Callable[P, T]) -> Callable[P, T]: # type: ignore def wrapper(self: "ObjectStash[StashT]", *args: Any, **kwargs: Any) -> Any: if inject_session and kwargs.get("session") is None: with self.sessionmaker() as session: - kwargs["session"] = session - return func(self, *args, **kwargs) + with session.begin(): + kwargs["session"] = session + return func(self, *args, **kwargs) return func(self, *args, **kwargs) return wrapper # type: ignore @@ -369,11 +372,13 @@ def set( uid = obj.id # check if the object already exists - if self.exists(credentials, uid) or not self.is_unique(obj): + if self.exists(credentials, uid, session=session) or not self.is_unique( + obj, session=session + ): if ignore_duplicates: return obj unique_fields_str = ", ".join(self.unique_fields) - raise StashException( + raise UniqueConstraintException( public_message=f"Duplication Key Error for {obj}.\n" f"The fields that should be unique are {unique_fields_str}." ) @@ -399,7 +404,6 @@ def set( raise StashException( f"Error serializing object: {e}. Some fields are invalid." ) - # create the object with the permissions stmt = self.table.insert().values( id=uid, @@ -408,7 +412,6 @@ def set( storage_permissions=storage_permissions, ) session.execute(stmt) - session.commit() return self.get_by_uid(credentials, uid, session=session).unwrap() @as_result(ValidationError, AttributeError) @@ -462,7 +465,7 @@ def update( ).unwrap() # TODO has_permission is not used - if not self.is_unique(obj): + if not self.is_unique(obj, session=session): raise UniqueConstraintException( f"Some fields are not unique for {type(obj).__name__} and unique fields {self.unique_fields}" ) @@ -483,14 +486,12 @@ def update( f"Error serializing object: {e}. Some fields are invalid." ) stmt = stmt.values(fields=fields) - result = session.execute(stmt) - session.commit() if result.rowcount == 0: raise NotFoundException( f"{self.object_type.__name__}: {obj.id} not found or no permission to update." ) - return self.get_by_uid(credentials, obj.id).unwrap() + return self.get_by_uid(credentials, obj.id, session=session).unwrap() @as_result(StashException, NotFoundException) @with_session @@ -510,7 +511,6 @@ def delete_by_uid( session=session, ) result = session.execute(stmt) - session.commit() if result.rowcount == 0: raise NotFoundException( f"{self.object_type.__name__}: {uid} not found or no permission to delete." @@ -649,8 +649,6 @@ def add_permission( stmt = self.table.update().where(self.table.c.id == permission.uid) stmt = stmt.values(permissions=list(existing_permissions)) session.execute(stmt) - session.commit() - return None @as_result(NotFoundException) @@ -685,7 +683,6 @@ def remove_permission( .values(permissions=list(permissions)) ) session.execute(stmt) - session.commit() return None @with_session @@ -842,7 +839,6 @@ def remove_storage_permission( .values(storage_permissions=[str(uid) for uid in permissions]) ) session.execute(stmt) - session.commit() return None @as_result(StashException) @@ -857,3 +853,26 @@ def _get_storage_permissions_for_uid( if result is None: raise NotFoundException(f"No storage permissions found for uid: {uid}") return {UID(uid) for uid in result.storage_permissions} + + @with_session + @as_result(StashException) + def upsert( + self, + credentials: SyftVerifyKey, + obj: StashT, + session: Session = None, + ) -> StashT: + """Insert or update an object in the stash if it already exists. + Atomic operation when using the same session for both operations. + """ + + try: + return self.set( + credentials=credentials, + obj=obj, + session=session, + ).unwrap() + except UniqueConstraintException: + return self.update( + credentials=credentials, obj=obj, session=session + ).unwrap() diff --git a/packages/syft/tests/syft/stores/base_stash_test.py b/packages/syft/tests/syft/stores/base_stash_test.py index 8ed9a312b86..5344ec8f5ab 100644 --- a/packages/syft/tests/syft/stores/base_stash_test.py +++ b/packages/syft/tests/syft/stores/base_stash_test.py @@ -190,6 +190,30 @@ def test_basestash_update( assert retrieved == updated_obj +def test_basestash_upsert( + root_verify_key, base_stash: MockStash, mock_object: MockObject, faker: Faker +) -> None: + base_stash.set(root_verify_key, mock_object).unwrap() + + updated_obj = mock_object.copy() + updated_obj.name = faker.name() + + retrieved = base_stash.upsert(root_verify_key, updated_obj).unwrap() + assert retrieved == updated_obj + + updated_obj.id = UID() + + with pytest.raises(StashException): + # fails because the name should be unique + base_stash.upsert(root_verify_key, updated_obj).unwrap() + + updated_obj.name = faker.name() + + retrieved = base_stash.upsert(root_verify_key, updated_obj).unwrap() + assert retrieved == updated_obj + assert len(base_stash.get_all(root_verify_key).unwrap()) == 2 + + def test_basestash_cannot_update_non_existent( root_verify_key, base_stash: MockStash, mock_object: MockObject, faker: Faker ) -> None: