-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FEAT: fast simulation mode for substraFL #168
Conversation
977d90e
to
ee7e4a3
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks a lot for the PR 🙏
It is definitely a very interesting feature we will want to have a look into at some point.
Concerning the TestDataNode predict / evaluate conundrum, we are planning on merging both tasks into one, it was done this way for some legacy reason, but there is no reason to keep them separate now that the metrics computation has been refactored.
from substrafl.nodes.references.local_state import LocalStateRef | ||
|
||
|
||
class SimuTrainDataNode(TrainDataNode): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor but we probably rather want to define a typing.Protocol
, to ensure that both objects (simulation and not) are implementing the same interface, all while ensuring we have no coupling through shared code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure it is possible to use typing.Protocol
because both objects will not have the same type signatures:
TrainDataNode.update_states
returns a tuple withLocalStateRef
and aSharedStateRef
SimuTrainDataNode.update_states
returns a tuple withLocalStateRef
and not aSharedStateRef
but the real output of the method (completely arbitrary)
So if we just say that update_states
returns a tuple it's fine, but it's not really possible to enforce a fine-grained type structure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If all the signatures are not the same, I'm even more wary of subclassing, since polymorphism can get nightmarish really fast.
I think I'd prefer having some loose typing at the Protocol level, and having specialisation at the child class level.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK I'll correct it then!
Signed-off-by: jeandut <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
This reverts commit 005ae0b. Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: jeandut <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
Signed-off-by: Mathieu Andreux <[email protected]>
7f58fa4
to
968917c
Compare
I'll close this PR as this changes has been taken into accound and merged with #184. Thanks again for all the work. |
Related issue
There is no formal issue open on this repo, however this PR addresses the following real issue.
Although the subprocess mode is local (which is great), its speed is limited by hard constraints:
As a consequence, this mode can be slow compared to a plain python process, especially when the number of local steps is small. The gap can span 1-2 orders of magnitude depending on the number of local steps.
Summary
This PR proposes a new simulation mode for substraFL. It completely sidesteps substra to perform the machine learning operations in memory, in the same process. It still requires to use a substra client in subprocess mode though, for registering the data. For small data, this is not an issue ; for bigger data, we might need to change this as well.
Expected benefits
General approach
This PR implements novel classes,
SimuTrainDataNode
,SimuTestDataNode
, andSimuAggregationNode
which act as placeholders forTrainDataNode
,TestDataNode
andAggregationNode
.SimuTrainNode
andSimuTestNode
require both analgo
and asubstra_client
, whileSimuAggregationNode
requires thestrategy
The main idea is to use the
update_states
method to call the remote operation which is passed with_skip=True
. Forremote_data
methods, we additionally need to load the datasamples and update themethod_parameters
. This is done once and cached, to avoid an I/O bottleneck.For each org, the corresponding
SimuTrainDataNode
andSimuTestDataNode
share the samealgo
attribute: they are paired right after the beginning ofexecute_experiment
.To avoid any change to the user, the nodes are modified in
execute_experiment
if the new optional argumentsimu_mode
isTrue
. This is less optimal than doing the changes on the client side, but this is cleaner from an architectural standpoint.Notes
Test data node's caveats
For
TrainDataNode
andAggregationNode
, the approach is straighforward, but forTestDataNode
things become messier. More precisely, thepredict
method of the algo already saves the predictions to disk, instead of returning them (as would be expected in a sklearn API). As a consequence, even if we are able to catch thealgo.predict
method within theSimuTestDatanode
, we still have a bottleneck. Similarly, the expected API of the metric functions also hardcodes the fact that predictions are dumped to disk.There would be at least two ways to solve this problem.
1. The hacky way (current)
return_predictions
to thepredict
method of the base torch algorithm. It is by defaultFalse
, so this should not affect real runs. In theSimuTestDataNode
, this argument is flagged asTrue
, so we are able to retrieve the true predictions.metric_func(datasamples, predictions_path)
. What I propose is to change the metric functions so that the content ofpredictions_path
is already the predictions (but the argument name is hardcoded). This is a bit hacky, but I did not see another way.2. The sustainable way
The current way of dealing with the metrics could be improved while allowing to tackle the above issue. I think the current implementation still draws a lot on the old subtra and has not completely benefitted from the
GenericTask
update.My proposition would be to change the API of the algo as sketched below:
This is much more sklearn-ish, which is always great. Further, this would circumvent the need to submit 2 tasks for 1 metric evaluation, as the predictions would not be done on disk. Last, by replacing the following lines
by
the behaviours of
SimuTrainDataNode
andSimuTestDataNode
would be exactly the same.Calls to
torch.no_grad
The extensive use of
torch.inference_mode
in theweight_manager
prevents the future calls to the model for training purposes in the same process and object. This does not appear in all other modes because everything is dumped to disk and deleted, but here thealgo
remains the same object during the full training. I therefore replaced them withtorch.no_grad
, which supposedly does the same thing. All tests still pass so I do not think it is a problem.Optimization of
_local_predict
I did an additional modification of the
_local_predict
methods: instead of a patternI propose
Indeed, each call to
torch.cat
involves a copy of the existing array to a new place. Therefore, the first proposal is quadratic in memory, while the first one is linear. This can lead to a speed-up when arrays are large. Of course, I am happy to move this to a separate PR if this is cleaner.Please check if the PR fulfills these requirements
The following tasks have yet to be performed: