Skip to content

Commit

Permalink
formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
chancejohnstone committed Oct 20, 2024
1 parent 5e2a418 commit 48130eb
Show file tree
Hide file tree
Showing 7 changed files with 36 additions and 36 deletions.
13 changes: 6 additions & 7 deletions baselines/fedht/fedht/aggregate.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,14 +367,14 @@ def _aggregate_n_closest_weights(

# calls hardthreshold function for each list element in weights_all
def hardthreshold_list(weights_all, num_keep: int) -> NDArrays:

"""Call hardthreshold."""
params = [hardthreshold(each, num_keep) for each in weights_all]
return params


# hardthreshold function applied to array
def hardthreshold(weights_prime, num_keep: int) -> NDArrays:

"""Perform hardthresholding on single array."""
# check for len of array
val_len = weights_prime.size

Expand All @@ -383,7 +383,7 @@ def hardthreshold(weights_prime, num_keep: int) -> NDArrays:
if num_keep > val_len:
params = weights_prime
print(
"The number of parameters kept is greater than the length of the vector. All parameters will be kept."
"num_keep parameter greater than length of vector. All parameters kept."
)
else:
# Compute the magnitudes
Expand All @@ -404,8 +404,7 @@ def hardthreshold(weights_prime, num_keep: int) -> NDArrays:
def aggregate_hardthreshold(
results: List[Tuple[NDArrays, int]], num_keep: int, iterht: bool
) -> NDArrays:
"""Applies hard thresholding to keep only the k largest weights in a client-weight
vector.
"""Apply hard thresholding to keep only the k largest weights.
Fed-HT (Fed-IterHT) can be found at
https://arxiv.org/abs/2101.00052
Expand All @@ -423,7 +422,7 @@ def aggregate_hardthreshold(
# check for iterht=True; set in cfg
if iterht:
print(
f"{green}INFO {reset}:\t\tUsing Fed-IterHT for model aggregation with threshold = ",
f"{green}INFO {reset}:\t\tUsing Fed-IterHT with num_keep = ",
num_keep,
)

Expand All @@ -448,7 +447,7 @@ def aggregate_hardthreshold(

else:
print(
f"{green}INFO {reset}:\t\tUsing Fed-HT for model aggregation with threshold = ",
f"{green}INFO {reset}:\t\tUsing Fed-HT with num_keep = ",
num_keep,
)

Expand Down
14 changes: 8 additions & 6 deletions baselines/fedht/fedht/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Generate client for fedht baseline."""

from collections import OrderedDict
from typing import cast

Expand All @@ -13,6 +15,8 @@

# SimII client
class SimIIClient(NumPyClient):
"""Define SimIIClient class."""

def __init__(
self,
trainloader,
Expand All @@ -24,7 +28,6 @@ def __init__(
cfg: DictConfig,
) -> None:
"""SimII client for simulation II experimentation."""

self.trainloader = trainloader
self.testloader = testloader
self.model = model
Expand All @@ -35,6 +38,7 @@ def __init__(

# get parameters from existing model
def get_parameters(self, config):
"""Get parameters."""
return [val.cpu().numpy() for _, val in self.model.state_dict().items()]

def fit(self, parameters, config):
Expand Down Expand Up @@ -69,12 +73,11 @@ def evaluate(self, parameters, config):
def generate_client_fn_simII(
dataset, num_features, num_classes, model, cfg: DictConfig
):
"""Generates client function for simulated FL."""
"""Generate client function for simulated FL."""

# def client_fn(cid: int):
def client_fn(context: Context) -> Client:
"""Define client function for centralized metrics."""

# Get node_config value to fetch partition_id
partition_id = cast(int, context.node_config["partition-id"])

Expand All @@ -96,6 +99,8 @@ def client_fn(context: Context) -> Client:

# MNIST client
class MnistClient(NumPyClient):
"""Define MnistClient class."""

def __init__(
self,
trainloader,
Expand All @@ -107,7 +112,6 @@ def __init__(
cfg: DictConfig,
) -> None:
"""MNIST client for MNIST experimentation."""

self.trainloader = trainloader
self.testloader = testloader
self.model = model
Expand Down Expand Up @@ -138,7 +142,6 @@ def fit(self, parameters, config):

def evaluate(self, parameters, config):
"""Evaluate model."""

# set model parameters
params_dict = zip(self.model.state_dict().keys(), parameters)
state_dict = OrderedDict({k: torch.Tensor(v) for k, v in params_dict})
Expand All @@ -159,7 +162,6 @@ def generate_client_fn_mnist(
# def client_fn(cid: int):
def client_fn(context: Context) -> Client:
"""Define client function for centralized metrics."""

# Get node_config value to fetch partition_id
partition_id = cast(int, context.node_config["partition-id"])

Expand Down
7 changes: 3 additions & 4 deletions baselines/fedht/fedht/fedht.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Federated Hardthresholding (FedHT)"""
"""Federated Hardthresholding (FedHT)."""


from logging import WARNING
Expand All @@ -35,8 +35,6 @@
from flwr.server.client_proxy import ClientProxy
from flwr.server.strategy.strategy import Strategy

# from flwr.server.strategy.aggregate import aggregate, aggregate_inplace, weighted_loss_avg
# from flwr.server.strategy.aggregate import aggregate_inplace, weighted_loss_avg, aggregate_hardthreshold
from fedht.aggregate import aggregate_hardthreshold, weighted_loss_avg

WARNING_MIN_AVAILABLE_CLIENTS_TOO_LOW = """
Expand Down Expand Up @@ -71,7 +69,8 @@ class FedHT(Strategy):
Minimum number of clients used during validation. Defaults to 2.
min_available_clients : int, optional
Minimum number of total clients in the system. Defaults to 2.
evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]],Optional[Tuple[float, Dict[str, Scalar]]]]]
evaluate_fn : Optional[Callable[[int, NDArrays, Dict[str, Scalar]],
Optional[Tuple[float, Dict[str, Scalar]]]]]
Optional function used for validation. Defaults to None.
on_fit_config_fn : Callable[[int], Dict[str, Scalar]], optional
Function used to configure training. Defaults to None.
Expand Down
17 changes: 4 additions & 13 deletions baselines/fedht/fedht/main.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,4 @@
"""
Author: Chance Johnstone
Purpose: main function for fedht baseline Simulation II example from Tong et al 2020
Notes
-----
Code included in this baseline generated with the help of numerous
Flower, Python, and PyTorch resources.
"""
"""Run main for fedht baseline."""

import pickle
import random
Expand All @@ -29,14 +21,13 @@

@hydra.main(config_path="conf", config_name="base_mnist", version_base=None)
def main(cfg: DictConfig):
"""Main file for fedht baseline
"""Run main file for fedht baseline.
Parameters
----------
cfg : DictConfig
Config file for federated baseline; read from fedht/conf
Config file for federated baseline; read from fedht/conf.
"""

# set seed
random.seed(2024)

Expand All @@ -53,7 +44,7 @@ def main(cfg: DictConfig):
# load MNIST data
num_features = 28 * 28
num_classes = 10
dataset = FederatedDataset(dataset="mnist", partitioners={"train": num_classes})
dataset = FederatedDataset(dataset="mnist", partitioners={"train": partitioner})
test_dataset = dataset.load_split("test").with_format("numpy")
testloader = DataLoader(test_dataset, batch_size=cfg.batch_size, shuffle=False)

Expand Down
15 changes: 10 additions & 5 deletions baselines/fedht/fedht/model.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Define model for fedht baseline."""

import torch
import torch.nn as nn
import torch.optim as optim
Expand All @@ -8,30 +10,33 @@
# model code initially pulled from fedprox baseline
# generates multinomial logistic regression model via torch
class LogisticRegression(nn.Module):
"""Define LogisticRegression class."""

def __init__(self, num_features, num_classes: int) -> None:
"""Define model."""
super().__init__()

# one single linear layer
self.linear = nn.Linear(num_features, num_classes)

def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
"""Define forward pass."""
# forward pass; sigmoid transform included in CBELoss criterion
output_tensor = self.linear(torch.flatten(input_tensor, 1))
return output_tensor


# define train function that will be called by each client to train the model
def train(model, trainloader: DataLoader, cfg: DictConfig) -> None:

"""Train model."""
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(
model.parameters(), lr=cfg.learning_rate, weight_decay=cfg.weight_decay
)

# train
for epoch in range(cfg.num_local_epochs):
for i, data in enumerate(trainloader):
for _epoch in range(cfg.num_local_epochs):
for _i, data in enumerate(trainloader):

inputs, labels = data["image"], data["label"]

Expand All @@ -49,7 +54,7 @@ def train(model, trainloader: DataLoader, cfg: DictConfig) -> None:


def test(model, testloader: DataLoader) -> None:

"""Test model."""
criterion = nn.CrossEntropyLoss()

# initialize
Expand All @@ -58,7 +63,7 @@ def test(model, testloader: DataLoader) -> None:
# put into evlauate mode
model.eval()
with torch.no_grad():
for i, data in enumerate(testloader):
for _i, data in enumerate(testloader):

images, labels = data["image"], data["label"]

Expand Down
2 changes: 2 additions & 0 deletions baselines/fedht/fedht/server.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Generate server for fedht baseline."""

from collections import OrderedDict
from typing import Dict

Expand Down
4 changes: 3 additions & 1 deletion baselines/fedht/fedht/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
"""Execute utility functions for fedht baseline."""

import numpy as np
from torch.utils.data import Dataset

Expand All @@ -10,7 +12,7 @@ def __init__(self, features, labels):
self.data = {"image": features, "label": labels}

def __len__(self):
"""Returns length of features in data."""
"""Return length of features in data."""
return len(self.data["image"])

def __getitem__(self, idx):
Expand Down

0 comments on commit 48130eb

Please sign in to comment.