diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index 8a9cd757e8..b045133769 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -84,7 +84,6 @@ def __init__( callbacks: Optional[List] = None, persist_checkpoint=True, persistent_db_path=None, - task_group: str = "learning", ): """Initializes the Aggregator. @@ -111,9 +110,7 @@ def __init__( Defaults to 1. initial_tensor_dict (dict, optional): Initial tensor dictionary. callbacks: List of callbacks to be used during the experiment. - task_group (str, optional): Selected task_group for assignment. """ - self.task_group = task_group self.round_number = 0 self.next_model_round_number = 0 @@ -132,10 +129,10 @@ def __init__( ) self.rounds_to_train = rounds_to_train - if self.task_group == "evaluation": + if self.assigner.is_task_group_evaluation(): self.rounds_to_train = 1 logger.info( - f"task_group is {self.task_group}, setting rounds_to_train = {self.rounds_to_train}" + f"For evaluation tasks setting rounds_to_train = {self.rounds_to_train}" ) self._end_of_round_check_done = [False] * rounds_to_train @@ -311,8 +308,8 @@ def _load_initial_tensors(self): ) # Check selected task_group before updating round number - if self.task_group == "evaluation": - logger.info(f"Skipping round_number check for {self.task_group} task_group") + if self.assigner.is_task_group_evaluation(): + logger.info("Skipping round_number check for evaluation run") elif round_number > self.round_number: logger.info(f"Starting training from round {round_number} of previously saved model") self.round_number = round_number diff --git a/openfl/component/assigner/assigner.py b/openfl/component/assigner/assigner.py index 0b5fc36e88..2b7ad432a7 100644 --- a/openfl/component/assigner/assigner.py +++ b/openfl/component/assigner/assigner.py @@ -81,6 +81,16 @@ def get_collaborators_for_task(self, task_name, round_number): """Abstract method.""" raise NotImplementedError + def is_task_group_evaluation(self): + """Check if the selected task group is for 'evaluation' run. + + Returns: + bool: True if the selected task group is 'evaluation', False otherwise. + """ + if hasattr(self, "selected_task_group"): + return self.selected_task_group == 'evaluation' + return False + def get_all_tasks_for_round(self, round_number): """Return tasks for the current round. diff --git a/openfl/interface/aggregator.py b/openfl/interface/aggregator.py index 8c922eee19..8f4a91ee8a 100644 --- a/openfl/interface/aggregator.py +++ b/openfl/interface/aggregator.py @@ -97,7 +97,6 @@ def start_(plan, authorized_cols, task_group): # Set task_group in aggregator settings if "settings" not in parsed_plan.config["aggregator"]: parsed_plan.config["aggregator"]["settings"] = {} - parsed_plan.config["aggregator"]["settings"]["task_group"] = task_group parsed_plan.config["assigner"]["settings"]["selected_task_group"] = task_group logger.info(f"Setting aggregator to assign: {task_group} task_group")