diff --git a/substrafl/nodes/protocol.py b/substrafl/nodes/protocol.py index 077efa03..8ffc5f0f 100644 --- a/substrafl/nodes/protocol.py +++ b/substrafl/nodes/protocol.py @@ -1,3 +1,4 @@ +from abc import abstractmethod from typing import Any from typing import List from typing import Protocol @@ -16,15 +17,19 @@ class TrainDataNodeProtocol(Protocol): data_manager_key: str data_sample_keys: List[str] + @abstractmethod def init_states(self, *args, **kwargs) -> LocalStateRef: pass + @abstractmethod def update_states(self, operation: RemoteDataOperation, *args, **kwargs) -> (LocalStateRef, Any): pass + @abstractmethod def register_operations(self, client: substra.Client, *args, **kwargs) -> Any: pass + @abstractmethod def summary(self) -> dict: pass @@ -35,12 +40,15 @@ class TestDataNodeProtocol(Protocol): data_manager_key: str data_sample_keys: List[str] + @abstractmethod def update_states(self, operation: RemoteDataOperation, *args, **kwargs) -> None: pass + @abstractmethod def register_operations(self, client: substra.Client, *args, **kwargs) -> Any: pass + @abstractmethod def summary(self) -> dict: pass @@ -49,11 +57,14 @@ def summary(self) -> dict: class AggregationNodeProtocol(Protocol): organization_id: str + @abstractmethod def update_states(self, operation: RemoteOperation, *args, **kwargs) -> Any: pass + @abstractmethod def register_operations(self, client: substra.Client, *args, **kwargs) -> Any: pass + @abstractmethod def summary(self) -> dict: pass