From 16a4e91e8eb39d7c96e5cac65a5c6355830cb4a8 Mon Sep 17 00:00:00 2001
From: bacox <bartcox93@gmail.com>
Date: Wed, 12 Jan 2022 10:24:13 +0100
Subject: [PATCH] Enable offloading

---
 Dockerfile                               |   2 +-
 configs/experiment.yaml                  |   9 +-
 configs/experiment_vanilla.yaml          |  10 +-
 deploy/templates/client_stub_default.yml |   2 +-
 deploy/templates/client_stub_medium.yml  |   2 +-
 fltk/client.py                           | 185 +++++++++++++++--------
 fltk/federator.py                        | 100 +++++++++---
 fltk/strategy/aggregation.py             |  17 +++
 fltk/strategy/offloading.py              |  22 +++
 fltk/util/base_config.py                 |  10 ++
 fltk/util/generate_docker_compose.py     |  35 ++++-
 fltk/util/results.py                     |   1 +
 requirements.txt                         |   4 +-
 13 files changed, 301 insertions(+), 98 deletions(-)
 create mode 100644 fltk/strategy/offloading.py

diff --git a/Dockerfile b/Dockerfile
index 8ad4937b..006c97d0 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -45,5 +45,5 @@ EXPOSE 5000
 COPY fltk ./fltk
 COPY configs ./configs
 #CMD python3 ./fltk/__main__.py single configs/experiment.yaml --rank=$RANK
-CMD python3 -m fltk single configs/experiment.yaml --rank=$RANK
+CMD python3 -m fltk single configs/experiment_vanilla.yaml --rank=$RANK
 #CMD python3 setup.py
\ No newline at end of file
diff --git a/configs/experiment.yaml b/configs/experiment.yaml
index c8e30bce..62ee3a93 100644
--- a/configs/experiment.yaml
+++ b/configs/experiment.yaml
@@ -1,6 +1,6 @@
 ---
 # Experiment configuration
-total_epochs: 4
+total_epochs: 30
 epochs_per_cycle: 1
 wait_for_clients: true
 net: Cifar10CNN
@@ -8,11 +8,14 @@ dataset: cifar10
 # Use cuda is available; setting to false will force CPU
 cuda: false
 experiment_prefix: 'experiment_sample'
+offload_stategy: vanilla
+profiling_time: 100
+deadline: 500
 output_location: 'output'
 tensor_board_active: true
 clients_per_round: 2
-# sampler: "dirichlet" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (default)
-sampler: "uniform" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (default)
+ sampler: "dirichlet" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (default)
+#sampler: "uniform" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (default)
 sampler_args:
   - 0.07     # label limit || q probability || alpha || unused
   - 42    # random seed || random seed || random seed || unused
diff --git a/configs/experiment_vanilla.yaml b/configs/experiment_vanilla.yaml
index 90fcb77b..a8c10a79 100644
--- a/configs/experiment_vanilla.yaml
+++ b/configs/experiment_vanilla.yaml
@@ -1,19 +1,21 @@
 ---
 # Experiment configuration
-total_epochs: 4
+total_epochs: 20
 epochs_per_cycle: 1
 wait_for_clients: true
 net: Cifar10CNN
 dataset: cifar10
 # Use cuda is available; setting to false will force CPU
 cuda: false
-experiment_prefix: 'offloading_vanilla'
+experiment_prefix: 'exp_offload_vanilla'
 offload_stategy: vanilla
+profiling_time: 100
+deadline: 500
 output_location: 'output'
 tensor_board_active: true
 clients_per_round: 2
-# sampler: "dirichlet" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (default)
-sampler: "uniform" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (default)
+sampler: "dirichlet" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (default)
+#sampler: "uniform" # "limit labels" || "q sampler" || "dirichlet" || "uniform" (default)
 sampler_args:
   - 0.07     # label limit || q probability || alpha || unused
   - 42    # random seed || random seed || random seed || unused
diff --git a/deploy/templates/client_stub_default.yml b/deploy/templates/client_stub_default.yml
index d8955310..838cf699 100644
--- a/deploy/templates/client_stub_default.yml
+++ b/deploy/templates/client_stub_default.yml
@@ -20,4 +20,4 @@ client_name: # name can be anything
       resources:
         limits:
           cpus: '2'
-          memory: 1024M
+#          memory: 1024M
diff --git a/deploy/templates/client_stub_medium.yml b/deploy/templates/client_stub_medium.yml
index 8f07f46b..6037ce44 100644
--- a/deploy/templates/client_stub_medium.yml
+++ b/deploy/templates/client_stub_medium.yml
@@ -19,5 +19,5 @@ client_name: # name can be anything
     deploy:
       resources:
         limits:
-          cpus: '0.75'
+          cpus: '1'
           memory: 1024M
diff --git a/fltk/client.py b/fltk/client.py
index f841a332..7c5fa710 100644
--- a/fltk/client.py
+++ b/fltk/client.py
@@ -15,6 +15,7 @@
 from torch.distributed.rpc import RRef
 
 from fltk.schedulers import MinCapableStepLR
+from fltk.strategy.offloading import OffloadingStrategy
 from fltk.util.arguments import Arguments
 from fltk.util.fed_avg import average_nn_parameters
 from fltk.util.log import FLLogger
@@ -68,6 +69,8 @@ class Client:
     call_to_offload = False
     client_to_offload_to : str = None
 
+    strategy = OffloadingStrategy.VANILLA
+
 
     def __init__(self, id, log_rref, rank, world_size, config = None):
         logging.info(f'Welcome to client {id}')
@@ -92,6 +95,43 @@ def __init__(self, id, log_rref, rank, world_size, config = None):
                                           self.args.get_scheduler_step_size(),
                                           self.args.get_scheduler_gamma(),
                                           self.args.get_min_lr())
+        self.strategy = OffloadingStrategy.Parse(config.offload_strategy)
+        self.configure_strategy(self.strategy)
+
+
+    def configure_strategy(self, strategy : OffloadingStrategy):
+        if strategy == OffloadingStrategy.VANILLA:
+            logging.info('Running with offloading strategy: VANILLA')
+            self.deadline_enabled = False
+            self.swyh_enabled = False
+            self.freeze_layers_enabled = False
+            self.offload_enabled = False
+        if strategy == OffloadingStrategy.DEADLINE:
+            logging.info('Running with offloading strategy: DEADLINE')
+            self.deadline_enabled = True
+            self.swyh_enabled = False
+            self.freeze_layers_enabled = False
+            self.offload_enabled = False
+        if strategy == OffloadingStrategy.SWYH:
+            logging.info('Running with offloading strategy: SWYH')
+            self.deadline_enabled = True
+            self.swyh_enabled = True
+            self.freeze_layers_enabled = False
+            self.offload_enabled = False
+        if strategy == OffloadingStrategy.FREEZE:
+            logging.info('Running with offloading strategy: FREEZE')
+            self.deadline_enabled = True
+            self.swyh_enabled = False
+            self.freeze_layers_enabled = True
+            self.offload_enabled = False
+        if strategy == OffloadingStrategy.MODEL_OFFLOAD:
+            logging.info('Running with offloading strategy: MODEL_OFFLOAD')
+            self.deadline_enabled = True
+            self.swyh_enabled = False
+            self.freeze_layers_enabled = True
+            self.offload_enabled = True
+        logging.info(f'Offload strategy params: deadline={self.deadline_enabled}, swyh={self.swyh_enabled}, freeze={self.freeze_layers_enabled}, offload={self.offload_enabled}')
+
 
     def init_device(self):
         if self.args.cuda and torch.cuda.is_available():
@@ -254,11 +294,36 @@ def unfreeze_layers(self):
         for param in self.net.parameters():
             param.requires_grad = True
 
-    def train(self, epoch, deadline_time: int = None):
+    def train(self, epoch, deadline: int = None):
         """
+
+        Different modes:
+        1. Vanilla
+        2. Deadline
+        3. SWYH
+        4. Just Freeze
+        5. Model Offload
+
+
+        :: Vanilla
+        Disable deadline
+        Disable swyh
+        Disable offload
+
+        :: Deadline
+        We need to keep track of the incoming deadline
+        We don't need to send data before the deadline
+
         :param epoch: Current epoch #
         :type epoch: int
         """
+        start_time = time.time()
+        deadline_threshold = 5
+        train_stop_time = None
+        if self.deadline_enabled and deadline is not None:
+            train_stop_time = start_time + deadline - deadline_threshold
+
+        strategy = OffloadingStrategy.VANILLA
 
         # Ignore profiler for now
         # p = Profiler()
@@ -266,7 +331,7 @@ def train(self, epoch, deadline_time: int = None):
 
         # self.net.train()
         global global_model_weights, global_offload_received
-        deadline_time = None
+        # deadline_time = None
         # save model
         if self.args.should_save_model(epoch):
             self.save_model(epoch, self.args.get_epoch_save_start_suffix())
@@ -281,65 +346,58 @@ def train(self, epoch, deadline_time: int = None):
         # performance_metric_interval = 20
         # perf_resp = None
 
-        profiling_size = 40
+        # Profiling parameters
+        profiling_size = self.args.profiling_size
         profiling_data = np.zeros(profiling_size)
         active_profiling = True
 
         control_start_time = time.time()
+        training_process = 0
         for i, (inputs, labels) in enumerate(self.dataset.get_train_loader(), 0):
             start_train_time = time.time()
 
-            # Check if there is a call to offload
-            if self.call_to_offload:
-                self.args.get_logger().info('Got call to offload model')
-                model_weights = self.get_nn_parameters()
-                # print(self.client_to_offload_to)
-                # r_ref = rpc.remote(self.client_to_offload_to, Client.static_ping, args=())
-                # print(f'Result of rref: {r_ref.to_here()}')
-                # ret = rpc.rpc_sync(self.client_to_offload_to, Client.static_ping, args=())
-                # print(f'Result of rref: {ret}')
-                # ret = rpc.rpc_sync(self.client_to_offload_to, Client.offload_receive_endpoint_2, args=(["Hello"]))
-                # print(f'Result of rref: {ret}')
-
-                ret = rpc.rpc_sync(self.client_to_offload_to, Client.offload_receive_endpoint, args=([model_weights]))
-                print(f'Result of rref: {ret}')
-
-                # r_ref = rpc.remote(self.client_to_offload_to, Client.static_ping, args=())
-                # r_ref = rpc.remote(self.client_to_offload_to, Client.offload_receive_endpoint_2, args=("Hello world"))
-                # _remote_method_async(Client.static_ping, self.client_to_offload_to)
-                # fut1 = rpc.rpc_async(self.client_to_offload_to, Client.ping)
-                # _remote_method_async_by_info(Client.offload_receive_endpoint, self.client_to_offload_to, model_weights)
-                self.call_to_offload = False
-                self.client_to_offload_to = None
-                # This number only works for cifar10cnn
-                self.freeze_layers(15)
-
-            # Check if there is a model to incorporate
-            if global_offload_received:
-                self.args.get_logger().info('Merging offloaded model')
-                self.args.get_logger().info('FedAvg locally with offloaded model')
-                updated_weights = average_nn_parameters([self.get_nn_parameters(), global_model_weights])
-                self.args.get_logger().info('Updating local weights due to offloading')
-                self.update_nn_parameters(updated_weights)
-                global_offload_received = False
-                global_model_weights = None
-
-
-            if deadline_time is not None:
-                if time.time() >= deadline_time:
-                    self.args.get_logger().info('Stopping training due to deadline time')
-                    break
-                else:
-                    self.args.get_logger().info(f'Time to deadline: {deadline_time - time.time()}')
+            if self.offload_enabled:
+                # Check if there is a call to offload
+                if self.call_to_offload:
+                    self.args.get_logger().info('Got call to offload model')
+                    model_weights = self.get_nn_parameters()
+
+                    ret = rpc.rpc_sync(self.client_to_offload_to, Client.offload_receive_endpoint, args=([model_weights]))
+                    print(f'Result of rref: {ret}')
+
+                    self.call_to_offload = False
+                    self.client_to_offload_to = None
+                    # This number only works for cifar10cnn
+                    # @TODO: Make this dynamic for other networks
+                    self.freeze_layers(15)
+
+                # Check if there is a model to incorporate
+                if global_offload_received:
+                    self.args.get_logger().info('Merging offloaded model')
+                    self.args.get_logger().info('FedAvg locally with offloaded model')
+                    updated_weights = average_nn_parameters([self.get_nn_parameters(), global_model_weights])
+                    self.args.get_logger().info('Updating local weights due to offloading')
+                    self.update_nn_parameters(updated_weights)
+                    global_offload_received = False
+                    global_model_weights = None
+
+            if self.deadline_enabled:
+                # Deadline
+                if train_stop_time is not None:
+                    if time.time() >= train_stop_time:
+                        self.args.get_logger().info('Stopping training due to deadline time')
+                        break
+                    # else:
+                    #     self.args.get_logger().info(f'Time to deadline: {train_stop_time - time.time()}')
+
+
+
 
             inputs, labels = inputs.to(self.device), labels.to(self.device)
+            training_process = i
             # zero the parameter gradients
             self.optimizer.zero_grad()
 
-            # Ignore profile for now
-            # p.set_warmup(False)
-            # p.signal_forward_start()
-            # forward + backward + optimize
             outputs = self.net(inputs)
             loss = self.loss_function(outputs, labels)
 
@@ -376,15 +434,25 @@ def train(self, epoch, deadline_time: int = None):
                     est_total_time = number_of_training_samples * time_per_batch
                     logging.info(f'Estimated training time is {est_total_time}')
                     self.report_performance_estimate((time_per_batch, est_total_time, number_of_training_samples))
-            # logging.info(f'Batch time is {batch_duration}')
 
+                    if self.freeze_layers_enabled:
+                        logging.info(f'Checking if need to freeze layers ? {est_total_time} > {deadline}')
+                        if est_total_time > deadline:
+                            logging.info('Will freeze layers to speed up computation')
+                            # This number only works for cifar10cnn
+                            # @TODO: Make this dynamic for other networks
+                            self.freeze_layers(15)
+            # logging.info(f'Batch time is {batch_duration}')
 
-            if i > 50:
-                break
+            # Break away from loop for debug purposes
+            # if i > 50:
+            #     break
 
         control_end_time = time.time()
 
         logging.info(f'Measure end time is {(control_end_time - control_start_time)}')
+        logging.info(f'Trained on {training_process} samples')
+
 
         self.scheduler.step()
 
@@ -395,7 +463,7 @@ def train(self, epoch, deadline_time: int = None):
         if self.args.should_save_model(epoch):
             self.save_model(epoch, self.args.get_epoch_save_end_suffix())
 
-        return final_running_loss, self.get_nn_parameters()
+        return final_running_loss, self.get_nn_parameters(), training_process
 
     def test(self):
         self.net.eval()
@@ -435,14 +503,11 @@ def test(self):
         return accuracy, loss, class_precision, class_recall
 
     def run_epochs(self, num_epoch, deadline: int = None):
-        start_time = time.time()
-        deadline_threshold = 10
         start_time_train = datetime.datetime.now()
-        train_stop_time = None
-        if deadline is not None:
-            train_stop_time = start_time + deadline - deadline_threshold
+
         self.dataset.get_train_sampler().set_epoch_size(num_epoch)
-        loss, weights = self.train(self.epoch_counter, train_stop_time)
+        # Train locally
+        loss, weights, training_process = self.train(self.epoch_counter, deadline)
         self.epoch_counter += num_epoch
         elapsed_time_train = datetime.datetime.now() - start_time_train
         train_time_ms = int(elapsed_time_train.total_seconds()*1000)
@@ -452,7 +517,7 @@ def run_epochs(self, num_epoch, deadline: int = None):
         elapsed_time_test = datetime.datetime.now() - start_time_test
         test_time_ms = int(elapsed_time_test.total_seconds()*1000)
 
-        data = EpochData(self.epoch_counter, train_time_ms, test_time_ms, loss, accuracy, test_loss, class_precision, class_recall, client_id=self.id)
+        data = EpochData(self.epoch_counter, num_epoch, train_time_ms, test_time_ms, loss, accuracy, test_loss, class_precision, class_recall, training_process, self.id)
         self.epoch_results.append(data)
 
         # Copy GPU tensors to CPU
diff --git a/fltk/federator.py b/fltk/federator.py
index f70f30cc..8790747c 100644
--- a/fltk/federator.py
+++ b/fltk/federator.py
@@ -12,7 +12,9 @@
 
 from fltk.client import Client
 from fltk.datasets.data_distribution import distribute_batches_equally
+from fltk.strategy.aggregation import FedAvg
 from fltk.strategy.client_selection import random_selection
+from fltk.strategy.offloading import OffloadingStrategy
 from fltk.util.arguments import Arguments
 from fltk.util.base_config import BareConfig
 from fltk.util.data_loader_utils import load_train_data_loader, load_test_data_loader, \
@@ -112,6 +114,14 @@ class Federator:
     reference_lookup = {}
     performance_estimate = {}
 
+    # Strategies
+    deadline_enabled = False
+    swyh_enabled = False
+    freeze_layers_enabled = False
+    offload_enabled = False
+
+    strategy = OffloadingStrategy.VANILLA
+
 
     # Keep track of the experiment data
     exp_data_general = []
@@ -134,8 +144,43 @@ def __init__(self, client_id_triple, num_epochs = 3, config=None):
         self.test_data = Client("test", None, 1, 2, config)
         config.data_sampler = copy_sampler
         self.reference_lookup[get_worker_info().name] = RRef(self)
-
-
+        self.strategy = OffloadingStrategy.Parse(config.offload_strategy)
+        self.configure_strategy(self.strategy)
+
+
+
+    def configure_strategy(self, strategy : OffloadingStrategy):
+        if strategy == OffloadingStrategy.VANILLA:
+            logging.info('Running with offloading strategy: VANILLA')
+            self.deadline_enabled = False
+            self.swyh_enabled = False
+            self.freeze_layers_enabled = False
+            self.offload_enabled = False
+        if strategy == OffloadingStrategy.DEADLINE:
+            logging.info('Running with offloading strategy: DEADLINE')
+            self.deadline_enabled = True
+            self.swyh_enabled = False
+            self.freeze_layers_enabled = False
+            self.offload_enabled = False
+        if strategy == OffloadingStrategy.SWYH:
+            logging.info('Running with offloading strategy: SWYH')
+            self.deadline_enabled = True
+            self.swyh_enabled = True
+            self.freeze_layers_enabled = False
+            self.offload_enabled = False
+        if strategy == OffloadingStrategy.FREEZE:
+            logging.info('Running with offloading strategy: FREEZE')
+            self.deadline_enabled = True
+            self.swyh_enabled = False
+            self.freeze_layers_enabled = True
+            self.offload_enabled = False
+        if strategy == OffloadingStrategy.MODEL_OFFLOAD:
+            logging.info('Running with offloading strategy: MODEL_OFFLOAD')
+            self.deadline_enabled = True
+            self.swyh_enabled = False
+            self.freeze_layers_enabled = True
+            self.offload_enabled = True
+        logging.info(f'Offload strategy params: deadline={self.deadline_enabled}, swyh={self.swyh_enabled}, freeze={self.freeze_layers_enabled}, offload={self.offload_enabled}')
 
     def create_clients(self, client_id_triple):
         for id, rank, world_size in client_id_triple:
@@ -235,8 +280,8 @@ def ask_client_to_offload(self, client1_ref, client2_ref):
 
     def remote_run_epoch(self, epochs):
         start_epoch_time = time.time()
-        deadline = 400
-
+        deadline = self.config.deadline
+        deadline_time = self.config.deadline
         """
         1. Client selection
         2. Run local updates
@@ -245,6 +290,9 @@ def remote_run_epoch(self, epochs):
         """
 
         client_weights = []
+
+        client_weights_dict = {}
+        client_training_process_dict = {}
         while self.num_available_clients() < self.config.clients_per_round:
             logging.warning(f'Waiting for enough clients to become available. # Available Clients = {self.num_available_clients()}, but need {self.config.clients_per_round}')
             self.process_response_list()
@@ -264,6 +312,10 @@ def remote_run_epoch(self, epochs):
             res[1].wait()
         logging.info('Weights are updated')
 
+        # Let clients train locally
+
+        if not self.deadline_enabled:
+            deadline = 0
         responses: List[ClientResponse] = []
         for client in selected_clients:
             cr = ClientResponse(self.response_id, client, _remote_method_async(Client.run_epochs, client.ref, num_epoch=epochs, deadline=deadline))
@@ -274,7 +326,6 @@ def remote_run_epoch(self, epochs):
             # responses.append((client, time.time(), _remote_method_async(Client.run_epochs, client.ref, num_epoch=epochs)))
         self.epoch_counter += epochs
 
-        deadline_time = 400
         # deadline_time = None
         # Wait loop with deadline
         start = time.time()
@@ -292,8 +343,8 @@ def reached_deadline():
         has_not_called = True
 
         show_perf_data = True
-        while not all_finished and not reached_deadline():
-
+        while not all_finished and not (self.deadline_enabled and reached_deadline()):
+            # if self.deadline_enabled and reached_deadline()
             # if has_not_called and (time.time() -start) > 10:
             #     logging.info('Sending call to offload')
             #     has_not_called = False
@@ -325,16 +376,17 @@ def reached_deadline():
                     #         weak_client = k
                     #     else:
                     #         strong_client = k
+                    if self.offload_enabled:
+                        weak_client = est_keys[0]
+                        strong_client = est_keys[1]
+                        if self.performance_estimate[est_keys[1]][1] > self.performance_estimate[est_keys[0]][1]:
+                            weak_client = est_keys[1]
+                            strong_client = est_keys[0]
 
-                    weak_client = est_keys[0]
-                    strong_client = est_keys[1]
-                    if self.performance_estimate[est_keys[1]][1] > self.performance_estimate[est_keys[0]][1]:
-                        weak_client = est_keys[1]
-                        strong_client = est_keys[0]
+                        logging.info(f'Offloading from {weak_client} -> {strong_client} due to {self.performance_estimate[weak_client]} and {self.performance_estimate[strong_client]}')
+                        logging.info('Sending call to offload')
+                        self.ask_client_to_offload(self.reference_lookup[selected_clients[0].name], selected_clients[1].name)
 
-                    logging.info(f'Offloading from {weak_client} -> {strong_client} due to {self.performance_estimate[weak_client]} and {self.performance_estimate[strong_client]}')
-                    logging.info('Sending call to offload')
-                    self.ask_client_to_offload(self.reference_lookup[selected_clients[0].name], selected_clients[1].name)
                 # selected_clients[0]
             # logging.info(f'Status of all_finished={all_finished} and deadline={reached_deadline()}')
             all_finished = True
@@ -344,6 +396,7 @@ def reached_deadline():
                         client_response.finish()
                 else:
                     all_finished = False
+            time.sleep(0.1)
         logging.info(f'Stopped waiting due to all_finished={all_finished} and deadline={reached_deadline()}')
 
         for client_response in responses:
@@ -361,6 +414,7 @@ def reached_deadline():
                 self.client_data[epoch_data.client_id].append(epoch_data)
                 logging.info(f'{client} had a loss of {epoch_data.loss}')
                 logging.info(f'{client} had a epoch data of {epoch_data}')
+                logging.info(f'{client} has trained on {epoch_data.training_process} samples')
 
                 client.tb_writer.add_scalar('training loss',
                                             epoch_data.loss_train,  # for every 1000 minibatches
@@ -379,10 +433,13 @@ def reached_deadline():
                                             self.epoch_counter)
 
                 client_weights.append(weights)
+                client_weights_dict[client.name] = weights
+                client_training_process_dict[client.name] = epoch_data.training_process
 
         self.performance_estimate = {}
         if len(client_weights):
-            updated_model = average_nn_parameters(client_weights)
+            updated_model = FedAvg(client_weights_dict, client_training_process_dict)
+            # updated_model = average_nn_parameters(client_weights)
 
             # test global model
             logging.info("Testing on global test set")
@@ -399,13 +456,13 @@ def reached_deadline():
 
     def save_experiment_data(self):
         p = Path(f'./{self.config.output_location}')
-        file_output = f'./{self.config.output_location}'
+        # file_output = f'./{self.config.output_location}'
         exp_prefix = self.config.experiment_prefix
-        self.ensure_path_exists(file_output)
-        file_output /= f'{exp_prefix}-general_data.csv'
+        self.ensure_path_exists(p)
+        p /= f'{exp_prefix}-general_data.csv'
         # general_filename = f'{file_output}/general_data.csv'
         df = pd.DataFrame(self.exp_data_general, columns=['epoch', 'duration', 'accuracy', 'loss', 'class_precision', 'class_recall'])
-        df.to_csv(file_output)
+        df.to_csv(p)
 
     def update_client_data_sizes(self):
         responses = []
@@ -427,9 +484,10 @@ def remote_test_sync(self):
 
     def save_epoch_data(self):
         file_output = f'./{self.config.output_location}'
+        exp_prefix = self.config.experiment_prefix
         self.ensure_path_exists(file_output)
         for key in self.client_data:
-            filename = f'{file_output}/{key}_epochs.csv'
+            filename = f'{file_output}/{exp_prefix}_{key}_epochs.csv'
             logging.info(f'Saving data at {filename}')
             with open(filename, "w") as f:
                 w = DataclassWriter(f, self.client_data[key], EpochData)
diff --git a/fltk/strategy/aggregation.py b/fltk/strategy/aggregation.py
index 81726d9f..10a9975c 100644
--- a/fltk/strategy/aggregation.py
+++ b/fltk/strategy/aggregation.py
@@ -25,6 +25,23 @@ def average_nn_parameters(parameters):
 
     return new_params
 
+def FedAvg(parameters, sizes):
+    new_params = {}
+    sum_size = 0
+    for client in parameters:
+        for name in parameters[client].keys():
+            try:
+                new_params[name].data += (parameters[client][name].data * sizes[client])
+            except:
+                new_params[name] = (parameters[client][name].data * sizes[client])
+        sum_size += sizes[client]
+
+    for name in new_params:
+        # @TODO: Is .long() really required?
+        new_params[name].data = new_params[name].data.long() / sum_size
+
+    return new_params
+
 def average_nn_parameters(parameters, sizes):
     new_params = {}
     sum_size = 0
diff --git a/fltk/strategy/offloading.py b/fltk/strategy/offloading.py
new file mode 100644
index 00000000..4473ad90
--- /dev/null
+++ b/fltk/strategy/offloading.py
@@ -0,0 +1,22 @@
+from enum import Enum
+
+
+class OffloadingStrategy(Enum):
+    VANILLA = 1
+    DEADLINE = 2
+    SWYH = 3
+    FREEZE = 4
+    MODEL_OFFLOAD = 5
+
+    @classmethod
+    def Parse(cls, string_value):
+        if string_value == 'vanilla':
+            return OffloadingStrategy.VANILLA
+        if string_value == 'deadline':
+            return OffloadingStrategy.DEADLINE
+        if string_value == 'swyh':
+            return OffloadingStrategy.SWYH
+        if string_value == 'freeze':
+            return OffloadingStrategy.FREEZE
+        if string_value == 'offload':
+            return OffloadingStrategy.MODEL_OFFLOAD
\ No newline at end of file
diff --git a/fltk/util/base_config.py b/fltk/util/base_config.py
index a5a3b74b..e41b92b9 100644
--- a/fltk/util/base_config.py
+++ b/fltk/util/base_config.py
@@ -43,6 +43,10 @@ def __init__(self):
         self.num_workers = 50
         # self.num_poisoned_workers = 10
 
+        self.offload_strategy = 'vanilla'
+        self.profiling_size = 100
+        self.deadline = 400
+
         self.federator_host = '0.0.0.0'
         self.rank = 0
         self.world_size = 0
@@ -109,6 +113,12 @@ def merge_yaml(self, cfg = {}):
             self.set_net_by_name(cfg['net'])
         if 'dataset' in cfg:
             self.dataset_name = cfg['dataset']
+        if 'offload_stategy' in cfg:
+            self.offload_strategy = cfg['offload_stategy']
+        if 'profiling_size' in cfg:
+            self.profiling_size = cfg['profiling_size']
+        if 'deadline' in cfg:
+            self.deadline = cfg['deadline']
         if 'experiment_prefix' in cfg:
             self.experiment_prefix = cfg['experiment_prefix']
         else:
diff --git a/fltk/util/generate_docker_compose.py b/fltk/util/generate_docker_compose.py
index 5c67c8da..8d910446 100644
--- a/fltk/util/generate_docker_compose.py
+++ b/fltk/util/generate_docker_compose.py
@@ -29,6 +29,28 @@ def generate_client(id, template: dict, world_size: int, type='default'):
     return local_template, container_name
 
 
+def generate_offload_exp():
+    num_clients = 2
+    world_size = num_clients + 1
+    system_template: dict = load_system_template()
+
+    for key, item in enumerate(system_template['services']['fl_server']['environment']):
+        if item == 'WORLD_SIZE={world_size}':
+            system_template['services']['fl_server']['environment'][key] = item.format(world_size=world_size)
+
+    for client_id in range(1, num_clients + 1):
+        client_type = 'default'
+        if client_id == 1:
+            client_type = 'medium'
+        # if client_id == 2:
+        #     client_type = 'slow'
+        client_template: dict = load_client_template(type=client_type)
+        client_definition, container_name = generate_client(client_id, client_template, world_size, type=client_type)
+        system_template['services'].update(client_definition)
+
+    with open(r'./docker-compose.yml', 'w') as file:
+        yaml.dump(system_template, file, sort_keys=False)
+
 def generate(num_clients: int):
     world_size = num_clients + 1
     system_template :dict = load_system_template()
@@ -39,10 +61,10 @@ def generate(num_clients: int):
 
     for client_id in range(1, num_clients+1):
         client_type = 'default'
-        # if client_id == 1:
-        #     client_type='slow'
-        # if client_id == 2:
-        #     client_type='medium'
+        if client_id == 1:
+            client_type='slow'
+        if client_id == 2:
+            client_type='medium'
         client_template: dict = load_client_template(type=client_type)
         client_definition, container_name = generate_client(client_id, client_template, world_size, type=client_type)
         system_template['services'].update(client_definition)
@@ -53,7 +75,8 @@ def generate(num_clients: int):
 
 if __name__ == '__main__':
 
-    num_clients = int(sys.argv[1])
-    generate(num_clients)
+    # num_clients = int(sys.argv[1])
+    # generate(num_clients)
+    generate_offload_exp()
     print('Done')
 
diff --git a/fltk/util/results.py b/fltk/util/results.py
index cf762b8a..a37fc8ad 100644
--- a/fltk/util/results.py
+++ b/fltk/util/results.py
@@ -12,6 +12,7 @@ class EpochData:
     loss: float
     class_precision: Any
     class_recall: Any
+    training_process: int
     client_id: str = None
 
     def to_csv_line(self):
diff --git a/requirements.txt b/requirements.txt
index b01b714e..e87e007e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,4 +10,6 @@ requests
 pyyaml
 torchsummary
 dataclass-csv
-tensorboard
\ No newline at end of file
+tensorboard
+seaborn
+matplotlib
\ No newline at end of file