Skip to content

Commit

Permalink
Incorporated karan's review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Ishant Thakare <[email protected]>
  • Loading branch information
ishant162 committed Dec 20, 2024
1 parent 3969416 commit eff79fc
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 45 deletions.
2 changes: 1 addition & 1 deletion openfl/component/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from openfl.component.aggregator.straggler_handling import (
CutoffPolicy,
PercentagePolicy,
StragglerHandlingPolicy,
StragglerPolicy,
)
from openfl.component.assigner.assigner import Assigner
from openfl.component.assigner.random_grouped_assigner import RandomGroupedAssigner
Expand Down
2 changes: 1 addition & 1 deletion openfl/component/aggregator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@
from openfl.component.aggregator.straggler_handling import (
CutoffPolicy,
PercentagePolicy,
StragglerHandlingPolicy,
StragglerPolicy,
)
10 changes: 3 additions & 7 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from logging import getLogger
from threading import Lock

from openfl.component.aggregator.straggler_handling import CutoffPolicy
from openfl.component.aggregator.straggler_handling import CutoffPolicy, StragglerPolicy
from openfl.databases import TensorDB
from openfl.interface.aggregation_functions import WeightedAverage
from openfl.pipelines import NoCompressionPipeline, TensorCodec
Expand Down Expand Up @@ -71,7 +71,7 @@ def __init__(
last_state_path,
assigner,
use_delta_updates=True,
straggler_handling_policy=CutoffPolicy,
straggler_handling_policy: StragglerPolicy = CutoffPolicy,
rounds_to_train=256,
single_col_cert_common_name=None,
compression_pipeline=None,
Expand All @@ -95,7 +95,6 @@ def __init__(
weight.
assigner: Assigner object.
straggler_handling_policy (optional): Straggler handling policy.
Defaults to CutoffPolicy.
rounds_to_train (int, optional): Number of rounds to train.
Defaults to 256.
single_col_cert_common_name (str, optional): Common name for single
Expand Down Expand Up @@ -123,10 +122,7 @@ def __init__(
# FIXME: "" instead of None is for protobuf compatibility.
self.single_col_cert_common_name = single_col_cert_common_name or ""

if straggler_handling_policy == CutoffPolicy:
self.straggler_handling_policy = straggler_handling_policy()
else:
self.straggler_handling_policy = straggler_handling_policy
self.straggler_handling_policy = straggler_handling_policy()

self._end_of_round_check_done = [False] * rounds_to_train
self.stragglers = []
Expand Down
46 changes: 12 additions & 34 deletions openfl/component/aggregator/straggler_handling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@

import numpy as np

logger = getLogger(__name__)

class StragglerHandlingPolicy(ABC):

class StragglerPolicy(ABC):
"""Federated Learning straggler handling interface."""

@abstractmethod
Expand All @@ -24,23 +26,12 @@ def start_policy(self, **kwargs) -> None:
Args:
**kwargs
Returns:
None
"""
raise NotImplementedError

@abstractmethod
def reset_policy_for_round(self) -> None:
"""
Reset policy variable for the next round.
Args:
None
Returns:
None
"""
"""Reset policy for the next round."""
raise NotImplementedError

@abstractmethod
Expand All @@ -65,7 +56,7 @@ def straggler_cutoff_check(
raise NotImplementedError


class CutoffPolicy(StragglerHandlingPolicy):
class CutoffPolicy(StragglerPolicy):
"""Cutoff time based Straggler Handling function."""

def __init__(
Expand All @@ -90,17 +81,12 @@ def __init__(
self.straggler_cutoff_time = straggler_cutoff_time
self.minimum_reporting = minimum_reporting
self.is_timer_started = False
self.logger = getLogger(__name__)

if self.straggler_cutoff_time == np.inf:
self.logger.warning(
"CutoffPolicy is disabled as straggler_cutoff_time " "is set to np.inf."
)
logger.warning("CutoffPolicy is disabled as straggler_cutoff_time " "is set to np.inf.")

def reset_policy_for_round(self) -> None:
"""
Reset timer for the next round.
"""
"""Reset timer for the next round."""
if hasattr(self, "timer"):
self.timer.cancel()
self.is_timer_started = False
Expand All @@ -113,9 +99,6 @@ def start_policy(self, callback: Callable) -> None:
Args:
callback: Callable
Callback function for when straggler_cutoff_time elapses
Returns:
None
"""
# If straggler_cutoff_time is set to infinity
# or if the timer is already running,
Expand Down Expand Up @@ -159,13 +142,13 @@ def straggler_cutoff_check(
# Time has expired
# Check if minimum_reporting collaborators have reported results
elif self.__minimum_collaborators_reported(num_collaborators_done):
self.logger.info(
logger.info(
f"{num_collaborators_done} collaborators have reported results. "
"Applying cutoff policy and proceeding with end of round."
)
return True
else:
self.logger.info(
logger.info(
f"Waiting for minimum {self.minimum_reporting} collaborator(s) to report results."
)
return False
Expand Down Expand Up @@ -194,7 +177,7 @@ def __minimum_collaborators_reported(self, num_collaborators_done) -> bool:
return num_collaborators_done >= self.minimum_reporting


class PercentagePolicy(StragglerHandlingPolicy):
class PercentagePolicy(StragglerPolicy):
"""Percentage based Straggler Handling function."""

def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwargs):
Expand All @@ -212,18 +195,13 @@ def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwar

self.percent_collaborators_needed = percent_collaborators_needed
self.minimum_reporting = minimum_reporting
self.logger = getLogger(__name__)

def reset_policy_for_round(self) -> None:
"""
Not required in PercentagePolicy.
"""
"""Not required in PercentagePolicy."""
pass

def start_policy(self, **kwargs) -> None:
"""
Not required in PercentagePolicy.
"""
"""Not required in PercentagePolicy."""
pass

def straggler_cutoff_check(
Expand Down
8 changes: 6 additions & 2 deletions openfl/federated/plan/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""Plan module."""

from functools import partial
from hashlib import sha384
from importlib import import_module
from logging import getLogger
Expand Down Expand Up @@ -43,7 +44,7 @@ class Plan:
server_ (AggregatorGRPCServer): gRPC server object.
client_ (AggregatorGRPCClient): gRPC client object.
pipe_ (CompressionPipeline): Compression pipeline object.
straggler_policy_ (StragglerHandlingPolicy): Straggler handling policy.
straggler_policy_ (StragglerPolicy): Straggler handling policy.
hash_ (str): Hash of the instance.
name_ (str): Name of the instance.
serializer_ (SerializerPlugin): Serializer plugin.
Expand Down Expand Up @@ -426,7 +427,10 @@ def get_straggler_handling_policy(self):
defaults = self.config.get("straggler_handling_policy", {TEMPLATE: template, SETTINGS: {}})

if self.straggler_policy_ is None:
self.straggler_policy_ = Plan.build(**defaults)
# Prepare a partial function for the straggler policy
self.straggler_policy_ = partial(
Plan.import_(defaults["template"]), **defaults["settings"]
)

return self.straggler_policy_

Expand Down

0 comments on commit eff79fc

Please sign in to comment.