Skip to content

Commit

Permalink
agg reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
chancejohnstone committed Oct 20, 2024
1 parent acc7671 commit 0a080f0
Showing 1 changed file with 40 additions and 23 deletions.
63 changes: 40 additions & 23 deletions baselines/fedht/fedht/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,57 +364,66 @@ def _aggregate_n_closest_weights(
aggregated_weights.append(np.mean(beta_closest_weights, axis=0))
return aggregated_weights


# calls hardthreshold function for each list element in weights_all
def hardthreshold_list(weights_all, num_keep: int) -> NDArrays:

params = [hardthreshold(each, num_keep) for each in weights_all]
return params


# hardthreshold function applied to array
def hardthreshold(weights_prime, num_keep: int) -> NDArrays:

# check for len of array
val_len = weights_prime.size

# intercepts not hardthresholded
# intercepts not hardthresholded
if val_len > 1:
if num_keep > val_len:
params = weights_prime
print("The number of parameters kept is greater than the length of the vector. All parameters will be kept.")
print(
"The number of parameters kept is greater than the length of the vector. All parameters will be kept."
)
else:
# Compute the magnitudes
magnitudes = np.abs(weights_prime)

# Get the k-th largest value in the vector
threshold = np.partition(magnitudes, -num_keep)[-num_keep]

# Create a new vector where values below the threshold are set to zero
params = np.where(magnitudes >= threshold, weights_prime, 0)

else:
else:
params = weights_prime

return params



def aggregate_hardthreshold(
results: List[Tuple[NDArrays, int]], num_keep: int, iterht: bool) -> NDArrays:
results: List[Tuple[NDArrays, int]], num_keep: int, iterht: bool
) -> NDArrays:
"""
Applies hard thresholding to keep only the k largest weights in a client-weight vector. Fed-HT (Fed-IterHT) can be
Applies hard thresholding to keep only the k largest weights in a client-weight vector. Fed-HT (Fed-IterHT) can be
found at https://arxiv.org/abs/2101.00052
"""
if num_keep <= 0:
raise ValueError("k must be a positive integer.")

"""Compute weighted average."""
# Calculate the total number of examples used during training
num_examples_total = sum(num_examples for (_, num_examples) in results)
green = '\033[92m'
reset = '\033[0m'

green = "\033[92m"
reset = "\033[0m"

# check for iterht=True; set in cfg
if iterht:
print(f"{green}INFO {reset}:\t\tUsing Fed-IterHT for model aggregation with threshold = ", num_keep)
print(
f"{green}INFO {reset}:\t\tUsing Fed-IterHT for model aggregation with threshold = ",
num_keep,
)

# apply across all models within each client

Expand All @@ -428,23 +437,31 @@ def aggregate_hardthreshold(
for j in range(len(results[i][0])):
for k in range(len(results[i][0][j])):
results[i][0][j][k] = hardthreshold(results[i][0][j][k], num_keep)

weighted_weights1 = [
[layer * num_examples for layer in weights] for weights, num_examples in results
[layer * num_examples for layer in weights]
for weights, num_examples in results
]
weighted_weights2 = weighted_weights1

else:
print(f"{green}INFO {reset}:\t\tUsing Fed-HT for model aggregation with threshold = ", num_keep)
else:
print(
f"{green}INFO {reset}:\t\tUsing Fed-HT for model aggregation with threshold = ",
num_keep,
)

weighted_weights1 = [
[layer * num_examples for layer in weights] for weights, num_examples in results
[layer * num_examples for layer in weights]
for weights, num_examples in results
]
weighted_weights2 = weighted_weights1

hold = [reduce(np.add, layer_updates) / num_examples_total for layer_updates in zip(*weighted_weights2)]

hold = [
reduce(np.add, layer_updates) / num_examples_total
for layer_updates in zip(*weighted_weights2)
]
params = [hardthreshold_list(layer_updates, num_keep) for layer_updates in hold]

result: NDArrays = params

return result

0 comments on commit 0a080f0

Please sign in to comment.