From 82270de55b0e990a5a913dcd8c2721522fac6a08 Mon Sep 17 00:00:00 2001 From: Parth Mandaliya Date: Mon, 23 Sep 2024 23:52:07 +0530 Subject: [PATCH] Updates to straggler handling functionality (#996) * v1.0 straggler handling added Signed-off-by: Parth Mandaliya * Added start_straggler_cutoff_timer abstract function in StragglerHandlingFunction abstract class. Renamed start_timer and __timer_expired functions to start_straggler_cutoff_timer and _straggler_cutoff_time_elapsed respectively. Added docstring to both functions mentioned above. Signed-off-by: Parth Mandaliya * Updated logs straggler handling in aggregator.py. If one or more collaborator(s) does not even 1 task results in time, all tasks results sent by that collaborator is excluded from aggregation. Signed-off-by: Parth Mandaliya * Only time based straggler handling policies require timer thread, removing start_straggler_cutoff_timer function from parent class StragglerHandlingFunction Signed-off-by: Parth Mandaliya * Find all unfinished tasks and straggler collaborators Signed-off-by: Parth Mandaliya * Review comments incorporated Signed-off-by: Parth Mandaliya * Added inline comments Signed-off-by: Parth Mandaliya * Changed logic to keep track of collaborators which have reported results for all tasks Changed straggler handling logs Added docstring for functions in all straggler handling policies Signed-off-by: Parth Mandaliya * 1. StragglerHandlingFunction: Added an interface method start_policy 2. CutoffTimeBasedStragglerHandling: start_policy method implements a timer to wait for cutoff-time and then call provided call callback method 3. Aggregator: - sendlocaltaskresults: update collaborators_done to keep track of collaborators that have finished ALL tasks - _straggler_cutoff_time_elapsed: call back function that is called after cutoff time has elapsed and applies the straggler policy Signed-off-by: Parth Mandaliya * Removed logger argument from start_policy & Added logger argument to get_straggler_handling_policy in plan.py Signed-off-by: Parth Mandaliya * If cutoff time is set to infinite, do not start the timer thread. Signed-off-by: Parth Mandaliya * Resolved lint issues. Signed-off-by: Parth Mandaliya * Pytest and code coverage test case failure resolved Signed-off-by: Parth Mandaliya * Default logger value is set to None Signed-off-by: Parth Mandaliya * Redesigned percentage based straggler policy. minimum_reporting cannot be set 0 in any straggler policy Signed-off-by: Parth Mandaliya * Code cleanup Signed-off-by: Parth Mandaliya * Only collaborators_done are used for aggregation Signed-off-by: Parth Mandaliya * Internal review comments incorporated Logs updated Logger argument removed from straggler handling policy classes Signed-off-by: Parth Mandaliya * Resolving potential issues found during testing Signed-off-by: Parth Mandaliya * Use _collaborator_task_completed method to check if all given tasks to collaborator are completed or not Signed-off-by: Parth Mandaliya * Few test cases failing issue resolved Signed-off-by: Parth Mandaliya * Potential issue in aggregator based workflow tutorial resolved Signed-off-by: Parth Mandaliya * Corner case issue discovered during testing is patched. Signed-off-by: Parth Mandaliya * Added reset_policy_for_round function in straggler handling policy base class. Signed-off-by: Parth Mandaliya * This commit includes following changes: 1. In cutoff time based policy after cutoff time expires wait for all collaborators not just minimum required. 2. Irregardless of tasks assigned to collaborators if minimum required collaborators report resultsw in time apply straggler handling policy. Signed-off-by: Parth Mandaliya * Code cleanup Signed-off-by: Parth Mandaliya * Logs modified Signed-off-by: Parth Mandaliya * Review comments on PR incorporated. Signed-off-by: Parth Mandaliya * Log updated Signed-off-by: Parth Mandaliya * Condition in straggler_cutoff_check fixed Signed-off-by: Parth Mandaliya * If straggler cutoff time set to infinite only wait for minimum required collaborators to report results. Signed-off-by: Parth Mandaliya * If straggler cutoff time is set to infinite wait for ALL collaborators not only for minimum_reporting collaborators Signed-off-by: Parth Mandaliya * Teodor's review comments and internal review comments incorporated. flake8 issues resovled. Signed-off-by: Parth Mandaliya * Changed minimum_reporting validation. Merged conditions in straggler_cutoff_check function in CutoffTimeBasedStragglerHandling class. Signed-off-by: Parth Mandaliya * Resolved all flake8 issues. Signed-off-by: Parth Mandaliya * Modified single inline comment Signed-off-by: Parth Mandaliya * Review comments incorporated. Signed-off-by: Parth Mandaliya * Lint fixes Signed-off-by: Ishant Thakare * Incorporated Micah's review comments and removed unused code Signed-off-by: Ishant Thakare * Incorporated Teo's review comments Signed-off-by: Ishant Thakare * Incorporated review comments & added mutex in aggregator for thread safety Signed-off-by: Ishant Thakare * Review comments incorporated and updated logs Signed-off-by: Ishant Thakare * Reverted a comment for Pytest and code coverage fix Signed-off-by: Ishant Thakare --------- Signed-off-by: Parth Mandaliya Signed-off-by: Ishant Thakare Co-authored-by: Ishant Thakare --- openfl/component/__init__.py | 2 +- openfl/component/aggregator/aggregator.py | 160 +++++++++++------- .../straggler_handling_functions/__init__.py | 2 +- .../cutoff_time_based_straggler_handling.py | 110 +++++++++--- .../percentage_based_straggler_handling.py | 55 ++++-- .../straggler_handling_function.py | 41 ++++- .../component/aggregator/test_aggregator.py | 77 --------- 7 files changed, 265 insertions(+), 182 deletions(-) diff --git a/openfl/component/__init__.py b/openfl/component/__init__.py index df4efe1c5c..3b787f87d0 100644 --- a/openfl/component/__init__.py +++ b/openfl/component/__init__.py @@ -14,5 +14,5 @@ PercentageBasedStragglerHandling, ) from openfl.component.straggler_handling_functions.straggler_handling_function import ( - StragglerHandlingFunction, + StragglerHandlingPolicy, ) diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index e1b90d22e5..81d3e7411a 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -6,6 +6,7 @@ import queue import time from logging import getLogger +from threading import Lock from openfl.component.straggler_handling_functions import CutoffTimeBasedStragglerHandling from openfl.databases import TensorDB @@ -53,9 +54,10 @@ class Aggregator: collaborator_tasks_results (dict): Dict of collaborator tasks results. collaborator_task_weight (dict): Dict of col task weight. + lock: A threading Lock object used to ensure thread-safe operations. .. note:: - - plan setting + - plan setting """ def __init__( @@ -177,6 +179,13 @@ def __init__( self.collaborator_task_weight = {} # {TaskResultKey: data_size} + # maintain a list of collaborators that have completed task and + # reported results in a given round + self.collaborators_done = [] + + # Initialize a lock for thread safety + self.lock = Lock() + def _load_initial_tensors(self): """Load all of the tensors required to begin federated learning. @@ -391,11 +400,30 @@ def get_tasks(self, collaborator_name): ) sleep_time = 0 - if hasattr(self.straggler_handling_policy, "round_start_time"): - self.straggler_handling_policy.round_start_time = time.time() + # Start straggler handling policy for timer based callback is required + # for %age based policy callback is not required + self.straggler_handling_policy.start_policy(callback=self._straggler_cutoff_time_elapsed) return tasks, self.round_number, sleep_time, time_to_quit + def _straggler_cutoff_time_elapsed(self) -> None: + """ + This method is called by the straggler handling policy when cutoff timer is elapsed. + It applies straggler handling policy and ends the round early. + + Returns: + None + """ + self.logger.warning( + f"Round number: {self.round_number} cutoff timer elapsed after " + f"{self.straggler_handling_policy.straggler_cutoff_time}s. " + f"Applying {self.straggler_handling_policy.__class__.__name__} policy." + ) + + with self.lock: + # Check if minimum collaborators reported results + self._end_of_round_with_stragglers_check() + def get_aggregated_tensor( self, collaborator_name, @@ -573,10 +601,10 @@ def send_local_task_results( Returns: None """ - if self._time_to_quit() or self._is_task_done(task_name): + if self._time_to_quit() or collaborator_name in self.stragglers: self.logger.warning( f"STRAGGLER: Collaborator {collaborator_name} is reporting results " - "after task {task_name} has finished." + f"after task {task_name} has finished." ) return @@ -596,10 +624,11 @@ def send_local_task_results( # we mustn't have results already if self._collaborator_task_completed(collaborator_name, task_name, round_number): - raise ValueError( + self.logger.warning( f"Aggregator already has task results from collaborator {collaborator_name}" f" for task {task_key}" ) + return # By giving task_key it's own weight, we can support different # training/validation weights @@ -632,7 +661,31 @@ def send_local_task_results( task_results.append(tensor_key) self.collaborator_tasks_results[task_key] = task_results - self._end_of_task_check(task_name) + + with self.lock: + self._is_collaborator_done(collaborator_name, round_number) + + self._end_of_round_with_stragglers_check() + + def _end_of_round_with_stragglers_check(self): + """ + Checks if the minimum required collaborators have reported their results, + identifies any stragglers, and initiates an early round end if necessary. + + Returns: + None + """ + if self.straggler_handling_policy.straggler_cutoff_check( + len(self.collaborators_done), len(self.authorized_cols) + ): + self.stragglers = [ + collab_name + for collab_name in self.authorized_cols + if collab_name not in self.collaborators_done + ] + if len(self.stragglers) != 0: + self.logger.warning(f"Identified stragglers: {self.stragglers}") + self._end_of_round_check() def _process_named_tensor(self, named_tensor, collaborator_name): """Extract the named tensor fields. @@ -724,21 +777,6 @@ def _process_named_tensor(self, named_tensor, collaborator_name): return final_tensor_key, final_nparray - def _end_of_task_check(self, task_name): - """Check whether all collaborators who are supposed to perform the - task complete. - - Args: - task_name (str): Task name. - The task name to check. - - Returns: - bool: Whether the task is done. - """ - if self._is_task_done(task_name): - # now check for the end of the round - self._end_of_round_check() - def _prepare_trained(self, tensor_name, origin, round_number, report, agg_results): """Prepare aggregated tensorkey tags. @@ -839,11 +877,12 @@ def _compute_validation_related_task_metrics(self, task_name): all_collaborators_for_task = self.assigner.get_collaborators_for_task( task_name, self.round_number ) - # leave out stragglers for the round + # Leave out straggler for the round even if they've paritally + # completed given tasks collaborators_for_task = [] - for c in all_collaborators_for_task: - if self._collaborator_task_completed(c, task_name, self.round_number): - collaborators_for_task.append(c) + collaborators_for_task = [ + c for c in all_collaborators_for_task if c in self.collaborators_done + ] # The collaborator data sizes for that task collaborator_weights_unnormalized = { @@ -919,7 +958,7 @@ def _end_of_round_check(self): Returns: None """ - if not self._is_round_done() or self._end_of_round_check_done[self.round_number]: + if self._end_of_round_check_done[self.round_number]: return # Compute all validation related metrics @@ -932,6 +971,8 @@ def _end_of_round_check(self): self.round_number += 1 # resetting stragglers for task for a new round self.stragglers = [] + # resetting collaborators_done for next round + self.collaborators_done = [] # Save the latest model self.logger.info("Saving round %s model...", self.round_number) @@ -945,49 +986,46 @@ def _end_of_round_check(self): # Cleaning tensor db self.tensor_db.clean_up(self.db_store_rounds) + # Reset straggler handling policy for the next round. + self.straggler_handling_policy.reset_policy_for_round() - def _is_task_done(self, task_name): - """Check that task is done. + def _is_collaborator_done(self, collaborator_name: str, round_number: int) -> None: + """ + Check if all tasks given to the collaborator are completed then, + completed or not. Args: - task_name (str): Task name. + collaborator_name (str): Collaborator name. + round_number (int): Round number. Returns: - bool: Whether the task is done. + None """ - all_collaborators = self.assigner.get_collaborators_for_task(task_name, self.round_number) - - collaborators_done = [] - for c in all_collaborators: - if self._collaborator_task_completed(c, task_name, self.round_number): - collaborators_done.append(c) - - straggler_check = self.straggler_handling_policy.straggler_cutoff_check( - len(collaborators_done), all_collaborators - ) + if self.round_number != round_number: + self.logger.warning( + f"Collaborator {collaborator_name} is reporting results" + f" for the wrong round: {round_number}. Ignoring..." + ) + return - if straggler_check: - for c in all_collaborators: - if c not in collaborators_done: - self.stragglers.append(c) + # Get all tasks given to the collaborator for current round + all_tasks = self.assigner.get_tasks_for_collaborator(collaborator_name, self.round_number) + # Check if all given tasks are completed by the collaborator + all_tasks_completed = True + for task in all_tasks: + if hasattr(task, "name"): + task = task.name + all_tasks_completed = all_tasks_completed and self._collaborator_task_completed( + collaborator=collaborator_name, task_name=task, round_num=self.round_number + ) + # If the collaborator has completed ALL tasks for current round, + # update collaborators_done + if all_tasks_completed: + self.collaborators_done.append(collaborator_name) self.logger.info( - "\tEnding task %s early due to straggler cutoff policy", - task_name, + f"Round: {self.round_number}, Collaborators that have completed all tasks: " + f"{self.collaborators_done}" ) - self.logger.warning("\tIdentified stragglers: %s", self.stragglers) - - # all are done or straggler policy calls for early round end. - return straggler_check or len(all_collaborators) == len(collaborators_done) - - def _is_round_done(self): - """Check that round is done. - - Returns: - bool: Whether the round is done. - """ - tasks_for_round = self.assigner.get_all_tasks_for_round(self.round_number) - - return all(self._is_task_done(task_name) for task_name in tasks_for_round) def _log_big_warning(self): """Warn user about single collaborator cert mode.""" diff --git a/openfl/component/straggler_handling_functions/__init__.py b/openfl/component/straggler_handling_functions/__init__.py index 58792cdda4..5ab0af1794 100644 --- a/openfl/component/straggler_handling_functions/__init__.py +++ b/openfl/component/straggler_handling_functions/__init__.py @@ -9,5 +9,5 @@ PercentageBasedStragglerHandling, ) from openfl.component.straggler_handling_functions.straggler_handling_function import ( - StragglerHandlingFunction, + StragglerHandlingPolicy, ) diff --git a/openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py b/openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py index 82bc19960e..ca8e218f7c 100644 --- a/openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py +++ b/openfl/component/straggler_handling_functions/cutoff_time_based_straggler_handling.py @@ -3,16 +3,19 @@ """Cutoff time based Straggler Handling function.""" +import threading import time +from logging import getLogger +from typing import Callable import numpy as np from openfl.component.straggler_handling_functions.straggler_handling_function import ( - StragglerHandlingFunction, + StragglerHandlingPolicy, ) -class CutoffTimeBasedStragglerHandling(StragglerHandlingFunction): +class CutoffTimeBasedStragglerHandling(StragglerHandlingPolicy): """Cutoff time based Straggler Handling function.""" def __init__( @@ -30,11 +33,92 @@ def __init__( collaborators that should report. Defaults to 1. **kwargs: Variable length argument list. """ + if minimum_reporting <= 0: + raise ValueError("minimum_reporting must be >0") + self.round_start_time = round_start_time self.straggler_cutoff_time = straggler_cutoff_time self.minimum_reporting = minimum_reporting + self.logger = getLogger(__name__) + + if self.straggler_cutoff_time == np.inf: + self.logger.warning( + "CutoffTimeBasedStragglerHandling is disabled as straggler_cutoff_time " + "is set to np.inf." + ) + + def reset_policy_for_round(self) -> None: + """ + Reset timer for the next round. + """ + if hasattr(self, "timer"): + self.timer.cancel() + delattr(self, "timer") + + def start_policy(self, callback: Callable) -> None: + """ + Start time-based straggler handling policy for collaborator for + a particular round. + + Args: + callback: Callable + Callback function for when straggler_cutoff_time elapses + + Returns: + None + """ + # If straggler_cutoff_time is set to infinity + # or if the timer is already running, + # do not start the policy. + if self.straggler_cutoff_time == np.inf or hasattr(self, "timer"): + return + self.round_start_time = time.time() + self.timer = threading.Timer( + self.straggler_cutoff_time, + callback, + ) + self.timer.daemon = True + self.timer.start() + + def straggler_cutoff_check( + self, + num_collaborators_done: int, + num_all_collaborators: int, + ) -> bool: + """ + If minimum_reporting collaborators have reported results within + straggler_cutoff_time then return True, otherwise False. - def straggler_time_expired(self): + Args: + num_collaborators_done: int + Number of collaborators finished. + num_all_collaborators: int + Total number of collaborators. + + Returns: + bool: True if the straggler cutoff conditions are met, False otherwise. + """ + + # if straggler time has not expired then + # wait for ALL collaborators to report results. + if not self.__straggler_time_expired(): + return num_all_collaborators == num_collaborators_done + + # Time has expired + # Check if minimum_reporting collaborators have reported results + elif self.__minimum_collaborators_reported(num_collaborators_done): + self.logger.info( + f"{num_collaborators_done} collaborators have reported results. " + "Applying cutoff policy and proceeding with end of round." + ) + return True + else: + self.logger.info( + f"Waiting for minimum {self.minimum_reporting} collaborator(s) to report results." + ) + return False + + def __straggler_time_expired(self) -> bool: """Check if the straggler time has expired. Returns: @@ -44,7 +128,7 @@ def straggler_time_expired(self): (time.time() - self.round_start_time) > self.straggler_cutoff_time ) - def minimum_collaborators_reported(self, num_collaborators_done): + def __minimum_collaborators_reported(self, num_collaborators_done) -> bool: """Check if the minimum number of collaborators have reported. Args: @@ -56,21 +140,3 @@ def minimum_collaborators_reported(self, num_collaborators_done): False otherwise. """ return num_collaborators_done >= self.minimum_reporting - - def straggler_cutoff_check(self, num_collaborators_done, all_collaborators=None): - """Check if the straggler cutoff conditions are met. - - Args: - num_collaborators_done (int): The number of collaborators that - have reported. - all_collaborators (optional): All the collaborators. Defaults to - None. - - Returns: - bool: True if the straggler cutoff conditions are met, False - otherwise. - """ - cutoff = self.straggler_time_expired() and self.minimum_collaborators_reported( - num_collaborators_done - ) - return cutoff diff --git a/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py b/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py index 260e9ca53b..e556ea291b 100644 --- a/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py +++ b/openfl/component/straggler_handling_functions/percentage_based_straggler_handling.py @@ -3,12 +3,14 @@ """Percentage based Straggler Handling function.""" +from logging import getLogger + from openfl.component.straggler_handling_functions.straggler_handling_function import ( - StragglerHandlingFunction, + StragglerHandlingPolicy, ) -class PercentageBasedStragglerHandling(StragglerHandlingFunction): +class PercentageBasedStragglerHandling(StragglerHandlingPolicy): """Percentage based Straggler Handling function.""" def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwargs): @@ -21,35 +23,56 @@ def __init__(self, percent_collaborators_needed=1.0, minimum_reporting=1, **kwar collaborators that should report. Defaults to 1. **kwargs: Variable length argument list. """ + if minimum_reporting <= 0: + raise ValueError("minimum_reporting must be >0") + self.percent_collaborators_needed = percent_collaborators_needed self.minimum_reporting = minimum_reporting + self.logger = getLogger(__name__) - def minimum_collaborators_reported(self, num_collaborators_done): - """Check if the minimum number of collaborators have reported. + def reset_policy_for_round(self) -> None: + """ + Not required in PercentageBasedStragglerHandling. + """ + pass + + def start_policy(self, **kwargs) -> None: + """ + Not required in PercentageBasedStragglerHandling. + """ + pass + + def straggler_cutoff_check( + self, + num_collaborators_done: int, + num_all_collaborators: int, + ) -> bool: + """ + If percent_collaborators_needed and minimum_reporting collaborators have + reported results, then it is time to end round early. Args: num_collaborators_done (int): The number of collaborators that have reported. + all_collaborators (list): All the collaborators. Returns: - bool: True if the minimum number of collaborators have reported, - False otherwise. + bool: True if the straggler cutoff conditions are met, False + otherwise. """ - return num_collaborators_done >= self.minimum_reporting + return ( + num_collaborators_done >= self.percent_collaborators_needed * num_all_collaborators + ) and self.__minimum_collaborators_reported(num_collaborators_done) - def straggler_cutoff_check(self, num_collaborators_done, all_collaborators): - """Check if the straggler cutoff conditions are met. + def __minimum_collaborators_reported(self, num_collaborators_done) -> bool: + """Check if the minimum number of collaborators have reported. Args: num_collaborators_done (int): The number of collaborators that have reported. - all_collaborators (list): All the collaborators. Returns: - bool: True if the straggler cutoff conditions are met, False - otherwise. + bool: True if the minimum number of collaborators have reported, + False otherwise. """ - cutoff = ( - num_collaborators_done >= self.percent_collaborators_needed * len(all_collaborators) - ) and self.minimum_collaborators_reported(num_collaborators_done) - return cutoff + return num_collaborators_done >= self.minimum_reporting diff --git a/openfl/component/straggler_handling_functions/straggler_handling_function.py b/openfl/component/straggler_handling_functions/straggler_handling_function.py index 5700cf12a9..8bd47bc045 100644 --- a/openfl/component/straggler_handling_functions/straggler_handling_function.py +++ b/openfl/component/straggler_handling_functions/straggler_handling_function.py @@ -7,15 +7,48 @@ from abc import ABC, abstractmethod -class StragglerHandlingFunction(ABC): +class StragglerHandlingPolicy(ABC): """Federated Learning straggler handling interface.""" @abstractmethod - def straggler_cutoff_check(self, **kwargs): - """Determines whether it is time to end the round early. + def start_policy(self, **kwargs) -> None: + """ + Start straggler handling policy for collaborator for a particular round. + NOTE: Refer CutoffTimeBasedStragglerHandling for reference. + + Args: + **kwargs + + Returns: + None + """ + raise NotImplementedError + + @abstractmethod + def reset_policy_for_round(self) -> None: + """ + Reset policy variable for the next round. + + Args: + None + + Returns: + None + """ + raise NotImplementedError + + @abstractmethod + def straggler_cutoff_check( + self, num_collaborators_done: int, num_all_collaborators: int, **kwargs + ) -> bool: + """ + Determines whether it is time to end the round early. Args: - **kwargs: Variable length argument list. + num_collaborators_done: int + Number of collaborators finished. + num_all_collaborators: int + Total number of collaborators. Returns: bool: True if it is time to end the round early, False otherwise. diff --git a/tests/openfl/component/aggregator/test_aggregator.py b/tests/openfl/component/aggregator/test_aggregator.py index 81e00e8074..f90b457925 100644 --- a/tests/openfl/component/aggregator/test_aggregator.py +++ b/tests/openfl/component/aggregator/test_aggregator.py @@ -171,80 +171,3 @@ def test_collaborator_task_completed_true(agg): col1, task_name, round_num) assert is_completed is True - - -def test_is_task_done_no_cols(agg): - """Test that is_task_done returns True without corresponded collaborators.""" - task_name = 'test_task_name' - agg.assigner.get_collaborators_for_task = mock.Mock(return_value=[]) - is_task_done = agg._is_task_done(task_name) - - assert is_task_done is True - - -def test_is_task_done_not_done(agg): - """Test that is_task_done returns False in the corresponded case.""" - task_name = 'test_task_name' - col1 = 'one' - col2 = 'two' - agg.assigner.get_collaborators_for_task = mock.Mock(return_value=[col1, col2]) - is_task_done = agg._is_task_done(task_name) - - assert is_task_done is False - - -def test_is_task_done_done(agg): - """Test that is_task_done returns True in the corresponded case.""" - round_num = 0 - task_name = 'test_task_name' - col1 = 'one' - col2 = 'two' - agg.assigner.get_collaborators_for_task = mock.Mock(return_value=[col1, col2]) - agg.collaborator_tasks_results = { - TaskResultKey(task_name, col1, round_num): 1, - TaskResultKey(task_name, col2, round_num): 1 - } - is_task_done = agg._is_task_done(task_name) - - assert is_task_done is True - - -def test_is_round_done_no_tasks(agg): - """Test that is_round_done returns True in the corresponded case.""" - agg.assigner.get_all_tasks_for_round = mock.Mock(return_value=[]) - is_round_done = agg._is_round_done() - - assert is_round_done is True - - -def test_is_round_done_not_done(agg): - """Test that is_round_done returns False in the corresponded case.""" - round_num = 0 - task_name = 'test_task_name' - col1 = 'one' - col2 = 'two' - agg.assigner.get_all_tasks_for_round = mock.Mock(return_value=[task_name]) - agg.assigner.get_collaborators_for_task = mock.Mock(return_value=[col1, col2]) - agg.collaborator_tasks_results = { - TaskResultKey(task_name, col1, round_num): 1, - } - is_round_done = agg._is_round_done() - - assert is_round_done is False - - -def test_is_round_done_done(agg): - """Test that is_round_done returns True in the corresponded case.""" - round_num = 0 - task_name = 'test_task_name' - col1 = 'one' - col2 = 'two' - agg.assigner.get_all_tasks_for_round = mock.Mock(return_value=[task_name]) - agg.assigner.get_collaborators_for_task = mock.Mock(return_value=[col1, col2]) - agg.collaborator_tasks_results = { - TaskResultKey(task_name, col1, round_num): 1, - TaskResultKey(task_name, col2, round_num): 1 - } - is_round_done = agg._is_round_done() - - assert is_round_done is True