Skip to content

Commit

Permalink
Refactor config and create json dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
JMGaljaard committed Mar 28, 2022
1 parent b2634b6 commit 16c8154
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 143 deletions.
41 changes: 20 additions & 21 deletions configs/cloud_experiment.yaml
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@
# Experiment configuration
total_epochs: 130
epochs_per_cycle: 1
wait_for_clients: true
net: Cifar10CNN
dataset: cifar10
sampler: "uniform"
sampler_args:
- 0.5 # p degree
- 42 # random seed
# Use cuda is available; setting to false will force CPU
cuda: false
experiment_prefix: 'experiment_single_machine'
output_location: 'output'
tensor_board_active: true
clients_per_round: 50
system:
federator:
cluster:
orchestrator:
wait_for_clients: true
# Use the SERVICE provided by the fl-server to connect
hostname: 'fl-server.test.svc.cluster.local'
service: 'fl-server.test.svc.cluster.local'
# Default NIC is eth0
nic: 'eth0'
clients:
amount: 50
worker:
prefix: 'client'
execution_config:
experiment_prefix: 'cloud_experiment'
tensor_board_active: true
cuda: false
net:
save_model: false
save_temp_model: false
save_epoch_interval: 1
save_model_path: "models"
epoch_save_start_suffix: "start"
epoch_save_end_suffix: "end"
reproducability:
torch_seed: 42
arrival_seed: 123
15 changes: 0 additions & 15 deletions configs/experiment.yaml

This file was deleted.

38 changes: 0 additions & 38 deletions configs/local_experiment.yaml

This file was deleted.

23 changes: 0 additions & 23 deletions configs/non_iid_experiment.yaml

This file was deleted.

88 changes: 42 additions & 46 deletions fltk/util/base_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from dataclasses import dataclass
from typing import Dict

import torch
from dataclasses_json import dataclass_json
Expand All @@ -9,10 +8,10 @@
# SEED = 1
# torch.manual_seed(SEED)


@dataclass
@dataclass_json
class ExecutionConfig():
cuda: bool = False
class GeneralNetConfig:
save_model: bool = False
save_temp_model: bool = False
save_epoch_interval: int = 1
Expand All @@ -21,14 +20,52 @@ class ExecutionConfig():
epoch_save_end_suffix = "end"


@dataclass(frozen=True)
@dataclass_json
class ReproducabilityConfig:
torch_seed: int
arrival_seed: int


@dataclass
@dataclass_json
class ExecutionConfig:
general_net: GeneralNetConfig
reproducability: ReproducabilityConfig
experiment_prefix: str = "experiment"
tensorboard_active: str = True
cuda: bool = False


@dataclass
@dataclass_json
class OrchestratorConfig:
service: str
nic: str


@dataclass
@dataclass_json
class ClientConfig:
prefix: str


@dataclass
@dataclass_json
class ClusterConfig:
orchestrator: OrchestratorConfig
client: ClientConfig
wait_for_clients: bool = True


@dataclass
@dataclass_json
class BareConfig(object):
# Configuration parameters for PyTorch and models that are generated.
execution_config = ExecutionConfig()
execution_config: ExecutionConfig
cluster_config: ClusterConfig

def __init__(self):

# TODO: Move to external class/object
self.train_data_loader_pickle_path = {
'cifar10': 'data_loaders/cifar10/train_data_loader.pickle',
Expand All @@ -48,47 +85,6 @@ def __init__(self):
self.default_model_folder_path = "default_models"
self.data_path = "data"

def merge_yaml(self, cfg: Dict[str, str] = {}):
"""
@deprecated This function will become redundant after using dataclasses_json to load the values into the object.
"""
if 'total_epochs' in cfg:
self.epochs = cfg['total_epochs']
if 'epochs_per_cycle' in cfg:
self.epochs_per_cycle = cfg['epochs_per_cycle']
if 'wait_for_clients' in cfg:
self.wait_for_clients = cfg['wait_for_clients']
if 'net' in cfg:
self.set_net_by_name(cfg['net'])
if 'dataset' in cfg:
self.dataset_name = cfg['dataset']
if 'experiment_prefix' in cfg:
self.experiment_prefix = cfg['experiment_prefix']
if 'output_location' in cfg:
self.output_location = cfg['output_location']
if 'tensor_board_active' in cfg:
self.tensor_board_active = cfg['tensor_board_active']
if 'clients_per_round' in cfg:
self.clients_per_round = cfg['clients_per_round']
if 'system' in cfg:
if 'clients' in cfg['system']:
if 'amount' in cfg['system']['clients']:
self.world_size = cfg['system']['clients']['amount'] + 1

if 'system' in cfg:
if 'federator' in cfg['system']:
if 'hostname' in cfg['system']['federator']:
self.federator_host = cfg['system']['federator']['hostname']
if 'cuda' in cfg:
if cfg['cuda']:
self.cuda = True
else:
self.cuda = False
if 'sampler' in cfg:
self.data_sampler = cfg['sampler']
if 'sampler_args' in cfg:
self.data_sampler_args = cfg['sampler_args']

def get_dataloader_list(self):
"""
@deprecated
Expand Down

0 comments on commit 16c8154

Please sign in to comment.