diff --git a/fltk/core/federator.py b/fltk/core/federator.py index 081fd752..75a45b22 100644 --- a/fltk/core/federator.py +++ b/fltk/core/federator.py @@ -301,9 +301,7 @@ def all_futures_done(futures: List[torch.Future]) -> bool: # pylint: disable=no self.logger.info('Continue with rest [1]') time.sleep(3) - # updated_model = FedAvg(client_weights, client_sizes) updated_model = self.aggregation_method(client_weights, client_sizes) - # updated_model = average_nn_parameters_simple(list(client_weights.values())) self.update_nn_parameters(updated_model) test_accuracy, test_loss = self.test(self.net) diff --git a/fltk/strategy/aggregation.py b/fltk/strategy/aggregation.py deleted file mode 100644 index f18ac1aa..00000000 --- a/fltk/strategy/aggregation.py +++ /dev/null @@ -1,35 +0,0 @@ -def average_nn_parameters(parameters): - """ - @deprecated Average passed parameters. - @param parameters: nn model named parameters - @type parameters: list - """ - new_params = {} - for name in parameters[0].keys(): - new_params[name] = sum([param[name].data for param in parameters]) / len(parameters) - - return new_params - - -def fed_average_nn_parameters(parameters, sizes): - """ - @deprecated Federated Average passed parameters. - @param parameters: nn model named parameters - @type parameters: list - @param sizes: - @type sizes: - """ - new_params = {} - sum_size = 0 - for client in parameters: - for name in parameters[client].keys(): - try: - new_params[name].data += (parameters[client][name].data * sizes[client]) - except Exception as e: - new_params[name] = (parameters[client][name].data * sizes[client]) - sum_size += sizes[client] - - for name in new_params: - new_params[name].data /= sum_size - - return new_params diff --git a/fltk/strategy/aggregation/aggregation.py b/fltk/strategy/aggregation/aggregation.py deleted file mode 100644 index 57bb1287..00000000 --- a/fltk/strategy/aggregation/aggregation.py +++ /dev/null @@ -1,37 +0,0 @@ - - -def average_nn_parameters_simple(parameters): - """ - Averages passed parameters. - :param parameters: nn model named parameters - :type parameters: list - """ - new_params = {} - for name in parameters[0].keys(): - new_params[name] = sum([param[name].data for param in parameters]) / len(parameters) - - return new_params - - -def average_nn_parameters(parameters, sizes): - """ - @deprecated Federated Average passed parameters. - :param parameters: nn model named parameters - :type parameters: list - :param sizes: - :type sizes: - """ - new_params = {} - sum_size = 0 - for client in parameters: - for key, _ in parameters[client].items(): - try: - new_params[key].data += (parameters[client][key].data * sizes[client]) - except Exception: # pylint: disable=broad-except - new_params[key] = (parameters[client][key].data * sizes[client]) - sum_size += sizes[client] - - for key, _ in new_params.items(): - new_params[key].data /= sum_size - - return new_params