diff --git a/fltk/launch.py b/fltk/launch.py index df90ec77..08e85bc9 100644 --- a/fltk/launch.py +++ b/fltk/launch.py @@ -1,7 +1,7 @@ # pylint: disable=unused-argument import logging import os -from typing import Callable, Optional, NewType +from typing import Callable, Optional, NewType, Any, List, Tuple, Dict from argparse import Namespace from multiprocessing.pool import ThreadPool @@ -11,7 +11,7 @@ from kubernetes import config from torch.distributed import rpc from fltk.core.distributed import DistClient, download_datasets -from fltk.util.config.definitions.orchestrator import get_orchestrator, get_arrival_generator +from fltk.util.config.definitions.orchestrator import get_orchestrator, get_arrival_generator, OrchestratorType from fltk.core import Client, Federator from fltk.nets.util.reproducability import init_reproducibility, init_learning_reproducibility from fltk.util.cluster.client import ClusterManager @@ -64,7 +64,23 @@ def exec_distributed_client(task_id: str, conf: DistributedConfig = None, print(epoch_data) -def exec_orchestrator(args: Namespace = None, conf: DistributedConfig = None, replication: int = 1): +def get_arrival_generator_args(conf: DistributedConfig, replication: int) -> Tuple[List[Any], Dict[str, Any]]: + """ + Function to get arrival generator arguments based on current configuration of Orchestrator. + @param conf: + @type conf: + @param replication: + @type replication: + @return: Configuration for args and kwargs for the generator. + @rtype: Tuple[List[Any], Dict[str, Any]] + """ + args, kwargs = [conf.get_duration()], {} + if conf.cluster_config.orchestrator.orchestrator_type == OrchestratorType.BATCH: + kwargs['seed'] = conf.execution_config.reproducibility.seeds[replication] + return args, kwargs + + +def exec_orchestrator(args: Namespace = None, conf: DistributedConfig = None, replication: int = 0): """ Default runner for the Orchestrator that is based on KubeFlow @param args: Commandline arguments passed to the execution. Might be removed in a future commit. @@ -72,7 +88,7 @@ def exec_orchestrator(args: Namespace = None, conf: DistributedConfig = None, re @param conf: Configuration for execution of Orchestrators components, needed for spinning up components of the Orchestrator. @type conf: Optional[DistributedConfig] - @param replication: Replication index of the experiment. + @param replication: Replication index of the experiment, zero indexed. @type replication: int @return: None @rtype: None @@ -100,7 +116,8 @@ def exec_orchestrator(args: Namespace = None, conf: DistributedConfig = None, re pool.apply(cluster_manager.start) logging.info("Starting arrival generator") - pool.apply_async(arrival_generator.start, args=[conf.get_duration()]) + arv_gen_args, arv_gen_kwargs = get_arrival_generator_args(conf, replication) + pool.apply_async(arrival_generator.start, args=arv_gen_args, kwds=arv_gen_kwargs) logging.info("Starting orchestrator") pool.apply(orchestrator.run, kwds={"experiment_replication": replication}) @@ -323,7 +340,7 @@ def launch_cluster(arg_path: Path, conf_path: Path, rank: Rank, nic: Optional[NI try: logging.info(f"Starting with experiment replication: {replication} with seed: {experiment_seed}") init_reproducibility(conf.execution_config) - exec_orchestrator(args=args, conf=conf, replication=replication+1) + exec_orchestrator(args=args, conf=conf, replication=replication) except Exception as e: logging.info(f"Execution of replication {replication} with seed {experiment_seed} failed." f"Reason: {e}")