Skip to content

Commit

Permalink
feat: merge predict and test tasks (#376)
Browse files Browse the repository at this point in the history

Signed-off-by: ThibaultFy <[email protected]>
Signed-off-by: ThibaultFy <[email protected]>
  • Loading branch information
ThibaultFy authored Feb 27, 2024
1 parent f355727 commit bec5bd7
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 60 deletions.
12 changes: 9 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed

- Test and predict tasks are now merged, after [SubstraFL #177](https://github.com/Substra/substrafl/pull/177)
- Rename `predictions_path` to `predictions` in metrics ([#376](https://github.com/Substra/substra-documentation/pull/376))
- Pass `metric_functions` to `Strategy` instead to `TestDataNodes` ([#376](https://github.com/Substra/substra-documentation/pull/376))

## [0.35.0]

### Added
Expand All @@ -17,15 +23,15 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

- Bump Sphinx to 7.2.6, and upgrade linked dependencies ([#388](https://github.com/Substra/substra-documentation/pull/388))
- Examples are not executed when building the documentation ([#388](https://github.com/Substra/substra-documentation/pull/388))

### Fixed

- Restor custom css on nbshpinx gallery ([#394](https://github.com/Substra/substra-documentation/pull/394))

### Removed

- Mentions to Orchestrator distributed mode ([#379](https://github.com/Substra/substra-documentation/pull/379))

## [0.34.0]

### Added
Expand Down
6 changes: 3 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
install-examples-dependencies:
pip3 install -r examples_requirements.txt

examples: example-substra example-substrafl
examples: examples-substra examples-substrafl

example-substra: example-core-diabetes example-core-titanic
examples-substra: example-core-diabetes example-core-titanic

example-core-diabetes:
cd docs/source/examples/substra_core/diabetes_example/ && ipython -c "%run run_diabetes.ipynb"
example-core-titanic:
cd docs/source/examples/substra_core/titanic_example/ && ipython -c "%run run_titanic.ipynb"

example-substrafl: example-fl-mnist example-fl-iris example-fl-cyclic example-fl-diabetes
examples-substrafl: example-fl-mnist example-fl-iris example-fl-cyclic example-fl-diabetes

example-fl-mnist:
cd docs/source/examples/substrafl/get_started/ && ipython -c "%run run_mnist_torch.ipynb"
Expand Down
28 changes: 12 additions & 16 deletions docs/source/examples/substrafl/get_started/run_mnist_torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -249,21 +249,19 @@
"import numpy as np\n",
"\n",
"\n",
"def accuracy(datasamples, predictions_path):\n",
"def accuracy(datasamples, predictions):\n",
" y_true = datasamples[\"labels\"]\n",
" y_pred = np.load(predictions_path)\n",
"\n",
" return accuracy_score(y_true, np.argmax(y_pred, axis=1))\n",
" return accuracy_score(y_true, np.argmax(predictions, axis=1))\n",
"\n",
"\n",
"def roc_auc(datasamples, predictions_path):\n",
"def roc_auc(datasamples, predictions):\n",
" y_true = datasamples[\"labels\"]\n",
" y_pred = np.load(predictions_path)\n",
"\n",
" n_class = np.max(y_true) + 1\n",
" y_true_one_hot = np.eye(n_class)[y_true]\n",
"\n",
" return roc_auc_score(y_true_one_hot, y_pred)"
" return roc_auc_score(y_true_one_hot, predictions)"
]
},
{
Expand Down Expand Up @@ -483,7 +481,7 @@
"source": [
"from substrafl.strategies import FedAvg\n",
"\n",
"strategy = FedAvg(algo=TorchCNN())"
"strategy = FedAvg(algo=TorchCNN(), metric_functions={\"Accuracy\": accuracy, \"ROC AUC\": roc_auc})"
]
},
{
Expand Down Expand Up @@ -556,7 +554,6 @@
" organization_id=org_id,\n",
" data_manager_key=dataset_keys[org_id],\n",
" test_data_sample_keys=[test_datasample_keys[org_id]],\n",
" metric_functions={\"Accuracy\": accuracy, \"ROC AUC\": roc_auc},\n",
" )\n",
" for org_id in DATA_PROVIDER_ORGS_ID\n",
"]\n",
Expand All @@ -576,8 +573,7 @@
"The [Dependency](https://docs.substra.org/en/stable/substrafl_doc/api/dependency.html) object is instantiated in order to install the right libraries in\n",
"the Python environment of each organization.\n",
"\n",
"The CPU torch version is installed here to have a `Dependency` object as light as possible as we don't use GPUs (`use_gpu` set to `False`). Remove the `--extra-index-url` to install the cuda torch version.\n",
"\n"
"The CPU torch version is installed here to have a `Dependency` object as light as possible as we don't use GPUs (`use_gpu` set to `False`). Remove the `--extra-index-url` to install the cuda torch version."
]
},
{
Expand Down Expand Up @@ -647,13 +643,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"The compute plan created is composed of 29 tasks:\n",
"The compute plan created is composed of 21 tasks:\n",
"\n",
"* For each local training step, we create 3 tasks per organisation: training + prediction + evaluation -> 3 tasks.\n",
"* We are training on 2 data organizations; for each round, we have 3 * 2 local tasks + 1 aggregation task -> 7 tasks.\n",
"* We are training for 3 rounds: 3 * 7 -> 21 tasks.\n",
"* Before the first local training step, there is an initialization step on each data organization: 21 + 2 -> 23 tasks.\n",
"* After the last aggregation step, there are three more tasks: applying the last updates from the aggregator + prediction + evaluation, on both organizations: 23 + 2 * 3 -> 29 tasks\n",
"* For each local training step, we create 2 tasks per organisation: training + evaluation -> 2 tasks.\n",
"* We are training on 2 data organizations; for each round, we have 2 * 2 local tasks + 1 aggregation task -> 5 tasks.\n",
"* We are training for 3 rounds: 3 * 5 -> 15 tasks.\n",
"* Before the first local training step, there is an initialization step on each data organization: 15 + 2 -> 17 tasks.\n",
"* After the last aggregation step, there are two more tasks: applying the last updates from the aggregator + evaluation, on both organizations: 17 + 2 * 2 -> 21 tasks\n",
"\n"
]
},
Expand Down
33 changes: 10 additions & 23 deletions docs/source/examples/substrafl/go_further/run_iris_sklearn.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -197,11 +197,10 @@
"import numpy as np\n",
"\n",
"\n",
"def accuracy(datasamples, predictions_path):\n",
"def accuracy(datasamples, predictions):\n",
" y_true = datasamples[\"targets\"]\n",
" y_pred = np.load(predictions_path)\n",
"\n",
" return accuracy_score(y_true, y_pred)"
" return accuracy_score(y_true, predictions)"
]
},
{
Expand Down Expand Up @@ -281,7 +280,7 @@
" The train method must accept as parameters `datasamples` and `shared_state`.\n",
"- **predict** (method): a function to describe how to compute the\n",
" predictions from the algo model.\n",
" The predict method must accept as parameters `datasamples`, `shared_state` and `predictions_path`.\n",
" The predict method must accept as parameters `datasamples` and `shared_state`.\n",
"- **save** (method): specify how to save the important states of our algo.\n",
"- **load** (method): specify how to load the important states of our algo from a previously saved filed\n",
" by the `save` function describe above.\n",
Expand All @@ -302,7 +301,6 @@
"\n",
"import joblib\n",
"from typing import Optional\n",
"import shutil\n",
"\n",
"# The Iris dataset proposes four attributes to predict three different classes.\n",
"INPUT_SIZE = 4\n",
Expand Down Expand Up @@ -390,29 +388,19 @@
" parameters_update=[p for p in delta_coef] + [delta_bias],\n",
" )\n",
"\n",
" @remote.remote_data\n",
" def predict(self, datasamples, shared_state, predictions_path):\n",
" \"\"\"The predict function to be executed on organizations containing\n",
" data we want to test our model on. The @remote_data decorator is mandatory\n",
" to allow this function to be sent and executed on the right organization.\n",
" def predict(self, datasamples, shared_state):\n",
" \"\"\"The predict function to be executed by the evaluation function on\n",
" data we want to test our model on. The predict method is mandatory and is \n",
" an `abstractmethod` of the `Algo` class.\n",
"\n",
" Args:\n",
" datasamples: datasamples extracted from the organizations data using\n",
" the given opener.\n",
" shared_state: shared_state provided by the aggregator.\n",
" predictions_path: Path where to save the predictions.\n",
" This path is provided by Substra and the metric will automatically\n",
" get access to this path to load the predictions.\n",
" \"\"\"\n",
" predictions = self._model.predict(datasamples[\"data\"])\n",
"\n",
" if predictions_path is not None:\n",
" np.save(predictions_path, predictions)\n",
"\n",
" # np.save() automatically adds a \".npy\" to the end of the file.\n",
" # We rename the file produced by removing the \".npy\" suffix, to make sure that\n",
" # predictions_path is the actual file name.\n",
" shutil.move(str(predictions_path) + \".npy\", predictions_path)\n",
" return predictions\n",
"\n",
" def save_local_state(self, path):\n",
" joblib.dump(\n",
Expand Down Expand Up @@ -448,8 +436,8 @@
"outputs": [],
"source": [
"from substrafl.strategies import FedAvg\n",
"\n",
"strategy = FedAvg(algo=SklearnLogisticRegression(model=cls, seed=SEED))"
" \n",
"strategy = FedAvg(algo=SklearnLogisticRegression(model=cls, seed=SEED), metric_functions=accuracy)"
]
},
{
Expand Down Expand Up @@ -508,7 +496,6 @@
" organization_id=org_id,\n",
" data_manager_key=dataset_keys[org_id],\n",
" test_data_sample_keys=[test_datasample_keys[org_id]],\n",
" metric_functions=accuracy,\n",
" )\n",
" for org_id in DATA_PROVIDER_ORGS_ID\n",
"]\n",
Expand Down
39 changes: 24 additions & 15 deletions docs/source/examples/substrafl/go_further/run_mnist_cyclic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -254,21 +254,19 @@
"import numpy as np\n",
"\n",
"\n",
"def accuracy(datasamples, predictions_path):\n",
"def accuracy(datasamples, predictions):\n",
" y_true = datasamples[\"labels\"]\n",
" y_pred = np.load(predictions_path)\n",
"\n",
" return accuracy_score(y_true, np.argmax(y_pred, axis=1))\n",
" return accuracy_score(y_true, np.argmax(predictions, axis=1))\n",
"\n",
"\n",
"def roc_auc(datasamples, predictions_path):\n",
"def roc_auc(datasamples, predictions):\n",
" y_true = datasamples[\"labels\"]\n",
" y_pred = np.load(predictions_path)\n",
"\n",
" n_class = np.max(y_true) + 1\n",
" y_true_one_hot = np.eye(n_class)[y_true]\n",
"\n",
" return roc_auc_score(y_true_one_hot, y_pred)"
" return roc_auc_score(y_true_one_hot, predictions)"
]
},
{
Expand Down Expand Up @@ -446,7 +444,7 @@
"- `initialization_round`, to indicate what tasks to execute at round 0, in order to setup the variable\n",
" and be able to compute the performances of the model before any training.\n",
"- `perform_round`, to indicate what tasks and in which order we need to compute to execute a round of the strategy.\n",
"- `perform_predict`, to indicate how to compute the predictions and performances .\n",
"- `perform_evaluation`, to indicate how to compute the predictions and performances .\n",
"\n",
"\n"
]
Expand All @@ -462,6 +460,8 @@
"from typing import Any\n",
"from typing import List\n",
"from typing import Optional\n",
"from typing import Dict\n",
"from typing import Callable\n",
"\n",
"from substrafl import strategies\n",
"from substrafl.algorithms.algo import Algo\n",
Expand All @@ -479,7 +479,13 @@
" strategy to trigger the tests tasks when needed.\n",
" \"\"\"\n",
"\n",
" def __init__(self, algo: Algo, *args, **kwargs):\n",
" def __init__(\n",
" self, \n",
" algo: Algo, \n",
" metric_functions: Optional[Dict[str, Callable]] = None, \n",
" *args, \n",
" **kwargs,\n",
" ):\n",
" \"\"\"\n",
" It is possible to add any arguments to a Strategy. It is important to pass these arguments as\n",
" args or kwargs to the parent class, using the super().__init__(...) method.\n",
Expand All @@ -490,8 +496,12 @@
" Args:\n",
" algo (Algo): A Strategy takes an Algo as argument, in order to deal with framework\n",
" specific function in a dedicated object.\n",
" metric_functions (Optional[Dict[str, Callable]]):\n",
" list of Functions that implement the different metrics. If a Dict is given, the keys will be used to\n",
" register the result of the associated function. If a Function or a List is given, function.__name__\n",
" will be used to store the result.\n",
" \"\"\"\n",
" super().__init__(algo=algo, *args, **kwargs)\n",
" super().__init__(algo=algo, metric_functions=metric_functions, *args, **kwargs)\n",
"\n",
" self._cyclic_local_state = None\n",
" self._cyclic_shared_state = None\n",
Expand Down Expand Up @@ -590,14 +600,14 @@
" clean_models=clean_models,\n",
" )\n",
"\n",
" def perform_predict(\n",
" def perform_evaluation(\n",
" self,\n",
" test_data_nodes: List[TestDataNode],\n",
" train_data_nodes: List[TrainDataNode],\n",
" round_idx: int,\n",
" ):\n",
" \"\"\"This method is called regarding the given evaluation strategy. If the round is included\n",
" in the evaluation strategy, the ``perform_predict`` method will be called on the different concerned nodes.\n",
" in the evaluation strategy, the ``perform_evaluation`` method will be called on the different concerned nodes.\n",
"\n",
" We are using the last computed ``_cyclic_local_state`` to feed the test task, which mean that we will\n",
" always test the model after its training on the last train data nodes of the list.\n",
Expand All @@ -611,9 +621,9 @@
" for test_node in test_data_nodes:\n",
" test_node.update_states(\n",
" traintask_id=self._cyclic_local_state.key,\n",
" operation=self.algo.predict(\n",
" operation=self.evaluate(\n",
" data_samples=test_node.test_data_sample_keys,\n",
" _algo_name=f\"Predicting with {self.algo.__class__.__name__}\",\n",
" _algo_name=f\"Evaluating with {self.__class__.__name__}\",\n",
" ),\n",
" round_idx=round_idx,\n",
" )"
Expand Down Expand Up @@ -823,7 +833,7 @@
" )\n",
"\n",
"\n",
"strategy = CyclicStrategy(algo=MyAlgo())"
"strategy = CyclicStrategy(algo=MyAlgo(), metric_functions={\"Accuracy\": accuracy, \"ROC AUC\": roc_auc})"
]
},
{
Expand Down Expand Up @@ -892,7 +902,6 @@
" organization_id=org_id,\n",
" data_manager_key=dataset_keys[org_id],\n",
" test_data_sample_keys=[test_datasample_keys[org_id]],\n",
" metric_functions={\"Accuracy\": accuracy, \"ROC AUC\": roc_auc},\n",
" )\n",
" for org_id in DATA_PROVIDER_ORGS_ID\n",
"]\n",
Expand Down

0 comments on commit bec5bd7

Please sign in to comment.