Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: fast simulation mode for substraFL #168

Closed
wants to merge 37 commits into from
Closed

Conversation

ghost
Copy link

@ghost ghost commented Sep 13, 2023

Related issue

There is no formal issue open on this repo, however this PR addresses the following real issue.

Although the subprocess mode is local (which is great), its speed is limited by hard constraints:

  • disk I/O (all the exchanges take place with some disk save/loads)
  • subprocess calls (there can be some latency)

As a consequence, this mode can be slow compared to a plain python process, especially when the number of local steps is small. The gap can span 1-2 orders of magnitude depending on the number of local steps.

Summary

This PR proposes a new simulation mode for substraFL. It completely sidesteps substra to perform the machine learning operations in memory, in the same process. It still requires to use a substra client in subprocess mode though, for registering the data. For small data, this is not an issue ; for bigger data, we might need to change this as well.

Expected benefits

  • speed
  • easier to debug (because it's faster, and objects are simpler to inspect)
  • coverage can track the coverage of the FL strategies

General approach

This PR implements novel classes, SimuTrainDataNode, SimuTestDataNode, and SimuAggregationNode which act as placeholders for TrainDataNode, TestDataNode and AggregationNode.

  • They expose the same methods with the same API: this way, they can be plugged in seamlessly into substraFL
  • They require more initialization arguments: SimuTrainNode and SimuTestNode require both an algo and a substra_client, while SimuAggregationNode requires the strategy

The main idea is to use the update_states method to call the remote operation which is passed with _skip=True. For remote_data methods, we additionally need to load the datasamples and update the method_parameters. This is done once and cached, to avoid an I/O bottleneck.

For each org, the corresponding SimuTrainDataNode and SimuTestDataNode share the same algo attribute: they are paired right after the beginning of execute_experiment.

To avoid any change to the user, the nodes are modified in execute_experiment if the new optional argument simu_mode is True. This is less optimal than doing the changes on the client side, but this is cleaner from an architectural standpoint.

Notes

Test data node's caveats

For TrainDataNode and AggregationNode, the approach is straighforward, but for TestDataNode things become messier. More precisely, the predict method of the algo already saves the predictions to disk, instead of returning them (as would be expected in a sklearn API). As a consequence, even if we are able to catch the algo.predict method within the SimuTestDatanode, we still have a bottleneck. Similarly, the expected API of the metric functions also hardcodes the fact that predictions are dumped to disk.

There would be at least two ways to solve this problem.

1. The hacky way (current)

  • I added an optional argument return_predictions to the predict method of the base torch algorithm. It is by default False, so this should not affect real runs. In the SimuTestDataNode, this argument is flagged as True, so we are able to retrieve the true predictions.
  • To then gather scores, we need to compute the metrics. Here, we cannot work with the current metrics, because they really expect the signature metric_func(datasamples, predictions_path). What I propose is to change the metric functions so that the content of predictions_path is already the predictions (but the argument name is hardcoded). This is a bit hacky, but I did not see another way.

2. The sustainable way

The current way of dealing with the metrics could be improved while allowing to tackle the above issue. I think the current implementation still draws a lot on the old subtra and has not completely benefitted from the GenericTask update.

My proposition would be to change the API of the algo as sketched below:

class Algo
    ...
    def predict(self, datasamples) -> Tensor:
         return self._model.predict(datasamples["samples"])

    def score(self, datasamples, metric_functions) -> Dict[float]:
         y_pred = self.predict(datasamples["samples"])
         output_dict = {}
         for key, metric in metric_functions.items():
              output_dict[key] = metric(datasamples["labels"], y_pred)
         return output_dict

This is much more sklearn-ish, which is always great. Further, this would circumvent the need to submit 2 tasks for 1 metric evaluation, as the predictions would not be done on disk. Last, by replacing the following lines

test_data_node.update_states(
    self.algo.predict(...)
    datasamples=...,
    metric_functions=...,
)

by

test_data_node.update_states(
    self.algo.score(metric_functions=metric_functions)
    datasamples=...
)

the behaviours of SimuTrainDataNode and SimuTestDataNode would be exactly the same.

Calls to torch.no_grad

The extensive use of torch.inference_mode in the weight_manager prevents the future calls to the model for training purposes in the same process and object. This does not appear in all other modes because everything is dumped to disk and deleted, but here the algo remains the same object during the full training. I therefore replaced them with torch.no_grad, which supposedly does the same thing. All tests still pass so I do not think it is a problem.

Optimization of _local_predict

I did an additional modification of the _local_predict methods: instead of a pattern

predictions = torch.Tensor([])
for x in dataloader:
    predictions = torch.cat([predictions, model(x)])

I propose

predictions = []
for x in dataloader:
    predictions.append(model(x))
predictions = torch.cat(predictions)

Indeed, each call to torch.cat involves a copy of the existing array to a new place. Therefore, the first proposal is quadratic in memory, while the first one is linear. This can lead to a speed-up when arrays are large. Of course, I am happy to move this to a separate PR if this is cleaner.

Please check if the PR fulfills these requirements

The following tasks have yet to be performed:

  • If the feature has an impact on the user experience, the changelog has been updated
  • Tests for the changes have been added (for bug fixes / features)
  • Docs have been added / updated (for bug fixes / features)
  • The commit message follows the conventional commit specification

@ghost ghost force-pushed the feat/substrafl-simu-mode branch 2 times, most recently from 977d90e to ee7e4a3 Compare September 19, 2023 10:02
Copy link
Contributor

@SdgJlbl SdgJlbl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks a lot for the PR 🙏
It is definitely a very interesting feature we will want to have a look into at some point.
Concerning the TestDataNode predict / evaluate conundrum, we are planning on merging both tasks into one, it was done this way for some legacy reason, but there is no reason to keep them separate now that the metrics computation has been refactored.

from substrafl.nodes.references.local_state import LocalStateRef


class SimuTrainDataNode(TrainDataNode):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor but we probably rather want to define a typing.Protocol , to ensure that both objects (simulation and not) are implementing the same interface, all while ensuring we have no coupling through shared code.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure it is possible to use typing.Protocol because both objects will not have the same type signatures:

  • TrainDataNode.update_states returns a tuple with LocalStateRef and a SharedStateRef
  • SimuTrainDataNode.update_states returns a tuple with LocalStateRef and not a SharedStateRef but the real output of the method (completely arbitrary)

So if we just say that update_states returns a tuple it's fine, but it's not really possible to enforce a fine-grained type structure

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If all the signatures are not the same, I'm even more wary of subclassing, since polymorphism can get nightmarish really fast.
I think I'd prefer having some loose typing at the Protocol level, and having specialisation at the child class level.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK I'll correct it then!

@linear linear bot mentioned this pull request Nov 7, 2023
4 tasks
jeandut and others added 26 commits January 19, 2024 11:03
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
This reverts commit 005ae0b.

Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
jeandut and others added 11 commits January 19, 2024 13:15
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
@jeandut jeandut force-pushed the feat/substrafl-simu-mode branch from 7f58fa4 to 968917c Compare January 19, 2024 12:22
@ThibaultFy
Copy link
Member

I'll close this PR as this changes has been taken into accound and merged with #184.

Thanks again for all the work.

@ThibaultFy ThibaultFy closed this Feb 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants