diff --git a/docs/developer_guide/advanced_topics/straggler_handling_algorithms.rst b/docs/developer_guide/advanced_topics/straggler_handling_algorithms.rst index cdfa2b6a2c..32489e9ef1 100644 --- a/docs/developer_guide/advanced_topics/straggler_handling_algorithms.rst +++ b/docs/developer_guide/advanced_topics/straggler_handling_algorithms.rst @@ -11,7 +11,7 @@ The Open Federated Learning (OpenFL) framework supports straggler handling inter The following are the straggler handling algorithms supported in OpenFL: -``CutoffTimeBasedStragglerHandling`` +``CutoffPolicy`` Identifies stragglers based on the cutoff time specified in the settings. Arguments to the function are: - *Cutoff Time* (straggler_cutoff_time), specifies the cutoff time by which the aggregator should end the round early. - *Minimum Reporting* (minimum_reporting), specifies the minimum number of collaborators needed to aggregate the model. @@ -19,7 +19,7 @@ The following are the straggler handling algorithms supported in OpenFL: For example, in a federation of 5 collaborators, if :code:`straggler_cutoff_time` (in seconds) is set to 20 and :code:`minimum_reporting` is set to 2, atleast 2 collaborators (or more) would be included in the round, provided that the time limit of 20 seconds is not exceeded. In an event where :code:`minimum_reporting` collaborators don't make it within the :code:`straggler_cutoff_time`, the straggler handling policy is disregarded. -``PercentageBasedStragglerHandling`` +``PercentagePolicy`` Identifies stragglers based on the percetage specified. Arguments to the function are: - *Percentage of collaborators* (percent_collaborators_needed), specifies a percentage of collaborators enough to end the round early. - *Minimum Reporting* (minimum_reporting), specifies the minimum number of collaborators needed to aggregate the model. @@ -29,12 +29,12 @@ The following are the straggler handling algorithms supported in OpenFL: Demonstration of adding the straggler handling interface ========================================================= -The example template, **torch_cnn_mnist_straggler_check**, uses the ``PercentageBasedStragglerHandling``. To gain a better understanding of how experiments perform, you can modify the **percent_collaborators_needed** or **minimum_reporting** parameter in the template **plan.yaml** or even choose **CutoffTimeBasedStragglerHandling** function instead: +The example template, **torch_cnn_mnist_straggler_check**, uses the ``PercentagePolicy``. To gain a better understanding of how experiments perform, you can modify the **percent_collaborators_needed** or **minimum_reporting** parameter in the template **plan.yaml** or even choose **CutoffPolicy** function instead: .. code-block:: yaml straggler_handling_policy : - template : openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling + template : openfl.component.aggregator.straggler_handling.CutoffPolicy settings : straggler_cutoff_time : 20 minimum_reporting : 1 diff --git a/openfl-workspace/torch_cnn_mnist_straggler_check/plan/plan.yaml b/openfl-workspace/torch_cnn_mnist_straggler_check/plan/plan.yaml index a42b064e56..e2378424b9 100644 --- a/openfl-workspace/torch_cnn_mnist_straggler_check/plan/plan.yaml +++ b/openfl-workspace/torch_cnn_mnist_straggler_check/plan/plan.yaml @@ -45,7 +45,7 @@ compression_pipeline : defaults : plan/defaults/compression_pipeline.yaml straggler_handling_policy : - template : openfl.component.straggler_handling_functions.PercentageBasedStragglerHandling + template : openfl.component.aggregator.straggler_handling.PercentagePolicy settings : percent_collaborators_needed : 0.5 minimum_reporting : 1 \ No newline at end of file diff --git a/openfl/component/__init__.py b/openfl/component/__init__.py index 3b787f87d0..792cfd855d 100644 --- a/openfl/component/__init__.py +++ b/openfl/component/__init__.py @@ -3,16 +3,12 @@ from openfl.component.aggregator.aggregator import Aggregator +from openfl.component.aggregator.straggler_handling import ( + CutoffPolicy, + PercentagePolicy, + StragglerPolicy, +) from openfl.component.assigner.assigner import Assigner from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner from openfl.component.assigner.static_grouped_assigner import StaticGroupedAssigner from openfl.component.collaborator.collaborator import Collaborator -from openfl.component.straggler_handling_functions.cutoff_time_based_straggler_handling import ( - CutoffTimeBasedStragglerHandling, -) -from openfl.component.straggler_handling_functions.percentage_based_straggler_handling import ( - PercentageBasedStragglerHandling, -) -from openfl.component.straggler_handling_functions.straggler_handling_function import ( - StragglerHandlingPolicy, -) diff --git a/openfl/component/aggregator/__init__.py b/openfl/component/aggregator/__init__.py index ed7661486e..1bbddaf4f2 100644 --- a/openfl/component/aggregator/__init__.py +++ b/openfl/component/aggregator/__init__.py @@ -3,3 +3,8 @@ from openfl.component.aggregator.aggregator import Aggregator +from openfl.component.aggregator.straggler_handling import ( + CutoffPolicy, + PercentagePolicy, + StragglerPolicy, +) diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index 1e34aa7e92..db718dfc85 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -9,7 +9,7 @@ from logging import getLogger from threading import Lock -from openfl.component.straggler_handling_functions import CutoffTimeBasedStragglerHandling +from openfl.component.aggregator.straggler_handling import CutoffPolicy, StragglerPolicy from openfl.databases import TensorDB from openfl.interface.aggregation_functions import WeightedAverage from openfl.pipelines import NoCompressionPipeline, TensorCodec @@ -71,7 +71,7 @@ def __init__( last_state_path, assigner, use_delta_updates=True, - straggler_handling_policy=None, + straggler_handling_policy: StragglerPolicy = CutoffPolicy, rounds_to_train=256, single_col_cert_common_name=None, compression_pipeline=None, @@ -95,7 +95,6 @@ def __init__( weight. assigner: Assigner object. straggler_handling_policy (optional): Straggler handling policy. - Defaults to CutoffTimeBasedStragglerHandling. rounds_to_train (int, optional): Number of rounds to train. Defaults to 256. single_col_cert_common_name (str, optional): Common name for single @@ -123,9 +122,8 @@ def __init__( # FIXME: "" instead of None is for protobuf compatibility. self.single_col_cert_common_name = single_col_cert_common_name or "" - self.straggler_handling_policy = ( - straggler_handling_policy or CutoffTimeBasedStragglerHandling() - ) + self.straggler_handling_policy = straggler_handling_policy() + self._end_of_round_check_done = [False] * rounds_to_train self.stragglers = [] # Flag can be enabled to get memory usage details for ubuntu system diff --git a/openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py b/openfl/component/aggregator/straggler_handling.py similarity index 50% rename from openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py rename to openfl/component/aggregator/straggler_handling.py index c3d11ffa4b..1d5f6ef7de 100644 --- a/openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py +++ b/openfl/component/aggregator/straggler_handling.py @@ -2,28 +2,68 @@ # SPDX-License-Identifier: Apache-2.0 -"""Cutoff time based Straggler Handling function.""" +"""Straggler handling module.""" import threading import time +from abc import ABC, abstractmethod from logging import getLogger from typing import Callable import numpy as np -from openfl.component.straggler_handling_functions.straggler_handling_function import ( - StragglerHandlingPolicy, -) +logger = getLogger(__name__) -class CutoffTimeBasedStragglerHandling(StragglerHandlingPolicy): +class StragglerPolicy(ABC): + """Federated Learning straggler handling interface.""" + + @abstractmethod + def start_policy(self, **kwargs) -> None: + """ + Start straggler handling policy for collaborator for a particular round. + NOTE: Refer CutoffPolicy for reference. + + Args: + **kwargs + """ + raise NotImplementedError + + @abstractmethod + def reset_policy_for_round(self) -> None: + """Reset policy for the next round.""" + raise NotImplementedError + + @abstractmethod + def straggler_cutoff_check( + self, num_collaborators_done: int, num_all_collaborators: int, **kwargs + ) -> bool: + """ + Determines whether it is time to end the round early. + + Args: + num_collaborators_done: int + Number of collaborators finished. + num_all_collaborators: int + Total number of collaborators. + + Returns: + bool: True if it is time to end the round early, False otherwise. + + Raises: + NotImplementedError: This method must be implemented by a subclass. + """ + raise NotImplementedError + + +class CutoffPolicy(StragglerPolicy): """Cutoff time based Straggler Handling function.""" def __init__( self, round_start_time=None, straggler_cutoff_time=np.inf, minimum_reporting=1, **kwargs ): """ - Initialize a CutoffTimeBasedStragglerHandling object. + Initialize a CutoffPolicy object. Args: round_start_time (optional): The start time of the round. Defaults @@ -40,21 +80,16 @@ def __init__( self.round_start_time = round_start_time self.straggler_cutoff_time = straggler_cutoff_time self.minimum_reporting = minimum_reporting - self.logger = getLogger(__name__) + self.is_timer_started = False if self.straggler_cutoff_time == np.inf: - self.logger.warning( - "CutoffTimeBasedStragglerHandling is disabled as straggler_cutoff_time " - "is set to np.inf." - ) + logger.warning("CutoffPolicy is disabled as straggler_cutoff_time " "is set to np.inf.") def reset_policy_for_round(self) -> None: - """ - Reset timer for the next round. - """ + """Reset timer for the next round.""" if hasattr(self, "timer"): self.timer.cancel() - delattr(self, "timer") + self.is_timer_started = False def start_policy(self, callback: Callable) -> None: """ @@ -64,15 +99,13 @@ def start_policy(self, callback: Callable) -> None: Args: callback: Callable Callback function for when straggler_cutoff_time elapses - - Returns: - None """ # If straggler_cutoff_time is set to infinity # or if the timer is already running, # do not start the policy. - if self.straggler_cutoff_time == np.inf or hasattr(self, "timer"): + if self.straggler_cutoff_time == np.inf or self.is_timer_started: return + self.round_start_time = time.time() self.timer = threading.Timer( self.straggler_cutoff_time, @@ -80,6 +113,7 @@ def start_policy(self, callback: Callable) -> None: ) self.timer.daemon = True self.timer.start() + self.is_timer_started = True def straggler_cutoff_check( self, @@ -108,13 +142,13 @@ def straggler_cutoff_check( # Time has expired # Check if minimum_reporting collaborators have reported results elif self.__minimum_collaborators_reported(num_collaborators_done): - self.logger.info( + logger.info( f"{num_collaborators_done} collaborators have reported results. " "Applying cutoff policy and proceeding with end of round." ) return True else: - self.logger.info( + logger.info( f"Waiting for minimum {self.minimum_reporting} collaborator(s) to report results." ) return False @@ -141,3 +175,66 @@ def __minimum_collaborators_reported(self, num_collaborators_done) -> bool: False otherwise. """ return num_collaborators_done >= self.minimum_reporting + + +class PercentagePolicy(StragglerPolicy): + """Percentage based Straggler Handling function.""" + + def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwargs): + """Initialize a PercentagePolicy object. + + Args: + percent_collaborators_needed (float, optional): The percentage of + collaborators needed. Defaults to 1.0. + minimum_reporting (int, optional): The minimum number of + collaborators that should report. Defaults to 1. + **kwargs: Variable length argument list. + """ + if minimum_reporting <= 0: + raise ValueError("minimum_reporting must be >0") + + self.percent_collaborators_needed = percent_collaborators_needed + self.minimum_reporting = minimum_reporting + + def reset_policy_for_round(self) -> None: + """Not required in PercentagePolicy.""" + pass + + def start_policy(self, **kwargs) -> None: + """Not required in PercentagePolicy.""" + pass + + def straggler_cutoff_check( + self, + num_collaborators_done: int, + num_all_collaborators: int, + ) -> bool: + """ + If percent_collaborators_needed and minimum_reporting collaborators have + reported results, then it is time to end round early. + + Args: + num_collaborators_done (int): The number of collaborators that + have reported. + all_collaborators (list): All the collaborators. + + Returns: + bool: True if the straggler cutoff conditions are met, False + otherwise. + """ + return ( + num_collaborators_done >= self.percent_collaborators_needed * num_all_collaborators + ) and self.__minimum_collaborators_reported(num_collaborators_done) + + def __minimum_collaborators_reported(self, num_collaborators_done) -> bool: + """Check if the minimum number of collaborators have reported. + + Args: + num_collaborators_done (int): The number of collaborators that + have reported. + + Returns: + bool: True if the minimum number of collaborators have reported, + False otherwise. + """ + return num_collaborators_done >= self.minimum_reporting diff --git a/openfl/component/straggler_handling_functions/__init__.py b/openfl/component/straggler_handling_functions/__init__.py deleted file mode 100644 index 5ab0af1794..0000000000 --- a/openfl/component/straggler_handling_functions/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2020-2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - - -from openfl.component.straggler_handling_functions.cutoff_time_based_straggler_handling import ( - CutoffTimeBasedStragglerHandling, -) -from openfl.component.straggler_handling_functions.percentage_based_straggler_handling import ( - PercentageBasedStragglerHandling, -) -from openfl.component.straggler_handling_functions.straggler_handling_function import ( - StragglerHandlingPolicy, -) diff --git a/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py b/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py deleted file mode 100644 index 099ab7e870..0000000000 --- a/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py +++ /dev/null @@ -1,79 +0,0 @@ -# Copyright 2020-2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - - -"""Percentage based Straggler Handling function.""" - -from logging import getLogger - -from openfl.component.straggler_handling_functions.straggler_handling_function import ( - StragglerHandlingPolicy, -) - - -class PercentageBasedStragglerHandling(StragglerHandlingPolicy): - """Percentage based Straggler Handling function.""" - - def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwargs): - """Initialize a PercentageBasedStragglerHandling object. - - Args: - percent_collaborators_needed (float, optional): The percentage of - collaborators needed. Defaults to 1.0. - minimum_reporting (int, optional): The minimum number of - collaborators that should report. Defaults to 1. - **kwargs: Variable length argument list. - """ - if minimum_reporting <= 0: - raise ValueError("minimum_reporting must be >0") - - self.percent_collaborators_needed = percent_collaborators_needed - self.minimum_reporting = minimum_reporting - self.logger = getLogger(__name__) - - def reset_policy_for_round(self) -> None: - """ - Not required in PercentageBasedStragglerHandling. - """ - pass - - def start_policy(self, **kwargs) -> None: - """ - Not required in PercentageBasedStragglerHandling. - """ - pass - - def straggler_cutoff_check( - self, - num_collaborators_done: int, - num_all_collaborators: int, - ) -> bool: - """ - If percent_collaborators_needed and minimum_reporting collaborators have - reported results, then it is time to end round early. - - Args: - num_collaborators_done (int): The number of collaborators that - have reported. - all_collaborators (list): All the collaborators. - - Returns: - bool: True if the straggler cutoff conditions are met, False - otherwise. - """ - return ( - num_collaborators_done >= self.percent_collaborators_needed * num_all_collaborators - ) and self.__minimum_collaborators_reported(num_collaborators_done) - - def __minimum_collaborators_reported(self, num_collaborators_done) -> bool: - """Check if the minimum number of collaborators have reported. - - Args: - num_collaborators_done (int): The number of collaborators that - have reported. - - Returns: - bool: True if the minimum number of collaborators have reported, - False otherwise. - """ - return num_collaborators_done >= self.minimum_reporting diff --git a/openfl/component/straggler_handling_functions/straggler_handling_function.py b/openfl/component/straggler_handling_functions/straggler_handling_function.py deleted file mode 100644 index 8bd47bc045..0000000000 --- a/openfl/component/straggler_handling_functions/straggler_handling_function.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright 2020-2024 Intel Corporation -# SPDX-License-Identifier: Apache-2.0 - - -"""Straggler handling module.""" - -from abc import ABC, abstractmethod - - -class StragglerHandlingPolicy(ABC): - """Federated Learning straggler handling interface.""" - - @abstractmethod - def start_policy(self, **kwargs) -> None: - """ - Start straggler handling policy for collaborator for a particular round. - NOTE: Refer CutoffTimeBasedStragglerHandling for reference. - - Args: - **kwargs - - Returns: - None - """ - raise NotImplementedError - - @abstractmethod - def reset_policy_for_round(self) -> None: - """ - Reset policy variable for the next round. - - Args: - None - - Returns: - None - """ - raise NotImplementedError - - @abstractmethod - def straggler_cutoff_check( - self, num_collaborators_done: int, num_all_collaborators: int, **kwargs - ) -> bool: - """ - Determines whether it is time to end the round early. - - Args: - num_collaborators_done: int - Number of collaborators finished. - num_all_collaborators: int - Total number of collaborators. - - Returns: - bool: True if it is time to end the round early, False otherwise. - - Raises: - NotImplementedError: This method must be implemented by a subclass. - """ - raise NotImplementedError diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 34c50a4d1e..b2321d2011 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -4,6 +4,7 @@ """Plan module.""" +from functools import partial from hashlib import sha384 from importlib import import_module from logging import getLogger @@ -43,7 +44,7 @@ class Plan: server_ (AggregatorGRPCServer): gRPC server object. client_ (AggregatorGRPCClient): gRPC client object. pipe_ (CompressionPipeline): Compression pipeline object. - straggler_policy_ (StragglerHandlingPolicy): Straggler handling policy. + straggler_policy_ (StragglerPolicy): Straggler handling policy. hash_ (str): Hash of the instance. name_ (str): Name of the instance. serializer_ (SerializerPlugin): Serializer plugin. @@ -422,11 +423,14 @@ def get_tensor_pipe(self): def get_straggler_handling_policy(self): """Get straggler handling policy.""" - template = "openfl.component.straggler_handling_functions.CutoffTimeBasedStragglerHandling" + template = "openfl.component.aggregator.straggler_handling.CutoffPolicy" defaults = self.config.get("straggler_handling_policy", {TEMPLATE: template, SETTINGS: {}}) if self.straggler_policy_ is None: - self.straggler_policy_ = Plan.build(**defaults) + # Prepare a partial function for the straggler policy + self.straggler_policy_ = partial( + Plan.import_(defaults["template"]), **defaults["settings"] + ) return self.straggler_policy_