diff --git a/baselines/fedht/fedht/aggregate.py b/baselines/fedht/fedht/aggregate.py index c2c6116629f..462dd7dade9 100644 --- a/baselines/fedht/fedht/aggregate.py +++ b/baselines/fedht/fedht/aggregate.py @@ -364,57 +364,66 @@ def _aggregate_n_closest_weights( aggregated_weights.append(np.mean(beta_closest_weights, axis=0)) return aggregated_weights + # calls hardthreshold function for each list element in weights_all def hardthreshold_list(weights_all, num_keep: int) -> NDArrays: params = [hardthreshold(each, num_keep) for each in weights_all] return params + # hardthreshold function applied to array def hardthreshold(weights_prime, num_keep: int) -> NDArrays: - + # check for len of array val_len = weights_prime.size - # intercepts not hardthresholded + # intercepts not hardthresholded if val_len > 1: if num_keep > val_len: params = weights_prime - print("The number of parameters kept is greater than the length of the vector. All parameters will be kept.") + print( + "The number of parameters kept is greater than the length of the vector. All parameters will be kept." + ) else: # Compute the magnitudes magnitudes = np.abs(weights_prime) - + # Get the k-th largest value in the vector threshold = np.partition(magnitudes, -num_keep)[-num_keep] - + # Create a new vector where values below the threshold are set to zero params = np.where(magnitudes >= threshold, weights_prime, 0) - else: + else: params = weights_prime return params - + + def aggregate_hardthreshold( - results: List[Tuple[NDArrays, int]], num_keep: int, iterht: bool) -> NDArrays: + results: List[Tuple[NDArrays, int]], num_keep: int, iterht: bool +) -> NDArrays: """ - Applies hard thresholding to keep only the k largest weights in a client-weight vector. Fed-HT (Fed-IterHT) can be + Applies hard thresholding to keep only the k largest weights in a client-weight vector. Fed-HT (Fed-IterHT) can be found at https://arxiv.org/abs/2101.00052 """ if num_keep <= 0: raise ValueError("k must be a positive integer.") - + """Compute weighted average.""" # Calculate the total number of examples used during training num_examples_total = sum(num_examples for (_, num_examples) in results) - - green = '\033[92m' - reset = '\033[0m' + + green = "\033[92m" + reset = "\033[0m" # check for iterht=True; set in cfg if iterht: - print(f"{green}INFO {reset}:\t\tUsing Fed-IterHT for model aggregation with threshold = ", num_keep) + print( + f"{green}INFO {reset}:\t\tUsing Fed-IterHT for model aggregation with threshold = ", + num_keep, + ) # apply across all models within each client @@ -428,23 +437,31 @@ def aggregate_hardthreshold( for j in range(len(results[i][0])): for k in range(len(results[i][0][j])): results[i][0][j][k] = hardthreshold(results[i][0][j][k], num_keep) - + weighted_weights1 = [ - [layer * num_examples for layer in weights] for weights, num_examples in results + [layer * num_examples for layer in weights] + for weights, num_examples in results ] weighted_weights2 = weighted_weights1 - else: - print(f"{green}INFO {reset}:\t\tUsing Fed-HT for model aggregation with threshold = ", num_keep) + else: + print( + f"{green}INFO {reset}:\t\tUsing Fed-HT for model aggregation with threshold = ", + num_keep, + ) weighted_weights1 = [ - [layer * num_examples for layer in weights] for weights, num_examples in results + [layer * num_examples for layer in weights] + for weights, num_examples in results ] weighted_weights2 = weighted_weights1 - - hold = [reduce(np.add, layer_updates) / num_examples_total for layer_updates in zip(*weighted_weights2)] + + hold = [ + reduce(np.add, layer_updates) / num_examples_total + for layer_updates in zip(*weighted_weights2) + ] params = [hardthreshold_list(layer_updates, num_keep) for layer_updates in hold] - + result: NDArrays = params - + return result