Skip to content

Commit

Permalink
- add mode option to aggregator cli, default to "train_and_validate"
Browse files Browse the repository at this point in the history
- enhance Aggregator to take mode attribute to enable fedeval or training switching at aggregator level
- rebase 10.Jan.2
- fixed formatting issues
Signed-off-by: Shailesh Pant <[email protected]>
  • Loading branch information
ishaileshpant committed Jan 10, 2025
1 parent dfe9512 commit 368b451
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 10 deletions.
15 changes: 13 additions & 2 deletions openfl/component/aggregator/aggregator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


"""Aggregator module."""

import logging
Expand All @@ -20,6 +19,8 @@

logger = logging.getLogger(__name__)

VALID_MODES = {"train_and_validate", "evaluate"}


class Aggregator:
"""An Aggregator is the central node in federated learning.
Expand Down Expand Up @@ -82,6 +83,7 @@ def __init__(
log_memory_usage=False,
write_logs=False,
callbacks: Optional[List] = None,
mode: str = "train_and_validate",
):
"""Initializes the Aggregator.
Expand All @@ -108,7 +110,12 @@ def __init__(
Defaults to 1.
initial_tensor_dict (dict, optional): Initial tensor dictionary.
callbacks: List of callbacks to be used during the experiment.
mode (str, optional): Operation mode. Can be 'train_and_validate',
'evaluate'. Defaults to 'train_and_validate'.
"""
if mode not in VALID_MODES:
raise ValueError(f"Mode must be one of {VALID_MODES}, got {mode}")
self.mode = mode
self.round_number = 0

if single_col_cert_common_name:
Expand Down Expand Up @@ -208,9 +215,13 @@ def _load_initial_tensors(self):
self.model, compression_pipeline=self.compression_pipeline
)

if round_number > self.round_number:
# Check mode before updating round number
if self.mode == "evaluate":
logger.info(f"Skipping round_number check for mode {self.mode}")
elif round_number > self.round_number:
logger.info(f"Starting training from round {round_number} of previously saved model")
self.round_number = round_number

tensor_key_dict = {
TensorKey(k, self.uuid, self.round_number, False, ("model",)): v
for k, v in tensor_dict.items()
Expand Down
51 changes: 43 additions & 8 deletions openfl/interface/aggregator.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,34 @@
# Copyright 2020-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0


"""Aggregator module."""

import sys
from logging import getLogger
from pathlib import Path

from click import Path as ClickPath
from click import confirm, echo, group, option, pass_context, style
from click import (
Choice,
confirm,
echo,
group,
option,
pass_context,
style,
)
from click import (
Path as ClickPath,
)

from openfl.cryptography.ca import sign_certificate
from openfl.cryptography.io import get_csr_hash, read_crt, read_csr, read_key, write_crt, write_key
from openfl.cryptography.io import (
get_csr_hash,
read_crt,
read_csr,
read_key,
write_crt,
write_key,
)
from openfl.cryptography.participant import generate_csr
from openfl.federated import Plan
from openfl.interface.cli_helper import CERT_DIR
Expand Down Expand Up @@ -52,24 +68,43 @@ def aggregator(context):
default="plan/cols.yaml",
type=ClickPath(exists=True),
)
def start_(plan, authorized_cols):
"""Start the aggregator service."""
@option(
"-m",
"--mode",
type=Choice(["train_and_validate", "evaluate"]),
default="train_and_validate",
help="Operation mode - either train_and_validate or evaluate",
)
def start_(plan, authorized_cols, mode):
"""Start the aggregator service.
Args:
plan (str): Path to plan config file
authorized_cols (str): Path to authorized collaborators file
mode (str): Operation mode - either train_and_validate or evaluate
"""
if is_directory_traversal(plan):
echo("Federated learning plan path is out of the openfl workspace scope.")
sys.exit(1)
if is_directory_traversal(authorized_cols):
echo("Authorized collaborator list file path is out of the openfl workspace scope.")
sys.exit(1)

plan = Plan.parse(
# Parse plan and override mode if specified
parsed_plan = Plan.parse(
plan_config_path=Path(plan).absolute(),
cols_config_path=Path(authorized_cols).absolute(),
)

# Set mode in aggregator settings
if "settings" not in parsed_plan.config["aggregator"]:
parsed_plan.config["aggregator"]["settings"] = {}
parsed_plan.config["aggregator"]["settings"]["mode"] = mode
logger.info(f"Setting aggregator mode to: {mode}")

logger.info("🧿 Starting the Aggregator Service.")

plan.get_server().serve()
parsed_plan.get_server().serve()


@aggregator.command(name="generate-cert-request")
Expand Down

0 comments on commit 368b451

Please sign in to comment.