From eff79fc52a5c5b59de958c510837127459380f91 Mon Sep 17 00:00:00 2001 From: Ishant Thakare Date: Fri, 20 Dec 2024 14:38:11 +0530 Subject: [PATCH] Incorporated karan's review comments Signed-off-by: Ishant Thakare --- openfl/component/__init__.py | 2 +- openfl/component/aggregator/__init__.py | 2 +- openfl/component/aggregator/aggregator.py | 10 ++-- .../aggregator/straggler_handling.py | 46 +++++-------------- openfl/federated/plan/plan.py | 8 +++- 5 files changed, 23 insertions(+), 45 deletions(-) diff --git a/openfl/component/__init__.py b/openfl/component/__init__.py index 60a078371e..792cfd855d 100644 --- a/openfl/component/__init__.py +++ b/openfl/component/__init__.py @@ -6,7 +6,7 @@ from openfl.component.aggregator.straggler_handling import ( CutoffPolicy, PercentagePolicy, - StragglerHandlingPolicy, + StragglerPolicy, ) from openfl.component.assigner.assigner import Assigner from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner diff --git a/openfl/component/aggregator/__init__.py b/openfl/component/aggregator/__init__.py index 728e4c68d8..1bbddaf4f2 100644 --- a/openfl/component/aggregator/__init__.py +++ b/openfl/component/aggregator/__init__.py @@ -6,5 +6,5 @@ from openfl.component.aggregator.straggler_handling import ( CutoffPolicy, PercentagePolicy, - StragglerHandlingPolicy, + StragglerPolicy, ) diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index 354d512d15..db718dfc85 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -9,7 +9,7 @@ from logging import getLogger from threading import Lock -from openfl.component.aggregator.straggler_handling import CutoffPolicy +from openfl.component.aggregator.straggler_handling import CutoffPolicy, StragglerPolicy from openfl.databases import TensorDB from openfl.interface.aggregation_functions import WeightedAverage from openfl.pipelines import NoCompressionPipeline, TensorCodec @@ -71,7 +71,7 @@ def __init__( last_state_path, assigner, use_delta_updates=True, - straggler_handling_policy=CutoffPolicy, + straggler_handling_policy: StragglerPolicy = CutoffPolicy, rounds_to_train=256, single_col_cert_common_name=None, compression_pipeline=None, @@ -95,7 +95,6 @@ def __init__( weight. assigner: Assigner object. straggler_handling_policy (optional): Straggler handling policy. - Defaults to CutoffPolicy. rounds_to_train (int, optional): Number of rounds to train. Defaults to 256. single_col_cert_common_name (str, optional): Common name for single @@ -123,10 +122,7 @@ def __init__( # FIXME: "" instead of None is for protobuf compatibility. self.single_col_cert_common_name = single_col_cert_common_name or "" - if straggler_handling_policy == CutoffPolicy: - self.straggler_handling_policy = straggler_handling_policy() - else: - self.straggler_handling_policy = straggler_handling_policy + self.straggler_handling_policy = straggler_handling_policy() self._end_of_round_check_done = [False] * rounds_to_train self.stragglers = [] diff --git a/openfl/component/aggregator/straggler_handling.py b/openfl/component/aggregator/straggler_handling.py index df26080e90..1d5f6ef7de 100644 --- a/openfl/component/aggregator/straggler_handling.py +++ b/openfl/component/aggregator/straggler_handling.py @@ -12,8 +12,10 @@ import numpy as np +logger = getLogger(__name__) -class StragglerHandlingPolicy(ABC): + +class StragglerPolicy(ABC): """Federated Learning straggler handling interface.""" @abstractmethod @@ -24,23 +26,12 @@ def start_policy(self, **kwargs) -> None: Args: **kwargs - - Returns: - None """ raise NotImplementedError @abstractmethod def reset_policy_for_round(self) -> None: - """ - Reset policy variable for the next round. - - Args: - None - - Returns: - None - """ + """Reset policy for the next round.""" raise NotImplementedError @abstractmethod @@ -65,7 +56,7 @@ def straggler_cutoff_check( raise NotImplementedError -class CutoffPolicy(StragglerHandlingPolicy): +class CutoffPolicy(StragglerPolicy): """Cutoff time based Straggler Handling function.""" def __init__( @@ -90,17 +81,12 @@ def __init__( self.straggler_cutoff_time = straggler_cutoff_time self.minimum_reporting = minimum_reporting self.is_timer_started = False - self.logger = getLogger(__name__) if self.straggler_cutoff_time == np.inf: - self.logger.warning( - "CutoffPolicy is disabled as straggler_cutoff_time " "is set to np.inf." - ) + logger.warning("CutoffPolicy is disabled as straggler_cutoff_time " "is set to np.inf.") def reset_policy_for_round(self) -> None: - """ - Reset timer for the next round. - """ + """Reset timer for the next round.""" if hasattr(self, "timer"): self.timer.cancel() self.is_timer_started = False @@ -113,9 +99,6 @@ def start_policy(self, callback: Callable) -> None: Args: callback: Callable Callback function for when straggler_cutoff_time elapses - - Returns: - None """ # If straggler_cutoff_time is set to infinity # or if the timer is already running, @@ -159,13 +142,13 @@ def straggler_cutoff_check( # Time has expired # Check if minimum_reporting collaborators have reported results elif self.__minimum_collaborators_reported(num_collaborators_done): - self.logger.info( + logger.info( f"{num_collaborators_done} collaborators have reported results. " "Applying cutoff policy and proceeding with end of round." ) return True else: - self.logger.info( + logger.info( f"Waiting for minimum {self.minimum_reporting} collaborator(s) to report results." ) return False @@ -194,7 +177,7 @@ def __minimum_collaborators_reported(self, num_collaborators_done) -> bool: return num_collaborators_done >= self.minimum_reporting -class PercentagePolicy(StragglerHandlingPolicy): +class PercentagePolicy(StragglerPolicy): """Percentage based Straggler Handling function.""" def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwargs): @@ -212,18 +195,13 @@ def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwar self.percent_collaborators_needed = percent_collaborators_needed self.minimum_reporting = minimum_reporting - self.logger = getLogger(__name__) def reset_policy_for_round(self) -> None: - """ - Not required in PercentagePolicy. - """ + """Not required in PercentagePolicy.""" pass def start_policy(self, **kwargs) -> None: - """ - Not required in PercentagePolicy. - """ + """Not required in PercentagePolicy.""" pass def straggler_cutoff_check( diff --git a/openfl/federated/plan/plan.py b/openfl/federated/plan/plan.py index 9c326a746a..b2321d2011 100644 --- a/openfl/federated/plan/plan.py +++ b/openfl/federated/plan/plan.py @@ -4,6 +4,7 @@ """Plan module.""" +from functools import partial from hashlib import sha384 from importlib import import_module from logging import getLogger @@ -43,7 +44,7 @@ class Plan: server_ (AggregatorGRPCServer): gRPC server object. client_ (AggregatorGRPCClient): gRPC client object. pipe_ (CompressionPipeline): Compression pipeline object. - straggler_policy_ (StragglerHandlingPolicy): Straggler handling policy. + straggler_policy_ (StragglerPolicy): Straggler handling policy. hash_ (str): Hash of the instance. name_ (str): Name of the instance. serializer_ (SerializerPlugin): Serializer plugin. @@ -426,7 +427,10 @@ def get_straggler_handling_policy(self): defaults = self.config.get("straggler_handling_policy", {TEMPLATE: template, SETTINGS: {}}) if self.straggler_policy_ is None: - self.straggler_policy_ = Plan.build(**defaults) + # Prepare a partial function for the straggler policy + self.straggler_policy_ = partial( + Plan.import_(defaults["template"]), **defaults["settings"] + ) return self.straggler_policy_