-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Jean/efficient bootstrapping #3
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great job @jeandut . I have a few comments to better understand the logic
|
||
btst_algo = BtstAlgo(**strategy.algo.kwargs) | ||
|
||
class BtstStrategy(strategy.__class__): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BootstrapStrategy
would have been more explicit IMO (same for BootstrapAlgo
vs BtstAlgo
self.checkpoints_list = [ | ||
copy.deepcopy(self._get_state_to_save()) | ||
for _ in range(len(bootstrap_seeds_list)) | ||
] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand this part. @jeandut could you please explain why you do not need to pass a variable to indicate which bootstrap replicate to save?
datasamples.shape[0], replace=True, random_state=rng | ||
) | ||
# Loading the correct state into the current main algo | ||
if self.checkpoints_list[idx] is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It must be not None
, no? Otherwise, there is an issue?
def equality_check(a, b): | ||
if type(a) != type(b): # noqa: E721 | ||
return False | ||
else: | ||
if isinstance(a, dict): | ||
for k in a.keys(): | ||
if not equality_check(a[k], b[k]): | ||
return False | ||
return True | ||
elif isinstance(a, list): | ||
for i in range(len(a)): | ||
if not equality_check(a[i], b[i]): | ||
return False | ||
return True | ||
elif isinstance(a, np.ndarray): | ||
return np.all(a == b) | ||
elif isinstance(a, torch.Tensor): | ||
return torch.all(a == b) | ||
else: | ||
return a == b |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Putting this function outside would easy readability
This PR implements a more efficient way to do bootstrap in FL using substra using variations of @arthurPignetOwkin's code patterns.
While the details of this implementation are quite tricky the main idea is simple: hook all methods of algo and strategies so that they are executed
n_bootstrap
times producingn_bootstrap
outputs. The rest is just boilerplate code to glue everything together respecting substra and substrafl's constraints.This PR is a first attempt and will probably be simplified over time (next iteration will use different self for the different bootstrap runs aka the self will not merged, which will allow much cleaner code and avoid tricky side effects).
Note that the local bootstrap vs global bootstrap is an issue but is a detail of the PR and will be done in another PR.