Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: rename test_data_samples_keys to data_samples_keys #398

Merged
merged 5 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Changed

- Rename `test_data_sample_keys` to `data_sample_keys` on `TestDataNodes` after [SubstraFL #185](https://github.com/Substra/substrafl/pull/185) ([#398](https://github.com/Substra/substra-documentation/pull/398))
- Test and predict tasks are now merged, after [SubstraFL #177](https://github.com/Substra/substrafl/pull/177)
- Rename `predictions_path` to `predictions` in metrics ([#376](https://github.com/Substra/substra-documentation/pull/376))
- Pass `metric_functions` to `Strategy` instead to `TestDataNodes` ([#376](https://github.com/Substra/substra-documentation/pull/376))
Expand Down
10 changes: 4 additions & 6 deletions docs/source/examples/substrafl/get_started/run_mnist_torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -229,11 +229,9 @@
"**datasamples**.\n",
"\n",
"To add a metric, you need to define a function that computes and returns a performance\n",
"from the datasamples (as returned by the opener) and the predictions_path (to be loaded within the function).\n",
"from the datasamples (as returned by the opener) and the predictions of the model.\n",
"\n",
"When using a Torch SubstraFL algorithm, the predictions are saved in the `predict` function in numpy format\n",
"so that you can simply load them using `np.load`.\n",
"\n"
"When using a Torch SubstraFL algorithm, the predictions are returned by the `predict` function.\n"
]
},
{
Expand Down Expand Up @@ -553,7 +551,7 @@
" TestDataNode(\n",
" organization_id=org_id,\n",
" data_manager_key=dataset_keys[org_id],\n",
" test_data_sample_keys=[test_datasample_keys[org_id]],\n",
" data_sample_keys=[test_datasample_keys[org_id]],\n",
" )\n",
" for org_id in DATA_PROVIDER_ORGS_ID\n",
"]\n",
Expand Down Expand Up @@ -788,7 +786,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,6 @@
" org_id: clients[org_id].add_data_sample(\n",
" DataSampleSpec(\n",
" data_manager_keys=[dataset_keys[org_id]],\n",
" test_only=False,\n",
" path=data_path / f\"org_{i + 1}\",\n",
" ),\n",
" local=True,\n",
Expand Down Expand Up @@ -690,7 +689,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@
"\n",
" def predict(self, datasamples, shared_state):\n",
" \"\"\"The predict function to be executed by the evaluation function on\n",
" data we want to test our model on. The predict method is mandatory and is \n",
" data we want to test our model on. The predict method is mandatory and is\n",
" an `abstractmethod` of the `Algo` class.\n",
"\n",
" Args:\n",
Expand Down Expand Up @@ -436,7 +436,7 @@
"outputs": [],
"source": [
"from substrafl.strategies import FedAvg\n",
" \n",
"\n",
"strategy = FedAvg(algo=SklearnLogisticRegression(model=cls, seed=SEED), metric_functions=accuracy)"
]
},
Expand Down Expand Up @@ -495,7 +495,7 @@
" TestDataNode(\n",
" organization_id=org_id,\n",
" data_manager_key=dataset_keys[org_id],\n",
" test_data_sample_keys=[test_datasample_keys[org_id]],\n",
" data_sample_keys=[test_datasample_keys[org_id]],\n",
" )\n",
" for org_id in DATA_PROVIDER_ORGS_ID\n",
"]\n",
Expand Down Expand Up @@ -658,7 +658,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.17"
"version": "3.11.6"
}
},
"nbformat": 4,
Expand Down
17 changes: 8 additions & 9 deletions docs/source/examples/substrafl/go_further/run_mnist_cyclic.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -234,10 +234,9 @@
"**datasamples**.\n",
"\n",
"To add a metric, you need to define a function that computes and returns a performance\n",
"from the datasamples (as returned by the opener) and the predictions_path (to be loaded within the function).\n",
"from the datasamples (as returned by the opener) and the predictions of the model.\n",
"\n",
"When using a Torch SubstraFL algorithm, the predictions are saved in the `predict` function in numpy format\n",
"so that you can simply load them using `np.load`.\n",
"When using a Torch SubstraFL algorithm, the predictions are returned by the `predict` function.\n",
"\n"
]
},
Expand Down Expand Up @@ -480,10 +479,10 @@
" \"\"\"\n",
"\n",
" def __init__(\n",
" self, \n",
" algo: Algo, \n",
" metric_functions: Optional[Dict[str, Callable]] = None, \n",
" *args, \n",
" self,\n",
" algo: Algo,\n",
" metric_functions: Optional[Dict[str, Callable]] = None,\n",
" *args,\n",
" **kwargs,\n",
" ):\n",
" \"\"\"\n",
Expand Down Expand Up @@ -622,7 +621,7 @@
" test_node.update_states(\n",
" traintask_id=self._cyclic_local_state.key,\n",
" operation=self.evaluate(\n",
" data_samples=test_node.test_data_sample_keys,\n",
" data_samples=test_node.data_sample_keys,\n",
" _algo_name=f\"Evaluating with {self.__class__.__name__}\",\n",
" ),\n",
" round_idx=round_idx,\n",
Expand Down Expand Up @@ -901,7 +900,7 @@
" TestDataNode(\n",
" organization_id=org_id,\n",
" data_manager_key=dataset_keys[org_id],\n",
" test_data_sample_keys=[test_datasample_keys[org_id]],\n",
" data_sample_keys=[test_datasample_keys[org_id]],\n",
" )\n",
" for org_id in DATA_PROVIDER_ORGS_ID\n",
"]\n",
Expand Down
Loading