From 16a4e91e8eb39d7c96e5cac65a5c6355830cb4a8 Mon Sep 17 00:00:00 2001 From: bacox <bartcox93@gmail.com> Date: Wed, 12 Jan 2022 10:24:13 +0100 Subject: [PATCH] Enable offloading --- Dockerfile | 2 +- configs/experiment.yaml | 9 +- configs/experiment_vanilla.yaml | 10 +- deploy/templates/client_stub_default.yml | 2 +- deploy/templates/client_stub_medium.yml | 2 +- fltk/client.py | 185 +++++++++++++++-------- fltk/federator.py | 100 +++++++++--- fltk/strategy/aggregation.py | 17 +++ fltk/strategy/offloading.py | 22 +++ fltk/util/base_config.py | 10 ++ fltk/util/generate_docker_compose.py | 35 ++++- fltk/util/results.py | 1 + requirements.txt | 4 +- 13 files changed, 301 insertions(+), 98 deletions(-) create mode 100644 fltk/strategy/offloading.py diff --git a/Dockerfile b/Dockerfile index 8ad4937b..006c97d0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -45,5 +45,5 @@ EXPOSE 5000 COPY fltk ./fltk COPY configs ./configs #CMD python3 ./fltk/__main__.py single configs/experiment.yaml --rank=$RANK -CMD python3 -m fltk single configs/experiment.yaml --rank=$RANK +CMD python3 -m fltk single configs/experiment_vanilla.yaml --rank=$RANK #CMD python3 setup.py \ No newline at end of file diff --git a/configs/experiment.yaml b/configs/experiment.yaml index c8e30bce..62ee3a93 100644 --- a/configs/experiment.yaml +++ b/configs/experiment.yaml @@ -1,6 +1,6 @@ --- # Experiment configuration -total_epochs: 4 +total_epochs: 30 epochs_per_cycle: 1 wait_for_clients: true net: Cifar10CNN @@ -8,11 +8,14 @@ dataset: cifar10 # Use cuda is available; setting to false will force CPU cuda: false experiment_prefix: 'experiment_sample' +offload_stategy: vanilla +profiling_time: 100 +deadline: 500 output_location: 'output' tensor_board_active: true clients_per_round: 2 -# sampler: "dirichlet" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (default) -sampler: "uniform" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (default) + sampler: "dirichlet" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (default) +#sampler: "uniform" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (default) sampler_args: - 0.07 # label limit || q probability || alpha || unused - 42 # random seed || random seed || random seed || unused diff --git a/configs/experiment_vanilla.yaml b/configs/experiment_vanilla.yaml index 90fcb77b..a8c10a79 100644 --- a/configs/experiment_vanilla.yaml +++ b/configs/experiment_vanilla.yaml @@ -1,19 +1,21 @@ --- # Experiment configuration -total_epochs: 4 +total_epochs: 20 epochs_per_cycle: 1 wait_for_clients: true net: Cifar10CNN dataset: cifar10 # Use cuda is available; setting to false will force CPU cuda: false -experiment_prefix: 'offloading_vanilla' +experiment_prefix: 'exp_offload_vanilla' offload_stategy: vanilla +profiling_time: 100 +deadline: 500 output_location: 'output' tensor_board_active: true clients_per_round: 2 -# sampler: "dirichlet" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (default) -sampler: "uniform" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (default) +sampler: "dirichlet" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (default) +#sampler: "uniform" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (default) sampler_args: - 0.07 # label limit || q probability || alpha || unused - 42 # random seed || random seed || random seed || unused diff --git a/deploy/templates/client_stub_default.yml b/deploy/templates/client_stub_default.yml index d8955310..838cf699 100644 --- a/deploy/templates/client_stub_default.yml +++ b/deploy/templates/client_stub_default.yml @@ -20,4 +20,4 @@ client_name: # name can be anything resources: limits: cpus: '2' - memory: 1024M +# memory: 1024M diff --git a/deploy/templates/client_stub_medium.yml b/deploy/templates/client_stub_medium.yml index 8f07f46b..6037ce44 100644 --- a/deploy/templates/client_stub_medium.yml +++ b/deploy/templates/client_stub_medium.yml @@ -19,5 +19,5 @@ client_name: # name can be anything deploy: resources: limits: - cpus: '0.75' + cpus: '1' memory: 1024M diff --git a/fltk/client.py b/fltk/client.py index f841a332..7c5fa710 100644 --- a/fltk/client.py +++ b/fltk/client.py @@ -15,6 +15,7 @@ from torch.distributed.rpc import RRef from fltk.schedulers import MinCapableStepLR +from fltk.strategy.offloading import OffloadingStrategy from fltk.util.arguments import Arguments from fltk.util.fed_avg import average_nn_parameters from fltk.util.log import FLLogger @@ -68,6 +69,8 @@ class Client: call_to_offload = False client_to_offload_to : str = None + strategy = OffloadingStrategy.VANILLA + def __init__(self, id, log_rref, rank, world_size, config = None): logging.info(f'Welcome to client {id}') @@ -92,6 +95,43 @@ def __init__(self, id, log_rref, rank, world_size, config = None): self.args.get_scheduler_step_size(), self.args.get_scheduler_gamma(), self.args.get_min_lr()) + self.strategy = OffloadingStrategy.Parse(config.offload_strategy) + self.configure_strategy(self.strategy) + + + def configure_strategy(self, strategy : OffloadingStrategy): + if strategy == OffloadingStrategy.VANILLA: + logging.info('Running with offloading strategy: VANILLA') + self.deadline_enabled = False + self.swyh_enabled = False + self.freeze_layers_enabled = False + self.offload_enabled = False + if strategy == OffloadingStrategy.DEADLINE: + logging.info('Running with offloading strategy: DEADLINE') + self.deadline_enabled = True + self.swyh_enabled = False + self.freeze_layers_enabled = False + self.offload_enabled = False + if strategy == OffloadingStrategy.SWYH: + logging.info('Running with offloading strategy: SWYH') + self.deadline_enabled = True + self.swyh_enabled = True + self.freeze_layers_enabled = False + self.offload_enabled = False + if strategy == OffloadingStrategy.FREEZE: + logging.info('Running with offloading strategy: FREEZE') + self.deadline_enabled = True + self.swyh_enabled = False + self.freeze_layers_enabled = True + self.offload_enabled = False + if strategy == OffloadingStrategy.MODEL_OFFLOAD: + logging.info('Running with offloading strategy: MODEL_OFFLOAD') + self.deadline_enabled = True + self.swyh_enabled = False + self.freeze_layers_enabled = True + self.offload_enabled = True + logging.info(f'Offload strategy params: deadline={self.deadline_enabled}, swyh={self.swyh_enabled}, freeze={self.freeze_layers_enabled}, offload={self.offload_enabled}') + def init_device(self): if self.args.cuda and torch.cuda.is_available(): @@ -254,11 +294,36 @@ def unfreeze_layers(self): for param in self.net.parameters(): param.requires_grad = True - def train(self, epoch, deadline_time: int = None): + def train(self, epoch, deadline: int = None): """ + + Different modes: + 1. Vanilla + 2. Deadline + 3. SWYH + 4. Just Freeze + 5. Model Offload + + + :: Vanilla + Disable deadline + Disable swyh + Disable offload + + :: Deadline + We need to keep track of the incoming deadline + We don't need to send data before the deadline + :param epoch: Current epoch # :type epoch: int """ + start_time = time.time() + deadline_threshold = 5 + train_stop_time = None + if self.deadline_enabled and deadline is not None: + train_stop_time = start_time + deadline - deadline_threshold + + strategy = OffloadingStrategy.VANILLA # Ignore profiler for now # p = Profiler() @@ -266,7 +331,7 @@ def train(self, epoch, deadline_time: int = None): # self.net.train() global global_model_weights, global_offload_received - deadline_time = None + # deadline_time = None # save model if self.args.should_save_model(epoch): self.save_model(epoch, self.args.get_epoch_save_start_suffix()) @@ -281,65 +346,58 @@ def train(self, epoch, deadline_time: int = None): # performance_metric_interval = 20 # perf_resp = None - profiling_size = 40 + # Profiling parameters + profiling_size = self.args.profiling_size profiling_data = np.zeros(profiling_size) active_profiling = True control_start_time = time.time() + training_process = 0 for i, (inputs, labels) in enumerate(self.dataset.get_train_loader(), 0): start_train_time = time.time() - # Check if there is a call to offload - if self.call_to_offload: - self.args.get_logger().info('Got call to offload model') - model_weights = self.get_nn_parameters() - # print(self.client_to_offload_to) - # r_ref = rpc.remote(self.client_to_offload_to, Client.static_ping, args=()) - # print(f'Result of rref: {r_ref.to_here()}') - # ret = rpc.rpc_sync(self.client_to_offload_to, Client.static_ping, args=()) - # print(f'Result of rref: {ret}') - # ret = rpc.rpc_sync(self.client_to_offload_to, Client.offload_receive_endpoint_2, args=(["Hello"])) - # print(f'Result of rref: {ret}') - - ret = rpc.rpc_sync(self.client_to_offload_to, Client.offload_receive_endpoint, args=([model_weights])) - print(f'Result of rref: {ret}') - - # r_ref = rpc.remote(self.client_to_offload_to, Client.static_ping, args=()) - # r_ref = rpc.remote(self.client_to_offload_to, Client.offload_receive_endpoint_2, args=("Hello world")) - # _remote_method_async(Client.static_ping, self.client_to_offload_to) - # fut1 = rpc.rpc_async(self.client_to_offload_to, Client.ping) - # _remote_method_async_by_info(Client.offload_receive_endpoint, self.client_to_offload_to, model_weights) - self.call_to_offload = False - self.client_to_offload_to = None - # This number only works for cifar10cnn - self.freeze_layers(15) - - # Check if there is a model to incorporate - if global_offload_received: - self.args.get_logger().info('Merging offloaded model') - self.args.get_logger().info('FedAvg locally with offloaded model') - updated_weights = average_nn_parameters([self.get_nn_parameters(), global_model_weights]) - self.args.get_logger().info('Updating local weights due to offloading') - self.update_nn_parameters(updated_weights) - global_offload_received = False - global_model_weights = None - - - if deadline_time is not None: - if time.time() >= deadline_time: - self.args.get_logger().info('Stopping training due to deadline time') - break - else: - self.args.get_logger().info(f'Time to deadline: {deadline_time - time.time()}') + if self.offload_enabled: + # Check if there is a call to offload + if self.call_to_offload: + self.args.get_logger().info('Got call to offload model') + model_weights = self.get_nn_parameters() + + ret = rpc.rpc_sync(self.client_to_offload_to, Client.offload_receive_endpoint, args=([model_weights])) + print(f'Result of rref: {ret}') + + self.call_to_offload = False + self.client_to_offload_to = None + # This number only works for cifar10cnn + # @TODO: Make this dynamic for other networks + self.freeze_layers(15) + + # Check if there is a model to incorporate + if global_offload_received: + self.args.get_logger().info('Merging offloaded model') + self.args.get_logger().info('FedAvg locally with offloaded model') + updated_weights = average_nn_parameters([self.get_nn_parameters(), global_model_weights]) + self.args.get_logger().info('Updating local weights due to offloading') + self.update_nn_parameters(updated_weights) + global_offload_received = False + global_model_weights = None + + if self.deadline_enabled: + # Deadline + if train_stop_time is not None: + if time.time() >= train_stop_time: + self.args.get_logger().info('Stopping training due to deadline time') + break + # else: + # self.args.get_logger().info(f'Time to deadline: {train_stop_time - time.time()}') + + + inputs, labels = inputs.to(self.device), labels.to(self.device) + training_process = i # zero the parameter gradients self.optimizer.zero_grad() - # Ignore profile for now - # p.set_warmup(False) - # p.signal_forward_start() - # forward + backward + optimize outputs = self.net(inputs) loss = self.loss_function(outputs, labels) @@ -376,15 +434,25 @@ def train(self, epoch, deadline_time: int = None): est_total_time = number_of_training_samples * time_per_batch logging.info(f'Estimated training time is {est_total_time}') self.report_performance_estimate((time_per_batch, est_total_time, number_of_training_samples)) - # logging.info(f'Batch time is {batch_duration}') + if self.freeze_layers_enabled: + logging.info(f'Checking if need to freeze layers ? {est_total_time} > {deadline}') + if est_total_time > deadline: + logging.info('Will freeze layers to speed up computation') + # This number only works for cifar10cnn + # @TODO: Make this dynamic for other networks + self.freeze_layers(15) + # logging.info(f'Batch time is {batch_duration}') - if i > 50: - break + # Break away from loop for debug purposes + # if i > 50: + # break control_end_time = time.time() logging.info(f'Measure end time is {(control_end_time - control_start_time)}') + logging.info(f'Trained on {training_process} samples') + self.scheduler.step() @@ -395,7 +463,7 @@ def train(self, epoch, deadline_time: int = None): if self.args.should_save_model(epoch): self.save_model(epoch, self.args.get_epoch_save_end_suffix()) - return final_running_loss, self.get_nn_parameters() + return final_running_loss, self.get_nn_parameters(), training_process def test(self): self.net.eval() @@ -435,14 +503,11 @@ def test(self): return accuracy, loss, class_precision, class_recall def run_epochs(self, num_epoch, deadline: int = None): - start_time = time.time() - deadline_threshold = 10 start_time_train = datetime.datetime.now() - train_stop_time = None - if deadline is not None: - train_stop_time = start_time + deadline - deadline_threshold + self.dataset.get_train_sampler().set_epoch_size(num_epoch) - loss, weights = self.train(self.epoch_counter, train_stop_time) + # Train locally + loss, weights, training_process = self.train(self.epoch_counter, deadline) self.epoch_counter += num_epoch elapsed_time_train = datetime.datetime.now() - start_time_train train_time_ms = int(elapsed_time_train.total_seconds()*1000) @@ -452,7 +517,7 @@ def run_epochs(self, num_epoch, deadline: int = None): elapsed_time_test = datetime.datetime.now() - start_time_test test_time_ms = int(elapsed_time_test.total_seconds()*1000) - data = EpochData(self.epoch_counter, train_time_ms, test_time_ms, loss, accuracy, test_loss, class_precision, class_recall, client_id=self.id) + data = EpochData(self.epoch_counter, num_epoch, train_time_ms, test_time_ms, loss, accuracy, test_loss, class_precision, class_recall, training_process, self.id) self.epoch_results.append(data) # Copy GPU tensors to CPU diff --git a/fltk/federator.py b/fltk/federator.py index f70f30cc..8790747c 100644 --- a/fltk/federator.py +++ b/fltk/federator.py @@ -12,7 +12,9 @@ from fltk.client import Client from fltk.datasets.data_distribution import distribute_batches_equally +from fltk.strategy.aggregation import FedAvg from fltk.strategy.client_selection import random_selection +from fltk.strategy.offloading import OffloadingStrategy from fltk.util.arguments import Arguments from fltk.util.base_config import BareConfig from fltk.util.data_loader_utils import load_train_data_loader, load_test_data_loader, \ @@ -112,6 +114,14 @@ class Federator: reference_lookup = {} performance_estimate = {} + # Strategies + deadline_enabled = False + swyh_enabled = False + freeze_layers_enabled = False + offload_enabled = False + + strategy = OffloadingStrategy.VANILLA + # Keep track of the experiment data exp_data_general = [] @@ -134,8 +144,43 @@ def __init__(self, client_id_triple, num_epochs = 3, config=None): self.test_data = Client("test", None, 1, 2, config) config.data_sampler = copy_sampler self.reference_lookup[get_worker_info().name] = RRef(self) - - + self.strategy = OffloadingStrategy.Parse(config.offload_strategy) + self.configure_strategy(self.strategy) + + + + def configure_strategy(self, strategy : OffloadingStrategy): + if strategy == OffloadingStrategy.VANILLA: + logging.info('Running with offloading strategy: VANILLA') + self.deadline_enabled = False + self.swyh_enabled = False + self.freeze_layers_enabled = False + self.offload_enabled = False + if strategy == OffloadingStrategy.DEADLINE: + logging.info('Running with offloading strategy: DEADLINE') + self.deadline_enabled = True + self.swyh_enabled = False + self.freeze_layers_enabled = False + self.offload_enabled = False + if strategy == OffloadingStrategy.SWYH: + logging.info('Running with offloading strategy: SWYH') + self.deadline_enabled = True + self.swyh_enabled = True + self.freeze_layers_enabled = False + self.offload_enabled = False + if strategy == OffloadingStrategy.FREEZE: + logging.info('Running with offloading strategy: FREEZE') + self.deadline_enabled = True + self.swyh_enabled = False + self.freeze_layers_enabled = True + self.offload_enabled = False + if strategy == OffloadingStrategy.MODEL_OFFLOAD: + logging.info('Running with offloading strategy: MODEL_OFFLOAD') + self.deadline_enabled = True + self.swyh_enabled = False + self.freeze_layers_enabled = True + self.offload_enabled = True + logging.info(f'Offload strategy params: deadline={self.deadline_enabled}, swyh={self.swyh_enabled}, freeze={self.freeze_layers_enabled}, offload={self.offload_enabled}') def create_clients(self, client_id_triple): for id, rank, world_size in client_id_triple: @@ -235,8 +280,8 @@ def ask_client_to_offload(self, client1_ref, client2_ref): def remote_run_epoch(self, epochs): start_epoch_time = time.time() - deadline = 400 - + deadline = self.config.deadline + deadline_time = self.config.deadline """ 1. Client selection 2. Run local updates @@ -245,6 +290,9 @@ def remote_run_epoch(self, epochs): """ client_weights = [] + + client_weights_dict = {} + client_training_process_dict = {} while self.num_available_clients() < self.config.clients_per_round: logging.warning(f'Waiting for enough clients to become available. # Available Clients = {self.num_available_clients()}, but need {self.config.clients_per_round}') self.process_response_list() @@ -264,6 +312,10 @@ def remote_run_epoch(self, epochs): res[1].wait() logging.info('Weights are updated') + # Let clients train locally + + if not self.deadline_enabled: + deadline = 0 responses: List[ClientResponse] = [] for client in selected_clients: cr = ClientResponse(self.response_id, client, _remote_method_async(Client.run_epochs, client.ref, num_epoch=epochs, deadline=deadline)) @@ -274,7 +326,6 @@ def remote_run_epoch(self, epochs): # responses.append((client, time.time(), _remote_method_async(Client.run_epochs, client.ref, num_epoch=epochs))) self.epoch_counter += epochs - deadline_time = 400 # deadline_time = None # Wait loop with deadline start = time.time() @@ -292,8 +343,8 @@ def reached_deadline(): has_not_called = True show_perf_data = True - while not all_finished and not reached_deadline(): - + while not all_finished and not (self.deadline_enabled and reached_deadline()): + # if self.deadline_enabled and reached_deadline() # if has_not_called and (time.time() -start) > 10: # logging.info('Sending call to offload') # has_not_called = False @@ -325,16 +376,17 @@ def reached_deadline(): # weak_client = k # else: # strong_client = k + if self.offload_enabled: + weak_client = est_keys[0] + strong_client = est_keys[1] + if self.performance_estimate[est_keys[1]][1] > self.performance_estimate[est_keys[0]][1]: + weak_client = est_keys[1] + strong_client = est_keys[0] - weak_client = est_keys[0] - strong_client = est_keys[1] - if self.performance_estimate[est_keys[1]][1] > self.performance_estimate[est_keys[0]][1]: - weak_client = est_keys[1] - strong_client = est_keys[0] + logging.info(f'Offloading from {weak_client} -> {strong_client} due to {self.performance_estimate[weak_client]} and {self.performance_estimate[strong_client]}') + logging.info('Sending call to offload') + self.ask_client_to_offload(self.reference_lookup[selected_clients[0].name], selected_clients[1].name) - logging.info(f'Offloading from {weak_client} -> {strong_client} due to {self.performance_estimate[weak_client]} and {self.performance_estimate[strong_client]}') - logging.info('Sending call to offload') - self.ask_client_to_offload(self.reference_lookup[selected_clients[0].name], selected_clients[1].name) # selected_clients[0] # logging.info(f'Status of all_finished={all_finished} and deadline={reached_deadline()}') all_finished = True @@ -344,6 +396,7 @@ def reached_deadline(): client_response.finish() else: all_finished = False + time.sleep(0.1) logging.info(f'Stopped waiting due to all_finished={all_finished} and deadline={reached_deadline()}') for client_response in responses: @@ -361,6 +414,7 @@ def reached_deadline(): self.client_data[epoch_data.client_id].append(epoch_data) logging.info(f'{client} had a loss of {epoch_data.loss}') logging.info(f'{client} had a epoch data of {epoch_data}') + logging.info(f'{client} has trained on {epoch_data.training_process} samples') client.tb_writer.add_scalar('training loss', epoch_data.loss_train, # for every 1000 minibatches @@ -379,10 +433,13 @@ def reached_deadline(): self.epoch_counter) client_weights.append(weights) + client_weights_dict[client.name] = weights + client_training_process_dict[client.name] = epoch_data.training_process self.performance_estimate = {} if len(client_weights): - updated_model = average_nn_parameters(client_weights) + updated_model = FedAvg(client_weights_dict, client_training_process_dict) + # updated_model = average_nn_parameters(client_weights) # test global model logging.info("Testing on global test set") @@ -399,13 +456,13 @@ def reached_deadline(): def save_experiment_data(self): p = Path(f'./{self.config.output_location}') - file_output = f'./{self.config.output_location}' + # file_output = f'./{self.config.output_location}' exp_prefix = self.config.experiment_prefix - self.ensure_path_exists(file_output) - file_output /= f'{exp_prefix}-general_data.csv' + self.ensure_path_exists(p) + p /= f'{exp_prefix}-general_data.csv' # general_filename = f'{file_output}/general_data.csv' df = pd.DataFrame(self.exp_data_general, columns=['epoch', 'duration', 'accuracy', 'loss', 'class_precision', 'class_recall']) - df.to_csv(file_output) + df.to_csv(p) def update_client_data_sizes(self): responses = [] @@ -427,9 +484,10 @@ def remote_test_sync(self): def save_epoch_data(self): file_output = f'./{self.config.output_location}' + exp_prefix = self.config.experiment_prefix self.ensure_path_exists(file_output) for key in self.client_data: - filename = f'{file_output}/{key}_epochs.csv' + filename = f'{file_output}/{exp_prefix}_{key}_epochs.csv' logging.info(f'Saving data at {filename}') with open(filename, "w") as f: w = DataclassWriter(f, self.client_data[key], EpochData) diff --git a/fltk/strategy/aggregation.py b/fltk/strategy/aggregation.py index 81726d9f..10a9975c 100644 --- a/fltk/strategy/aggregation.py +++ b/fltk/strategy/aggregation.py @@ -25,6 +25,23 @@ def average_nn_parameters(parameters): return new_params +def FedAvg(parameters, sizes): + new_params = {} + sum_size = 0 + for client in parameters: + for name in parameters[client].keys(): + try: + new_params[name].data += (parameters[client][name].data * sizes[client]) + except: + new_params[name] = (parameters[client][name].data * sizes[client]) + sum_size += sizes[client] + + for name in new_params: + # @TODO: Is .long() really required? + new_params[name].data = new_params[name].data.long() / sum_size + + return new_params + def average_nn_parameters(parameters, sizes): new_params = {} sum_size = 0 diff --git a/fltk/strategy/offloading.py b/fltk/strategy/offloading.py new file mode 100644 index 00000000..4473ad90 --- /dev/null +++ b/fltk/strategy/offloading.py @@ -0,0 +1,22 @@ +from enum import Enum + + +class OffloadingStrategy(Enum): + VANILLA = 1 + DEADLINE = 2 + SWYH = 3 + FREEZE = 4 + MODEL_OFFLOAD = 5 + + @classmethod + def Parse(cls, string_value): + if string_value == 'vanilla': + return OffloadingStrategy.VANILLA + if string_value == 'deadline': + return OffloadingStrategy.DEADLINE + if string_value == 'swyh': + return OffloadingStrategy.SWYH + if string_value == 'freeze': + return OffloadingStrategy.FREEZE + if string_value == 'offload': + return OffloadingStrategy.MODEL_OFFLOAD \ No newline at end of file diff --git a/fltk/util/base_config.py b/fltk/util/base_config.py index a5a3b74b..e41b92b9 100644 --- a/fltk/util/base_config.py +++ b/fltk/util/base_config.py @@ -43,6 +43,10 @@ def __init__(self): self.num_workers = 50 # self.num_poisoned_workers = 10 + self.offload_strategy = 'vanilla' + self.profiling_size = 100 + self.deadline = 400 + self.federator_host = '0.0.0.0' self.rank = 0 self.world_size = 0 @@ -109,6 +113,12 @@ def merge_yaml(self, cfg = {}): self.set_net_by_name(cfg['net']) if 'dataset' in cfg: self.dataset_name = cfg['dataset'] + if 'offload_stategy' in cfg: + self.offload_strategy = cfg['offload_stategy'] + if 'profiling_size' in cfg: + self.profiling_size = cfg['profiling_size'] + if 'deadline' in cfg: + self.deadline = cfg['deadline'] if 'experiment_prefix' in cfg: self.experiment_prefix = cfg['experiment_prefix'] else: diff --git a/fltk/util/generate_docker_compose.py b/fltk/util/generate_docker_compose.py index 5c67c8da..8d910446 100644 --- a/fltk/util/generate_docker_compose.py +++ b/fltk/util/generate_docker_compose.py @@ -29,6 +29,28 @@ def generate_client(id, template: dict, world_size: int, type='default'): return local_template, container_name +def generate_offload_exp(): + num_clients = 2 + world_size = num_clients + 1 + system_template: dict = load_system_template() + + for key, item in enumerate(system_template['services']['fl_server']['environment']): + if item == 'WORLD_SIZE={world_size}': + system_template['services']['fl_server']['environment'][key] = item.format(world_size=world_size) + + for client_id in range(1, num_clients + 1): + client_type = 'default' + if client_id == 1: + client_type = 'medium' + # if client_id == 2: + # client_type = 'slow' + client_template: dict = load_client_template(type=client_type) + client_definition, container_name = generate_client(client_id, client_template, world_size, type=client_type) + system_template['services'].update(client_definition) + + with open(r'./docker-compose.yml', 'w') as file: + yaml.dump(system_template, file, sort_keys=False) + def generate(num_clients: int): world_size = num_clients + 1 system_template :dict = load_system_template() @@ -39,10 +61,10 @@ def generate(num_clients: int): for client_id in range(1, num_clients+1): client_type = 'default' - # if client_id == 1: - # client_type='slow' - # if client_id == 2: - # client_type='medium' + if client_id == 1: + client_type='slow' + if client_id == 2: + client_type='medium' client_template: dict = load_client_template(type=client_type) client_definition, container_name = generate_client(client_id, client_template, world_size, type=client_type) system_template['services'].update(client_definition) @@ -53,7 +75,8 @@ def generate(num_clients: int): if __name__ == '__main__': - num_clients = int(sys.argv[1]) - generate(num_clients) + # num_clients = int(sys.argv[1]) + # generate(num_clients) + generate_offload_exp() print('Done') diff --git a/fltk/util/results.py b/fltk/util/results.py index cf762b8a..a37fc8ad 100644 --- a/fltk/util/results.py +++ b/fltk/util/results.py @@ -12,6 +12,7 @@ class EpochData: loss: float class_precision: Any class_recall: Any + training_process: int client_id: str = None def to_csv_line(self): diff --git a/requirements.txt b/requirements.txt index b01b714e..e87e007e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,6 @@ requests pyyaml torchsummary dataclass-csv -tensorboard \ No newline at end of file +tensorboard +seaborn +matplotlib \ No newline at end of file