Skip to content

Commit

Permalink
chore: use protocols for nodes (#185)
Browse files Browse the repository at this point in the history
Signed-off-by: ThibaultFy <[email protected]>
Signed-off-by: ThibaultFy <[email protected]>
  • Loading branch information
ThibaultFy authored Feb 28, 2024
1 parent ff7f043 commit 6476e19
Show file tree
Hide file tree
Showing 33 changed files with 340 additions and 299 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- BREAKING: the `perform_predict` method of `Strategy` changed in favor of `perform_evaluation` that calls the new `evaluate` method [#177](https://github.com/Substra/substrafl/pull/177)
- BREAKING: `metric_functions` are now passed to the `Strategy` instead of the `TestDataNode` [#177](https://github.com/Substra/substrafl/pull/177)
- BREAKING: the `predict` method of `Algo` has no `@remote_data` decorator anymore. It signatures does not take `prediction_path` anymore, and the predictions are return by the method [#177](https://github.com/Substra/substrafl/pull/177)
- Abstract base class `Node` is replaced by `Protocols`, defined in `substrafl.nodes.protocol.py` ([#185](https://github.com/Substra/substrafl/pull/185))
- BREAKING: rename `test_data_sample_keys`, `test_tasks` and `register_test_operations`, `tasks` to `data_sample_keys` and `register_operations` in `TestDataNodes` ([#185](https://github.com/Substra/substrafl/pull/185))
- BREAKING: `InputIdentifiers` and `OutputIdentifiers` move from `substrafl.nodes.node` to `substrafl.nodes.schemas` ([#185](https://github.com/Substra/substrafl/pull/185))

## [0.43.0](https://github.com/Substra/substrafl/releases/tag/0.43.0) - 2024-02-26

Expand Down
4 changes: 2 additions & 2 deletions benchmark/camelyon/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,12 +184,12 @@ In remote, one can choose to reuse some assets by passing their keys in the [key
{
"MyOrg1MSP": {
"dataset_key": "b8d754f0-40a5-4976-ae16-8dd4eca35ffc",
"test_data_sample_keys": ["1238452c-a1dd-47ef-84a8-410c0841693a"],
"data_sample_keys": ["1238452c-a1dd-47ef-84a8-410c0841693a"],
"train_data_sample_keys": ["38071944-c974-4b3b-a671-aa4835a0ae62"]
},
"MyOrg2MSP": {
"dataset_key": "fa8e9bf7-5084-4b59-b089-a459495a08be",
"test_data_sample_keys": ["73715d69-9447-4270-9d3f-d0b17bb88a87"],
"data_sample_keys": ["73715d69-9447-4270-9d3f-d0b17bb88a87"],
"train_data_sample_keys": ["766d2029-f90b-440e-8b39-2389ab04041d"]
},
"metric_key": "e5a99be6-0138-461a-92fe-23f685cdc9e1"
Expand Down
6 changes: 3 additions & 3 deletions benchmark/camelyon/pure_substrafl/register_assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ def add_duplicated_dataset(
{
<msp_id>: {
"dataset_key": "b8d754f0-40a5-4976-ae16-8dd4eca35ffc",
"test_data_sample_keys": ["1238452c-a1dd-47ef-84a8-410c0841693a"],
"data_sample_keys": ["1238452c-a1dd-47ef-84a8-410c0841693a"],
"train_data_sample_keys": ["38071944-c974-4b3b-a671-aa4835a0ae62"]
},
<msp_id>: {
"dataset_key": "fa8e9bf7-5084-4b59-b089-a459495a08be",
"test_data_sample_keys": ["73715d69-9447-4270-9d3f-d0b17bb88a87"],
"data_sample_keys": ["73715d69-9447-4270-9d3f-d0b17bb88a87"],
"train_data_sample_keys": ["766d2029-f90b-440e-8b39-2389ab04041d"]
},
...
Expand Down Expand Up @@ -239,7 +239,7 @@ def get_test_data_nodes(
TestDataNode(
organization_id=msp_id,
data_manager_key=asset_keys.get(msp_id)["dataset_key"],
test_data_sample_keys=asset_keys.get(msp_id)["test_data_sample_keys"],
data_sample_keys=asset_keys.get(msp_id)["test_data_sample_keys"],
)
)

Expand Down
12 changes: 8 additions & 4 deletions docs/api/nodes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,17 @@ TestDataNode
^^^^^^^^^^^^^
.. autoclass:: substrafl.nodes.test_data_node.TestDataNode

Node
^^^^^
.. autoclass:: substrafl.nodes.node.Node
Protocols
^^^^^^^^^
.. autoclass:: substrafl.nodes.protocol.TrainDataNodeProtocol
.. autoclass:: substrafl.nodes.protocol.TestDataNodeProtocol
.. autoclass:: substrafl.nodes.protocol.AggregationNodeProtocol

References
^^^^^^^^^^
.. automodule:: substrafl.nodes.references
.. autoclass:: substrafl.nodes.node.OperationKey
.. autoclass:: substrafl.nodes.schemas.OperationKey
.. autoclass:: substrafl.nodes.schemas.InputIdentifiers
.. autoclass:: substrafl.nodes.schemas.OutputIdentifiers
.. autoclass:: substrafl.nodes.references.local_state.LocalStateRef
.. autoclass:: substrafl.nodes.references.shared_state.SharedStateRef
14 changes: 8 additions & 6 deletions substrafl/algorithms/pytorch/torch_base_algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,12 +390,14 @@ def summary(self):
{
"model": str(type(self._model)),
"criterion": str(type(self._criterion)),
"optimizer": None
if self._optimizer is None
else {
"type": str(type(self._optimizer)),
"parameters": self._optimizer.defaults,
},
"optimizer": (
None
if self._optimizer is None
else {
"type": str(type(self._optimizer)),
"parameters": self._optimizer.defaults,
}
),
"scheduler": None if self._scheduler is None else str(type(self._scheduler)),
}
)
Expand Down
12 changes: 6 additions & 6 deletions substrafl/compute_plan_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import Optional

from substrafl.evaluation_strategy import EvaluationStrategy
from substrafl.nodes.aggregation_node import AggregationNode
from substrafl.nodes.train_data_node import TrainDataNode
from substrafl.nodes import AggregationNodeProtocol
from substrafl.nodes import TrainDataNodeProtocol


class ComputePlanBuilder(abc.ABC):
Expand All @@ -31,8 +31,8 @@ def __init__(self, custom_arg, my_custom_kwargs="value"):
@abc.abstractmethod
def build_compute_plan(
self,
train_data_nodes: Optional[List[TrainDataNode]],
aggregation_node: Optional[List[AggregationNode]],
train_data_nodes: Optional[List[TrainDataNodeProtocol]],
aggregation_node: Optional[List[AggregationNodeProtocol]],
evaluation_strategy: Optional[EvaluationStrategy],
num_rounds: Optional[int],
clean_models: Optional[bool] = True,
Expand All @@ -41,8 +41,8 @@ def build_compute_plan(
:func:`~substrafl.experiment.execute_experiment` function.
Args:
train_data_nodes (typing.List[TrainDataNode]): list of the train organizations
aggregation_node (typing.Optional[AggregationNode]): aggregation node, necessary for
train_data_nodes (typing.List[TrainDataNodeProtocol]): list of the train organizations
aggregation_node (typing.Optional[AggregationNodeProtocol]): aggregation node, necessary for
centralized strategy, unused otherwise
evaluation_strategy (Optional[EvaluationStrategy]): evaluation strategy to follow for testing models.
num_rounds (int): Number of times to repeat the compute plan sub-graph (define in perform round). It is
Expand Down
12 changes: 6 additions & 6 deletions substrafl/evaluation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from typing import Optional
from typing import Set

from substrafl.nodes.test_data_node import TestDataNode
from substrafl.nodes import TestDataNodeProtocol


class EvaluationStrategy:
def __init__(
self,
test_data_nodes: List[TestDataNode],
test_data_nodes: List[TestDataNodeProtocol],
eval_frequency: Optional[int] = None,
eval_rounds: Optional[List[int]] = None,
) -> None:
Expand All @@ -17,15 +17,15 @@ def __init__(
union of both selected indexes will be evaluated.
Args:
test_data_nodes (List[TestDataNode]): nodes on which the model is to be tested.
test_data_nodes (List[TestDataNodeProtocol]): nodes on which the model is to be tested.
eval_frequency (Optional[int]): The model will be tested every ``eval_frequency`` rounds.
Set to None to activate eval_rounds only. Defaults to None.
eval_rounds (Optional[List[int]]): If specified, the model will be tested on the index of a round given
in the rounds list. Set to None to activate eval_frequency only. Defaults to None.
Raises:
ValueError: test_data_nodes cannot be an empty list
TypeError: test_data_nodes must be filled with instances of TestDataNode
TypeError: test_data_nodes must be filled with instances of TestDataNodeProtocol
TypeError: rounds must be a list or an int
ValueError: both eval_rounds and eval_frequency cannot be None at the same time
Expand Down Expand Up @@ -83,8 +83,8 @@ def __init__(
if not test_data_nodes:
raise ValueError("test_data_nodes lists cannot be empty")

if not all(isinstance(node, TestDataNode) for node in test_data_nodes):
raise TypeError("test_data_nodes must include objects of TestDataNode type")
if not all(isinstance(node, TestDataNodeProtocol) for node in test_data_nodes):
raise TypeError("test_data_nodes must implement the TestDataNodeProtocol")

if eval_frequency is None and eval_rounds is None:
raise ValueError("At least one of eval_frequency or eval_rounds must be defined")
Expand Down
34 changes: 17 additions & 17 deletions substrafl/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,27 @@
from substrafl.evaluation_strategy import EvaluationStrategy
from substrafl.exceptions import KeyMetadataError
from substrafl.exceptions import LenMetadataError
from substrafl.nodes.aggregation_node import AggregationNode
from substrafl.nodes.node import OperationKey
from substrafl.nodes.train_data_node import TrainDataNode
from substrafl.nodes import AggregationNodeProtocol
from substrafl.nodes import TrainDataNodeProtocol
from substrafl.nodes.schemas import OperationKey
from substrafl.remote.remote_struct import RemoteStruct

logger = logging.getLogger(__name__)


def _register_operations(
client: substra.Client,
train_data_nodes: List[TrainDataNode],
aggregation_node: Optional[AggregationNode],
train_data_nodes: List[TrainDataNodeProtocol],
aggregation_node: Optional[AggregationNodeProtocol],
evaluation_strategy: Optional[EvaluationStrategy],
dependencies: Dependency,
) -> Tuple[List[dict], Dict[RemoteStruct, OperationKey]]:
"""Register the operations in Substra: define the functions we need and submit them
Args:
client (substra.Client): substra client
train_data_nodes (typing.List[TrainDataNode]): list of train data nodes
aggregation_node (typing.Optional[AggregationNode]): the aggregation node for
train_data_nodes (typing.List[TrainDataNodeProtocol]): list of train data nodes
aggregation_node (typing.Optional[AggregationNodeProtocol]): the aggregation node for
centralized strategies
evaluation_strategy (typing.Optional[EvaluationStrategy]): the evaluation strategy
if there is one dependencies
Expand Down Expand Up @@ -80,14 +80,14 @@ def _register_operations(

if evaluation_strategy is not None:
for test_data_node in evaluation_strategy.test_data_nodes:
test_function_cache = test_data_node.register_test_operations(
test_function_cache = test_data_node.register_operations(
client=client,
permissions=permissions,
cache=test_function_cache,
dependencies=dependencies,
)

tasks += test_data_node.testtasks
tasks += test_data_node.tasks

# The aggregation operation is defined in the strategy, its dependencies are
# the strategy dependencies
Expand All @@ -110,8 +110,8 @@ def _save_experiment_summary(
strategy: ComputePlanBuilder,
num_rounds: int,
operation_cache: Dict[RemoteStruct, OperationKey],
train_data_nodes: List[TrainDataNode],
aggregation_node: Optional[AggregationNode],
train_data_nodes: List[TrainDataNodeProtocol],
aggregation_node: Optional[AggregationNodeProtocol],
evaluation_strategy: EvaluationStrategy,
timestamp: str,
additional_metadata: Optional[Dict],
Expand All @@ -124,8 +124,8 @@ def _save_experiment_summary(
strategy (substrafl.strategies.Strategy): strategy
num_rounds (int): num_rounds
operation_cache (typing.Dict[RemoteStruct, OperationKey]): operation_cache
train_data_nodes (typing.List[TrainDataNode]): train_data_nodes
aggregation_node (typing.Optional[AggregationNode]): aggregation_node
train_data_nodes (typing.List[TrainDataNodeProtocol]): train_data_nodes
aggregation_node (typing.Optional[AggregationNodeProtocol]): aggregation_node
evaluation_strategy (EvaluationStrategy): evaluation_strategy
timestamp (str): timestamp with "%Y_%m_%d_%H_%M_%S" format
additional_metadata (dict, Optional): Optional dictionary of metadata to be shown on the Substra WebApp.
Expand Down Expand Up @@ -209,10 +209,10 @@ def execute_experiment(
*,
client: substra.Client,
strategy: ComputePlanBuilder,
train_data_nodes: List[TrainDataNode],
train_data_nodes: List[TrainDataNodeProtocol],
experiment_folder: Union[str, Path],
num_rounds: Optional[int] = None,
aggregation_node: Optional[AggregationNode] = None,
aggregation_node: Optional[AggregationNodeProtocol] = None,
evaluation_strategy: Optional[EvaluationStrategy] = None,
dependencies: Optional[Dependency] = None,
clean_models: bool = True,
Expand Down Expand Up @@ -243,10 +243,10 @@ def execute_experiment(
Args:
client (substra.Client): A substra client to interact with the Substra platform
strategy (Strategy): The strategy that will be executed
train_data_nodes (typing.List[TrainDataNode]): List of the nodes where training on data
train_data_nodes (typing.List[TrainDataNodeProtocol]): List of the nodes where training on data
occurs evaluation_strategy (EvaluationStrategy, Optional): If None performance will not be measured at all.
Otherwise measuring of performance will follow the EvaluationStrategy. Defaults to None.
aggregation_node (typing.Optional[AggregationNode]): For centralized strategy, the aggregation
aggregation_node (typing.Optional[AggregationNodeProtocol]): For centralized strategy, the aggregation
node, where all the shared tasks occurs else None.
evaluation_strategy: Optional[EvaluationStrategy]
num_rounds (int): The number of time your strategy will be executed
Expand Down
2 changes: 1 addition & 1 deletion substrafl/model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import substrafl
from substrafl import exceptions
from substrafl.constants import SUBSTRAFL_FOLDER
from substrafl.nodes.node import OutputIdentifiers
from substrafl.nodes.schemas import OutputIdentifiers
from substrafl.remote.remote_struct import RemoteStruct
from substrafl.schemas import TaskType

Expand Down
17 changes: 13 additions & 4 deletions substrafl/nodes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from substrafl.nodes.node import Node # isort:skip
from substrafl.nodes.node import OperationKey # isort:skip
from substrafl.nodes.schemas import OperationKey # isort:skip

from substrafl.nodes.aggregation_node import AggregationNode
from substrafl.nodes.protocol import AggregationNodeProtocol
from substrafl.nodes.protocol import TestDataNodeProtocol
from substrafl.nodes.protocol import TrainDataNodeProtocol
from substrafl.nodes.test_data_node import TestDataNode
from substrafl.nodes.train_data_node import TrainDataNode

# This is needed for auto doc to find that Node module's is organizations.organization, otherwise when
# trying to link Node references from one page to the Node documentation page, it fails.
AggregationNode.__module__ = "organizations.aggregation_node"
Node.__module__ = "organizations.organization"

__all__ = ["Node", "AggregationNode", "TrainDataNode", "TestDataNode", "OperationKey"]
__all__ = [
"TestDataNodeProtocol",
"TrainDataNodeProtocol",
"AggregationNodeProtocol",
"AggregationNode",
"TrainDataNode",
"TestDataNode",
"OperationKey",
]
25 changes: 20 additions & 5 deletions substrafl/nodes/aggregation_node.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import uuid
from typing import Dict
from typing import List
from typing import Optional
from typing import Set
from typing import TypeVar

import substra

from substrafl.dependency import Dependency
from substrafl.nodes.node import InputIdentifiers
from substrafl.nodes.node import Node
from substrafl.nodes.node import OperationKey
from substrafl.nodes.node import OutputIdentifiers
from substrafl.nodes.protocol import AggregationNodeProtocol
from substrafl.nodes.references.shared_state import SharedStateRef
from substrafl.nodes.schemas import InputIdentifiers
from substrafl.nodes.schemas import OperationKey
from substrafl.nodes.schemas import OutputIdentifiers
from substrafl.remote.operations import RemoteOperation
from substrafl.remote.register import register_function
from substrafl.remote.remote_struct import RemoteStruct
Expand All @@ -20,12 +21,16 @@
SharedState = TypeVar("SharedState")


class AggregationNode(Node):
class AggregationNode(AggregationNodeProtocol):
"""The node which applies operations to the shared states which are received from ``TrainDataNode``
data operations.
The result is sent to the ``TrainDataNode`` and/or ``TestDataNode`` data operations.
"""

def __init__(self, organization_id: str):
self.organization_id = organization_id
self.tasks: List[Dict] = []

def update_states(
self,
operation: RemoteOperation,
Expand Down Expand Up @@ -161,3 +166,13 @@ def register_operations(
task["function_key"] = function_key

return cache

def summary(self) -> dict:
"""Summary of the class to be exposed in the experiment summary file
Returns:
dict: a json-serializable dict with the attributes the user wants to store
"""
return {
"organization_id": self.organization_id,
}
Loading

0 comments on commit 6476e19

Please sign in to comment.