Skip to content

Commit

Permalink
Format code examples (#2767)
Browse files Browse the repository at this point in the history
  • Loading branch information
panh99 authored Jan 4, 2024
1 parent 3e8e60e commit 0f5ce99
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -484,7 +484,7 @@
" min_available_clients=10, # Wait until all 10 clients are available\n",
")\n",
"\n",
"# Specify the resources each of your clients need. By default, each \n",
"# Specify the resources each of your clients need. By default, each\n",
"# client will be allocated 1x CPU and 0x CPUs\n",
"client_resources = {\"num_cpus\": 1, \"num_gpus\": 0.0}\n",
"if DEVICE.type == \"cuda\":\n",
Expand Down
14 changes: 7 additions & 7 deletions examples/pytorch-from-centralized-to-federated/cifar.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ def apply_transforms(batch):


def train(
net: Net,
trainloader: torch.utils.data.DataLoader,
epochs: int,
device: torch.device, # pylint: disable=no-member
net: Net,
trainloader: torch.utils.data.DataLoader,
epochs: int,
device: torch.device, # pylint: disable=no-member
) -> None:
"""Train the network."""
# Define loss and optimizer
Expand Down Expand Up @@ -110,9 +110,9 @@ def train(


def test(
net: Net,
testloader: torch.utils.data.DataLoader,
device: torch.device, # pylint: disable=no-member
net: Net,
testloader: torch.utils.data.DataLoader,
device: torch.device, # pylint: disable=no-member
) -> Tuple[float, float]:
"""Validate the network on the entire test set."""
# Define loss and metrics
Expand Down
12 changes: 6 additions & 6 deletions examples/pytorch-from-centralized-to-federated/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ class CifarClient(fl.client.NumPyClient):
"""Flower client implementing CIFAR-10 image classification using PyTorch."""

def __init__(
self,
model: cifar.Net,
trainloader: DataLoader,
testloader: DataLoader,
self,
model: cifar.Net,
trainloader: DataLoader,
testloader: DataLoader,
) -> None:
self.model = model
self.trainloader = trainloader
Expand Down Expand Up @@ -61,15 +61,15 @@ def set_parameters(self, parameters: List[np.ndarray]) -> None:
self.model.load_state_dict(state_dict, strict=True)

def fit(
self, parameters: List[np.ndarray], config: Dict[str, str]
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[List[np.ndarray], int, Dict]:
# Set model parameters, train model, return updated model parameters
self.set_parameters(parameters)
cifar.train(self.model, self.trainloader, epochs=1, device=DEVICE)
return self.get_parameters(config={}), len(self.trainloader.dataset), {}

def evaluate(
self, parameters: List[np.ndarray], config: Dict[str, str]
self, parameters: List[np.ndarray], config: Dict[str, str]
) -> Tuple[float, int, Dict]:
# Set model parameters, evaluate model on local test dataset, return result
self.set_parameters(parameters)
Expand Down
2 changes: 1 addition & 1 deletion examples/quickstart-pytorch-lightning/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

disable_progress_bar()


class FlowerClient(fl.client.NumPyClient):
def __init__(self, model, train_loader, val_loader, test_loader):
self.model = model
Expand Down Expand Up @@ -55,7 +56,6 @@ def _set_parameters(model, parameters):


def main() -> None:

parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--node-id",
Expand Down
16 changes: 10 additions & 6 deletions examples/quickstart-pytorch-lightning/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,16 +86,20 @@ def load_data(partition):
# 60 % for the federated train and 20 % for the federated validation (both in fit)
partition_train_valid = partition_full["train"].train_test_split(train_size=0.75)
trainloader = DataLoader(
partition_train_valid["train"], batch_size=32,
shuffle=True, collate_fn=collate_fn, num_workers=1
partition_train_valid["train"],
batch_size=32,
shuffle=True,
collate_fn=collate_fn,
num_workers=1,
)
valloader = DataLoader(
partition_train_valid["test"], batch_size=32,
collate_fn=collate_fn, num_workers=1
partition_train_valid["test"],
batch_size=32,
collate_fn=collate_fn,
num_workers=1,
)
testloader = DataLoader(
partition_full["test"], batch_size=32,
collate_fn=collate_fn, num_workers=1
partition_full["test"], batch_size=32, collate_fn=collate_fn, num_workers=1
)
return trainloader, valloader, testloader

Expand Down
4 changes: 3 additions & 1 deletion examples/quickstart-sklearn-tabular/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,6 @@ def evaluate(self, parameters, config): # type: ignore
return loss, len(X_test), {"test_accuracy": accuracy}

# Start Flower client
fl.client.start_client(server_address="0.0.0.0:8080", client=IrisClient().to_client())
fl.client.start_client(
server_address="0.0.0.0:8080", client=IrisClient().to_client()
)

0 comments on commit 0f5ce99

Please sign in to comment.