-
Notifications
You must be signed in to change notification settings - Fork 212
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Updates to straggler handling functionality #996
Changes from 57 commits
fe27f17
1bb0169
419abe1
f175657
9dd88dc
4ccac38
210cdf3
22e9ea1
535dddb
ba41adc
9aaba8f
4e477d8
37af42e
21c0a23
7796414
3343ecb
38d21c0
ae82b53
7a028ff
dcf4a36
c5a84c2
2995e39
83c03d8
50ff1a2
7b7726e
09ae6bb
bc27cfa
7eede72
966a8de
bab851a
7fccd1c
ce4747f
1fc17a7
e1e32ea
10ea392
726b2bd
c20eabf
6a795d9
64f18ae
39b289a
b8a85c9
f56f9c0
a11db0a
cabcc2f
cc7454b
181cd8d
3433e36
148d230
474c2e1
8a1cc8d
75dd979
1d94a7e
d4b9323
55f71fd
20073e5
afa27ff
6c8b755
185243f
65e3576
7be92ea
f07592e
3aca322
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Class attribute docstring is out of line with the actual arguments. Can you please fix that as well? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change was introduced in seperate PR: #1003. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -115,9 +115,6 @@ def __init__( | |
# Cleaner solution? | ||
self.single_col_cert_common_name = "" | ||
|
||
self.straggler_handling_policy = ( | ||
straggler_handling_policy or CutoffTimeBasedStragglerHandling() | ||
) | ||
self._end_of_round_check_done = [False] * rounds_to_train | ||
self.stragglers = [] | ||
|
||
|
@@ -140,6 +137,10 @@ def __init__( | |
self.write_logs = write_logs | ||
self.log_metric_callback = log_metric_callback | ||
|
||
self.straggler_handling_policy = ( | ||
straggler_handling_policy or CutoffTimeBasedStragglerHandling() | ||
teoparvanov marked this conversation as resolved.
Show resolved
Hide resolved
psfoley marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) | ||
|
||
if self.write_logs: | ||
self.log_metric = write_metric | ||
if self.log_metric_callback: | ||
|
@@ -177,6 +178,10 @@ def __init__( | |
|
||
self.collaborator_task_weight = {} # {TaskResultKey: data_size} | ||
|
||
# maintain a list of collaborators that have completed task and | ||
# reported results in a given round | ||
self.collaborators_done = [] | ||
|
||
Comment on lines
+182
to
+185
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this variable is only used by straggler policies, it should be stored in the policy object. Intent is to not spill a particular component's needs into other component's code. Example APIs: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is not only used by straggler policies but also to identify collaborators who have completed all their tasks in the |
||
def _load_initial_tensors(self): | ||
"""Load all of the tensors required to begin federated learning. | ||
|
||
|
@@ -391,11 +396,29 @@ def get_tasks(self, collaborator_name): | |
) | ||
sleep_time = 0 | ||
|
||
if hasattr(self.straggler_handling_policy, "round_start_time"): | ||
self.straggler_handling_policy.round_start_time = time.time() | ||
# Start straggler handling policy for timer based callback is required | ||
# for %age based policy callback is not required | ||
self.straggler_handling_policy.start_policy(callback=self._straggler_cutoff_time_elapsed) | ||
|
||
return tasks, self.round_number, sleep_time, time_to_quit | ||
|
||
def _straggler_cutoff_time_elapsed(self) -> None: | ||
""" | ||
This method is called by the straggler handling policy when cutoff timer is elapsed. | ||
It applies straggler handling policy and ends the round early. | ||
|
||
Returns: | ||
None | ||
""" | ||
self.logger.warning( | ||
f"Round number: {self.round_number} cutoff timer elapsed after " | ||
f"{self.straggler_handling_policy.straggler_cutoff_time}s. " | ||
f"Applying {self.straggler_handling_policy.__class__.__name__} policy." | ||
) | ||
|
||
# Check if minimum collaborators reported results | ||
self._end_of_round_with_stragglers_check() | ||
|
||
def get_aggregated_tensor( | ||
self, | ||
collaborator_name, | ||
|
@@ -573,10 +596,10 @@ def send_local_task_results( | |
Returns: | ||
None | ||
""" | ||
if self._time_to_quit() or self._is_task_done(task_name): | ||
if self._time_to_quit() or collaborator_name in self.stragglers: | ||
self.logger.warning( | ||
f"STRAGGLER: Collaborator {collaborator_name} is reporting results " | ||
"after task {task_name} has finished." | ||
f"after task {task_name} has finished." | ||
) | ||
return | ||
|
||
|
@@ -632,7 +655,30 @@ def send_local_task_results( | |
task_results.append(tensor_key) | ||
|
||
self.collaborator_tasks_results[task_key] = task_results | ||
self._end_of_task_check(task_name) | ||
|
||
self._is_collaborator_done(collaborator_name) | ||
|
||
self._end_of_round_with_stragglers_check() | ||
|
||
def _end_of_round_with_stragglers_check(self): | ||
""" | ||
Checks if the minimum required collaborators have reported their results, | ||
identifies any stragglers, and initiates an early round end if necessary. | ||
|
||
Returns: | ||
None | ||
""" | ||
if self.straggler_handling_policy.straggler_cutoff_check( | ||
len(self.collaborators_done), len(self.authorized_cols) | ||
): | ||
self.stragglers = [ | ||
collab_name | ||
for collab_name in self.authorized_cols | ||
if collab_name not in self.collaborators_done | ||
] | ||
if len(self.stragglers) != 0: | ||
self.logger.warning(f"Identified stragglers: {self.stragglers}") | ||
self._end_of_round_check() | ||
|
||
def _process_named_tensor(self, named_tensor, collaborator_name): | ||
"""Extract the named tensor fields. | ||
|
@@ -724,21 +770,6 @@ def _process_named_tensor(self, named_tensor, collaborator_name): | |
|
||
return final_tensor_key, final_nparray | ||
|
||
def _end_of_task_check(self, task_name): | ||
"""Check whether all collaborators who are supposed to perform the | ||
task complete. | ||
|
||
Args: | ||
task_name (str): Task name. | ||
The task name to check. | ||
|
||
Returns: | ||
bool: Whether the task is done. | ||
""" | ||
if self._is_task_done(task_name): | ||
# now check for the end of the round | ||
self._end_of_round_check() | ||
|
||
def _prepare_trained(self, tensor_name, origin, round_number, report, agg_results): | ||
"""Prepare aggregated tensorkey tags. | ||
|
||
|
@@ -839,11 +870,12 @@ def _compute_validation_related_task_metrics(self, task_name): | |
all_collaborators_for_task = self.assigner.get_collaborators_for_task( | ||
task_name, self.round_number | ||
) | ||
# leave out stragglers for the round | ||
# Leave out straggler for the round even if they've paritally | ||
# completed given tasks | ||
collaborators_for_task = [] | ||
for c in all_collaborators_for_task: | ||
if self._collaborator_task_completed(c, task_name, self.round_number): | ||
collaborators_for_task.append(c) | ||
collaborators_for_task = [ | ||
c for c in all_collaborators_for_task if c in self.collaborators_done | ||
] | ||
|
||
# The collaborator data sizes for that task | ||
collaborator_weights_unnormalized = { | ||
|
@@ -919,7 +951,7 @@ def _end_of_round_check(self): | |
Returns: | ||
None | ||
""" | ||
if not self._is_round_done() or self._end_of_round_check_done[self.round_number]: | ||
if self._end_of_round_check_done[self.round_number]: | ||
return | ||
|
||
# Compute all validation related metrics | ||
|
@@ -932,6 +964,8 @@ def _end_of_round_check(self): | |
self.round_number += 1 | ||
# resetting stragglers for task for a new round | ||
self.stragglers = [] | ||
# resetting collaborators_done for next round | ||
self.collaborators_done = [] | ||
|
||
# Save the latest model | ||
self.logger.info("Saving round %s model...", self.round_number) | ||
|
@@ -945,49 +979,32 @@ def _end_of_round_check(self): | |
|
||
# Cleaning tensor db | ||
self.tensor_db.clean_up(self.db_store_rounds) | ||
# Reset straggler handling policy for the next round. | ||
self.straggler_handling_policy.reset_policy_for_round() | ||
|
||
def _is_task_done(self, task_name): | ||
"""Check that task is done. | ||
|
||
Args: | ||
task_name (str): Task name. | ||
|
||
Returns: | ||
bool: Whether the task is done. | ||
def _is_collaborator_done(self, collaborator_name: str) -> None: | ||
""" | ||
all_collaborators = self.assigner.get_collaborators_for_task(task_name, self.round_number) | ||
|
||
collaborators_done = [] | ||
for c in all_collaborators: | ||
if self._collaborator_task_completed(c, task_name, self.round_number): | ||
collaborators_done.append(c) | ||
|
||
straggler_check = self.straggler_handling_policy.straggler_cutoff_check( | ||
len(collaborators_done), all_collaborators | ||
) | ||
|
||
if straggler_check: | ||
for c in all_collaborators: | ||
if c not in collaborators_done: | ||
self.stragglers.append(c) | ||
Check if all tasks given to the collaborator are completed then, | ||
completed or not. | ||
""" | ||
# Get all tasks given to the collaborator for current round | ||
all_tasks = self.assigner.get_tasks_for_collaborator(collaborator_name, self.round_number) | ||
# Check if all given tasks are completed by the collaborator | ||
all_tasks_completed = True | ||
for task in all_tasks: | ||
if hasattr(task, "name"): | ||
task = task.name | ||
all_tasks_completed = all_tasks_completed and self._collaborator_task_completed( | ||
collaborator=collaborator_name, task_name=task, round_num=self.round_number | ||
) | ||
# If the collaborator has completed ALL tasks for current round, | ||
# update collaborators_done | ||
if all_tasks_completed: | ||
self.collaborators_done.append(collaborator_name) | ||
self.logger.info( | ||
"\tEnding task %s early due to straggler cutoff policy", | ||
task_name, | ||
f"Round: {self.round_number}, Collaborators that have completed all tasks: " | ||
f"{self.collaborators_done}" | ||
) | ||
self.logger.warning("\tIdentified stragglers: %s", self.stragglers) | ||
|
||
# all are done or straggler policy calls for early round end. | ||
return straggler_check or len(all_collaborators) == len(collaborators_done) | ||
|
||
def _is_round_done(self): | ||
"""Check that round is done. | ||
|
||
Returns: | ||
bool: Whether the round is done. | ||
""" | ||
tasks_for_round = self.assigner.get_all_tasks_for_round(self.round_number) | ||
|
||
return all(self._is_task_done(task_name) for task_name in tasks_for_round) | ||
|
||
def _log_big_warning(self): | ||
"""Warn user about single collaborator cert mode.""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Even to a developer, this naming scheme is unintuitive. How can we expect users to remember?
Please eliminate redundancy in naming modules.
If this also means merging the three policy modules into one file, consider doing so. The code complexity of any policy does not justify having separate modules.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Teo had alternate naming suggestions #996 (comment).
Recommended that naming and other suggestions to be taken up in a seperate follow-up PR.