Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates to straggler handling functionality #996

Merged
merged 62 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 57 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
fe27f17
v1.0 straggler handling added
ParthMandaliya Jun 20, 2024
1bb0169
Added start_straggler_cutoff_timer abstract function in StragglerHand…
ParthMandaliya Jun 20, 2024
419abe1
Updated logs straggler handling in aggregator.py.
ParthMandaliya Jun 21, 2024
f175657
Only time based straggler handling policies require timer thread,
ParthMandaliya Jun 21, 2024
9dd88dc
Find all unfinished tasks and straggler collaborators
ParthMandaliya Jun 25, 2024
4ccac38
Merge branch 'securefederatedai:develop' into straggler-handling
ParthMandaliya Jun 25, 2024
210cdf3
Review comments incorporated
ParthMandaliya Jul 2, 2024
22e9ea1
Added inline comments
ParthMandaliya Jul 2, 2024
535dddb
Changed logic to keep track of collaborators which have reported resu…
ParthMandaliya Jul 2, 2024
ba41adc
1. StragglerHandlingFunction: Added an interface method start_policy
ParthMandaliya Jul 4, 2024
9aaba8f
Removed logger argument from start_policy &
ParthMandaliya Jul 4, 2024
4e477d8
Merge branch 'develop' into straggler-handling
ParthMandaliya Jul 4, 2024
37af42e
If cutoff time is set to infinite, do not start the timer thread.
ParthMandaliya Jul 4, 2024
21c0a23
Resolved lint issues.
ParthMandaliya Jul 4, 2024
7796414
Pytest and code coverage test case failure resolved
ParthMandaliya Jul 4, 2024
3343ecb
Default logger value is set to None
ParthMandaliya Jul 4, 2024
38d21c0
Redesigned percentage based straggler policy.
ParthMandaliya Jul 9, 2024
ae82b53
Code cleanup
ParthMandaliya Jul 9, 2024
7a028ff
Only collaborators_done are used for aggregation
ParthMandaliya Jul 9, 2024
dcf4a36
Internal review comments incorporated
ParthMandaliya Jul 10, 2024
c5a84c2
Merge branch 'securefederatedai:develop' into straggler-handling
ParthMandaliya Jul 10, 2024
2995e39
Merge branch 'securefederatedai:develop' into straggler-handling
ParthMandaliya Jul 11, 2024
83c03d8
Resolving potential issues found during testing
ParthMandaliya Jul 11, 2024
50ff1a2
Use _collaborator_task_completed method to check if all given tasks to
ParthMandaliya Jul 11, 2024
7b7726e
Few test cases failing issue resolved
ParthMandaliya Jul 11, 2024
09ae6bb
Potential issue in aggregator based workflow tutorial resolved
ParthMandaliya Jul 11, 2024
bc27cfa
Merge branch 'securefederatedai:develop' into straggler-handling
ParthMandaliya Jul 15, 2024
7eede72
Corner case issue discovered during testing is patched.
ParthMandaliya Jul 15, 2024
966a8de
Added reset_policy_for_round function in straggler handling policy ba…
ParthMandaliya Jul 16, 2024
bab851a
This commit includes following changes:
ParthMandaliya Jul 17, 2024
7fccd1c
Code cleanup
ParthMandaliya Jul 18, 2024
ce4747f
Logs modified
ParthMandaliya Jul 18, 2024
1fc17a7
Review comments on PR incorporated.
ParthMandaliya Jul 18, 2024
e1e32ea
Log updated
ParthMandaliya Jul 18, 2024
10ea392
Condition in straggler_cutoff_check fixed
ParthMandaliya Jul 18, 2024
726b2bd
If straggler cutoff time set to infinite only wait for minimum requir…
ParthMandaliya Jul 18, 2024
c20eabf
If straggler cutoff time is set to infinite wait for ALL collaborator…
ParthMandaliya Jul 18, 2024
6a795d9
Teodor's review comments and internal review comments incorporated.
ParthMandaliya Jul 18, 2024
64f18ae
Changed minimum_reporting validation.
ParthMandaliya Jul 19, 2024
39b289a
Resolved all flake8 issues.
ParthMandaliya Jul 19, 2024
b8a85c9
Merge pull request #4 from ParthMandaliya/straggler-handling-v2.0
ParthMandaliya Jul 22, 2024
f56f9c0
Merge branch 'securefederatedai:develop' into straggler-handling
ParthMandaliya Jul 23, 2024
a11db0a
Modified single inline comment
ParthMandaliya Jul 23, 2024
cabcc2f
Merge branch 'develop' into straggler-handling
ParthMandaliya Aug 1, 2024
cc7454b
Review comments incorporated.
ParthMandaliya Aug 1, 2024
181cd8d
Merge branch 'securefederatedai:develop' into straggler-handling
ishant162 Aug 9, 2024
3433e36
Merge branch 'securefederatedai:develop' into straggler-handling
ishant162 Aug 10, 2024
148d230
Merge branch 'develop' into straggler-handling
ishant162 Aug 19, 2024
474c2e1
Lint fixes
ishant162 Aug 19, 2024
8a1cc8d
Merge branch 'securefederatedai:develop' into straggler-handling
ishant162 Aug 20, 2024
75dd979
Incorporated Micah's review comments and removed unused code
ishant162 Aug 20, 2024
1d94a7e
Merge branch 'securefederatedai:develop' into straggler-handling
ishant162 Aug 25, 2024
d4b9323
Merge branch 'securefederatedai:develop' into straggler-handling
ishant162 Aug 27, 2024
55f71fd
Incorporated Teo's review comments
ishant162 Aug 27, 2024
20073e5
Merge branch 'securefederatedai:develop' into straggler-handling
ishant162 Aug 29, 2024
afa27ff
Merge branch 'securefederatedai:develop' into straggler-handling
ishant162 Aug 30, 2024
6c8b755
Merge branch 'securefederatedai:develop' into straggler-handling
ishant162 Sep 4, 2024
185243f
Incorporated review comments & added mutex in aggregator for thread s…
ishant162 Sep 10, 2024
65e3576
Merge branch 'securefederatedai:develop' into straggler-handling
ishant162 Sep 12, 2024
7be92ea
Review comments incorporated and updated logs
ishant162 Sep 12, 2024
f07592e
Merge branch 'securefederatedai:develop' into straggler-handling
ishant162 Sep 19, 2024
3aca322
Reverted a comment for Pytest and code coverage fix
ishant162 Sep 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion openfl/component/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
PercentageBasedStragglerHandling,
)
from openfl.component.straggler_handling_functions.straggler_handling_function import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even to a developer, this naming scheme is unintuitive. How can we expect users to remember?
Please eliminate redundancy in naming modules.

openfl.component.straggler_policy.CutOffPolicy
# This is self-documenting, and easy to remember than

openfl.component.straggler_handling_functions.straggler_handling_function.CutoffTimeBasedStragglerHandling

If this also means merging the three policy modules into one file, consider doing so. The code complexity of any policy does not justify having separate modules.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Teo had alternate naming suggestions #996 (comment).
Recommended that naming and other suggestions to be taken up in a seperate follow-up PR.

StragglerHandlingFunction,
StragglerHandlingPolicy,
)
151 changes: 84 additions & 67 deletions openfl/component/aggregator/aggregator.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Class attribute docstring is out of line with the actual arguments. Can you please fix that as well?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change was introduced in seperate PR: #1003.
Unsure whether the changes to class attribute were intentional. Suggested to be discussed seperately as they do not seem to be related to straggler handling.

Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,6 @@ def __init__(
# Cleaner solution?
self.single_col_cert_common_name = ""

self.straggler_handling_policy = (
straggler_handling_policy or CutoffTimeBasedStragglerHandling()
)
self._end_of_round_check_done = [False] * rounds_to_train
self.stragglers = []

Expand All @@ -140,6 +137,10 @@ def __init__(
self.write_logs = write_logs
self.log_metric_callback = log_metric_callback

self.straggler_handling_policy = (
straggler_handling_policy or CutoffTimeBasedStragglerHandling()
teoparvanov marked this conversation as resolved.
Show resolved Hide resolved
psfoley marked this conversation as resolved.
Show resolved Hide resolved
)

if self.write_logs:
self.log_metric = write_metric
if self.log_metric_callback:
Expand Down Expand Up @@ -177,6 +178,10 @@ def __init__(

self.collaborator_task_weight = {} # {TaskResultKey: data_size}

# maintain a list of collaborators that have completed task and
# reported results in a given round
self.collaborators_done = []

Comment on lines +182 to +185
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this variable is only used by straggler policies, it should be stored in the policy object. Intent is to not spill a particular component's needs into other component's code.

Example APIs: self.straggler_policy.reset() or self.straggler_policy.mark_as_done(...)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not only used by straggler policies but also to identify collaborators who have completed all their tasks in the
_compute_validation_related_task_metrics() function.

def _load_initial_tensors(self):
"""Load all of the tensors required to begin federated learning.

Expand Down Expand Up @@ -391,11 +396,29 @@ def get_tasks(self, collaborator_name):
)
sleep_time = 0

if hasattr(self.straggler_handling_policy, "round_start_time"):
self.straggler_handling_policy.round_start_time = time.time()
# Start straggler handling policy for timer based callback is required
# for %age based policy callback is not required
self.straggler_handling_policy.start_policy(callback=self._straggler_cutoff_time_elapsed)

return tasks, self.round_number, sleep_time, time_to_quit

def _straggler_cutoff_time_elapsed(self) -> None:
"""
This method is called by the straggler handling policy when cutoff timer is elapsed.
It applies straggler handling policy and ends the round early.

Returns:
None
"""
self.logger.warning(
f"Round number: {self.round_number} cutoff timer elapsed after "
f"{self.straggler_handling_policy.straggler_cutoff_time}s. "
f"Applying {self.straggler_handling_policy.__class__.__name__} policy."
)

# Check if minimum collaborators reported results
self._end_of_round_with_stragglers_check()

def get_aggregated_tensor(
self,
collaborator_name,
Expand Down Expand Up @@ -573,10 +596,10 @@ def send_local_task_results(
Returns:
None
"""
if self._time_to_quit() or self._is_task_done(task_name):
if self._time_to_quit() or collaborator_name in self.stragglers:
self.logger.warning(
f"STRAGGLER: Collaborator {collaborator_name} is reporting results "
"after task {task_name} has finished."
f"after task {task_name} has finished."
)
return

Expand Down Expand Up @@ -632,7 +655,30 @@ def send_local_task_results(
task_results.append(tensor_key)

self.collaborator_tasks_results[task_key] = task_results
self._end_of_task_check(task_name)

self._is_collaborator_done(collaborator_name)

self._end_of_round_with_stragglers_check()

def _end_of_round_with_stragglers_check(self):
"""
Checks if the minimum required collaborators have reported their results,
identifies any stragglers, and initiates an early round end if necessary.

Returns:
None
"""
if self.straggler_handling_policy.straggler_cutoff_check(
len(self.collaborators_done), len(self.authorized_cols)
):
self.stragglers = [
collab_name
for collab_name in self.authorized_cols
if collab_name not in self.collaborators_done
]
if len(self.stragglers) != 0:
self.logger.warning(f"Identified stragglers: {self.stragglers}")
self._end_of_round_check()

def _process_named_tensor(self, named_tensor, collaborator_name):
"""Extract the named tensor fields.
Expand Down Expand Up @@ -724,21 +770,6 @@ def _process_named_tensor(self, named_tensor, collaborator_name):

return final_tensor_key, final_nparray

def _end_of_task_check(self, task_name):
"""Check whether all collaborators who are supposed to perform the
task complete.

Args:
task_name (str): Task name.
The task name to check.

Returns:
bool: Whether the task is done.
"""
if self._is_task_done(task_name):
# now check for the end of the round
self._end_of_round_check()

def _prepare_trained(self, tensor_name, origin, round_number, report, agg_results):
"""Prepare aggregated tensorkey tags.

Expand Down Expand Up @@ -839,11 +870,12 @@ def _compute_validation_related_task_metrics(self, task_name):
all_collaborators_for_task = self.assigner.get_collaborators_for_task(
task_name, self.round_number
)
# leave out stragglers for the round
# Leave out straggler for the round even if they've paritally
# completed given tasks
collaborators_for_task = []
for c in all_collaborators_for_task:
if self._collaborator_task_completed(c, task_name, self.round_number):
collaborators_for_task.append(c)
collaborators_for_task = [
c for c in all_collaborators_for_task if c in self.collaborators_done
]

# The collaborator data sizes for that task
collaborator_weights_unnormalized = {
Expand Down Expand Up @@ -919,7 +951,7 @@ def _end_of_round_check(self):
Returns:
None
"""
if not self._is_round_done() or self._end_of_round_check_done[self.round_number]:
if self._end_of_round_check_done[self.round_number]:
return

# Compute all validation related metrics
Expand All @@ -932,6 +964,8 @@ def _end_of_round_check(self):
self.round_number += 1
# resetting stragglers for task for a new round
self.stragglers = []
# resetting collaborators_done for next round
self.collaborators_done = []

# Save the latest model
self.logger.info("Saving round %s model...", self.round_number)
Expand All @@ -945,49 +979,32 @@ def _end_of_round_check(self):

# Cleaning tensor db
self.tensor_db.clean_up(self.db_store_rounds)
# Reset straggler handling policy for the next round.
self.straggler_handling_policy.reset_policy_for_round()

def _is_task_done(self, task_name):
"""Check that task is done.

Args:
task_name (str): Task name.

Returns:
bool: Whether the task is done.
def _is_collaborator_done(self, collaborator_name: str) -> None:
"""
all_collaborators = self.assigner.get_collaborators_for_task(task_name, self.round_number)

collaborators_done = []
for c in all_collaborators:
if self._collaborator_task_completed(c, task_name, self.round_number):
collaborators_done.append(c)

straggler_check = self.straggler_handling_policy.straggler_cutoff_check(
len(collaborators_done), all_collaborators
)

if straggler_check:
for c in all_collaborators:
if c not in collaborators_done:
self.stragglers.append(c)
Check if all tasks given to the collaborator are completed then,
completed or not.
"""
# Get all tasks given to the collaborator for current round
all_tasks = self.assigner.get_tasks_for_collaborator(collaborator_name, self.round_number)
# Check if all given tasks are completed by the collaborator
all_tasks_completed = True
for task in all_tasks:
if hasattr(task, "name"):
task = task.name
all_tasks_completed = all_tasks_completed and self._collaborator_task_completed(
collaborator=collaborator_name, task_name=task, round_num=self.round_number
)
# If the collaborator has completed ALL tasks for current round,
# update collaborators_done
if all_tasks_completed:
self.collaborators_done.append(collaborator_name)
self.logger.info(
"\tEnding task %s early due to straggler cutoff policy",
task_name,
f"Round: {self.round_number}, Collaborators that have completed all tasks: "
f"{self.collaborators_done}"
)
self.logger.warning("\tIdentified stragglers: %s", self.stragglers)

# all are done or straggler policy calls for early round end.
return straggler_check or len(all_collaborators) == len(collaborators_done)

def _is_round_done(self):
"""Check that round is done.

Returns:
bool: Whether the round is done.
"""
tasks_for_round = self.assigner.get_all_tasks_for_round(self.round_number)

return all(self._is_task_done(task_name) for task_name in tasks_for_round)

def _log_big_warning(self):
"""Warn user about single collaborator cert mode."""
Expand Down
2 changes: 1 addition & 1 deletion openfl/component/straggler_handling_functions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,5 @@
PercentageBasedStragglerHandling,
)
from openfl.component.straggler_handling_functions.straggler_handling_function import (
StragglerHandlingFunction,
StragglerHandlingPolicy,
)
Loading
Loading