From 6476e191ad628d145bc7a4fdfa2e559230887dbd Mon Sep 17 00:00:00 2001 From: ThibaultFy <50656860+ThibaultFy@users.noreply.github.com> Date: Wed, 28 Feb 2024 12:09:06 +0100 Subject: [PATCH] chore: use protocols for nodes (#185) Signed-off-by: ThibaultFy <50656860+ThibaultFy@users.noreply.github.com> Signed-off-by: ThibaultFy --- CHANGELOG.md | 3 + benchmark/camelyon/README.md | 4 +- .../pure_substrafl/register_assets.py | 6 +- docs/api/nodes.rst | 12 ++-- .../algorithms/pytorch/torch_base_algo.py | 14 ++-- substrafl/compute_plan_builder.py | 12 ++-- substrafl/evaluation_strategy.py | 12 ++-- substrafl/experiment.py | 34 ++++----- substrafl/model_loading.py | 2 +- substrafl/nodes/__init__.py | 17 +++-- substrafl/nodes/aggregation_node.py | 25 +++++-- substrafl/nodes/node.py | 55 --------------- substrafl/nodes/protocol.py | 70 +++++++++++++++++++ substrafl/nodes/schemas.py | 21 ++++++ substrafl/nodes/test_data_node.py | 49 ++++++------- substrafl/nodes/train_data_node.py | 28 ++++---- substrafl/remote/substratools_methods.py | 4 +- substrafl/strategies/fed_avg.py | 32 ++++----- substrafl/strategies/fed_pca.py | 28 ++++---- substrafl/strategies/newton_raphson.py | 28 ++++---- substrafl/strategies/scaffold.py | 32 ++++----- substrafl/strategies/single_organization.py | 30 ++++---- substrafl/strategies/strategy.py | 36 +++++----- tests/algorithms/pytorch/test_base_algo.py | 4 +- tests/dependency/test_dependency.py | 4 +- tests/remote/test_decorator.py | 4 +- tests/strategies/test_fed_pca.py | 4 +- tests/strategies/test_newton_raphson.py | 2 +- tests/strategies/test_strategy.py | 2 +- tests/test_evaluation_strategy.py | 59 +++++----------- tests/test_experiment.py | 2 +- tests/test_model_loading.py | 2 +- tests/utils.py | 2 +- 33 files changed, 340 insertions(+), 299 deletions(-) delete mode 100644 substrafl/nodes/node.py create mode 100644 substrafl/nodes/protocol.py create mode 100644 substrafl/nodes/schemas.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6ce5d4e6..897faabb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - BREAKING: the `perform_predict` method of `Strategy` changed in favor of `perform_evaluation` that calls the new `evaluate` method [#177](https://github.com/Substra/substrafl/pull/177) - BREAKING: `metric_functions` are now passed to the `Strategy` instead of the `TestDataNode` [#177](https://github.com/Substra/substrafl/pull/177) - BREAKING: the `predict` method of `Algo` has no `@remote_data` decorator anymore. It signatures does not take `prediction_path` anymore, and the predictions are return by the method [#177](https://github.com/Substra/substrafl/pull/177) +- Abstract base class `Node` is replaced by `Protocols`, defined in `substrafl.nodes.protocol.py` ([#185](https://github.com/Substra/substrafl/pull/185)) +- BREAKING: rename `test_data_sample_keys`, `test_tasks` and `register_test_operations`, `tasks` to `data_sample_keys` and `register_operations` in `TestDataNodes` ([#185](https://github.com/Substra/substrafl/pull/185)) +- BREAKING: `InputIdentifiers` and `OutputIdentifiers` move from `substrafl.nodes.node` to `substrafl.nodes.schemas` ([#185](https://github.com/Substra/substrafl/pull/185)) ## [0.43.0](https://github.com/Substra/substrafl/releases/tag/0.43.0) - 2024-02-26 diff --git a/benchmark/camelyon/README.md b/benchmark/camelyon/README.md index d430c028..cd7db24b 100644 --- a/benchmark/camelyon/README.md +++ b/benchmark/camelyon/README.md @@ -184,12 +184,12 @@ In remote, one can choose to reuse some assets by passing their keys in the [key { "MyOrg1MSP": { "dataset_key": "b8d754f0-40a5-4976-ae16-8dd4eca35ffc", - "test_data_sample_keys": ["1238452c-a1dd-47ef-84a8-410c0841693a"], + "data_sample_keys": ["1238452c-a1dd-47ef-84a8-410c0841693a"], "train_data_sample_keys": ["38071944-c974-4b3b-a671-aa4835a0ae62"] }, "MyOrg2MSP": { "dataset_key": "fa8e9bf7-5084-4b59-b089-a459495a08be", - "test_data_sample_keys": ["73715d69-9447-4270-9d3f-d0b17bb88a87"], + "data_sample_keys": ["73715d69-9447-4270-9d3f-d0b17bb88a87"], "train_data_sample_keys": ["766d2029-f90b-440e-8b39-2389ab04041d"] }, "metric_key": "e5a99be6-0138-461a-92fe-23f685cdc9e1" diff --git a/benchmark/camelyon/pure_substrafl/register_assets.py b/benchmark/camelyon/pure_substrafl/register_assets.py index ea2ccd2e..b3c08f15 100644 --- a/benchmark/camelyon/pure_substrafl/register_assets.py +++ b/benchmark/camelyon/pure_substrafl/register_assets.py @@ -111,12 +111,12 @@ def add_duplicated_dataset( { : { "dataset_key": "b8d754f0-40a5-4976-ae16-8dd4eca35ffc", - "test_data_sample_keys": ["1238452c-a1dd-47ef-84a8-410c0841693a"], + "data_sample_keys": ["1238452c-a1dd-47ef-84a8-410c0841693a"], "train_data_sample_keys": ["38071944-c974-4b3b-a671-aa4835a0ae62"] }, : { "dataset_key": "fa8e9bf7-5084-4b59-b089-a459495a08be", - "test_data_sample_keys": ["73715d69-9447-4270-9d3f-d0b17bb88a87"], + "data_sample_keys": ["73715d69-9447-4270-9d3f-d0b17bb88a87"], "train_data_sample_keys": ["766d2029-f90b-440e-8b39-2389ab04041d"] }, ... @@ -239,7 +239,7 @@ def get_test_data_nodes( TestDataNode( organization_id=msp_id, data_manager_key=asset_keys.get(msp_id)["dataset_key"], - test_data_sample_keys=asset_keys.get(msp_id)["test_data_sample_keys"], + data_sample_keys=asset_keys.get(msp_id)["test_data_sample_keys"], ) ) diff --git a/docs/api/nodes.rst b/docs/api/nodes.rst index 0ad17b8d..d60423c2 100755 --- a/docs/api/nodes.rst +++ b/docs/api/nodes.rst @@ -14,13 +14,17 @@ TestDataNode ^^^^^^^^^^^^^ .. autoclass:: substrafl.nodes.test_data_node.TestDataNode -Node -^^^^^ -.. autoclass:: substrafl.nodes.node.Node +Protocols +^^^^^^^^^ +.. autoclass:: substrafl.nodes.protocol.TrainDataNodeProtocol +.. autoclass:: substrafl.nodes.protocol.TestDataNodeProtocol +.. autoclass:: substrafl.nodes.protocol.AggregationNodeProtocol References ^^^^^^^^^^ .. automodule:: substrafl.nodes.references -.. autoclass:: substrafl.nodes.node.OperationKey +.. autoclass:: substrafl.nodes.schemas.OperationKey +.. autoclass:: substrafl.nodes.schemas.InputIdentifiers +.. autoclass:: substrafl.nodes.schemas.OutputIdentifiers .. autoclass:: substrafl.nodes.references.local_state.LocalStateRef .. autoclass:: substrafl.nodes.references.shared_state.SharedStateRef diff --git a/substrafl/algorithms/pytorch/torch_base_algo.py b/substrafl/algorithms/pytorch/torch_base_algo.py index 3edb66b8..e409ebb9 100644 --- a/substrafl/algorithms/pytorch/torch_base_algo.py +++ b/substrafl/algorithms/pytorch/torch_base_algo.py @@ -390,12 +390,14 @@ def summary(self): { "model": str(type(self._model)), "criterion": str(type(self._criterion)), - "optimizer": None - if self._optimizer is None - else { - "type": str(type(self._optimizer)), - "parameters": self._optimizer.defaults, - }, + "optimizer": ( + None + if self._optimizer is None + else { + "type": str(type(self._optimizer)), + "parameters": self._optimizer.defaults, + } + ), "scheduler": None if self._scheduler is None else str(type(self._scheduler)), } ) diff --git a/substrafl/compute_plan_builder.py b/substrafl/compute_plan_builder.py index d4946f08..a1b9ea39 100644 --- a/substrafl/compute_plan_builder.py +++ b/substrafl/compute_plan_builder.py @@ -5,8 +5,8 @@ from typing import Optional from substrafl.evaluation_strategy import EvaluationStrategy -from substrafl.nodes.aggregation_node import AggregationNode -from substrafl.nodes.train_data_node import TrainDataNode +from substrafl.nodes import AggregationNodeProtocol +from substrafl.nodes import TrainDataNodeProtocol class ComputePlanBuilder(abc.ABC): @@ -31,8 +31,8 @@ def __init__(self, custom_arg, my_custom_kwargs="value"): @abc.abstractmethod def build_compute_plan( self, - train_data_nodes: Optional[List[TrainDataNode]], - aggregation_node: Optional[List[AggregationNode]], + train_data_nodes: Optional[List[TrainDataNodeProtocol]], + aggregation_node: Optional[List[AggregationNodeProtocol]], evaluation_strategy: Optional[EvaluationStrategy], num_rounds: Optional[int], clean_models: Optional[bool] = True, @@ -41,8 +41,8 @@ def build_compute_plan( :func:`~substrafl.experiment.execute_experiment` function. Args: - train_data_nodes (typing.List[TrainDataNode]): list of the train organizations - aggregation_node (typing.Optional[AggregationNode]): aggregation node, necessary for + train_data_nodes (typing.List[TrainDataNodeProtocol]): list of the train organizations + aggregation_node (typing.Optional[AggregationNodeProtocol]): aggregation node, necessary for centralized strategy, unused otherwise evaluation_strategy (Optional[EvaluationStrategy]): evaluation strategy to follow for testing models. num_rounds (int): Number of times to repeat the compute plan sub-graph (define in perform round). It is diff --git a/substrafl/evaluation_strategy.py b/substrafl/evaluation_strategy.py index ab51bf01..c21c7921 100644 --- a/substrafl/evaluation_strategy.py +++ b/substrafl/evaluation_strategy.py @@ -2,13 +2,13 @@ from typing import Optional from typing import Set -from substrafl.nodes.test_data_node import TestDataNode +from substrafl.nodes import TestDataNodeProtocol class EvaluationStrategy: def __init__( self, - test_data_nodes: List[TestDataNode], + test_data_nodes: List[TestDataNodeProtocol], eval_frequency: Optional[int] = None, eval_rounds: Optional[List[int]] = None, ) -> None: @@ -17,7 +17,7 @@ def __init__( union of both selected indexes will be evaluated. Args: - test_data_nodes (List[TestDataNode]): nodes on which the model is to be tested. + test_data_nodes (List[TestDataNodeProtocol]): nodes on which the model is to be tested. eval_frequency (Optional[int]): The model will be tested every ``eval_frequency`` rounds. Set to None to activate eval_rounds only. Defaults to None. eval_rounds (Optional[List[int]]): If specified, the model will be tested on the index of a round given @@ -25,7 +25,7 @@ def __init__( Raises: ValueError: test_data_nodes cannot be an empty list - TypeError: test_data_nodes must be filled with instances of TestDataNode + TypeError: test_data_nodes must be filled with instances of TestDataNodeProtocol TypeError: rounds must be a list or an int ValueError: both eval_rounds and eval_frequency cannot be None at the same time @@ -83,8 +83,8 @@ def __init__( if not test_data_nodes: raise ValueError("test_data_nodes lists cannot be empty") - if not all(isinstance(node, TestDataNode) for node in test_data_nodes): - raise TypeError("test_data_nodes must include objects of TestDataNode type") + if not all(isinstance(node, TestDataNodeProtocol) for node in test_data_nodes): + raise TypeError("test_data_nodes must implement the TestDataNodeProtocol") if eval_frequency is None and eval_rounds is None: raise ValueError("At least one of eval_frequency or eval_rounds must be defined") diff --git a/substrafl/experiment.py b/substrafl/experiment.py index cdde8458..4a41f2d7 100644 --- a/substrafl/experiment.py +++ b/substrafl/experiment.py @@ -20,9 +20,9 @@ from substrafl.evaluation_strategy import EvaluationStrategy from substrafl.exceptions import KeyMetadataError from substrafl.exceptions import LenMetadataError -from substrafl.nodes.aggregation_node import AggregationNode -from substrafl.nodes.node import OperationKey -from substrafl.nodes.train_data_node import TrainDataNode +from substrafl.nodes import AggregationNodeProtocol +from substrafl.nodes import TrainDataNodeProtocol +from substrafl.nodes.schemas import OperationKey from substrafl.remote.remote_struct import RemoteStruct logger = logging.getLogger(__name__) @@ -30,8 +30,8 @@ def _register_operations( client: substra.Client, - train_data_nodes: List[TrainDataNode], - aggregation_node: Optional[AggregationNode], + train_data_nodes: List[TrainDataNodeProtocol], + aggregation_node: Optional[AggregationNodeProtocol], evaluation_strategy: Optional[EvaluationStrategy], dependencies: Dependency, ) -> Tuple[List[dict], Dict[RemoteStruct, OperationKey]]: @@ -39,8 +39,8 @@ def _register_operations( Args: client (substra.Client): substra client - train_data_nodes (typing.List[TrainDataNode]): list of train data nodes - aggregation_node (typing.Optional[AggregationNode]): the aggregation node for + train_data_nodes (typing.List[TrainDataNodeProtocol]): list of train data nodes + aggregation_node (typing.Optional[AggregationNodeProtocol]): the aggregation node for centralized strategies evaluation_strategy (typing.Optional[EvaluationStrategy]): the evaluation strategy if there is one dependencies @@ -80,14 +80,14 @@ def _register_operations( if evaluation_strategy is not None: for test_data_node in evaluation_strategy.test_data_nodes: - test_function_cache = test_data_node.register_test_operations( + test_function_cache = test_data_node.register_operations( client=client, permissions=permissions, cache=test_function_cache, dependencies=dependencies, ) - tasks += test_data_node.testtasks + tasks += test_data_node.tasks # The aggregation operation is defined in the strategy, its dependencies are # the strategy dependencies @@ -110,8 +110,8 @@ def _save_experiment_summary( strategy: ComputePlanBuilder, num_rounds: int, operation_cache: Dict[RemoteStruct, OperationKey], - train_data_nodes: List[TrainDataNode], - aggregation_node: Optional[AggregationNode], + train_data_nodes: List[TrainDataNodeProtocol], + aggregation_node: Optional[AggregationNodeProtocol], evaluation_strategy: EvaluationStrategy, timestamp: str, additional_metadata: Optional[Dict], @@ -124,8 +124,8 @@ def _save_experiment_summary( strategy (substrafl.strategies.Strategy): strategy num_rounds (int): num_rounds operation_cache (typing.Dict[RemoteStruct, OperationKey]): operation_cache - train_data_nodes (typing.List[TrainDataNode]): train_data_nodes - aggregation_node (typing.Optional[AggregationNode]): aggregation_node + train_data_nodes (typing.List[TrainDataNodeProtocol]): train_data_nodes + aggregation_node (typing.Optional[AggregationNodeProtocol]): aggregation_node evaluation_strategy (EvaluationStrategy): evaluation_strategy timestamp (str): timestamp with "%Y_%m_%d_%H_%M_%S" format additional_metadata (dict, Optional): Optional dictionary of metadata to be shown on the Substra WebApp. @@ -209,10 +209,10 @@ def execute_experiment( *, client: substra.Client, strategy: ComputePlanBuilder, - train_data_nodes: List[TrainDataNode], + train_data_nodes: List[TrainDataNodeProtocol], experiment_folder: Union[str, Path], num_rounds: Optional[int] = None, - aggregation_node: Optional[AggregationNode] = None, + aggregation_node: Optional[AggregationNodeProtocol] = None, evaluation_strategy: Optional[EvaluationStrategy] = None, dependencies: Optional[Dependency] = None, clean_models: bool = True, @@ -243,10 +243,10 @@ def execute_experiment( Args: client (substra.Client): A substra client to interact with the Substra platform strategy (Strategy): The strategy that will be executed - train_data_nodes (typing.List[TrainDataNode]): List of the nodes where training on data + train_data_nodes (typing.List[TrainDataNodeProtocol]): List of the nodes where training on data occurs evaluation_strategy (EvaluationStrategy, Optional): If None performance will not be measured at all. Otherwise measuring of performance will follow the EvaluationStrategy. Defaults to None. - aggregation_node (typing.Optional[AggregationNode]): For centralized strategy, the aggregation + aggregation_node (typing.Optional[AggregationNodeProtocol]): For centralized strategy, the aggregation node, where all the shared tasks occurs else None. evaluation_strategy: Optional[EvaluationStrategy] num_rounds (int): The number of time your strategy will be executed diff --git a/substrafl/model_loading.py b/substrafl/model_loading.py index ce5d88ef..b60d7261 100644 --- a/substrafl/model_loading.py +++ b/substrafl/model_loading.py @@ -15,7 +15,7 @@ import substrafl from substrafl import exceptions from substrafl.constants import SUBSTRAFL_FOLDER -from substrafl.nodes.node import OutputIdentifiers +from substrafl.nodes.schemas import OutputIdentifiers from substrafl.remote.remote_struct import RemoteStruct from substrafl.schemas import TaskType diff --git a/substrafl/nodes/__init__.py b/substrafl/nodes/__init__.py index 286ece96..23ba919d 100644 --- a/substrafl/nodes/__init__.py +++ b/substrafl/nodes/__init__.py @@ -1,13 +1,22 @@ -from substrafl.nodes.node import Node # isort:skip -from substrafl.nodes.node import OperationKey # isort:skip +from substrafl.nodes.schemas import OperationKey # isort:skip from substrafl.nodes.aggregation_node import AggregationNode +from substrafl.nodes.protocol import AggregationNodeProtocol +from substrafl.nodes.protocol import TestDataNodeProtocol +from substrafl.nodes.protocol import TrainDataNodeProtocol from substrafl.nodes.test_data_node import TestDataNode from substrafl.nodes.train_data_node import TrainDataNode # This is needed for auto doc to find that Node module's is organizations.organization, otherwise when # trying to link Node references from one page to the Node documentation page, it fails. AggregationNode.__module__ = "organizations.aggregation_node" -Node.__module__ = "organizations.organization" -__all__ = ["Node", "AggregationNode", "TrainDataNode", "TestDataNode", "OperationKey"] +__all__ = [ + "TestDataNodeProtocol", + "TrainDataNodeProtocol", + "AggregationNodeProtocol", + "AggregationNode", + "TrainDataNode", + "TestDataNode", + "OperationKey", +] diff --git a/substrafl/nodes/aggregation_node.py b/substrafl/nodes/aggregation_node.py index f7446d4c..15f72008 100644 --- a/substrafl/nodes/aggregation_node.py +++ b/substrafl/nodes/aggregation_node.py @@ -1,5 +1,6 @@ import uuid from typing import Dict +from typing import List from typing import Optional from typing import Set from typing import TypeVar @@ -7,11 +8,11 @@ import substra from substrafl.dependency import Dependency -from substrafl.nodes.node import InputIdentifiers -from substrafl.nodes.node import Node -from substrafl.nodes.node import OperationKey -from substrafl.nodes.node import OutputIdentifiers +from substrafl.nodes.protocol import AggregationNodeProtocol from substrafl.nodes.references.shared_state import SharedStateRef +from substrafl.nodes.schemas import InputIdentifiers +from substrafl.nodes.schemas import OperationKey +from substrafl.nodes.schemas import OutputIdentifiers from substrafl.remote.operations import RemoteOperation from substrafl.remote.register import register_function from substrafl.remote.remote_struct import RemoteStruct @@ -20,12 +21,16 @@ SharedState = TypeVar("SharedState") -class AggregationNode(Node): +class AggregationNode(AggregationNodeProtocol): """The node which applies operations to the shared states which are received from ``TrainDataNode`` data operations. The result is sent to the ``TrainDataNode`` and/or ``TestDataNode`` data operations. """ + def __init__(self, organization_id: str): + self.organization_id = organization_id + self.tasks: List[Dict] = [] + def update_states( self, operation: RemoteOperation, @@ -161,3 +166,13 @@ def register_operations( task["function_key"] = function_key return cache + + def summary(self) -> dict: + """Summary of the class to be exposed in the experiment summary file + + Returns: + dict: a json-serializable dict with the attributes the user wants to store + """ + return { + "organization_id": self.organization_id, + } diff --git a/substrafl/nodes/node.py b/substrafl/nodes/node.py deleted file mode 100644 index d2354159..00000000 --- a/substrafl/nodes/node.py +++ /dev/null @@ -1,55 +0,0 @@ -from enum import Enum -from typing import Dict -from typing import List -from typing import NewType - -OperationKey = NewType("OperationKey", str) - - -class InputIdentifiers(str, Enum): - local = "local" - shared = "shared" - predictions = "predictions" - opener = "opener" - datasamples = "datasamples" - rank = "rank" - X = "X" - y = "y" - - -class OutputIdentifiers(str, Enum): - local = "local" - shared = "shared" - predictions = "predictions" - - -class Node: - def __init__(self, organization_id: str): - self.organization_id = organization_id - self.tasks: List[Dict] = [] - - def summary(self) -> dict: - """Summary of the class to be exposed in the experiment summary file - For inherited classes, override this function and add ``super.summary()`` - - Example: - - .. code-block:: python - - def summary(self): - - summary = super().summary() - summary.update( - { - "attribute": self.attribute, - ... - } - ) - return summary - - Returns: - dict: a json-serializable dict with the attributes the user wants to store - """ - return { - "organization_id": self.organization_id, - } diff --git a/substrafl/nodes/protocol.py b/substrafl/nodes/protocol.py new file mode 100644 index 00000000..8ffc5f0f --- /dev/null +++ b/substrafl/nodes/protocol.py @@ -0,0 +1,70 @@ +from abc import abstractmethod +from typing import Any +from typing import List +from typing import Protocol +from typing import runtime_checkable + +import substra + +from substrafl.nodes.references.local_state import LocalStateRef +from substrafl.remote.operations import RemoteDataOperation +from substrafl.remote.operations import RemoteOperation + + +@runtime_checkable +class TrainDataNodeProtocol(Protocol): + organization_id: str + data_manager_key: str + data_sample_keys: List[str] + + @abstractmethod + def init_states(self, *args, **kwargs) -> LocalStateRef: + pass + + @abstractmethod + def update_states(self, operation: RemoteDataOperation, *args, **kwargs) -> (LocalStateRef, Any): + pass + + @abstractmethod + def register_operations(self, client: substra.Client, *args, **kwargs) -> Any: + pass + + @abstractmethod + def summary(self) -> dict: + pass + + +@runtime_checkable +class TestDataNodeProtocol(Protocol): + organization_id: str + data_manager_key: str + data_sample_keys: List[str] + + @abstractmethod + def update_states(self, operation: RemoteDataOperation, *args, **kwargs) -> None: + pass + + @abstractmethod + def register_operations(self, client: substra.Client, *args, **kwargs) -> Any: + pass + + @abstractmethod + def summary(self) -> dict: + pass + + +@runtime_checkable +class AggregationNodeProtocol(Protocol): + organization_id: str + + @abstractmethod + def update_states(self, operation: RemoteOperation, *args, **kwargs) -> Any: + pass + + @abstractmethod + def register_operations(self, client: substra.Client, *args, **kwargs) -> Any: + pass + + @abstractmethod + def summary(self) -> dict: + pass diff --git a/substrafl/nodes/schemas.py b/substrafl/nodes/schemas.py new file mode 100644 index 00000000..eb1b60b1 --- /dev/null +++ b/substrafl/nodes/schemas.py @@ -0,0 +1,21 @@ +from enum import Enum +from typing import NewType + +OperationKey = NewType("OperationKey", str) + + +class InputIdentifiers(str, Enum): + local = "local" + shared = "shared" + predictions = "predictions" + opener = "opener" + datasamples = "datasamples" + rank = "rank" + X = "X" + y = "y" + + +class OutputIdentifiers(str, Enum): + local = "local" + shared = "shared" + predictions = "predictions" diff --git a/substrafl/nodes/test_data_node.py b/substrafl/nodes/test_data_node.py index c1256479..7e5543b8 100644 --- a/substrafl/nodes/test_data_node.py +++ b/substrafl/nodes/test_data_node.py @@ -6,36 +6,36 @@ import substra from substrafl.dependency import Dependency -from substrafl.nodes.node import InputIdentifiers -from substrafl.nodes.node import Node -from substrafl.nodes.node import OperationKey -from substrafl.nodes.node import OutputIdentifiers +from substrafl.nodes.protocol import TestDataNodeProtocol +from substrafl.nodes.schemas import InputIdentifiers +from substrafl.nodes.schemas import OperationKey +from substrafl.nodes.schemas import OutputIdentifiers from substrafl.remote.operations import RemoteDataOperation from substrafl.remote.register import register_function from substrafl.remote.remote_struct import RemoteStruct -class TestDataNode(Node): +class TestDataNode(TestDataNodeProtocol): """A node on which you will test your algorithm. Args: organization_id (str): The substra organization ID (shared with other organizations if permissions are needed) data_manager_key (str): Substra data_manager_key opening data samples used by the strategy - test_data_sample_keys (List[str]): Substra data_sample_keys used for the training on this node + data_sample_keys (List[str]): Substra data_sample_keys used for the training on this node """ def __init__( self, organization_id: str, data_manager_key: str, - test_data_sample_keys: List[str], + data_sample_keys: List[str], ): - self.data_manager_key = data_manager_key - self.test_data_sample_keys = test_data_sample_keys + self.organization_id = organization_id - self.testtasks: List[Dict] = [] + self.data_manager_key = data_manager_key + self.data_sample_keys = data_sample_keys - super().__init__(organization_id) + self.tasks: List[Dict] = [] def update_states( self, @@ -58,7 +58,7 @@ def update_states( substra.schemas.InputRef(identifier=InputIdentifiers.opener, asset_key=self.data_manager_key) ] + [ substra.schemas.InputRef(identifier=InputIdentifiers.datasamples, asset_key=data_sample) - for data_sample in self.test_data_sample_keys + for data_sample in self.data_sample_keys ] local_input = [ substra.schemas.InputRef( @@ -89,9 +89,9 @@ def update_states( ).model_dump() testtask.pop("function_key") testtask["remote_operation"] = operation.remote_struct - self.testtasks.append(testtask) + self.tasks.append(testtask) - def register_test_operations( + def register_operations( self, *, client: substra.Client, @@ -99,8 +99,8 @@ def register_test_operations( cache: Dict[RemoteStruct, OperationKey], dependencies: Dependency, ): - for testtask in self.testtasks: - remote_struct: RemoteStruct = testtask["remote_operation"] + for task in self.tasks: + remote_struct: RemoteStruct = task["remote_operation"] if remote_struct not in cache: # Register the evaluation task @@ -141,11 +141,11 @@ def register_test_operations( ], dependencies=dependencies, ) - testtask["function_key"] = function_key + task["function_key"] = function_key cache[remote_struct] = function_key else: function_key = cache[remote_struct] - testtask["function_key"] = function_key + task["function_key"] = function_key return cache @@ -155,11 +155,8 @@ def summary(self) -> dict: Returns: dict: a json-serializable dict with the attributes the user wants to store """ - summary = super().summary() - summary.update( - { - "data_manager_key": self.data_manager_key, - "data_sample_keys": self.test_data_sample_keys, - } - ) - return summary + return { + "organization_id": self.organization_id, + "data_manager_key": self.data_manager_key, + "data_sample_keys": self.data_sample_keys, + } diff --git a/substrafl/nodes/train_data_node.py b/substrafl/nodes/train_data_node.py index 3c1b4d10..e35c168d 100644 --- a/substrafl/nodes/train_data_node.py +++ b/substrafl/nodes/train_data_node.py @@ -8,12 +8,12 @@ import substra from substrafl.dependency import Dependency -from substrafl.nodes.node import InputIdentifiers -from substrafl.nodes.node import Node -from substrafl.nodes.node import OperationKey -from substrafl.nodes.node import OutputIdentifiers +from substrafl.nodes.protocol import TrainDataNodeProtocol from substrafl.nodes.references.local_state import LocalStateRef from substrafl.nodes.references.shared_state import SharedStateRef +from substrafl.nodes.schemas import InputIdentifiers +from substrafl.nodes.schemas import OperationKey +from substrafl.nodes.schemas import OutputIdentifiers from substrafl.remote.operations import RemoteDataOperation from substrafl.remote.operations import RemoteOperation from substrafl.remote.register import register_function @@ -21,7 +21,7 @@ from substrafl.schemas import TaskType -class TrainDataNode(Node): +class TrainDataNode(TrainDataNodeProtocol): """ A predefined structure that allows you to register operations on your train node in a static way before submitting them to substra. @@ -38,12 +38,13 @@ def __init__( data_manager_key: str, data_sample_keys: List[str], ): + self.organization_id = organization_id + self.data_manager_key = data_manager_key self.data_sample_keys = data_sample_keys self.init_task = None - - super().__init__(organization_id) + self.tasks: List[Dict] = [] def init_states( self, @@ -299,11 +300,8 @@ def summary(self) -> dict: Returns: dict: a json-serializable dict with the attributes the user wants to store """ - summary = super().summary() - summary.update( - { - "data_manager_key": self.data_manager_key, - "data_sample_keys": self.data_sample_keys, - } - ) - return summary + return { + "organization_id": self.organization_id, + "data_manager_key": self.data_manager_key, + "data_sample_keys": self.data_sample_keys, + } diff --git a/substrafl/remote/substratools_methods.py b/substrafl/remote/substratools_methods.py index 7b457e62..c71ea1e4 100644 --- a/substrafl/remote/substratools_methods.py +++ b/substrafl/remote/substratools_methods.py @@ -9,8 +9,8 @@ import substratools as tools -from substrafl.nodes.node import InputIdentifiers -from substrafl.nodes.node import OutputIdentifiers +from substrafl.nodes.schemas import InputIdentifiers +from substrafl.nodes.schemas import OutputIdentifiers from substrafl.remote.serializers.pickle_serializer import PickleSerializer from substrafl.remote.serializers.serializer import Serializer diff --git a/substrafl/strategies/fed_avg.py b/substrafl/strategies/fed_avg.py index 1b8fc8b0..bbc39c7a 100644 --- a/substrafl/strategies/fed_avg.py +++ b/substrafl/strategies/fed_avg.py @@ -8,11 +8,11 @@ from substrafl.algorithms.algo import Algo from substrafl.exceptions import EmptySharedStatesError -from substrafl.nodes.aggregation_node import AggregationNode +from substrafl.nodes import AggregationNodeProtocol +from substrafl.nodes import TestDataNodeProtocol +from substrafl.nodes import TrainDataNodeProtocol from substrafl.nodes.references.local_state import LocalStateRef from substrafl.nodes.references.shared_state import SharedStateRef -from substrafl.nodes.test_data_node import TestDataNode -from substrafl.nodes.train_data_node import TrainDataNode from substrafl.remote import remote from substrafl.strategies.schemas import FedAvgAveragedState from substrafl.strategies.schemas import FedAvgSharedState @@ -28,8 +28,8 @@ class FedAvg(Strategy): passes on each client, aggregating updates by computing their means and distributing the consensus update to all clients. In FedAvg, strategy is performed in a centralized way, where a single server or - ``AggregationNode`` communicates with a number of clients ``TrainDataNode`` - and ``TestDataNode``. + ``AggregationNodeProtocol`` communicates with a number of clients ``TrainDataNodeProtocol`` + and ``TestDataNodeProtocol``. Formally, if :math:`w_t` denotes the parameters of the model at round :math:`t`, a single round consists in the following steps: @@ -79,8 +79,8 @@ def name(self) -> StrategyName: def perform_round( self, *, - train_data_nodes: List[TrainDataNode], - aggregation_node: AggregationNode, + train_data_nodes: List[TrainDataNodeProtocol], + aggregation_node: AggregationNodeProtocol, round_idx: int, clean_models: bool, additional_orgs_permissions: Optional[set] = None, @@ -93,9 +93,9 @@ def perform_round( - perform a local update (train on n mini-batches) of the models on each train data nodes Args: - train_data_nodes (typing.List[TrainDataNode]): List of the nodes on which to perform + train_data_nodes (typing.List[TrainDataNodeProtocol]): List of the nodes on which to perform local updates. - aggregation_node (AggregationNode): Node without data, used to perform + aggregation_node (AggregationNodeProtocol): Node without data, used to perform operations on the shared states of the models round_idx (int): Round number, it starts at 0. clean_models (bool): Clean the intermediary models of this round on the Substra platform. @@ -138,15 +138,15 @@ def perform_round( def perform_evaluation( self, - test_data_nodes: List[TestDataNode], - train_data_nodes: List[TrainDataNode], + test_data_nodes: List[TestDataNodeProtocol], + train_data_nodes: List[TrainDataNodeProtocol], round_idx: int, ): """Perform evaluation on test_data_nodes. Args: - test_data_nodes (List[TestDataNode]): test data nodes to perform the prediction from the algo on. - train_data_nodes (List[TrainDataNode]): train data nodes the model has been trained + test_data_nodes (List[TestDataNodeProtocol]): test data nodes to perform the prediction from the algo on. + train_data_nodes (List[TrainDataNodeProtocol]): train data nodes the model has been trained on. round_idx (int): round index. """ @@ -167,7 +167,7 @@ def perform_evaluation( test_data_node.update_states( traintask_id=local_state.key, operation=self.evaluate( - data_samples=test_data_node.test_data_sample_keys, + data_samples=test_data_node.data_sample_keys, _algo_name=f"Evaluating with {self.__class__.__name__}", ), round_idx=round_idx, @@ -225,7 +225,7 @@ def avg_shared_states(self, shared_states: List[FedAvgSharedState]) -> FedAvgAve def _perform_local_updates( self, - train_data_nodes: List[TrainDataNode], + train_data_nodes: List[TrainDataNodeProtocol], current_aggregation: Optional[SharedStateRef], round_idx: int, aggregation_id: str, @@ -236,7 +236,7 @@ def _perform_local_updates( on each train data nodes. Args: - train_data_nodes (typing.List[TrainDataNode]): List of the organizations on which to perform + train_data_nodes (typing.List[TrainDataNodeProtocol]): List of the organizations on which to perform local updates current_aggregation (SharedStateRef, Optional): Reference of an aggregation operation to be passed as input to each local training round_idx (int): Round number, it starts at 1. diff --git a/substrafl/strategies/fed_pca.py b/substrafl/strategies/fed_pca.py index e8bb7a60..83bd00e7 100644 --- a/substrafl/strategies/fed_pca.py +++ b/substrafl/strategies/fed_pca.py @@ -9,11 +9,11 @@ from substrafl.algorithms.algo import Algo from substrafl.exceptions import EmptySharedStatesError -from substrafl.nodes.aggregation_node import AggregationNode +from substrafl.nodes import AggregationNodeProtocol +from substrafl.nodes import TestDataNodeProtocol +from substrafl.nodes import TrainDataNodeProtocol from substrafl.nodes.references.local_state import LocalStateRef from substrafl.nodes.references.shared_state import SharedStateRef -from substrafl.nodes.test_data_node import TestDataNode -from substrafl.nodes.train_data_node import TrainDataNode from substrafl.remote import remote from substrafl.strategies.schemas import FedPCAAveragedState from substrafl.strategies.schemas import FedPCASharedState @@ -103,16 +103,16 @@ def name(self) -> StrategyName: def perform_evaluation( self, - test_data_nodes: List[TestDataNode], - train_data_nodes: List[TrainDataNode], + test_data_nodes: List[TestDataNodeProtocol], + train_data_nodes: List[TrainDataNodeProtocol], round_idx: int, ) -> None: """Perform evaluation on test_data_nodes. Perform prediction before round 3 is not take into account as all objects to compute prediction are not initialize before the second round. Args: - test_data_nodes (List[TestDataNode]): test data nodes to perform the prediction from the algo on. - train_data_nodes (List[TrainDataNode]): train data nodes the model has been trained + test_data_nodes (List[TestDataNodeProtocol]): test data nodes to perform the prediction from the algo on. + train_data_nodes (List[TrainDataNodeProtocol]): train data nodes the model has been trained on. round_idx (int): round index. """ @@ -137,7 +137,7 @@ def perform_evaluation( test_data_node.update_states( traintask_id=local_state.key, operation=self.evaluate( - data_samples=test_data_node.test_data_sample_keys, + data_samples=test_data_node.data_sample_keys, _algo_name=f"Evaluating with {self.__class__.__name__}", ), round_idx=round_idx, @@ -145,8 +145,8 @@ def perform_evaluation( def perform_round( self, - train_data_nodes: List[TrainDataNode], - aggregation_node: AggregationNode, + train_data_nodes: List[TrainDataNodeProtocol], + aggregation_node: AggregationNodeProtocol, round_idx: int, clean_models: bool, additional_orgs_permissions: Optional[set] = None, @@ -159,9 +159,9 @@ def perform_round( - Use the local covariance matrices to compute the orthogonal matrix for every next rounds. Args: - train_data_nodes (typing.List[TrainDataNode]): List of the nodes on which to perform + train_data_nodes (typing.List[TrainDataNodeProtocol]): List of the nodes on which to perform local updates. - aggregation_node (AggregationNode): Node without data, used to perform + aggregation_node (AggregationNodeProtocol): Node without data, used to perform operations on the shared states of the models round_idx (int): Round number, it starts at 0. clean_models (bool): Clean the intermediary models of this round on the Substra platform. @@ -300,7 +300,7 @@ def avg_shared_states_with_qr(self, shared_states: List[FedPCASharedState]) -> F def _perform_local_updates( self, - train_data_nodes: List[TrainDataNode], + train_data_nodes: List[TrainDataNodeProtocol], current_aggregation: Optional[SharedStateRef], round_idx: int, aggregation_id: str, @@ -311,7 +311,7 @@ def _perform_local_updates( on each train data nodes. Args: - train_data_nodes (typing.List[TrainDataNode]): List of the organizations on which to perform + train_data_nodes (typing.List[TrainDataNodeProtocol]): List of the organizations on which to perform local updates current_aggregation (SharedStateRef, Optional): Reference of an aggregation operation to be passed as input to each local training round_idx (int): Round number, it starts at 1. diff --git a/substrafl/strategies/newton_raphson.py b/substrafl/strategies/newton_raphson.py index 13a3bad8..6bdd966f 100644 --- a/substrafl/strategies/newton_raphson.py +++ b/substrafl/strategies/newton_raphson.py @@ -10,11 +10,11 @@ from substrafl.exceptions import DampingFactorValueError from substrafl.exceptions import EmptySharedStatesError from substrafl.exceptions import SharedStatesError -from substrafl.nodes.aggregation_node import AggregationNode +from substrafl.nodes import AggregationNodeProtocol +from substrafl.nodes import TestDataNodeProtocol +from substrafl.nodes import TrainDataNodeProtocol from substrafl.nodes.references.local_state import LocalStateRef from substrafl.nodes.references.shared_state import SharedStateRef -from substrafl.nodes.test_data_node import TestDataNode -from substrafl.nodes.train_data_node import TrainDataNode from substrafl.remote import remote from substrafl.strategies.schemas import NewtonRaphsonAveragedStates from substrafl.strategies.schemas import NewtonRaphsonSharedState @@ -85,8 +85,8 @@ def name(self) -> StrategyName: def perform_round( self, *, - train_data_nodes: List[TrainDataNode], - aggregation_node: AggregationNode, + train_data_nodes: List[TrainDataNodeProtocol], + aggregation_node: AggregationNodeProtocol, round_idx: int, clean_models: bool, additional_orgs_permissions: Optional[set] = None, @@ -100,9 +100,9 @@ def perform_round( - perform a local update of the models on each train data nodes Args: - train_data_nodes (typing.List[TrainDataNode]): List of the nodes on which to perform + train_data_nodes (typing.List[TrainDataNodeProtocol]): List of the nodes on which to perform local updates - aggregation_node (AggregationNode): node without data, used to perform operations + aggregation_node (AggregationNodeProtocol): node without data, used to perform operations on the shared states of the models round_idx (int): Round number, it starts at 0. clean_models (bool): Clean the intermediary models of this round on the Substra platform. @@ -245,7 +245,7 @@ def _unflatten_array( def _perform_local_updates( self, - train_data_nodes: List[TrainDataNode], + train_data_nodes: List[TrainDataNodeProtocol], current_aggregation: Optional[SharedStateRef], round_idx: int, aggregation_id: str, @@ -255,7 +255,7 @@ def _perform_local_updates( """Perform a local update of the model on each train data nodes. Args: - train_data_nodes (typing.List[TrainDataNode]): List of the nodes on which to + train_data_nodes (typing.List[TrainDataNodeProtocol]): List of the nodes on which to perform local updates current_aggregation (SharedStateRef, Optional): Reference of an aggregation operation to be passed as input to each local training @@ -295,15 +295,15 @@ def _perform_local_updates( def perform_evaluation( self, - test_data_nodes: List[TestDataNode], - train_data_nodes: List[TrainDataNode], + test_data_nodes: List[TestDataNodeProtocol], + train_data_nodes: List[TrainDataNodeProtocol], round_idx: int, ): """Perform evaluation on test_data_nodes. Args: - test_data_nodes (List[TestDataNode]): test data nodes to perform the prediction from the algo on. - train_data_nodes (List[TrainDataNode]): train data nodes the model has been trained + test_data_nodes (List[TestDataNodeProtocol]): test data nodes to perform the prediction from the algo on. + train_data_nodes (List[TrainDataNodeProtocol]): train data nodes the model has been trained on. round_idx (int): round index. """ @@ -324,7 +324,7 @@ def perform_evaluation( test_data_node.update_states( operation=self.evaluate( - data_samples=test_data_node.test_data_sample_keys, + data_samples=test_data_node.data_sample_keys, _algo_name=f"Evaluating with {self.__class__.__name__}", ), traintask_id=local_state.key, diff --git a/substrafl/strategies/scaffold.py b/substrafl/strategies/scaffold.py index a8960589..fc73381a 100644 --- a/substrafl/strategies/scaffold.py +++ b/substrafl/strategies/scaffold.py @@ -7,11 +7,11 @@ import numpy as np from substrafl.algorithms.algo import Algo -from substrafl.nodes.aggregation_node import AggregationNode +from substrafl.nodes import AggregationNodeProtocol +from substrafl.nodes import TestDataNodeProtocol +from substrafl.nodes import TrainDataNodeProtocol from substrafl.nodes.references.local_state import LocalStateRef from substrafl.nodes.references.shared_state import SharedStateRef -from substrafl.nodes.test_data_node import TestDataNode -from substrafl.nodes.train_data_node import TrainDataNode from substrafl.remote import remote from substrafl.strategies.schemas import ScaffoldAveragedStates from substrafl.strategies.schemas import ScaffoldSharedState @@ -31,8 +31,8 @@ class Scaffold(Strategy): passes on each client, aggregating updates by computing their means and distributing the consensus update to all clients. In Scaffold, strategy is performed in a centralized way, where a single server or - ``AggregationNode`` communicates with a number of clients ``TrainDataNode`` - and ``TestDataNode``. + ``AggregationNodeProtocol`` communicates with a number of clients ``TrainDataNodeProtocol`` + and ``TestDataNodeProtocol``. """ def __init__( @@ -72,8 +72,8 @@ def name(self) -> StrategyName: def perform_round( self, *, - train_data_nodes: List[TrainDataNode], - aggregation_node: AggregationNode, + train_data_nodes: List[TrainDataNodeProtocol], + aggregation_node: AggregationNodeProtocol, round_idx: int, clean_models: bool, additional_orgs_permissions: Optional[set] = None, @@ -86,8 +86,8 @@ def perform_round( - perform a local update (train on n mini-batches) of the models on each train data nodes Args: - train_data_nodes (typing.List[TrainDataNode]): List of the organizations on which to perform - local updates aggregation_node (AggregationNode): Node without data, used to perform + train_data_nodes (typing.List[TrainDataNodeProtocol]): List of the organizations on which to perform + local updates aggregation_node (AggregationNodeProtocol): Node without data, used to perform operations on the shared states of the models round_idx (int): Round number, it starts at 0. clean_models (bool): Clean the intermediary models of this round on the Substra platform. @@ -130,15 +130,15 @@ def perform_round( def perform_evaluation( self, - test_data_nodes: List[TestDataNode], - train_data_nodes: List[TrainDataNode], + test_data_nodes: List[TestDataNodeProtocol], + train_data_nodes: List[TrainDataNodeProtocol], round_idx: int, ): """Perform evaluation on test_data_nodes. Args: - test_data_nodes (List[TestDataNode]): test data nodes to perform the prediction from the algo on. - train_data_nodes (List[TrainDataNode]): train data nodes the model has been trained + test_data_nodes (List[TestDataNodeProtocol]): test data nodes to perform the prediction from the algo on. + train_data_nodes (List[TrainDataNodeProtocol]): train data nodes the model has been trained on. round_idx (int): round index. """ @@ -158,7 +158,7 @@ def perform_evaluation( test_data_node.update_states( operation=self.evaluate( - data_samples=test_data_node.test_data_sample_keys, + data_samples=test_data_node.data_sample_keys, _algo_name=f"Evaluating with {self.__class__.__name__}", ), traintask_id=local_state.key, @@ -338,7 +338,7 @@ def avg_shared_states(self, shared_states: List[ScaffoldSharedState]) -> Scaffol def _perform_local_updates( self, - train_data_nodes: List[TrainDataNode], + train_data_nodes: List[TrainDataNodeProtocol], current_aggregation: Optional[SharedStateRef], round_idx: int, aggregation_id: str, @@ -349,7 +349,7 @@ def _perform_local_updates( on each train data nodes. Args: - train_data_nodes (typing.List[TrainDataNode]): List of the organizations on which to perform + train_data_nodes (typing.List[TrainDataNodeProtocol]): List of the organizations on which to perform local updates current_aggregation (SharedStateRef, Optional): Reference of an aggregation operation to be passed as input to each local training round_idx (int): Round number, it starts at 1. diff --git a/substrafl/strategies/single_organization.py b/substrafl/strategies/single_organization.py index 75a0a121..b6fe56d1 100644 --- a/substrafl/strategies/single_organization.py +++ b/substrafl/strategies/single_organization.py @@ -6,9 +6,9 @@ from typing import Union from substrafl.algorithms import Algo -from substrafl.nodes import AggregationNode -from substrafl.nodes import TestDataNode -from substrafl.nodes import TrainDataNode +from substrafl.nodes import AggregationNodeProtocol +from substrafl.nodes import TestDataNodeProtocol +from substrafl.nodes import TrainDataNodeProtocol from substrafl.nodes.references.local_state import LocalStateRef from substrafl.strategies.schemas import StrategyName from substrafl.strategies.strategy import Strategy @@ -22,7 +22,7 @@ class SingleOrganization(Strategy): Single organization is not a real federated strategy and it is rather used for testing as it is faster than other 'real' strategies. The training and prediction are performed on a single Node. However, the number of passes to that Node (num_rounds) is still defined to test the actual federated setting. - In SingleOrganization strategy a single client ``TrainDataNode`` and ``TestDataNode`` performs + In SingleOrganization strategy a single client ``TrainDataNodeProtocol`` and ``TestDataNodeProtocol`` performs all the model execution. """ @@ -56,7 +56,7 @@ def name(self) -> StrategyName: def initialization_round( self, *, - train_data_nodes: List[TrainDataNode], + train_data_nodes: List[TrainDataNodeProtocol], clean_models: bool, round_idx: Optional[int] = 0, additional_orgs_permissions: Optional[set] = None, @@ -64,7 +64,7 @@ def initialization_round( """Call the initialize function of the algo on each train node. Args: - train_data_nodes (typing.List[TrainDataNode]): list of the train organizations + train_data_nodes (typing.List[TrainDataNodeProtocol]): list of the train organizations clean_models (bool): Clean the intermediary models of this round on the Substra platform. Set it to False if you want to download or re-use intermediary models. This causes the disk space to fill quickly so should be set to True unless needed. @@ -93,19 +93,19 @@ def initialization_round( def perform_round( self, *, - train_data_nodes: List[TrainDataNode], + train_data_nodes: List[TrainDataNodeProtocol], round_idx: int, clean_models: bool, - aggregation_node: Optional[AggregationNode] = None, + aggregation_node: Optional[AggregationNodeProtocol] = None, additional_orgs_permissions: Optional[set] = None, ): """One round of the SingleOrganization strategy: perform a local update (train on n mini-batches) of the models on a given data node Args: - train_data_nodes (List[TrainDataNode]): List of the nodes on which to perform local + train_data_nodes (List[TrainDataNodeProtocol]): List of the nodes on which to perform local updates, there should be exactly one item in the list. - aggregation_node (AggregationNode): Should be None otherwise it will be ignored + aggregation_node (AggregationNodeProtocol): Should be None otherwise it will be ignored round_idx (int): Round number, it starts at 0. clean_models (bool): Clean the intermediary models of this round on the Substra platform. Set it to False if you want to download or re-use intermediary models. This causes the disk @@ -143,15 +143,15 @@ def perform_round( def perform_evaluation( self, - test_data_nodes: List[TestDataNode], - train_data_nodes: List[TrainDataNode], + test_data_nodes: List[TestDataNodeProtocol], + train_data_nodes: List[TrainDataNodeProtocol], round_idx: int, ): """Perform evaluation on test_data_nodes. Args: - test_data_nodes (List[TestDataNode]): test data nodes to perform the prediction from the algo on. - train_data_nodes (List[TrainDataNode]): train data nodes the model has been trained + test_data_nodes (List[TestDataNodeProtocol]): test data nodes to perform the prediction from the algo on. + train_data_nodes (List[TrainDataNodeProtocol]): train data nodes the model has been trained on. round_idx (int): round index. """ @@ -166,7 +166,7 @@ def perform_evaluation( test_data_node.update_states( traintask_id=self.local_state.key, operation=self.evaluate( - data_samples=test_data_node.test_data_sample_keys, + data_samples=test_data_node.data_sample_keys, _algo_name=f"Evaluating with {self.__class__.__name__}", ), round_idx=round_idx, diff --git a/substrafl/strategies/strategy.py b/substrafl/strategies/strategy.py index c918985f..3442ae6e 100644 --- a/substrafl/strategies/strategy.py +++ b/substrafl/strategies/strategy.py @@ -15,10 +15,10 @@ from substrafl.algorithms.algo import Algo from substrafl.compute_plan_builder import ComputePlanBuilder from substrafl.evaluation_strategy import EvaluationStrategy -from substrafl.nodes.aggregation_node import AggregationNode -from substrafl.nodes.node import OutputIdentifiers -from substrafl.nodes.test_data_node import TestDataNode -from substrafl.nodes.train_data_node import TrainDataNode +from substrafl.nodes import AggregationNodeProtocol +from substrafl.nodes import TestDataNodeProtocol +from substrafl.nodes import TrainDataNodeProtocol +from substrafl.nodes.schemas import OutputIdentifiers from substrafl.remote.decorators import remote_data from substrafl.strategies.schemas import StrategyName @@ -83,7 +83,7 @@ def name(self) -> StrategyName: def initialization_round( self, *, - train_data_nodes: List[TrainDataNode], + train_data_nodes: List[TrainDataNodeProtocol], clean_models: bool, round_idx: Optional[int] = 0, additional_orgs_permissions: Optional[set] = None, @@ -91,7 +91,7 @@ def initialization_round( """Call the initialize function of the algo on each train node. Args: - train_data_nodes (typing.List[TrainDataNode]): list of the train organizations + train_data_nodes (typing.List[TrainDataNodeProtocol]): list of the train organizations clean_models (bool): Clean the intermediary models of this round on the Substra platform. Set it to False if you want to download or re-use intermediary models. This causes the disk space to fill quickly so should be set to True unless needed. @@ -119,8 +119,8 @@ def initialization_round( def perform_round( self, *, - train_data_nodes: List[TrainDataNode], - aggregation_node: Optional[AggregationNode], + train_data_nodes: List[TrainDataNodeProtocol], + aggregation_node: Optional[AggregationNodeProtocol], round_idx: int, clean_models: bool, additional_orgs_permissions: Optional[set] = None, @@ -128,8 +128,8 @@ def perform_round( """Perform one round of the strategy Args: - train_data_nodes (typing.List[TrainDataNode]): list of the train organizations - aggregation_node (typing.Optional[AggregationNode]): aggregation node, necessary for + train_data_nodes (typing.List[TrainDataNodeProtocol]): list of the train organizations + aggregation_node (typing.Optional[AggregationNodeProtocol]): aggregation node, necessary for centralized strategy, unused otherwise round_idx (int): index of the round clean_models (bool): Clean the intermediary models of this round on the Substra platform. @@ -143,8 +143,8 @@ def perform_round( @abstractmethod def perform_evaluation( self, - test_data_nodes: List[TestDataNode], - train_data_nodes: List[TrainDataNode], + test_data_nodes: List[TestDataNodeProtocol], + train_data_nodes: List[TrainDataNodeProtocol], round_idx: int, ): """Perform the evaluation of the algo on each test nodes. @@ -152,8 +152,8 @@ def perform_evaluation( test nodes. Args: - test_data_nodes (typing.List[TestDataNode]): list of nodes on which to evaluate - train_data_nodes (typing.List[TrainDataNode]): list of nodes on which the model has + test_data_nodes (typing.List[TestDataNodeProtocol]): list of nodes on which to evaluate + train_data_nodes (typing.List[TrainDataNodeProtocol]): list of nodes on which the model has been trained round_idx (int): index of the round """ @@ -180,8 +180,8 @@ def evaluate(self, datasamples: Any, shared_state: Any = None) -> Dict[str, floa def build_compute_plan( self, - train_data_nodes: List[TrainDataNode], - aggregation_node: Optional[List[AggregationNode]], + train_data_nodes: List[TrainDataNodeProtocol], + aggregation_node: Optional[List[AggregationNodeProtocol]], evaluation_strategy: Optional[EvaluationStrategy], num_rounds: int, clean_models: Optional[bool] = True, @@ -195,8 +195,8 @@ def build_compute_plan( called to complete the graph. Args: - train_data_nodes (typing.List[TrainDataNode]): list of the train organizations - aggregation_node (typing.Optional[AggregationNode]): aggregation node, necessary for + train_data_nodes (typing.List[TrainDataNodeProtocol]): list of the train organizations + aggregation_node (typing.Optional[AggregationNodeProtocol]): aggregation node, necessary for centralized strategy, unused otherwise evaluation_strategy (Optional[EvaluationStrategy]): evaluation strategy to follow for testing models. num_rounds (int): Number of times to repeat the compute plan sub-graph (define in perform round). diff --git a/tests/algorithms/pytorch/test_base_algo.py b/tests/algorithms/pytorch/test_base_algo.py index daf966c0..3327d19f 100644 --- a/tests/algorithms/pytorch/test_base_algo.py +++ b/tests/algorithms/pytorch/test_base_algo.py @@ -14,8 +14,8 @@ from substrafl.exceptions import DatasetSignatureError from substrafl.exceptions import DatasetTypeError from substrafl.index_generator import NpIndexGenerator -from substrafl.nodes.node import InputIdentifiers -from substrafl.nodes.node import OutputIdentifiers +from substrafl.nodes.schemas import InputIdentifiers +from substrafl.nodes.schemas import OutputIdentifiers from substrafl.remote.decorators import remote_data from substrafl.remote.remote_struct import RemoteStruct from substrafl.remote.serializers import PickleSerializer diff --git a/tests/dependency/test_dependency.py b/tests/dependency/test_dependency.py index 5791e575..22cbd066 100644 --- a/tests/dependency/test_dependency.py +++ b/tests/dependency/test_dependency.py @@ -11,8 +11,8 @@ from substrafl.constants import SUBSTRAFL_FOLDER from substrafl.dependency import Dependency from substrafl.exceptions import InvalidPathError -from substrafl.nodes.node import InputIdentifiers -from substrafl.nodes.node import OutputIdentifiers +from substrafl.nodes.schemas import InputIdentifiers +from substrafl.nodes.schemas import OutputIdentifiers from substrafl.remote import remote_data from substrafl.remote.register.register import register_function diff --git a/tests/remote/test_decorator.py b/tests/remote/test_decorator.py index 6871b989..aa008f6f 100644 --- a/tests/remote/test_decorator.py +++ b/tests/remote/test_decorator.py @@ -3,8 +3,8 @@ import numpy as np -from substrafl.nodes.node import InputIdentifiers -from substrafl.nodes.node import OutputIdentifiers +from substrafl.nodes.schemas import InputIdentifiers +from substrafl.nodes.schemas import OutputIdentifiers from substrafl.remote.decorators import remote from substrafl.remote.decorators import remote_data from substrafl.remote.operations import RemoteDataOperation diff --git a/tests/strategies/test_fed_pca.py b/tests/strategies/test_fed_pca.py index b8070970..a0366ead 100644 --- a/tests/strategies/test_fed_pca.py +++ b/tests/strategies/test_fed_pca.py @@ -131,7 +131,7 @@ def test_fed_pca_predict(dummy_algo_class): round_idx=1, ) - assert all([len(test_data_node.testtasks) == 0 for test_data_node in test_data_nodes]) + assert all([len(test_data_node.tasks) == 0 for test_data_node in test_data_nodes]) strategy.perform_evaluation( test_data_nodes=test_data_nodes, @@ -139,7 +139,7 @@ def test_fed_pca_predict(dummy_algo_class): round_idx=3, ) - assert all([len(test_data_node.testtasks) == 1 for test_data_node in test_data_nodes]) + assert all([len(test_data_node.tasks) == 1 for test_data_node in test_data_nodes]) @pytest.mark.parametrize("additional_orgs_permissions", [set(), {"TestId"}, {"TestId1", "TestId2"}]) diff --git a/tests/strategies/test_newton_raphson.py b/tests/strategies/test_newton_raphson.py index b3c5764f..52e29fb5 100644 --- a/tests/strategies/test_newton_raphson.py +++ b/tests/strategies/test_newton_raphson.py @@ -150,7 +150,7 @@ def test_newton_raphson_predict(dummy_algo_class): round_idx=0, ) - assert all([len(test_data_node.testtasks) == 1 for test_data_node in test_data_nodes]) + assert all([len(test_data_node.tasks) == 1 for test_data_node in test_data_nodes]) @pytest.mark.parametrize("additional_orgs_permissions", [set(), {"TestId"}, {"TestId1", "TestId2"}]) diff --git a/tests/strategies/test_strategy.py b/tests/strategies/test_strategy.py index 0e383ab1..340becc7 100644 --- a/tests/strategies/test_strategy.py +++ b/tests/strategies/test_strategy.py @@ -5,7 +5,7 @@ import pytest from substrafl import exceptions -from substrafl.nodes.node import OutputIdentifiers +from substrafl.nodes.schemas import OutputIdentifiers @pytest.mark.parametrize( diff --git a/tests/test_evaluation_strategy.py b/tests/test_evaluation_strategy.py index d40ad9ce..fb3b2c28 100644 --- a/tests/test_evaluation_strategy.py +++ b/tests/test_evaluation_strategy.py @@ -1,20 +1,21 @@ -from unittest.mock import Mock - import pytest from substrafl.evaluation_strategy import EvaluationStrategy from substrafl.nodes.test_data_node import TestDataNode +@pytest.fixture(scope="module") +def test_data_node(): + return TestDataNode("fake_id", "fake_id", ["fake_id"]) + + @pytest.mark.parametrize("eval_frequency", [1, 2, 4, 10]) -def test_eval_frequency(eval_frequency): +def test_eval_frequency(eval_frequency, test_data_node): # tests that each next() returns expected True or False # tests that next called > num_rounds raises StopIteration n_nodes = 3 num_rounds = 10 # test rounds as frequencies give expected result - # mock the test nodes - test_data_node = Mock(spec=TestDataNode) test_data_nodes = [test_data_node] * n_nodes evaluation_strategy = EvaluationStrategy(test_data_nodes=test_data_nodes, eval_frequency=eval_frequency) @@ -34,14 +35,12 @@ def test_eval_frequency(eval_frequency): @pytest.mark.parametrize("eval_rounds", [[1], [1, 4], [5, 1, 7, 3]]) -def test_eval_rounds(eval_rounds): +def test_eval_rounds(eval_rounds, test_data_node): # tests that each next() returns expected True or False # tests that next called > num_rounds raises StopIteration n_nodes = 3 num_rounds = 10 # test rounds as frequencies give expected result - # mock the test nodes - test_data_node = Mock(spec=TestDataNode) test_data_nodes = [test_data_node] * n_nodes evaluation_strategy = EvaluationStrategy( @@ -63,14 +62,12 @@ def test_eval_rounds(eval_rounds): @pytest.mark.parametrize("eval_frequency, eval_rounds", [(2, [1, 3]), (3, [0, 1])]) -def test_union_eval_rounds_and_eval_frequency(eval_frequency, eval_rounds): +def test_union_eval_rounds_and_eval_frequency(eval_frequency, eval_rounds, test_data_node): # tests that each next() returns expected True or False # tests that next called > num_rounds raises StopIteration n_nodes = 3 num_rounds = 10 # test rounds as frequencies give expected result - # mock the test nodes - test_data_node = Mock(spec=TestDataNode) test_data_nodes = [test_data_node] * n_nodes evaluation_strategy = EvaluationStrategy( @@ -94,8 +91,7 @@ def test_union_eval_rounds_and_eval_frequency(eval_frequency, eval_rounds): response = next(evaluation_strategy) -def test_eval_rounds_and_eval_frequency_at_none(): - test_data_node = Mock(spec=TestDataNode) +def test_eval_rounds_and_eval_frequency_at_none(test_data_node): test_data_nodes = [test_data_node] with pytest.raises(ValueError): @@ -113,13 +109,11 @@ def test_eval_rounds_and_eval_frequency_at_none(): [4.5, TypeError], ], ) -def test_eval_frequency_edges(eval_frequency, e): +def test_eval_frequency_edges(test_data_node, eval_frequency, e): # tests that EvaluationStrategy raises appropriate error if the eval_frequency # is not correct n_nodes = 3 - # mock the test data nodes - test_data_node = Mock(spec=TestDataNode) test_data_nodes = [test_data_node] * n_nodes with pytest.raises(e): @@ -134,13 +128,11 @@ def test_eval_frequency_edges(eval_frequency, e): [[4, -1, 5], ValueError], ], ) -def test_eval_rounds_edges(eval_rounds, e): +def test_eval_rounds_edges(test_data_node, eval_rounds, e): # tests that EvaluationStrategy raises appropriate error if the eval_rounds # is not correct n_nodes = 3 - # mock the test data nodes - test_data_node = Mock(spec=TestDataNode) test_data_nodes = [test_data_node] * n_nodes with pytest.raises(e): @@ -158,13 +150,11 @@ def test_eval_rounds_edges(eval_rounds, e): [None, [1], 1, None], ], ) -def test_rounds_inconsitancy(eval_frequency, eval_rounds, num_rounds, e): +def test_rounds_inconsitancy(test_data_node, eval_frequency, eval_rounds, num_rounds, e): # tests that consistency between selected rounds and num_rounds is # checked for and if inconsistency is found error is raised n_nodes = 3 - # mock the test data nodes - test_data_node = Mock(spec=TestDataNode) test_data_nodes = [test_data_node] * n_nodes evaluation_strategy = EvaluationStrategy( @@ -178,20 +168,10 @@ def test_rounds_inconsitancy(eval_frequency, eval_rounds, num_rounds, e): evaluation_strategy.num_rounds = num_rounds -@pytest.mark.parametrize( - "test_data_nodes, e", - [ - [[Mock(spec=TestDataNode)], None], - [[1], TypeError], - ], -) -def test_error_on_wrong_node_instance(test_data_nodes, e): +def test_error_on_wrong_node_instance(): # test that only list of TestDataNodes are accepted as test_data_nodes - if e is not None: - with pytest.raises(e): - EvaluationStrategy(test_data_nodes=test_data_nodes, eval_frequency=3) - else: - EvaluationStrategy(test_data_nodes=test_data_nodes, eval_frequency=3) + with pytest.raises(TypeError): + EvaluationStrategy(test_data_nodes=[1], eval_frequency=3) @pytest.mark.parametrize( @@ -201,11 +181,10 @@ def test_error_on_wrong_node_instance(test_data_nodes, e): [2, None, 4, [True, False, True, False, True, StopIteration]], ], ) -def test_docstring_examples(eval_frequency, eval_rounds, num_rounds, result): +def test_docstring_examples(test_data_node, eval_frequency, eval_rounds, num_rounds, result): """tests that the examples given in the docstring of EvaluationStrategy indeed give the correct result""" n_nodes = 3 - data_node = Mock(spec=TestDataNode) - data_nodes = [data_node] * n_nodes + data_nodes = [test_data_node] * n_nodes evaluation_strategy = EvaluationStrategy( test_data_nodes=data_nodes, eval_frequency=eval_frequency, eval_rounds=eval_rounds @@ -223,15 +202,13 @@ def test_docstring_examples(eval_frequency, eval_rounds, num_rounds, result): @pytest.mark.parametrize( "eval_frequency, eval_rounds", [(1, None), (2, None), (4, None), (10, None), (None, [1]), (None, [1, 4])] ) -def test_restart_rounds(eval_frequency, eval_rounds): +def test_restart_rounds(test_data_node, eval_frequency, eval_rounds): # tests running a second time an evaluation strategy after calling restart_rounds # give the same results n_nodes = 3 num_rounds = 10 - # mock the test nodes - test_data_node = Mock(spec=TestDataNode) test_data_nodes = [test_data_node] * n_nodes evaluation_strategy = EvaluationStrategy( diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 831f9b17..286c77ed 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -57,7 +57,7 @@ def test_execute_experiment_has_no_side_effect( experiment_folder=session_dir / "experiment_folder", ) - assert sum(len(node.testtasks) for node in test_linear_nodes) == 0 + assert sum(len(node.tasks) for node in test_linear_nodes) == 0 assert sum(len(node.tasks) for node in train_linear_nodes) == 0 assert len(aggregation_node.tasks) == 0 assert cp1 == cp2 diff --git a/tests/test_model_loading.py b/tests/test_model_loading.py index de93cbb9..19a348f5 100644 --- a/tests/test_model_loading.py +++ b/tests/test_model_loading.py @@ -22,7 +22,7 @@ from substrafl.model_loading import REQUIRED_KEYS from substrafl.model_loading import _download_task_output_files from substrafl.model_loading import _load_from_files -from substrafl.nodes.node import OutputIdentifiers +from substrafl.nodes.schemas import OutputIdentifiers from substrafl.remote.register.register import _create_substra_function_files from substrafl.schemas import TaskType diff --git a/tests/utils.py b/tests/utils.py index 9247e2f7..0d312c38 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -1,6 +1,6 @@ import pickle -from substrafl.nodes.node import OutputIdentifiers +from substrafl.nodes.schemas import OutputIdentifiers from substrafl.schemas import TaskType