From 6ecc5cc73948f72d6e12d00315f6a423f5168f94 Mon Sep 17 00:00:00 2001 From: JMGaljaard Date: Sun, 18 Sep 2022 11:10:10 +0200 Subject: [PATCH] Prepare arguments to be compatible with Orchestrator implementation --- fltk/launch.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/fltk/launch.py b/fltk/launch.py index a0183b47..a64632ad 100644 --- a/fltk/launch.py +++ b/fltk/launch.py @@ -64,20 +64,20 @@ def exec_distributed_client(task_id: str, conf: DistributedConfig = None, print(epoch_data) -def get_arrival_generator_args(conf: DistributedConfig, replication: int) -> List[Any]: +def get_arrival_generator_args(conf: DistributedConfig, replication: int) -> (List[Any], Dict[str, Any]): """ Function to get arrival generator arguments based on current configuration of Orchestrator. - @param conf: - @type conf: - @param replication: - @type replication: - @return: Configuration for args and kwargs for the generator. + @param conf: Configuration for distributed/federated learning experiments. + @type conf: DistributedConfig + @param replication: Replication index for experiment to retrieve arguemnts for. + @type replication: int + @return: Configuration for args and keyword args (kwd) for the generator. @rtype: Tuple[List[Any], Dict[str, Any]] """ - args = [conf.get_duration()] + args, kwd_args = [conf.get_duration()], {} if conf.cluster_config.orchestrator.orchestrator_type == OrchestratorType.BATCH: - args.append(conf.execution_config.reproducibility.seeds[replication]) - return args + kwd_args['seed'] = conf.execution_config.reproducibility.seeds[replication] + return args, kwd_args def exec_orchestrator(args: Namespace = None, conf: DistributedConfig = None, replication: int = 0): @@ -116,8 +116,8 @@ def exec_orchestrator(args: Namespace = None, conf: DistributedConfig = None, re pool.apply(cluster_manager.start) logging.info("Starting arrival generator") - arv_gen_args = get_arrival_generator_args(conf, replication) - pool.apply_async(arrival_generator.start, args=arv_gen_args) + arv_gen_args, kwd_args = get_arrival_generator_args(conf, replication) + pool.apply_async(arrival_generator.start, args=arv_gen_args, kwds=kwd_args) logging.info("Starting orchestrator") pool.apply(orchestrator.run, kwds={"experiment_replication": replication})