Skip to content

Commit

Permalink
chore: rename test_data_samples_keys to data_samples_keys (#398)
Browse files Browse the repository at this point in the history

Signed-off-by: ThibaultFy <[email protected]>
  • Loading branch information
ThibaultFy authored Feb 28, 2024
1 parent bec5bd7 commit 74b145c
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Rename `test_data_sample_keys` to `data_sample_keys` on `TestDataNodes` after [SubstraFL #185](https://github.com/Substra/substrafl/pull/185) ([#398](https://github.com/Substra/substra-documentation/pull/398))
- Test and predict tasks are now merged, after [SubstraFL #177](https://github.com/Substra/substrafl/pull/177)
- Rename `predictions_path` to `predictions` in metrics ([#376](https://github.com/Substra/substra-documentation/pull/376))
- Pass `metric_functions` to `Strategy` instead to `TestDataNodes` ([#376](https://github.com/Substra/substra-documentation/pull/376))
Expand Down
10 changes: 4 additions & 6 deletions docs/source/examples/substrafl/get_started/run_mnist_torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,9 @@
"**datasamples**.\n",
"\n",
"To add a metric, you need to define a function that computes and returns a performance\n",
"from the datasamples (as returned by the opener) and the predictions_path (to be loaded within the function).\n",
"from the datasamples (as returned by the opener) and the predictions of the model.\n",
"\n",
"When using a Torch SubstraFL algorithm, the predictions are saved in the `predict` function in numpy format\n",
"so that you can simply load them using `np.load`.\n",
"\n"
"When using a Torch SubstraFL algorithm, the predictions are returned by the `predict` function.\n"
]
},
{
Expand Down Expand Up @@ -553,7 +551,7 @@
" TestDataNode(\n",
" organization_id=org_id,\n",
" data_manager_key=dataset_keys[org_id],\n",
" test_data_sample_keys=[test_datasample_keys[org_id]],\n",
" data_sample_keys=[test_datasample_keys[org_id]],\n",
" )\n",
" for org_id in DATA_PROVIDER_ORGS_ID\n",
"]\n",
Expand Down Expand Up @@ -788,7 +786,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@
" org_id: clients[org_id].add_data_sample(\n",
" DataSampleSpec(\n",
" data_manager_keys=[dataset_keys[org_id]],\n",
" test_only=False,\n",
" path=data_path / f\"org_{i + 1}\",\n",
" ),\n",
" local=True,\n",
Expand Down Expand Up @@ -690,7 +689,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@
"\n",
" def predict(self, datasamples, shared_state):\n",
" \"\"\"The predict function to be executed by the evaluation function on\n",
" data we want to test our model on. The predict method is mandatory and is \n",
" data we want to test our model on. The predict method is mandatory and is\n",
" an `abstractmethod` of the `Algo` class.\n",
"\n",
" Args:\n",
Expand Down Expand Up @@ -436,7 +436,7 @@
"outputs": [],
"source": [
"from substrafl.strategies import FedAvg\n",
" \n",
"\n",
"strategy = FedAvg(algo=SklearnLogisticRegression(model=cls, seed=SEED), metric_functions=accuracy)"
]
},
Expand Down Expand Up @@ -495,7 +495,7 @@
" TestDataNode(\n",
" organization_id=org_id,\n",
" data_manager_key=dataset_keys[org_id],\n",
" test_data_sample_keys=[test_datasample_keys[org_id]],\n",
" data_sample_keys=[test_datasample_keys[org_id]],\n",
" )\n",
" for org_id in DATA_PROVIDER_ORGS_ID\n",
"]\n",
Expand Down Expand Up @@ -658,7 +658,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down
17 changes: 8 additions & 9 deletions docs/source/examples/substrafl/go_further/run_mnist_cyclic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,9 @@
"**datasamples**.\n",
"\n",
"To add a metric, you need to define a function that computes and returns a performance\n",
"from the datasamples (as returned by the opener) and the predictions_path (to be loaded within the function).\n",
"from the datasamples (as returned by the opener) and the predictions of the model.\n",
"\n",
"When using a Torch SubstraFL algorithm, the predictions are saved in the `predict` function in numpy format\n",
"so that you can simply load them using `np.load`.\n",
"When using a Torch SubstraFL algorithm, the predictions are returned by the `predict` function.\n",
"\n"
]
},
Expand Down Expand Up @@ -480,10 +479,10 @@
" \"\"\"\n",
"\n",
" def __init__(\n",
" self, \n",
" algo: Algo, \n",
" metric_functions: Optional[Dict[str, Callable]] = None, \n",
" *args, \n",
" self,\n",
" algo: Algo,\n",
" metric_functions: Optional[Dict[str, Callable]] = None,\n",
" *args,\n",
" **kwargs,\n",
" ):\n",
" \"\"\"\n",
Expand Down Expand Up @@ -622,7 +621,7 @@
" test_node.update_states(\n",
" traintask_id=self._cyclic_local_state.key,\n",
" operation=self.evaluate(\n",
" data_samples=test_node.test_data_sample_keys,\n",
" data_samples=test_node.data_sample_keys,\n",
" _algo_name=f\"Evaluating with {self.__class__.__name__}\",\n",
" ),\n",
" round_idx=round_idx,\n",
Expand Down Expand Up @@ -901,7 +900,7 @@
" TestDataNode(\n",
" organization_id=org_id,\n",
" data_manager_key=dataset_keys[org_id],\n",
" test_data_sample_keys=[test_datasample_keys[org_id]],\n",
" data_sample_keys=[test_datasample_keys[org_id]],\n",
" )\n",
" for org_id in DATA_PROVIDER_ORGS_ID\n",
"]\n",
Expand Down

0 comments on commit 74b145c

Please sign in to comment.