From 0260756fd77f9781c25633a628841cfda22157ae Mon Sep 17 00:00:00 2001 From: khoaguin Date: Thu, 23 May 2024 17:32:30 +0700 Subject: [PATCH] [syft/network] add checking `NodePeerUpdate` type in `NetworkStash.update` --- .../src/syft/service/network/network_service.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/packages/syft/src/syft/service/network/network_service.py b/packages/syft/src/syft/service/network/network_service.py index 622316fa606..5a44afc69cd 100644 --- a/packages/syft/src/syft/service/network/network_service.py +++ b/packages/syft/src/syft/service/network/network_service.py @@ -50,6 +50,7 @@ from ..warnings import CRUDWarning from .association_request import AssociationRequestChange from .node_peer import NodePeer +from .node_peer import NodePeerUpdate from .routes import HTTPNodeRoute from .routes import NodeRoute from .routes import NodeRouteType @@ -87,12 +88,17 @@ def get_by_name( def update( self, credentials: SyftVerifyKey, - peer: NodePeer, + peer: NodePeer | NodePeerUpdate, has_permission: bool = False, ) -> Result[NodePeer, str]: - valid = self.check_type(peer, NodePeer) - if valid.is_err(): - return Err(SyftError(message=valid.err())) + valid_node_peer = self.check_type(peer, NodePeer) + valid_node_peer_update = self.check_type(peer, NodePeerUpdate) + if valid_node_peer.is_err() and valid_node_peer_update.is_err(): + return Err( + SyftError( + message=f"{type(peer)} does not match required type: NodePeer or NodePeerUpdate" + ) + ) return super().update(credentials, peer, has_permission=has_permission) def create_or_update_peer(