Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Straggler handling Follow-up #1097

Draft
wants to merge 21 commits into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
2606ed6
Renamed Straggler Handling package
ishant162 Oct 18, 2024
87152a5
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Oct 18, 2024
c779a4a
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Oct 21, 2024
2f3873d
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Oct 23, 2024
896786f
Incorporated review comments
ishant162 Oct 23, 2024
f5cb6d6
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Oct 25, 2024
2f8ad23
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Oct 25, 2024
be8aeb7
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Oct 25, 2024
989828d
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Nov 2, 2024
92c8871
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Nov 8, 2024
664774a
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Nov 11, 2024
bd23ceb
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Nov 12, 2024
92f8497
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Nov 13, 2024
1c86a4c
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Nov 19, 2024
e89959b
Merge branch 'develop' into straggler_handling_update
ishant162 Nov 27, 2024
03f2bc1
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Dec 2, 2024
cab6d9f
resolving merge conflicts
ishant162 Dec 17, 2024
3969416
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Dec 19, 2024
eff79fc
Incorporated karan's review comments
ishant162 Dec 20, 2024
7f16cdb
Resolving merge conflicts
ishant162 Dec 20, 2024
f55e890
Merge branch 'securefederatedai:develop' into straggler_handling_update
ishant162 Dec 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@ The Open Federated Learning (OpenFL) framework supports straggler handling inter

The following are the straggler handling algorithms supported in OpenFL:

``CutoffTimeBasedStragglerHandling``
``CutoffPolicy``
Identifies stragglers based on the cutoff time specified in the settings. Arguments to the function are:
- *Cutoff Time* (straggler_cutoff_time), specifies the cutoff time by which the aggregator should end the round early.
- *Minimum Reporting* (minimum_reporting), specifies the minimum number of collaborators needed to aggregate the model.

For example, in a federation of 5 collaborators, if :code:`straggler_cutoff_time` (in seconds) is set to 20 and :code:`minimum_reporting` is set to 2, atleast 2 collaborators (or more) would be included in the round, provided that the time limit of 20 seconds is not exceeded.
In an event where :code:`minimum_reporting` collaborators don't make it within the :code:`straggler_cutoff_time`, the straggler handling policy is disregarded.

``PercentageBasedStragglerHandling``
``PercentagePolicy``
Identifies stragglers based on the percetage specified. Arguments to the function are:
- *Percentage of collaborators* (percent_collaborators_needed), specifies a percentage of collaborators enough to end the round early.
- *Minimum Reporting* (minimum_reporting), specifies the minimum number of collaborators needed to aggregate the model.
Expand All @@ -29,12 +29,12 @@ The following are the straggler handling algorithms supported in OpenFL:
Demonstration of adding the straggler handling interface
=========================================================

The example template, **torch_cnn_mnist_straggler_check**, uses the ``PercentageBasedStragglerHandling``. To gain a better understanding of how experiments perform, you can modify the **percent_collaborators_needed** or **minimum_reporting** parameter in the template **plan.yaml** or even choose **CutoffTimeBasedStragglerHandling** function instead:
The example template, **torch_cnn_mnist_straggler_check**, uses the ``PercentagePolicy``. To gain a better understanding of how experiments perform, you can modify the **percent_collaborators_needed** or **minimum_reporting** parameter in the template **plan.yaml** or even choose **CutoffPolicy** function instead:

.. code-block:: yaml

straggler_handling_policy :
template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling
template : openfl.component.aggregator.straggler_handling.CutoffPolicy
settings :
straggler_cutoff_time : 20
minimum_reporting : 1
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ compression_pipeline :
defaults : plan/defaults/compression_pipeline.yaml

straggler_handling_policy :
template : openfl.component.straggler_handling_functions.PercentageBasedStragglerHandling
template : openfl.component.aggregator.straggler_handling.PercentagePolicy
settings :
percent_collaborators_needed : 0.5
minimum_reporting : 1
14 changes: 5 additions & 9 deletions openfl/component/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,12 @@


from openfl.component.aggregator.aggregator import Aggregator
from openfl.component.aggregator.straggler_handling import (
CutoffPolicy,
PercentagePolicy,
StragglerPolicy,
)
from openfl.component.assigner.assigner import Assigner
from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner
from openfl.component.assigner.static_grouped_assigner import StaticGroupedAssigner
from openfl.component.collaborator.collaborator import Collaborator
from openfl.component.straggler_handling_functions.cutoff_time_based_straggler_handling import (
CutoffTimeBasedStragglerHandling,
)
from openfl.component.straggler_handling_functions.percentage_based_straggler_handling import (
PercentageBasedStragglerHandling,
)
from openfl.component.straggler_handling_functions.straggler_handling_function import (
StragglerHandlingPolicy,
)
5 changes: 5 additions & 0 deletions openfl/component/aggregator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,8 @@


from openfl.component.aggregator.aggregator import Aggregator
from openfl.component.aggregator.straggler_handling import (
CutoffPolicy,
PercentagePolicy,
StragglerPolicy,
)
10 changes: 4 additions & 6 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from logging import getLogger
from threading import Lock

from openfl.component.straggler_handling_functions import CutoffTimeBasedStragglerHandling
from openfl.component.aggregator.straggler_handling import CutoffPolicy, StragglerPolicy
from openfl.databases import TensorDB
from openfl.interface.aggregation_functions import WeightedAverage
from openfl.pipelines import NoCompressionPipeline, TensorCodec
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(
last_state_path,
assigner,
use_delta_updates=True,
straggler_handling_policy=None,
straggler_handling_policy: StragglerPolicy = CutoffPolicy,
rounds_to_train=256,
single_col_cert_common_name=None,
compression_pipeline=None,
Expand All @@ -95,7 +95,6 @@ def __init__(
weight.
assigner: Assigner object.
straggler_handling_policy (optional): Straggler handling policy.
Defaults to CutoffTimeBasedStragglerHandling.
rounds_to_train (int, optional): Number of rounds to train.
Defaults to 256.
single_col_cert_common_name (str, optional): Common name for single
Expand Down Expand Up @@ -123,9 +122,8 @@ def __init__(
# FIXME: "" instead of None is for protobuf compatibility.
self.single_col_cert_common_name = single_col_cert_common_name or ""

self.straggler_handling_policy = (
straggler_handling_policy or CutoffTimeBasedStragglerHandling()
)
self.straggler_handling_policy = straggler_handling_policy()

self._end_of_round_check_done = [False] * rounds_to_train
self.stragglers = []
# Flag can be enabled to get memory usage details for ubuntu system
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,28 +2,68 @@
# SPDX-License-Identifier: Apache-2.0


"""Cutoff time based Straggler Handling function."""
"""Straggler handling module."""

import threading
import time
from abc import ABC, abstractmethod
from logging import getLogger
from typing import Callable

import numpy as np

from openfl.component.straggler_handling_functions.straggler_handling_function import (
StragglerHandlingPolicy,
)
logger = getLogger(__name__)


class CutoffTimeBasedStragglerHandling(StragglerHandlingPolicy):
class StragglerPolicy(ABC):
"""Federated Learning straggler handling interface."""

@abstractmethod
def start_policy(self, **kwargs) -> None:
"""
Start straggler handling policy for collaborator for a particular round.
NOTE: Refer CutoffPolicy for reference.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is CutoffPolicy mentioned here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To showcase the CutoffPolicy's start_policy implementation as a reference for users.


Args:
**kwargs
"""
raise NotImplementedError

@abstractmethod
def reset_policy_for_round(self) -> None:
"""Reset policy for the next round."""
raise NotImplementedError

@abstractmethod
def straggler_cutoff_check(
self, num_collaborators_done: int, num_all_collaborators: int, **kwargs
) -> bool:
"""
Determines whether it is time to end the round early.

Args:
num_collaborators_done: int
Number of collaborators finished.
num_all_collaborators: int
Total number of collaborators.

Returns:
bool: True if it is time to end the round early, False otherwise.

Raises:
NotImplementedError: This method must be implemented by a subclass.
"""
raise NotImplementedError


class CutoffPolicy(StragglerPolicy):
"""Cutoff time based Straggler Handling function."""

def __init__(
self, round_start_time=None, straggler_cutoff_time=np.inf, minimum_reporting=1, **kwargs
):
"""
Initialize a CutoffTimeBasedStragglerHandling object.
Initialize a CutoffPolicy object.

Args:
round_start_time (optional): The start time of the round. Defaults
Expand All @@ -40,21 +80,16 @@ def __init__(
self.round_start_time = round_start_time
self.straggler_cutoff_time = straggler_cutoff_time
self.minimum_reporting = minimum_reporting
self.logger = getLogger(__name__)
self.is_timer_started = False

if self.straggler_cutoff_time == np.inf:
self.logger.warning(
"CutoffTimeBasedStragglerHandling is disabled as straggler_cutoff_time "
"is set to np.inf."
)
logger.warning("CutoffPolicy is disabled as straggler_cutoff_time " "is set to np.inf.")

def reset_policy_for_round(self) -> None:
"""
Reset timer for the next round.
"""
"""Reset timer for the next round."""
if hasattr(self, "timer"):
self.timer.cancel()
delattr(self, "timer")
self.is_timer_started = False

def start_policy(self, callback: Callable) -> None:
"""
Expand All @@ -64,22 +99,21 @@ def start_policy(self, callback: Callable) -> None:
Args:
callback: Callable
Callback function for when straggler_cutoff_time elapses

Returns:
None
"""
# If straggler_cutoff_time is set to infinity
# or if the timer is already running,
# do not start the policy.
if self.straggler_cutoff_time == np.inf or hasattr(self, "timer"):
if self.straggler_cutoff_time == np.inf or self.is_timer_started:
return

self.round_start_time = time.time()
self.timer = threading.Timer(
self.straggler_cutoff_time,
callback,
)
self.timer.daemon = True
self.timer.start()
self.is_timer_started = True

def straggler_cutoff_check(
self,
Expand Down Expand Up @@ -108,13 +142,13 @@ def straggler_cutoff_check(
# Time has expired
# Check if minimum_reporting collaborators have reported results
elif self.__minimum_collaborators_reported(num_collaborators_done):
self.logger.info(
logger.info(
f"{num_collaborators_done} collaborators have reported results. "
"Applying cutoff policy and proceeding with end of round."
)
return True
else:
self.logger.info(
logger.info(
f"Waiting for minimum {self.minimum_reporting} collaborator(s) to report results."
)
return False
Expand All @@ -141,3 +175,66 @@ def __minimum_collaborators_reported(self, num_collaborators_done) -> bool:
False otherwise.
"""
return num_collaborators_done >= self.minimum_reporting


class PercentagePolicy(StragglerPolicy):
"""Percentage based Straggler Handling function."""

def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwargs):
"""Initialize a PercentagePolicy object.

Args:
percent_collaborators_needed (float, optional): The percentage of
collaborators needed. Defaults to 1.0.
minimum_reporting (int, optional): The minimum number of
collaborators that should report. Defaults to 1.
**kwargs: Variable length argument list.
"""
Comment on lines +184 to +192
Copy link
Collaborator

@MasterSkepticista MasterSkepticista Oct 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Class docstring should go near attributes, not in __init__. Example:

class PercentagePolicy(StragglerHandlingPolicy):
    """Percentage based Straggler Handling function.

    Args:
            percent_collaborators_needed (float, optional): The percentage of
                collaborators required before concluding the round. Defaults to 1.0.
            minimum_reporting (int, optional): The minimum number of
                collaborators that should report, regardless of percentage threshold. Defaults to 1.
            **kwargs: Variable length argument list.
    """

    def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwargs):
        if minimum_reporting <= 0:
            raise ValueError("minimum_reporting must be >0")

        self.percent_collaborators_needed = percent_collaborators_needed
        self.minimum_reporting = minimum_reporting
        self.logger = getLogger(__name__)

On the **kwargs: Why would a user provide this? Where does it get used, please mention like keyword arguments forwarded to some.module.class). If they are never forwarded or stored (which seems to be the case looking at certain classes, they should not be accepted from the user either.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The arguments specified are intended for the constructor rather than the class attributes, so they are correctly placed. I have removed the kwargs as per your suggestion.

if minimum_reporting <= 0:
raise ValueError("minimum_reporting must be >0")
Comment on lines +193 to +194
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One liner: assert minimum_reporting > 0, "Minimum reporting must be >0"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO, assert statements are primarily intended for debugging purposes.
Using raise with a ValueError more clearly indicates that the error is due to an invalid argument


self.percent_collaborators_needed = percent_collaborators_needed
self.minimum_reporting = minimum_reporting

def reset_policy_for_round(self) -> None:
"""Not required in PercentagePolicy."""
pass

def start_policy(self, **kwargs) -> None:
"""Not required in PercentagePolicy."""
pass

def straggler_cutoff_check(
self,
num_collaborators_done: int,
num_all_collaborators: int,
) -> bool:
"""
If percent_collaborators_needed and minimum_reporting collaborators have
reported results, then it is time to end round early.

Args:
num_collaborators_done (int): The number of collaborators that
have reported.
all_collaborators (list): All the collaborators.

Returns:
bool: True if the straggler cutoff conditions are met, False
otherwise.
"""
return (
num_collaborators_done >= self.percent_collaborators_needed * num_all_collaborators
) and self.__minimum_collaborators_reported(num_collaborators_done)

def __minimum_collaborators_reported(self, num_collaborators_done) -> bool:
"""Check if the minimum number of collaborators have reported.

Args:
num_collaborators_done (int): The number of collaborators that
have reported.

Returns:
bool: True if the minimum number of collaborators have reported,
False otherwise.
"""
return num_collaborators_done >= self.minimum_reporting
13 changes: 0 additions & 13 deletions openfl/component/straggler_handling_functions/__init__.py

This file was deleted.

Loading
Loading