-
Notifications
You must be signed in to change notification settings - Fork 210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Straggler handling Follow-up #1097
base: develop
Are you sure you want to change the base?
Changes from all commits
2606ed6
87152a5
c779a4a
2f3873d
896786f
f5cb6d6
2f8ad23
be8aeb7
989828d
92c8871
664774a
bd23ceb
92f8497
1c86a4c
e89959b
03f2bc1
cab6d9f
3969416
eff79fc
7f16cdb
f55e890
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,28 +2,68 @@ | |
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
|
||
"""Cutoff time based Straggler Handling function.""" | ||
"""Straggler handling module.""" | ||
|
||
import threading | ||
import time | ||
from abc import ABC, abstractmethod | ||
from logging import getLogger | ||
from typing import Callable | ||
|
||
import numpy as np | ||
|
||
from openfl.component.straggler_handling_functions.straggler_handling_function import ( | ||
StragglerHandlingPolicy, | ||
) | ||
logger = getLogger(__name__) | ||
|
||
|
||
class CutoffTimeBasedStragglerHandling(StragglerHandlingPolicy): | ||
class StragglerPolicy(ABC): | ||
"""Federated Learning straggler handling interface.""" | ||
|
||
@abstractmethod | ||
def start_policy(self, **kwargs) -> None: | ||
""" | ||
Start straggler handling policy for collaborator for a particular round. | ||
NOTE: Refer CutoffPolicy for reference. | ||
|
||
Args: | ||
**kwargs | ||
""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def reset_policy_for_round(self) -> None: | ||
"""Reset policy for the next round.""" | ||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def straggler_cutoff_check( | ||
self, num_collaborators_done: int, num_all_collaborators: int, **kwargs | ||
) -> bool: | ||
""" | ||
Determines whether it is time to end the round early. | ||
|
||
Args: | ||
num_collaborators_done: int | ||
Number of collaborators finished. | ||
num_all_collaborators: int | ||
Total number of collaborators. | ||
|
||
Returns: | ||
bool: True if it is time to end the round early, False otherwise. | ||
|
||
Raises: | ||
NotImplementedError: This method must be implemented by a subclass. | ||
""" | ||
raise NotImplementedError | ||
|
||
|
||
class CutoffPolicy(StragglerPolicy): | ||
"""Cutoff time based Straggler Handling function.""" | ||
|
||
def __init__( | ||
self, round_start_time=None, straggler_cutoff_time=np.inf, minimum_reporting=1, **kwargs | ||
): | ||
""" | ||
Initialize a CutoffTimeBasedStragglerHandling object. | ||
Initialize a CutoffPolicy object. | ||
|
||
Args: | ||
round_start_time (optional): The start time of the round. Defaults | ||
|
@@ -40,21 +80,16 @@ def __init__( | |
self.round_start_time = round_start_time | ||
self.straggler_cutoff_time = straggler_cutoff_time | ||
self.minimum_reporting = minimum_reporting | ||
self.logger = getLogger(__name__) | ||
self.is_timer_started = False | ||
|
||
if self.straggler_cutoff_time == np.inf: | ||
self.logger.warning( | ||
"CutoffTimeBasedStragglerHandling is disabled as straggler_cutoff_time " | ||
"is set to np.inf." | ||
) | ||
logger.warning("CutoffPolicy is disabled as straggler_cutoff_time " "is set to np.inf.") | ||
|
||
def reset_policy_for_round(self) -> None: | ||
""" | ||
Reset timer for the next round. | ||
""" | ||
"""Reset timer for the next round.""" | ||
if hasattr(self, "timer"): | ||
self.timer.cancel() | ||
delattr(self, "timer") | ||
self.is_timer_started = False | ||
|
||
def start_policy(self, callback: Callable) -> None: | ||
""" | ||
|
@@ -64,22 +99,21 @@ def start_policy(self, callback: Callable) -> None: | |
Args: | ||
callback: Callable | ||
Callback function for when straggler_cutoff_time elapses | ||
|
||
Returns: | ||
None | ||
""" | ||
# If straggler_cutoff_time is set to infinity | ||
# or if the timer is already running, | ||
# do not start the policy. | ||
if self.straggler_cutoff_time == np.inf or hasattr(self, "timer"): | ||
if self.straggler_cutoff_time == np.inf or self.is_timer_started: | ||
return | ||
|
||
self.round_start_time = time.time() | ||
self.timer = threading.Timer( | ||
self.straggler_cutoff_time, | ||
callback, | ||
) | ||
self.timer.daemon = True | ||
self.timer.start() | ||
self.is_timer_started = True | ||
|
||
def straggler_cutoff_check( | ||
self, | ||
|
@@ -108,13 +142,13 @@ def straggler_cutoff_check( | |
# Time has expired | ||
# Check if minimum_reporting collaborators have reported results | ||
elif self.__minimum_collaborators_reported(num_collaborators_done): | ||
self.logger.info( | ||
logger.info( | ||
f"{num_collaborators_done} collaborators have reported results. " | ||
"Applying cutoff policy and proceeding with end of round." | ||
) | ||
return True | ||
else: | ||
self.logger.info( | ||
logger.info( | ||
f"Waiting for minimum {self.minimum_reporting} collaborator(s) to report results." | ||
) | ||
return False | ||
|
@@ -141,3 +175,66 @@ def __minimum_collaborators_reported(self, num_collaborators_done) -> bool: | |
False otherwise. | ||
""" | ||
return num_collaborators_done >= self.minimum_reporting | ||
|
||
|
||
class PercentagePolicy(StragglerPolicy): | ||
"""Percentage based Straggler Handling function.""" | ||
|
||
def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwargs): | ||
"""Initialize a PercentagePolicy object. | ||
|
||
Args: | ||
percent_collaborators_needed (float, optional): The percentage of | ||
collaborators needed. Defaults to 1.0. | ||
minimum_reporting (int, optional): The minimum number of | ||
collaborators that should report. Defaults to 1. | ||
**kwargs: Variable length argument list. | ||
""" | ||
Comment on lines
+184
to
+192
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Class docstring should go near attributes, not in class PercentagePolicy(StragglerHandlingPolicy):
"""Percentage based Straggler Handling function.
Args:
percent_collaborators_needed (float, optional): The percentage of
collaborators required before concluding the round. Defaults to 1.0.
minimum_reporting (int, optional): The minimum number of
collaborators that should report, regardless of percentage threshold. Defaults to 1.
**kwargs: Variable length argument list.
"""
def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwargs):
if minimum_reporting <= 0:
raise ValueError("minimum_reporting must be >0")
self.percent_collaborators_needed = percent_collaborators_needed
self.minimum_reporting = minimum_reporting
self.logger = getLogger(__name__) On the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The arguments specified are intended for the constructor rather than the class attributes, so they are correctly placed. I have removed the kwargs as per your suggestion. |
||
if minimum_reporting <= 0: | ||
raise ValueError("minimum_reporting must be >0") | ||
Comment on lines
+193
to
+194
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. One liner: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO, assert statements are primarily intended for debugging purposes. |
||
|
||
self.percent_collaborators_needed = percent_collaborators_needed | ||
self.minimum_reporting = minimum_reporting | ||
|
||
def reset_policy_for_round(self) -> None: | ||
"""Not required in PercentagePolicy.""" | ||
pass | ||
|
||
def start_policy(self, **kwargs) -> None: | ||
"""Not required in PercentagePolicy.""" | ||
pass | ||
|
||
def straggler_cutoff_check( | ||
self, | ||
num_collaborators_done: int, | ||
num_all_collaborators: int, | ||
) -> bool: | ||
""" | ||
If percent_collaborators_needed and minimum_reporting collaborators have | ||
reported results, then it is time to end round early. | ||
|
||
Args: | ||
num_collaborators_done (int): The number of collaborators that | ||
have reported. | ||
all_collaborators (list): All the collaborators. | ||
|
||
Returns: | ||
bool: True if the straggler cutoff conditions are met, False | ||
otherwise. | ||
""" | ||
return ( | ||
num_collaborators_done >= self.percent_collaborators_needed * num_all_collaborators | ||
) and self.__minimum_collaborators_reported(num_collaborators_done) | ||
|
||
def __minimum_collaborators_reported(self, num_collaborators_done) -> bool: | ||
"""Check if the minimum number of collaborators have reported. | ||
|
||
Args: | ||
num_collaborators_done (int): The number of collaborators that | ||
have reported. | ||
|
||
Returns: | ||
bool: True if the minimum number of collaborators have reported, | ||
False otherwise. | ||
""" | ||
return num_collaborators_done >= self.minimum_reporting |
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is
CutoffPolicy
mentioned here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To showcase the
CutoffPolicy's
start_policy
implementation as a reference for users.