Skip to content

Commit

Permalink
Rework BareConfig to allow for easier loading and parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed Mar 28, 2022
1 parent 16c8154 commit 667b2c0
Show file tree
Hide file tree
Showing 13 changed files with 134 additions and 234 deletions.
23 changes: 0 additions & 23 deletions configs/cloud_experiment.yaml

This file was deleted.

30 changes: 30 additions & 0 deletions configs/example_cloud_experiment.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
{
"cluster": {
"orchestrator": {
"wait_for_clients": true,
"service": "fl-server.test.svc.cluster.local",
"nic": "eth0"
},
"client": {
"prefix": "client",
"tensorboard_active": true
}
},
"execution_config": {
"experiment_prefix": "cloud_experiment",
"tensorboard_active": true,
"cuda": false,
"net": {
"save_model": false,
"save_temp_model": false,
"save_epoch_interval": 1,
"save_model_path": "models",
"epoch_save_start_suffix": "start",
"epoch_save_end_suffix": "end"
},
"reproducibility": {
"torch_seed": 42,
"arrival_seed": 123
}
}
}
112 changes: 37 additions & 75 deletions fltk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,13 @@
from multiprocessing.pool import ThreadPool
from pathlib import Path

import yaml
from dotenv import load_dotenv

from fltk.launch import run_single

from fltk.util.base_config import BareConfig
from fltk.util.config.base_config import BareConfig
from fltk.util.cluster.client import ClusterManager
from fltk.util.generator.arrival_generator import ExperimentGenerator
from fltk.util.task.generator.arrival_generator import ExperimentGenerator

logging.basicConfig(level=logging.INFO)

Expand All @@ -27,61 +26,27 @@ def main():
subparsers = parser.add_subparsers(dest="mode")

# Create single experiment parser
single_parser = subparsers.add_parser('single')
single_parser.add_argument('config', type=str)
single_parser.add_argument('--rank', type=int)
single_parser.add_argument('--nic', type=str, default=None)
single_parser.add_argument('--host', type=str, default=None)
add_default_arguments(single_parser)

# Create spawn parser
spawn_parser = subparsers.add_parser('spawn')
spawn_parser.add_argument('config', type=str)
add_default_arguments(spawn_parser)

# Create remote parser
remote_parser = subparsers.add_parser('remote')
remote_parser.add_argument('--rank', type=int)
remote_parser.add_argument('--nic', type=str, default=None)
remote_parser.add_argument('--host', type=str, default=None)
add_default_arguments(remote_parser)

# Create poisoned parser
poison_parser = subparsers.add_parser('poison')
poison_parser.add_argument('config', type=str)
poison_parser.add_argument('--rank', type=int)
poison_parser.add_argument('--nic', type=str, default=None)
poison_parser.add_argument('--host', type=str, default=None)
add_default_arguments(poison_parser)

poison_parser = subparsers.add_parser('cluster')
poison_parser.add_argument('config', type=str)
poison_parser.add_argument('--rank', type=int)
poison_parser.add_argument('--nic', type=str, default=None)
poison_parser.add_argument('--host', type=str, default=None)
add_default_arguments(poison_parser)

args = parser.parse_args()

if args.mode == 'cluster':
logging.info("[Fed] Starting in cluster mode.")
# TODO: Load configuration path
config_path: Path = None
cluster_manager = ClusterManager()
arrival_generator = ExperimentGenerator(config_path)

pool = ThreadPool(4)
pool.apply(cluster_manager.start)
pool.apply(arrival_generator.run)

pool.join()
else:
with open(args.config) as config_file:
cfg = BareConfig()
yaml_data = yaml.load(config_file, Loader=yaml.FullLoader)
cfg.merge_yaml(yaml_data)
if args.mode == 'poison':
perform_poison_experiment(args, cfg, parser, yaml_data)
cluster_parser = subparsers.add_parser('cluster')
cluster_parser.add_argument('config', type=str)
cluster_parser.add_argument('--rank', type=int)
cluster_parser.add_argument('--nic', type=str, default=None)
cluster_parser.add_argument('--host', type=str, default=None)
add_default_arguments(cluster_parser)

arguments = parser.parse_args()


with open(arguments.config) as config_file:
try:
config = BareConfig.from_json(config_file)
except Exception as e:
print("Cannot load provided configuration, exiting...")
exit(-1)

if arguments.mode == 'orchestrator':
start_clusterized(arguments, config)
elif arguments.mode == 'client':
run_single()


def perform_single_experiment(args, cfg, parser, yaml_data):
Expand All @@ -103,29 +68,26 @@ def perform_single_experiment(args, cfg, parser, yaml_data):
run_single(rank=args.rank, world_size=world_size, host=master_address, args=cfg, nic=nic)


def perform_poison_experiment(args, cfg, yaml_data):
def start_clusterized(args: dict, config: BareConfig):
"""
Function to start poisoned experiment.
"""
if args.rank is None:
print('Missing rank argument when in \'poison\' mode!')
exit(1)
if not yaml_data.get('poison'):
print(f'Missing poison configuration for \'poison\' mode')
exit(1)
logging.info("[Fed] Starting in cluster mode.")
# TODO: Load configuration path
config_path: Path = None
cluster_manager = ClusterManager()
arrival_generator = ExperimentGenerator(config_path)

pool = ThreadPool(4)
pool.apply(cluster_manager.start)
pool.apply(arrival_generator.run)

pool.join()


world_size = args.world_size
master_address = args.host
nic = args.nic

if not world_size:
world_size = yaml_data['system']['clients']['amount'] + 1
if not master_address:
master_address = yaml_data['system']['federator']['hostname']
if not nic:
nic = yaml_data['system']['federator']['nic']
print(f'rank={args.rank}, world_size={world_size}, host={master_address}, args=cfg, nic={nic}')
run_single(rank=args.rank, world_size=world_size, host=master_address, args=cfg, nic=nic)
run_single(rank=args.rank, args=config, nic=nic)


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion fltk/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from torch.distributed import rpc

from fltk.schedulers import MinCapableStepLR
from fltk.util.base_config import BareConfig
from fltk.util.config.base_config import BareConfig
from fltk.util.log import DistLearningLogger
from fltk.util.results import EpochData

Expand Down
20 changes: 15 additions & 5 deletions fltk/launch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,34 +16,44 @@ def run_ps(rpc_ids_triple, args):
fed = Orchestrator(rpc_ids_triple, config=args)
fed.run()

def await_assigned_orchestrator():
# TODO: Implement await function for client

"""
1. Setup everything correctly according to provided configuration files.
2. Register to cleint
3. Start working on task description provided by orchestrator
4. Send heartbeats? (Alternatively use Kubernetes for this)
5. Send completed data
6. Terminate/complete pod execution.
"""
pass
def run_single(rank, world_size, host = None, args = None, nic = None):
logging.info(f'Starting with rank={rank} and world size={world_size}')
prepare_environment(host, nic)

logging.info(f'Starting with host={os.environ["MASTER_ADDR"]} and port={os.environ["MASTER_PORT"]}')
options = rpc.TensorPipeRpcBackendOptions(
num_worker_threads=20, # TODO: Retrieve number of cores from system
rpc_timeout=0, # infinite timeout
num_worker_threads=20,
rpc_timeout=0,
init_method=f'tcp://{os.environ["MASTER_ADDR"]}:{os.environ["MASTER_PORT"]}'
)

if rank != 0:

logging.info(f'Starting worker {rank}')
rpc.init_rpc(
f"client{rank}",
rank=rank,
world_size=world_size,
rpc_backend_options=options,
)
# trainer passively waiting for ps to kick off training iterations
else:
logging.info('Starting the ps')
rpc.init_rpc(
"ps",
rank=rank,
world_size=world_size,
rpc_backend_options=options

)
run_ps([(f"client{r}", r, world_size) for r in range(1, world_size)], args)

Expand Down
2 changes: 1 addition & 1 deletion fltk/nets/util/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch
from torch.utils.tensorboard import SummaryWriter

from fltk.util.base_config import BareConfig
from fltk.util.config.base_config import BareConfig
from fltk.util.results import EpochData


Expand Down
4 changes: 2 additions & 2 deletions fltk/orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@

from fltk.client import Client
from fltk.nets.util.utils import flatten_params, save_model
from fltk.util.base_config import BareConfig
from fltk.util.config.base_config import BareConfig
from fltk.util.cluster.client import ClientRef
from fltk.util.generator.arrival_generator import ArrivalGenerator
from fltk.util.task.generator.arrival_generator import ArrivalGenerator
from fltk.util.log import DistLearningLogger
from fltk.util.results import EpochData

Expand Down
Loading

0 comments on commit 667b2c0

Please sign in to comment.