Skip to content

Commit

Permalink
Incorporated review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Ishant Thakare <[email protected]>
  • Loading branch information
ishant162 committed Oct 23, 2024
1 parent 2f3873d commit 896786f
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ The Open Federated Learning (|productName|) framework supports straggler handlin

The following are the straggler handling algorithms supported in |productName|:

``CutoffTimeBasedStragglerHandling``
``CutoffPolicy``
Identifies stragglers based on the cutoff time specified in the settings. Arguments to the function are:
- *Cutoff Time* (straggler_cutoff_time), specifies the cutoff time by which the aggregator should end the round early.
- *Minimum Reporting* (minimum_reporting), specifies the minimum number of collaborators needed to aggregate the model.

For example, in a federation of 5 collaborators, if :code:`straggler_cutoff_time` (in seconds) is set to 20 and :code:`minimum_reporting` is set to 2, atleast 2 collaborators (or more) would be included in the round, provided that the time limit of 20 seconds is not exceeded.
In an event where :code:`minimum_reporting` collaborators don't make it within the :code:`straggler_cutoff_time`, the straggler handling policy is disregarded.

``PercentageBasedStragglerHandling``
``PercentagePolicy``
Identifies stragglers based on the percetage specified. Arguments to the function are:
- *Percentage of collaborators* (percent_collaborators_needed), specifies a percentage of collaborators enough to end the round early.
- *Minimum Reporting* (minimum_reporting), specifies the minimum number of collaborators needed to aggregate the model.
Expand All @@ -29,12 +29,12 @@ The following are the straggler handling algorithms supported in |productName|:
Demonstration of adding the straggler handling interface
=========================================================

The example template, **torch_cnn_mnist_straggler_check**, uses the ``PercentageBasedStragglerHandling``. To gain a better understanding of how experiments perform, you can modify the **percent_collaborators_needed** or **minimum_reporting** parameter in the template **plan.yaml** or even choose **CutoffTimeBasedStragglerHandling** function instead:
The example template, **torch_cnn_mnist_straggler_check**, uses the ``PercentagePolicy``. To gain a better understanding of how experiments perform, you can modify the **percent_collaborators_needed** or **minimum_reporting** parameter in the template **plan.yaml** or even choose **CutoffPolicy** function instead:

.. code-block:: yaml
straggler_handling_policy :
template : openfl.component.straggler_handling_policy.CutoffTimeBasedStragglerHandling
template : openfl.component.aggregator.straggler_handling.CutoffPolicy
settings :
straggler_cutoff_time : 20
minimum_reporting : 1
Expand Down
1 change: 0 additions & 1 deletion docs/source/api/openfl_component.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,3 @@ Component modules reference:
openfl.component.collaborator
openfl.component.director
openfl.component.envoy
openfl.component.straggler_handling_policy
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml

straggler_handling_policy :
template : openfl.component.straggler_handling_policy.CutoffTimeBasedStragglerHandling
template : openfl.component.aggregator.straggler_handling.PercentagePolicy
settings :
percent_collaborators_needed : 0.5
minimum_reporting : 1
10 changes: 5 additions & 5 deletions openfl/component/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@


from openfl.component.aggregator.aggregator import Aggregator
from openfl.component.aggregator.straggler_handling import (
CutoffPolicy,
PercentagePolicy,
StragglerHandlingPolicy,
)
from openfl.component.assigner.assigner import Assigner
from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner
from openfl.component.assigner.static_grouped_assigner import StaticGroupedAssigner
from openfl.component.collaborator.collaborator import Collaborator
from openfl.component.straggler_handling_policy import (
CutoffTimeBasedStragglerHandling,
PercentageBasedStragglerHandling,
StragglerHandlingPolicy,
)
5 changes: 5 additions & 0 deletions openfl/component/aggregator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@


from openfl.component.aggregator.aggregator import Aggregator
from openfl.component.aggregator.straggler_handling import (
CutoffPolicy,
PercentagePolicy,
StragglerHandlingPolicy,
)
14 changes: 8 additions & 6 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from logging import getLogger
from threading import Lock

from openfl.component.straggler_handling_policy import CutoffTimeBasedStragglerHandling
from openfl.component.aggregator.straggler_handling import CutoffPolicy
from openfl.databases import TensorDB
from openfl.interface.aggregation_functions import WeightedAverage
from openfl.pipelines import NoCompressionPipeline, TensorCodec
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(
best_state_path,
last_state_path,
assigner,
straggler_handling_policy=None,
straggler_handling_policy=CutoffPolicy,
rounds_to_train=256,
single_col_cert_common_name=None,
compression_pipeline=None,
Expand All @@ -92,7 +92,7 @@ def __init__(
weight.
assigner: Assigner object.
straggler_handling_policy (optional): Straggler handling policy.
Defaults to CutoffTimeBasedStragglerHandling.
Defaults to CutoffPolicy.
rounds_to_train (int, optional): Number of rounds to train.
Defaults to 256.
single_col_cert_common_name (str, optional): Common name for single
Expand All @@ -117,9 +117,11 @@ def __init__(
# Cleaner solution?
self.single_col_cert_common_name = ""

self.straggler_handling_policy = (
straggler_handling_policy or CutoffTimeBasedStragglerHandling()
)
if straggler_handling_policy == CutoffPolicy:
self.straggler_handling_policy = straggler_handling_policy()
else:
self.straggler_handling_policy = straggler_handling_policy

self._end_of_round_check_done = [False] * rounds_to_train
self.stragglers = []

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ class StragglerHandlingPolicy(ABC):
def start_policy(self, **kwargs) -> None:
"""
Start straggler handling policy for collaborator for a particular round.
NOTE: Refer CutoffTimeBasedStragglerHandling for reference.
NOTE: Refer CutoffPolicy for reference.
Args:
**kwargs
Expand Down Expand Up @@ -65,14 +65,14 @@ def straggler_cutoff_check(
raise NotImplementedError


class CutoffTimeBasedStragglerHandling(StragglerHandlingPolicy):
class CutoffPolicy(StragglerHandlingPolicy):
"""Cutoff time based Straggler Handling function."""

def __init__(
self, round_start_time=None, straggler_cutoff_time=np.inf, minimum_reporting=1, **kwargs
):
"""
Initialize a CutoffTimeBasedStragglerHandling object.
Initialize a CutoffPolicy object.
Args:
round_start_time (optional): The start time of the round. Defaults
Expand All @@ -89,12 +89,12 @@ def __init__(
self.round_start_time = round_start_time
self.straggler_cutoff_time = straggler_cutoff_time
self.minimum_reporting = minimum_reporting
self.is_timer_started = False
self.logger = getLogger(__name__)

if self.straggler_cutoff_time == np.inf:
self.logger.warning(
"CutoffTimeBasedStragglerHandling is disabled as straggler_cutoff_time "
"is set to np.inf."
"CutoffPolicy is disabled as straggler_cutoff_time " "is set to np.inf."
)

def reset_policy_for_round(self) -> None:
Expand All @@ -103,7 +103,7 @@ def reset_policy_for_round(self) -> None:
"""
if hasattr(self, "timer"):
self.timer.cancel()
delattr(self, "timer")
self.is_timer_started = False

def start_policy(self, callback: Callable) -> None:
"""
Expand All @@ -120,15 +120,17 @@ def start_policy(self, callback: Callable) -> None:
# If straggler_cutoff_time is set to infinity
# or if the timer is already running,
# do not start the policy.
if self.straggler_cutoff_time == np.inf or hasattr(self, "timer"):
if self.straggler_cutoff_time == np.inf or self.is_timer_started:
return

self.round_start_time = time.time()
self.timer = threading.Timer(
self.straggler_cutoff_time,
callback,
)
self.timer.daemon = True
self.timer.start()
self.is_timer_started = True

def straggler_cutoff_check(
self,
Expand Down Expand Up @@ -192,11 +194,11 @@ def __minimum_collaborators_reported(self, num_collaborators_done) -> bool:
return num_collaborators_done >= self.minimum_reporting


class PercentageBasedStragglerHandling(StragglerHandlingPolicy):
class PercentagePolicy(StragglerHandlingPolicy):
"""Percentage based Straggler Handling function."""

def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwargs):
"""Initialize a PercentageBasedStragglerHandling object.
"""Initialize a PercentagePolicy object.
Args:
percent_collaborators_needed (float, optional): The percentage of
Expand All @@ -214,13 +216,13 @@ def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwar

def reset_policy_for_round(self) -> None:
"""
Not required in PercentageBasedStragglerHandling.
Not required in PercentagePolicy.
"""
pass

def start_policy(self, **kwargs) -> None:
"""
Not required in PercentageBasedStragglerHandling.
Not required in PercentagePolicy.
"""
pass

Expand Down
9 changes: 0 additions & 9 deletions openfl/component/straggler_handling_policy/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion openfl/federated/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def get_tensor_pipe(self):

def get_straggler_handling_policy(self):
"""Get straggler handling policy."""
template = "openfl.component.straggler_handling_policy.CutoffTimeBasedStragglerHandling"
template = "openfl.component.aggregator.straggler_handling.CutoffPolicy"
defaults = self.config.get("straggler_handling_policy", {TEMPLATE: template, SETTINGS: {}})

if self.straggler_policy_ is None:
Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,6 @@ def run(self):
'openfl.component.collaborator',
'openfl.component.director',
'openfl.component.envoy',
'openfl.component.straggler_handling_policy',
'openfl.cryptography',
'openfl.databases',
'openfl.databases.utilities',
Expand Down

0 comments on commit 896786f

Please sign in to comment.