diff --git a/openfl/component/aggregator/aggregator.py b/openfl/component/aggregator/aggregator.py index eaac9fa6a0..c64ae4a5f7 100644 --- a/openfl/component/aggregator/aggregator.py +++ b/openfl/component/aggregator/aggregator.py @@ -1,7 +1,6 @@ # Copyright 2020-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Aggregator module.""" import logging @@ -20,6 +19,8 @@ logger = logging.getLogger(__name__) +VALID_MODES = {"train_and_validate", "evaluate"} + class Aggregator: """An Aggregator is the central node in federated learning. @@ -82,6 +83,7 @@ def __init__( log_memory_usage=False, write_logs=False, callbacks: Optional[List] = None, + mode: str = "train_and_validate", ): """Initializes the Aggregator. @@ -108,7 +110,12 @@ def __init__( Defaults to 1. initial_tensor_dict (dict, optional): Initial tensor dictionary. callbacks: List of callbacks to be used during the experiment. + mode (str, optional): Operation mode. Can be 'train_and_validate', + 'evaluate'. Defaults to 'train_and_validate'. """ + if mode not in VALID_MODES: + raise ValueError(f"Mode must be one of {VALID_MODES}, got {mode}") + self.mode = mode self.round_number = 0 if single_col_cert_common_name: @@ -208,9 +215,13 @@ def _load_initial_tensors(self): self.model, compression_pipeline=self.compression_pipeline ) - if round_number > self.round_number: + # Check mode before updating round number + if self.mode == "evaluate": + logger.info(f"Skipping round_number check for mode {self.mode}") + elif round_number > self.round_number: logger.info(f"Starting training from round {round_number} of previously saved model") self.round_number = round_number + tensor_key_dict = { TensorKey(k, self.uuid, self.round_number, False, ("model",)): v for k, v in tensor_dict.items() diff --git a/openfl/interface/aggregator.py b/openfl/interface/aggregator.py index 80ce56e32e..fd9246ad61 100644 --- a/openfl/interface/aggregator.py +++ b/openfl/interface/aggregator.py @@ -1,18 +1,34 @@ # Copyright 2020-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 - """Aggregator module.""" import sys from logging import getLogger from pathlib import Path -from click import Path as ClickPath -from click import confirm, echo, group, option, pass_context, style +from click import ( + Choice, + confirm, + echo, + group, + option, + pass_context, + style, +) +from click import ( + Path as ClickPath, +) from openfl.cryptography.ca import sign_certificate -from openfl.cryptography.io import get_csr_hash, read_crt, read_csr, read_key, write_crt, write_key +from openfl.cryptography.io import ( + get_csr_hash, + read_crt, + read_csr, + read_key, + write_crt, + write_key, +) from openfl.cryptography.participant import generate_csr from openfl.federated import Plan from openfl.interface.cli_helper import CERT_DIR @@ -52,9 +68,21 @@ def aggregator(context): default="plan/cols.yaml", type=ClickPath(exists=True), ) -def start_(plan, authorized_cols): - """Start the aggregator service.""" +@option( + "-m", + "--mode", + type=Choice(["train_and_validate", "evaluate"]), + default="train_and_validate", + help="Operation mode - either train_and_validate or evaluate", +) +def start_(plan, authorized_cols, mode): + """Start the aggregator service. + Args: + plan (str): Path to plan config file + authorized_cols (str): Path to authorized collaborators file + mode (str): Operation mode - either train_and_validate or evaluate + """ if is_directory_traversal(plan): echo("Federated learning plan path is out of the openfl workspace scope.") sys.exit(1) @@ -62,14 +90,21 @@ def start_(plan, authorized_cols): echo("Authorized collaborator list file path is out of the openfl workspace scope.") sys.exit(1) - plan = Plan.parse( + # Parse plan and override mode if specified + parsed_plan = Plan.parse( plan_config_path=Path(plan).absolute(), cols_config_path=Path(authorized_cols).absolute(), ) + # Set mode in aggregator settings + if "settings" not in parsed_plan.config["aggregator"]: + parsed_plan.config["aggregator"]["settings"] = {} + parsed_plan.config["aggregator"]["settings"]["mode"] = mode + logger.info(f"Setting aggregator mode to: {mode}") + logger.info("🧿 Starting the Aggregator Service.") - plan.get_server().serve() + parsed_plan.get_server().serve() @aggregator.command(name="generate-cert-request")