Skip to content

Commit

Permalink
Update launch to reflect moving towards repeatable batch experiments
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed Sep 16, 2022
1 parent 5c61dc7 commit 462f7c1
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions fltk/launch.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# pylint: disable=unused-argument
import logging
import os
from typing import Callable, Optional, NewType
from typing import Callable, Optional, NewType, Any, List, Tuple, Dict

from argparse import Namespace
from multiprocessing.pool import ThreadPool
Expand All @@ -11,7 +11,7 @@
from kubernetes import config
from torch.distributed import rpc
from fltk.core.distributed import DistClient, download_datasets
from fltk.util.config.definitions.orchestrator import get_orchestrator, get_arrival_generator
from fltk.util.config.definitions.orchestrator import get_orchestrator, get_arrival_generator, OrchestratorType
from fltk.core import Client, Federator
from fltk.nets.util.reproducability import init_reproducibility, init_learning_reproducibility
from fltk.util.cluster.client import ClusterManager
Expand Down Expand Up @@ -64,15 +64,31 @@ def exec_distributed_client(task_id: str, conf: DistributedConfig = None,
print(epoch_data)


def exec_orchestrator(args: Namespace = None, conf: DistributedConfig = None, replication: int = 1):
def get_arrival_generator_args(conf: DistributedConfig, replication: int) -> Tuple[List[Any], Dict[str, Any]]:
"""
Function to get arrival generator arguments based on current configuration of Orchestrator.
@param conf:
@type conf:
@param replication:
@type replication:
@return: Configuration for args and kwargs for the generator.
@rtype: Tuple[List[Any], Dict[str, Any]]
"""
args, kwargs = [conf.get_duration()], {}
if conf.cluster_config.orchestrator.orchestrator_type == OrchestratorType.BATCH:
kwargs['seed'] = conf.execution_config.reproducibility.seeds[replication]
return args, kwargs


def exec_orchestrator(args: Namespace = None, conf: DistributedConfig = None, replication: int = 0):
"""
Default runner for the Orchestrator that is based on KubeFlow
@param args: Commandline arguments passed to the execution. Might be removed in a future commit.
@type args: Namespace
@param conf: Configuration for execution of Orchestrators components, needed for spinning up components of the
Orchestrator.
@type conf: Optional[DistributedConfig]
@param replication: Replication index of the experiment.
@param replication: Replication index of the experiment, zero indexed.
@type replication: int
@return: None
@rtype: None
Expand Down Expand Up @@ -100,7 +116,8 @@ def exec_orchestrator(args: Namespace = None, conf: DistributedConfig = None, re
pool.apply(cluster_manager.start)

logging.info("Starting arrival generator")
pool.apply_async(arrival_generator.start, args=[conf.get_duration()])
arv_gen_args, arv_gen_kwargs = get_arrival_generator_args(conf, replication)
pool.apply_async(arrival_generator.start, args=arv_gen_args, kwds=arv_gen_kwargs)
logging.info("Starting orchestrator")
pool.apply(orchestrator.run, kwds={"experiment_replication": replication})

Expand Down Expand Up @@ -323,7 +340,7 @@ def launch_cluster(arg_path: Path, conf_path: Path, rank: Rank, nic: Optional[NI
try:
logging.info(f"Starting with experiment replication: {replication} with seed: {experiment_seed}")
init_reproducibility(conf.execution_config)
exec_orchestrator(args=args, conf=conf, replication=replication+1)
exec_orchestrator(args=args, conf=conf, replication=replication)
except Exception as e:
logging.info(f"Execution of replication {replication} with seed {experiment_seed} failed."
f"Reason: {e}")

0 comments on commit 462f7c1

Please sign in to comment.