Skip to content

Commit

Permalink
Merge branch 'main' into refactor/alexnet_model
Browse files Browse the repository at this point in the history
  • Loading branch information
jafermarq authored Jan 24, 2024
2 parents 953ca47 + a96bc7f commit ab5200a
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 13 deletions.
2 changes: 1 addition & 1 deletion examples/advanced-pytorch/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def main() -> None:
)
parser.add_argument(
"--toy",
action='store_true',
action="store_true",
help="Set to true to quicky run the client using only 10 datasamples. \
Useful for testing purposes. Default: False",
)
Expand Down
2 changes: 1 addition & 1 deletion examples/advanced-pytorch/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def main():
parser = argparse.ArgumentParser(description="Flower")
parser.add_argument(
"--toy",
action='store_true',
action="store_true",
help="Set to true to use only 10 datasamples for validation. \
Useful for testing purposes. Default: False",
)
Expand Down
24 changes: 14 additions & 10 deletions examples/advanced-pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,21 @@ def load_centralized_data():

def apply_transforms(batch):
"""Apply transforms to the partition from FederatedDataset."""
pytorch_transforms = Compose([
Resize(256),
CenterCrop(224),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
pytorch_transforms = Compose(
[
Resize(256),
CenterCrop(224),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
batch["img"] = [pytorch_transforms(img) for img in batch["img"]]
return batch


def train(net, trainloader, valloader, epochs,
device: torch.device = torch.device("cpu")):
def train(
net, trainloader, valloader, epochs, device: torch.device = torch.device("cpu")
):
"""Train the network on the training set."""
print("Starting training...")
net.to(device) # move model to GPU if available
Expand Down Expand Up @@ -72,8 +75,9 @@ def train(net, trainloader, valloader, epochs,
return results


def test(net, testloader, steps: int = None,
device: torch.device = torch.device("cpu")):
def test(
net, testloader, steps: int = None, device: torch.device = torch.device("cpu")
):
"""Validate the network on the entire test set."""
print("Starting evalutation...")
net.to(device) # move model to GPU if available
Expand Down
2 changes: 1 addition & 1 deletion examples/advanced-tensorflow/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def main() -> None:
)
parser.add_argument(
"--toy",
action='store_true',
action="store_true",
help="Set to true to quicky run the client using only 10 datasamples. "
"Useful for testing purposes. Default: False",
)
Expand Down

0 comments on commit ab5200a

Please sign in to comment.