diff --git a/examples/advanced-pytorch/README.md b/examples/advanced-pytorch/README.md index 2527e8e4a820..9101105b2618 100644 --- a/examples/advanced-pytorch/README.md +++ b/examples/advanced-pytorch/README.md @@ -68,4 +68,4 @@ but this can be changed by removing the `--toy` argument in the script. You can The `run.sh` script starts processes in the background so that you don't have to open eleven terminal windows. If you experiment with the code example and something goes wrong, simply using `CTRL + C` on Linux (or `CMD + C` on macOS) wouldn't normally kill all these processes, which is why the script ends with `trap "trap - SIGTERM && kill -- -$$" SIGINT SIGTERM EXIT` and `wait`. This simply allows you to stop the experiment using `CTRL + C` (or `CMD + C`). If you change the script and anything goes wrong you can still use `killall python` (or `killall python3`) to kill all background processes (or a more specific command if you have other Python processes running that you don't want to kill). -You can also manually run `python3 server.py` and `python3 client.py --client-id ` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network). +You can also manually run `python3 server.py` and `python3 client.py --client-id ` for as many clients as you want but you have to make sure that each command is run in a different terminal window (or a different computer on the network). In addition, you can make your clients use either `EfficienNet` (default) or `AlexNet` (but all clients in the experiment should use the same). Switch between models using the `--model` flag when launching `client.py` and `server.py`. diff --git a/examples/advanced-pytorch/client.py b/examples/advanced-pytorch/client.py index b22cbcd70465..d4c8abe3d404 100644 --- a/examples/advanced-pytorch/client.py +++ b/examples/advanced-pytorch/client.py @@ -1,6 +1,5 @@ import utils from torch.utils.data import DataLoader -import torchvision.datasets import torch import flwr as fl import argparse @@ -17,26 +16,31 @@ def __init__( trainset: datasets.Dataset, testset: datasets.Dataset, device: torch.device, + model_str: str, validation_split: int = 0.1, ): self.device = device self.trainset = trainset self.testset = testset self.validation_split = validation_split + if model_str == "alexnet": + self.model = utils.load_alexnet(classes=10) + else: + self.model = utils.load_efficientnet(classes=10) def set_parameters(self, parameters): - """Loads a efficientnet model and replaces it parameters with the ones given.""" - model = utils.load_efficientnet(classes=10) - params_dict = zip(model.state_dict().keys(), parameters) + """Loads a alexnet or efficientnet model and replaces it parameters with the + ones given.""" + + params_dict = zip(self.model.state_dict().keys(), parameters) state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict}) - model.load_state_dict(state_dict, strict=True) - return model + self.model.load_state_dict(state_dict, strict=True) def fit(self, parameters, config): """Train parameters on the locally held training set.""" # Update local model parameters - model = self.set_parameters(parameters) + self.set_parameters(parameters) # Get hyperparameters for this round batch_size: int = config["batch_size"] @@ -49,9 +53,9 @@ def fit(self, parameters, config): train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True) val_loader = DataLoader(valset, batch_size=batch_size) - results = utils.train(model, train_loader, val_loader, epochs, self.device) + results = utils.train(self.model, train_loader, val_loader, epochs, self.device) - parameters_prime = utils.get_model_params(model) + parameters_prime = utils.get_model_params(self.model) num_examples_train = len(trainset) return parameters_prime, num_examples_train, results @@ -59,7 +63,7 @@ def fit(self, parameters, config): def evaluate(self, parameters, config): """Evaluate parameters on the locally held test set.""" # Update local model parameters - model = self.set_parameters(parameters) + self.set_parameters(parameters) # Get config values steps: int = config["val_steps"] @@ -67,7 +71,7 @@ def evaluate(self, parameters, config): # Evaluate global model parameters on the local test data and return results testloader = DataLoader(self.testset, batch_size=16) - loss, accuracy = utils.test(model, testloader, steps, self.device) + loss, accuracy = utils.test(self.model, testloader, steps, self.device) return float(loss), len(self.testset), {"accuracy": float(accuracy)} @@ -110,7 +114,7 @@ def main() -> None: ) parser.add_argument( "--toy", - action='store_true', + action="store_true", help="Set to true to quicky run the client using only 10 datasamples. \ Useful for testing purposes. Default: False", ) @@ -121,6 +125,14 @@ def main() -> None: required=False, help="Set to true to use GPU. Default: False", ) + parser.add_argument( + "--model", + type=str, + default="efficientnet", + choices=["efficientnet", "alexnet"], + help="Use either Efficientnet or Alexnet models. \ + If you want to achieve differential privacy, please use the Alexnet model", + ) args = parser.parse_args() @@ -138,9 +150,8 @@ def main() -> None: trainset = trainset.select(range(10)) testset = testset.select(range(10)) # Start Flower client - client = CifarClient(trainset, testset, device) - - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client) + client = CifarClient(trainset, testset, device, args.model).to_client() + fl.client.start_client(server_address="127.0.0.1:8080", client=client) if __name__ == "__main__": diff --git a/examples/advanced-pytorch/run.sh b/examples/advanced-pytorch/run.sh index 3367e1680535..c3d52491b987 100755 --- a/examples/advanced-pytorch/run.sh +++ b/examples/advanced-pytorch/run.sh @@ -2,11 +2,6 @@ set -e cd "$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"/ -# Download the EfficientNetB0 model -python -c "import torch; torch.hub.load( \ - 'NVIDIA/DeepLearningExamples:torchhub', \ - 'nvidia_efficientnet_b0', pretrained=True)" - python server.py --toy & sleep 10 # Sleep for 10s to give the server enough time to start and dowload the dataset diff --git a/examples/advanced-pytorch/server.py b/examples/advanced-pytorch/server.py index fda49b71a311..489694ab1ea1 100644 --- a/examples/advanced-pytorch/server.py +++ b/examples/advanced-pytorch/server.py @@ -76,21 +76,32 @@ def main(): parser = argparse.ArgumentParser(description="Flower") parser.add_argument( "--toy", - action='store_true', + action="store_true", help="Set to true to use only 10 datasamples for validation. \ Useful for testing purposes. Default: False", ) + parser.add_argument( + "--model", + type=str, + default="efficientnet", + choices=["efficientnet", "alexnet"], + help="Use either Efficientnet or Alexnet models. \ + If you want to achieve differential privacy, please use the Alexnet model", + ) args = parser.parse_args() - model = utils.load_efficientnet(classes=10) + if args.model == "alexnet": + model = utils.load_alexnet(classes=10) + else: + model = utils.load_efficientnet(classes=10) model_parameters = [val.cpu().numpy() for _, val in model.state_dict().items()] # Create strategy strategy = fl.server.strategy.FedAvg( - fraction_fit=0.2, - fraction_evaluate=0.2, + fraction_fit=1.0, + fraction_evaluate=1.0, min_fit_clients=2, min_evaluate_clients=2, min_available_clients=10, diff --git a/examples/advanced-pytorch/utils.py b/examples/advanced-pytorch/utils.py index 6512010b1f23..186f079010dc 100644 --- a/examples/advanced-pytorch/utils.py +++ b/examples/advanced-pytorch/utils.py @@ -1,11 +1,11 @@ import torch from torchvision.transforms import Compose, ToTensor, Normalize, Resize, CenterCrop -from torch.utils.data import DataLoader - +from torchvision.models import efficientnet_b0, AlexNet import warnings from flwr_datasets import FederatedDataset + warnings.filterwarnings("ignore") @@ -28,24 +28,27 @@ def load_centralized_data(): def apply_transforms(batch): """Apply transforms to the partition from FederatedDataset.""" - pytorch_transforms = Compose([ - Resize(256), - CenterCrop(224), - ToTensor(), - Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), - ]) + pytorch_transforms = Compose( + [ + Resize(256), + CenterCrop(224), + ToTensor(), + Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) batch["img"] = [pytorch_transforms(img) for img in batch["img"]] return batch -def train(net, trainloader, valloader, epochs, - device: torch.device = torch.device("cpu")): +def train( + net, trainloader, valloader, epochs, device: torch.device = torch.device("cpu") +): """Train the network on the training set.""" print("Starting training...") net.to(device) # move model to GPU if available criterion = torch.nn.CrossEntropyLoss().to(device) optimizer = torch.optim.SGD( - net.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4 + net.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4 ) net.train() for _ in range(epochs): @@ -71,8 +74,9 @@ def train(net, trainloader, valloader, epochs, return results -def test(net, testloader, steps: int = None, - device: torch.device = torch.device("cpu")): +def test( + net, testloader, steps: int = None, device: torch.device = torch.device("cpu") +): """Validate the network on the entire test set.""" print("Starting evalutation...") net.to(device) # move model to GPU if available @@ -94,38 +98,21 @@ def test(net, testloader, steps: int = None, return loss, accuracy -def replace_classifying_layer(efficientnet_model, num_classes: int = 10): - """Replaces the final layer of the classifier.""" - num_features = efficientnet_model.classifier.fc.in_features - efficientnet_model.classifier.fc = torch.nn.Linear(num_features, num_classes) - - -def load_efficientnet(entrypoint: str = "nvidia_efficientnet_b0", classes: int = None): - """Loads pretrained efficientnet model from torch hub. Replaces final classifying - layer if classes is specified. - - Args: - entrypoint: EfficientNet model to download. - For supported entrypoints, please refer - https://pytorch.org/hub/nvidia_deeplearningexamples_efficientnet/ - classes: Number of classes in final classifying layer. Leave as None to get - the downloaded - model untouched. - Returns: - EfficientNet Model - - Note: One alternative implementation can be found at - https://github.com/lukemelas/EfficientNet-PyTorch - """ - efficientnet = torch.hub.load( - "NVIDIA/DeepLearningExamples:torchhub", entrypoint, pretrained=True - ) - - if classes is not None: - replace_classifying_layer(efficientnet, classes) +def load_efficientnet(classes: int = 10): + """Loads EfficienNetB0 from TorchVision.""" + efficientnet = efficientnet_b0(pretrained=True) + # Re-init output linear layer with the right number of classes + model_classes = efficientnet.classifier[1].in_features + if classes != model_classes: + efficientnet.classifier[1] = torch.nn.Linear(model_classes, classes) return efficientnet def get_model_params(model): """Returns a model's parameters.""" return [val.cpu().numpy() for _, val in model.state_dict().items()] + + +def load_alexnet(classes): + """Load AlexNet model from TorchVision.""" + return AlexNet(num_classes=classes) diff --git a/examples/advanced-tensorflow/client.py b/examples/advanced-tensorflow/client.py index 033f20b1b027..17d1d2306270 100644 --- a/examples/advanced-tensorflow/client.py +++ b/examples/advanced-tensorflow/client.py @@ -86,7 +86,7 @@ def main() -> None: ) parser.add_argument( "--toy", - action='store_true', + action="store_true", help="Set to true to quicky run the client using only 10 datasamples. " "Useful for testing purposes. Default: False", ) @@ -106,9 +106,9 @@ def main() -> None: x_test, y_test = x_test[:10], y_test[:10] # Start Flower client - client = CifarClient(model, x_train, y_train, x_test, y_test) + client = CifarClient(model, x_train, y_train, x_test, y_test).to_client() - fl.client.start_numpy_client( + fl.client.start_client( server_address="127.0.0.1:8080", client=client, root_certificates=Path(".cache/certificates/ca.crt").read_bytes(), diff --git a/examples/custom-metrics/client.py b/examples/custom-metrics/client.py index b2206118ed44..09e786a0cfac 100644 --- a/examples/custom-metrics/client.py +++ b/examples/custom-metrics/client.py @@ -68,4 +68,4 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) +fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/examples/embedded-devices/client_pytorch.py b/examples/embedded-devices/client_pytorch.py index 5d236c9e9389..8e72d668531a 100644 --- a/examples/embedded-devices/client_pytorch.py +++ b/examples/embedded-devices/client_pytorch.py @@ -204,11 +204,11 @@ def main(): trainsets, valsets, _ = prepare_dataset(use_mnist) # Start Flower client setting its associated data partition - fl.client.start_numpy_client( + fl.client.start_client( server_address=args.server_address, client=FlowerClient( trainset=trainsets[args.cid], valset=valsets[args.cid], use_mnist=use_mnist - ), + ).to_client(), ) diff --git a/examples/embedded-devices/client_tf.py b/examples/embedded-devices/client_tf.py index 3457af1c7a66..5ea34e44a98b 100644 --- a/examples/embedded-devices/client_tf.py +++ b/examples/embedded-devices/client_tf.py @@ -123,9 +123,9 @@ def main(): trainset, valset = partitions[args.cid] # Start Flower client setting its associated data partition - fl.client.start_numpy_client( + fl.client.start_client( server_address=args.server_address, - client=FlowerClient(trainset=trainset, valset=valset, use_mnist=use_mnist), + client=FlowerClient(trainset=trainset, valset=valset, use_mnist=use_mnist).to_client(), ) diff --git a/examples/flower-in-30-minutes/tutorial.ipynb b/examples/flower-in-30-minutes/tutorial.ipynb index 336ec4c19644..8f9eccf65b74 100644 --- a/examples/flower-in-30-minutes/tutorial.ipynb +++ b/examples/flower-in-30-minutes/tutorial.ipynb @@ -776,7 +776,7 @@ "\n", " return FlowerClient(\n", " trainloader=trainloaders[int(cid)], vallodaer=valloaders[int(cid)]\n", - " )\n", + " ).to_client()\n", "\n", " return client_fn\n", "\n", diff --git a/examples/mt-pytorch/client.py b/examples/mt-pytorch/client.py index 23cc736fd62b..295770a802bd 100644 --- a/examples/mt-pytorch/client.py +++ b/examples/mt-pytorch/client.py @@ -35,8 +35,8 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_numpy_client( +fl.client.start_client( server_address="0.0.0.0:9092", # "0.0.0.0:9093" for REST - client=FlowerClient(), + client=FlowerClient().to_client(), transport="grpc-rere", # "rest" for REST ) diff --git a/examples/opacus/dp_cifar_client.py b/examples/opacus/dp_cifar_client.py index bab1451ba707..cc30e7728222 100644 --- a/examples/opacus/dp_cifar_client.py +++ b/examples/opacus/dp_cifar_client.py @@ -28,7 +28,7 @@ def load_data(): model = Net() trainloader, testloader, sample_rate = load_data() -fl.client.start_numpy_client( +fl.client.start_client( server_address="127.0.0.1:8080", - client=DPCifarClient(model, trainloader, testloader), + client=DPCifarClient(model, trainloader, testloader).to_client(), ) diff --git a/examples/opacus/dp_cifar_simulation.py b/examples/opacus/dp_cifar_simulation.py index 14a9d037685b..d957caf8785c 100644 --- a/examples/opacus/dp_cifar_simulation.py +++ b/examples/opacus/dp_cifar_simulation.py @@ -1,14 +1,14 @@ import math from collections import OrderedDict -from typing import Callable, Optional, Tuple +from typing import Callable, Dict, Optional, Tuple import flwr as fl import numpy as np import torch import torchvision.transforms as transforms -from opacus.dp_model_inspector import DPModelInspector from torch.utils.data import DataLoader from torchvision.datasets import CIFAR10 +from flwr.common.typing import Scalar from dp_cifar_main import DEVICE, PARAMS, DPCifarClient, Net, test @@ -23,8 +23,6 @@ def client_fn(cid: str) -> fl.client.Client: # Load model. model = Net() # Check model is compatible with Opacus. - # inspector = DPModelInspector() - # print(f"Is the model valid? {inspector.validate(model)}") # Load data partition (divide CIFAR10 into NUM_CLIENTS distinct partitions, using 30% for validation). transform = transforms.Compose( @@ -45,12 +43,14 @@ def client_fn(cid: str) -> fl.client.Client: client_trainloader = DataLoader(client_trainset, PARAMS["batch_size"]) client_testloader = DataLoader(client_testset, PARAMS["batch_size"]) - return DPCifarClient(model, client_trainloader, client_testloader) + return DPCifarClient(model, client_trainloader, client_testloader).to_client() # Define an evaluation function for centralized evaluation (using whole CIFAR10 testset). def get_evaluate_fn() -> Callable[[fl.common.NDArrays], Optional[Tuple[float, float]]]: - def evaluate(weights: fl.common.NDArrays) -> Optional[Tuple[float, float]]: + def evaluate( + server_round: int, parameters: fl.common.NDArrays, config: Dict[str, Scalar] + ): transform = transforms.Compose( [ transforms.ToTensor(), @@ -63,7 +63,7 @@ def evaluate(weights: fl.common.NDArrays) -> Optional[Tuple[float, float]]: state_dict = OrderedDict( { k: torch.tensor(np.atleast_1d(v)) - for k, v in zip(model.state_dict().keys(), weights) + for k, v in zip(model.state_dict().keys(), parameters) } ) model.load_state_dict(state_dict, strict=True) @@ -82,7 +82,7 @@ def main() -> None: client_fn=client_fn, num_clients=NUM_CLIENTS, client_resources={"num_cpus": 1}, - num_rounds=3, + config=fl.server.ServerConfig(num_rounds=3), strategy=fl.server.strategy.FedAvg( fraction_fit=0.1, fraction_evaluate=0.1, evaluate_fn=get_evaluate_fn() ), diff --git a/examples/pytorch-federated-variational-autoencoder/client.py b/examples/pytorch-federated-variational-autoencoder/client.py index ceb55c79f564..65a86bcc2184 100644 --- a/examples/pytorch-federated-variational-autoencoder/client.py +++ b/examples/pytorch-federated-variational-autoencoder/client.py @@ -93,7 +93,7 @@ def evaluate(self, parameters, config): loss = test(net, testloader) return float(loss), len(testloader), {} - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=CifarClient()) + fl.client.start_client(server_address="127.0.0.1:8080", client=CifarClient().to_client()) if __name__ == "__main__": diff --git a/examples/pytorch-from-centralized-to-federated/client.py b/examples/pytorch-from-centralized-to-federated/client.py index 61c7e7f762b3..f89e03bc2053 100644 --- a/examples/pytorch-from-centralized-to-federated/client.py +++ b/examples/pytorch-from-centralized-to-federated/client.py @@ -93,8 +93,8 @@ def main() -> None: _ = model(next(iter(trainloader))["img"].to(DEVICE)) # Start client - client = CifarClient(model, trainloader, testloader) - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client) + client = CifarClient(model, trainloader, testloader).to_client() + fl.client.start_client(server_address="127.0.0.1:8080", client=client) if __name__ == "__main__": diff --git a/examples/quickstart-fastai/client.py b/examples/quickstart-fastai/client.py index a88abbe525dc..6bb2a751d544 100644 --- a/examples/quickstart-fastai/client.py +++ b/examples/quickstart-fastai/client.py @@ -43,7 +43,7 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_numpy_client( +fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(), + client=FlowerClient().to_client(), ) diff --git a/examples/quickstart-huggingface/client.py b/examples/quickstart-huggingface/client.py index 5fa10b9ca0f2..0bfc81342972 100644 --- a/examples/quickstart-huggingface/client.py +++ b/examples/quickstart-huggingface/client.py @@ -108,7 +108,7 @@ def evaluate(self, parameters, config): return float(loss), len(testloader), {"accuracy": float(accuracy)} # Start client - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=IMDBClient()) + fl.client.start_client(server_address="127.0.0.1:8080", client=IMDBClient().to_client()) if __name__ == "__main__": @@ -119,7 +119,7 @@ def evaluate(self, parameters, config): required=True, type=int, help="Partition of the dataset divided into 1,000 iid partitions created " - "artificially.", + "artificially.", ) node_id = parser.parse_args().node_id main(node_id) diff --git a/examples/quickstart-jax/client.py b/examples/quickstart-jax/client.py index f9b056276deb..0cf74e2d2c05 100644 --- a/examples/quickstart-jax/client.py +++ b/examples/quickstart-jax/client.py @@ -52,4 +52,4 @@ def evaluate( # Start Flower client -fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=FlowerClient()) +fl.client.start_client(server_address="127.0.0.1:8080", client=FlowerClient().to_client()) diff --git a/examples/quickstart-mlcube/client.py b/examples/quickstart-mlcube/client.py index 0a1962d8da8a..0470720bc296 100644 --- a/examples/quickstart-mlcube/client.py +++ b/examples/quickstart-mlcube/client.py @@ -43,8 +43,8 @@ def main(): os.path.dirname(os.path.abspath(__file__)), "workspaces", workspace_name ) - fl.client.start_numpy_client( - server_address="0.0.0.0:8080", client=MLCubeClient(workspace=workspace) + fl.client.start_client( + server_address="0.0.0.0:8080", client=MLCubeClient(workspace=workspace).to_client() ) diff --git a/examples/quickstart-pandas/client.py b/examples/quickstart-pandas/client.py index c2f2605594d5..8585922e4572 100644 --- a/examples/quickstart-pandas/client.py +++ b/examples/quickstart-pandas/client.py @@ -59,7 +59,7 @@ def fit( X = dataset[column_names] # Start Flower client - fl.client.start_numpy_client( + fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(X), + client=FlowerClient(X).to_client(), ) diff --git a/examples/quickstart-pytorch-lightning/client.py b/examples/quickstart-pytorch-lightning/client.py index 1dabd5732b9b..fc5f1ee03cfe 100644 --- a/examples/quickstart-pytorch-lightning/client.py +++ b/examples/quickstart-pytorch-lightning/client.py @@ -72,8 +72,8 @@ def main() -> None: train_loader, val_loader, test_loader = mnist.load_data(node_id) # Flower client - client = FlowerClient(model, train_loader, val_loader, test_loader) - fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=client) + client = FlowerClient(model, train_loader, val_loader, test_loader).to_client() + fl.client.start_client(server_address="127.0.0.1:8080", client=client) if __name__ == "__main__": diff --git a/examples/quickstart-pytorch/client.py b/examples/quickstart-pytorch/client.py index 1edb42d1ec81..b5ea4c94dd21 100644 --- a/examples/quickstart-pytorch/client.py +++ b/examples/quickstart-pytorch/client.py @@ -132,7 +132,7 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_numpy_client( +fl.client.start_client( server_address="127.0.0.1:8080", - client=FlowerClient(), + client=FlowerClient().to_client(), ) diff --git a/examples/quickstart-tabnet/client.py b/examples/quickstart-tabnet/client.py index da391a95710a..53913b2a2a09 100644 --- a/examples/quickstart-tabnet/client.py +++ b/examples/quickstart-tabnet/client.py @@ -79,4 +79,4 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=TabNetClient()) +fl.client.start_client(server_address="127.0.0.1:8080", client=TabNetClient().to_client()) diff --git a/examples/quickstart-tensorflow/client.py b/examples/quickstart-tensorflow/client.py index d998adbdd899..43121f062c45 100644 --- a/examples/quickstart-tensorflow/client.py +++ b/examples/quickstart-tensorflow/client.py @@ -52,4 +52,4 @@ def evaluate(self, parameters, config): # Start Flower client -fl.client.start_numpy_client(server_address="127.0.0.1:8080", client=CifarClient()) +fl.client.start_client(server_address="127.0.0.1:8080", client=CifarClient().to_client()) diff --git a/examples/simulation-pytorch/sim.ipynb b/examples/simulation-pytorch/sim.ipynb index 508630cf9422..d1e7358566cc 100644 --- a/examples/simulation-pytorch/sim.ipynb +++ b/examples/simulation-pytorch/sim.ipynb @@ -509,7 +509,7 @@ " valloader = DataLoader(valset.with_transform(apply_transforms), batch_size=32)\n", "\n", " # Create and return client\n", - " return FlowerClient(trainloader, valloader)\n", + " return FlowerClient(trainloader, valloader).to_client()\n", "\n", " return client_fn\n", "\n", diff --git a/examples/simulation-pytorch/sim.py b/examples/simulation-pytorch/sim.py index 68d9426e83ab..0a6ed8ebb9b8 100644 --- a/examples/simulation-pytorch/sim.py +++ b/examples/simulation-pytorch/sim.py @@ -104,7 +104,7 @@ def client_fn(cid: str) -> fl.client.Client: valset = valset.with_transform(apply_transforms) # Create and return client - return FlowerClient(trainset, valset) + return FlowerClient(trainset, valset).to_client() return client_fn diff --git a/examples/simulation-tensorflow/sim.ipynb b/examples/simulation-tensorflow/sim.ipynb index 575b437018f3..5ef1992bcc7e 100644 --- a/examples/simulation-tensorflow/sim.ipynb +++ b/examples/simulation-tensorflow/sim.ipynb @@ -189,7 +189,7 @@ " )\n", "\n", " # Create and return client\n", - " return FlowerClient(trainset, valset)\n", + " return FlowerClient(trainset, valset).to_client()\n", "\n", " return client_fn\n", "\n", diff --git a/examples/simulation-tensorflow/sim.py b/examples/simulation-tensorflow/sim.py index 490e25fe8c8d..043c624a40a9 100644 --- a/examples/simulation-tensorflow/sim.py +++ b/examples/simulation-tensorflow/sim.py @@ -94,7 +94,7 @@ def client_fn(cid: str) -> fl.client.Client: ) # Create and return client - return FlowerClient(trainset, valset) + return FlowerClient(trainset, valset).to_client() return client_fn diff --git a/examples/sklearn-logreg-mnist/client.py b/examples/sklearn-logreg-mnist/client.py index a5fcaba87409..c5a312b41e61 100644 --- a/examples/sklearn-logreg-mnist/client.py +++ b/examples/sklearn-logreg-mnist/client.py @@ -62,4 +62,4 @@ def evaluate(self, parameters, config): # type: ignore return loss, len(X_test), {"accuracy": accuracy} # Start Flower client - fl.client.start_numpy_client(server_address="0.0.0.0:8080", client=MnistClient()) + fl.client.start_client(server_address="0.0.0.0:8080", client=MnistClient().to_client()) diff --git a/examples/vertical-fl/README.md b/examples/vertical-fl/README.md index d5ab0ab9c30d..78588180d3d6 100644 --- a/examples/vertical-fl/README.md +++ b/examples/vertical-fl/README.md @@ -295,7 +295,7 @@ class ServerModel(nn.Module): It comprises a single linear layer that accepts the concatenated outputs from all client models as its input. The number of inputs to this layer equals the -total number of outputs from the client models ( $3 \times 4 = 12$ ). After processing +total number of outputs from the client models (3 x 4 = 12). After processing these inputs, the linear layer's output is passed through a sigmoid activation function (`nn.Sigmoid()`), which maps the result to a `(0, 1)` range, providing a probability score indicative of the likelihood of survival. diff --git a/examples/whisper-federated-finetuning/client.py b/examples/whisper-federated-finetuning/client.py index 2bfeadfbdae6..07abc641f8fe 100644 --- a/examples/whisper-federated-finetuning/client.py +++ b/examples/whisper-federated-finetuning/client.py @@ -146,7 +146,7 @@ def client_fn(cid: str): return WhisperFlowerClient( full_train_dataset, num_classes, disable_tqdm, compile - ) + ).to_client() return client_fn @@ -174,7 +174,7 @@ def run_client(): client_data_path=CLIENT_DATA, ) - fl.client.start_numpy_client( + fl.client.start_client( server_address=f"{args.server_address}:8080", client=client_fn(args.cid) ) diff --git a/examples/xgboost-quickstart/client.py b/examples/xgboost-quickstart/client.py index b5eab59ba14d..62e8a441bae1 100644 --- a/examples/xgboost-quickstart/client.py +++ b/examples/xgboost-quickstart/client.py @@ -173,4 +173,4 @@ def evaluate(self, ins: EvaluateIns) -> EvaluateRes: # Start Flower client -fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient()) +fl.client.start_client(server_address="127.0.0.1:8080", client=XgbClient().to_client()) diff --git a/src/proto/flwr/proto/recordset.proto b/src/proto/flwr/proto/recordset.proto index 8e2e5d60b6db..d51d0f9ce416 100644 --- a/src/proto/flwr/proto/recordset.proto +++ b/src/proto/flwr/proto/recordset.proto @@ -68,3 +68,9 @@ message ParametersRecord { message MetricsRecord { map data = 1; } message ConfigsRecord { map data = 1; } + +message RecordSet { + map parameters = 1; + map metrics = 2; + map configs = 3; +} diff --git a/src/proto/flwr/proto/task.proto b/src/proto/flwr/proto/task.proto index 20dd5a3aa6c8..2cde16143d8d 100644 --- a/src/proto/flwr/proto/task.proto +++ b/src/proto/flwr/proto/task.proto @@ -29,10 +29,11 @@ message Task { string ttl = 5; repeated string ancestry = 6; string task_type = 7; - SecureAggregation sa = 8; + RecordSet recordset = 8; ServerMessage legacy_server_message = 101 [ deprecated = true ]; ClientMessage legacy_client_message = 102 [ deprecated = true ]; + SecureAggregation sa = 103 [ deprecated = true ]; } message TaskIns { diff --git a/src/py/flwr/common/context.py b/src/py/flwr/common/context.py new file mode 100644 index 000000000000..ad859d4c8407 --- /dev/null +++ b/src/py/flwr/common/context.py @@ -0,0 +1,38 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Context.""" + + +from dataclasses import dataclass + +from .recordset import RecordSet + + +@dataclass +class Context: + """State of your run. + + Parameters + ---------- + state : RecordSet + Holds records added by the entity in a given run and that will stay local. + This means that the data it holds will never leave the system it's running from. + This can be used as an intermediate storage or scratchpad when + executing middleware layers. It can also be used as a memory to access + at different points during the lifecycle of this entity (e.g. across + multiple rounds) + """ + + state: RecordSet diff --git a/src/py/flwr/common/flowercontext.py b/src/py/flwr/common/message.py similarity index 63% rename from src/py/flwr/common/flowercontext.py rename to src/py/flwr/common/message.py index 6e26d93bfe9a..f693d8e27bc3 100644 --- a/src/py/flwr/common/flowercontext.py +++ b/src/py/flwr/common/message.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""FlowerContext and Metadata.""" +"""Message.""" from dataclasses import dataclass @@ -48,30 +48,17 @@ class Metadata: @dataclass -class FlowerContext: +class Message: """State of your application from the viewpoint of the entity using it. Parameters ---------- - in_message : RecordSet - Holds records sent by another entity (e.g. sent by the server-side - logic to a client, or vice-versa) - out_message : RecordSet - Holds records added by the current entity. This `RecordSet` will - be sent out (e.g. back to the server-side for aggregation of - parameter, or to the client to perform a certain task) - local : RecordSet - Holds record added by the current entity and that will stay local. - This means that the data it holds will never leave the system it's running from. - This can be used as an intermediate storage or scratchpad when - executing middleware layers. It can also be used as a memory to access - at different points during the lifecycle of this entity (e.g. across - multiple rounds) metadata : Metadata A dataclass including information about the task to be executed. + message : RecordSet + Holds records either sent by another entity (e.g. sent by the server-side + logic to a client, or vice-versa) or that will be sent to it. """ - in_message: RecordSet - out_message: RecordSet - local: RecordSet metadata: Metadata + message: RecordSet diff --git a/src/py/flwr/common/recordset_compat.py b/src/py/flwr/common/recordset_compat.py new file mode 100644 index 000000000000..c45f7fcd9fb8 --- /dev/null +++ b/src/py/flwr/common/recordset_compat.py @@ -0,0 +1,401 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet utilities.""" + + +from typing import Dict, Mapping, OrderedDict, Tuple, Union, cast, get_args + +from .configsrecord import ConfigsRecord +from .metricsrecord import MetricsRecord +from .parametersrecord import Array, ParametersRecord +from .recordset import RecordSet +from .typing import ( + Code, + ConfigsRecordValues, + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + GetParametersIns, + GetParametersRes, + GetPropertiesIns, + GetPropertiesRes, + MetricsRecordValues, + Parameters, + Scalar, + Status, +) + + +def parametersrecord_to_parameters( + record: ParametersRecord, keep_input: bool = False +) -> Parameters: + """Convert ParameterRecord to legacy Parameters. + + Warning: Because `Arrays` in `ParametersRecord` encode more information of the + array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it + might not be possible to reconstruct such data structures from `Parameters` objects + alone. Additional information or metadta must be provided from elsewhere. + + Parameters + ---------- + record : ParametersRecord + The record to be conveted into Parameters. + keep_input : bool (default: False) + A boolean indicating whether entries in the record should be deleted from the + input dictionary immediately after adding them to the record. + """ + parameters = Parameters(tensors=[], tensor_type="") + + for key in list(record.data.keys()): + parameters.tensors.append(record[key].data) + + if not parameters.tensor_type: + # Setting from first array in record. Recall the warning in the docstrings + # of this function. + parameters.tensor_type = record[key].stype + + if not keep_input: + del record.data[key] + + return parameters + + +def parameters_to_parametersrecord( + parameters: Parameters, keep_input: bool = False +) -> ParametersRecord: + """Convert legacy Parameters into a single ParametersRecord. + + Because there is no concept of names in the legacy Parameters, arbitrary keys will + be used when constructing the ParametersRecord. Similarly, the shape and data type + won't be recorded in the Array objects. + + Parameters + ---------- + parameters : Parameters + Parameters object to be represented as a ParametersRecord. + keep_input : bool (default: False) + A boolean indicating whether parameters should be deleted from the input + Parameters object (i.e. a list of serialized NumPy arrays) immediately after + adding them to the record. + """ + tensor_type = parameters.tensor_type + + p_record = ParametersRecord() + + num_arrays = len(parameters.tensors) + for idx in range(num_arrays): + if keep_input: + tensor = parameters.tensors[idx] + else: + tensor = parameters.tensors.pop(0) + p_record.set_parameters( + OrderedDict( + {str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])} + ) + ) + + return p_record + + +def _check_mapping_from_recordscalartype_to_scalar( + record_data: Mapping[str, Union[ConfigsRecordValues, MetricsRecordValues]] +) -> Dict[str, Scalar]: + """Check mapping `common.*RecordValues` into `common.Scalar` is possible.""" + for value in record_data.values(): + if not isinstance(value, get_args(Scalar)): + raise TypeError( + "There is not a 1:1 mapping between `common.Scalar` types and those " + "supported in `common.ConfigsRecordValues` or " + "`common.ConfigsRecordValues`. Consider casting your values to a type " + "supported by the `common.RecordSet` infrastructure. " + f"You used type: {type(value)}" + ) + return cast(Dict[str, Scalar], record_data) + + +def _recordset_to_fit_or_evaluate_ins_components( + recordset: RecordSet, + ins_str: str, + keep_input: bool, +) -> Tuple[Parameters, Dict[str, Scalar]]: + """Derive Fit/Evaluate Ins from a RecordSet.""" + # get Array and construct Parameters + parameters_record = recordset.get_parameters(f"{ins_str}.parameters") + + parameters = parametersrecord_to_parameters( + parameters_record, keep_input=keep_input + ) + + # get config dict + config_record = recordset.get_configs(f"{ins_str}.config") + + config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data) + + return parameters, config_dict + + +def _fit_or_evaluate_ins_to_recordset( + ins: Union[FitIns, EvaluateIns], keep_input: bool +) -> RecordSet: + recordset = RecordSet() + + ins_str = "fitins" if isinstance(ins, FitIns) else "evaluateins" + recordset.set_parameters( + name=f"{ins_str}.parameters", + record=parameters_to_parametersrecord(ins.parameters, keep_input=keep_input), + ) + + recordset.set_configs( + name=f"{ins_str}.config", record=ConfigsRecord(ins.config) # type: ignore + ) + + return recordset + + +def _embed_status_into_recordset( + res_str: str, status: Status, recordset: RecordSet +) -> RecordSet: + status_dict: Dict[str, ConfigsRecordValues] = { + "code": int(status.code.value), + "message": status.message, + } + # we add it to a `ConfigsRecord`` because the `status.message`` is a string + # and `str` values aren't supported in `MetricsRecords` + recordset.set_configs(f"{res_str}.status", record=ConfigsRecord(status_dict)) + return recordset + + +def _extract_status_from_recordset(res_str: str, recordset: RecordSet) -> Status: + status = recordset.get_configs(f"{res_str}.status") + code = cast(int, status["code"]) + return Status(code=Code(code), message=str(status["message"])) + + +def recordset_to_fitins(recordset: RecordSet, keep_input: bool) -> FitIns: + """Derive FitIns from a RecordSet object.""" + parameters, config = _recordset_to_fit_or_evaluate_ins_components( + recordset, + ins_str="fitins", + keep_input=keep_input, + ) + + return FitIns(parameters=parameters, config=config) + + +def fitins_to_recordset(fitins: FitIns, keep_input: bool) -> RecordSet: + """Construct a RecordSet from a FitIns object.""" + return _fit_or_evaluate_ins_to_recordset(fitins, keep_input) + + +def recordset_to_fitres(recordset: RecordSet, keep_input: bool) -> FitRes: + """Derive FitRes from a RecordSet object.""" + ins_str = "fitres" + parameters = parametersrecord_to_parameters( + recordset.get_parameters(f"{ins_str}.parameters"), keep_input=keep_input + ) + + num_examples = cast( + int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"] + ) + configs_record = recordset.get_configs(f"{ins_str}.metrics") + + metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data) + status = _extract_status_from_recordset(ins_str, recordset) + + return FitRes( + status=status, parameters=parameters, num_examples=num_examples, metrics=metrics + ) + + +def fitres_to_recordset(fitres: FitRes, keep_input: bool) -> RecordSet: + """Construct a RecordSet from a FitRes object.""" + recordset = RecordSet() + + res_str = "fitres" + + recordset.set_configs( + name=f"{res_str}.metrics", record=ConfigsRecord(fitres.metrics) # type: ignore + ) + recordset.set_metrics( + name=f"{res_str}.num_examples", + record=MetricsRecord({"num_examples": fitres.num_examples}), + ) + recordset.set_parameters( + name=f"{res_str}.parameters", + record=parameters_to_parametersrecord(fitres.parameters, keep_input), + ) + + # status + recordset = _embed_status_into_recordset(res_str, fitres.status, recordset) + + return recordset + + +def recordset_to_evaluateins(recordset: RecordSet, keep_input: bool) -> EvaluateIns: + """Derive EvaluateIns from a RecordSet object.""" + parameters, config = _recordset_to_fit_or_evaluate_ins_components( + recordset, + ins_str="evaluateins", + keep_input=keep_input, + ) + + return EvaluateIns(parameters=parameters, config=config) + + +def evaluateins_to_recordset(evaluateins: EvaluateIns, keep_input: bool) -> RecordSet: + """Construct a RecordSet from a EvaluateIns object.""" + return _fit_or_evaluate_ins_to_recordset(evaluateins, keep_input) + + +def recordset_to_evaluateres(recordset: RecordSet) -> EvaluateRes: + """Derive EvaluateRes from a RecordSet object.""" + ins_str = "evaluateres" + + loss = cast(int, recordset.get_metrics(f"{ins_str}.loss")["loss"]) + + num_examples = cast( + int, recordset.get_metrics(f"{ins_str}.num_examples")["num_examples"] + ) + configs_record = recordset.get_configs(f"{ins_str}.metrics") + + metrics = _check_mapping_from_recordscalartype_to_scalar(configs_record.data) + status = _extract_status_from_recordset(ins_str, recordset) + + return EvaluateRes( + status=status, loss=loss, num_examples=num_examples, metrics=metrics + ) + + +def evaluateres_to_recordset(evaluateres: EvaluateRes) -> RecordSet: + """Construct a RecordSet from a EvaluateRes object.""" + recordset = RecordSet() + + res_str = "evaluateres" + # loss + recordset.set_metrics( + name=f"{res_str}.loss", + record=MetricsRecord({"loss": evaluateres.loss}), + ) + + # num_examples + recordset.set_metrics( + name=f"{res_str}.num_examples", + record=MetricsRecord({"num_examples": evaluateres.num_examples}), + ) + + # metrics + recordset.set_configs( + name=f"{res_str}.metrics", + record=ConfigsRecord(evaluateres.metrics), # type: ignore + ) + + # status + recordset = _embed_status_into_recordset( + f"{res_str}", evaluateres.status, recordset + ) + + return recordset + + +def recordset_to_getparametersins(recordset: RecordSet) -> GetParametersIns: + """Derive GetParametersIns from a RecordSet object.""" + config_record = recordset.get_configs("getparametersins.config") + + config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data) + + return GetParametersIns(config=config_dict) + + +def getparametersins_to_recordset(getparameters_ins: GetParametersIns) -> RecordSet: + """Construct a RecordSet from a GetParametersIns object.""" + recordset = RecordSet() + + recordset.set_configs( + name="getparametersins.config", + record=ConfigsRecord(getparameters_ins.config), # type: ignore + ) + return recordset + + +def getparametersres_to_recordset(getparametersres: GetParametersRes) -> RecordSet: + """Construct a RecordSet from a GetParametersRes object.""" + recordset = RecordSet() + res_str = "getparametersres" + parameters_record = parameters_to_parametersrecord(getparametersres.parameters) + recordset.set_parameters(f"{res_str}.parameters", parameters_record) + + # status + recordset = _embed_status_into_recordset( + res_str, getparametersres.status, recordset + ) + + return recordset + + +def recordset_to_getparametersres(recordset: RecordSet) -> GetParametersRes: + """Derive GetParametersRes from a RecordSet object.""" + res_str = "getparametersres" + parameters = parametersrecord_to_parameters( + recordset.get_parameters(f"{res_str}.parameters") + ) + + status = _extract_status_from_recordset(res_str, recordset) + return GetParametersRes(status=status, parameters=parameters) + + +def recordset_to_getpropertiesins(recordset: RecordSet) -> GetPropertiesIns: + """Derive GetPropertiesIns from a RecordSet object.""" + config_record = recordset.get_configs("getpropertiesins.config") + config_dict = _check_mapping_from_recordscalartype_to_scalar(config_record.data) + + return GetPropertiesIns(config=config_dict) + + +def getpropertiesins_to_recordset(getpropertiesins: GetPropertiesIns) -> RecordSet: + """Construct a RecordSet from a GetPropertiesRes object.""" + recordset = RecordSet() + recordset.set_configs( + name="getpropertiesins.config", + record=ConfigsRecord(getpropertiesins.config), # type: ignore + ) + return recordset + + +def recordset_to_getpropertiesres(recordset: RecordSet) -> GetPropertiesRes: + """Derive GetPropertiesRes from a RecordSet object.""" + res_str = "getpropertiesres" + config_record = recordset.get_configs(f"{res_str}.properties") + properties = _check_mapping_from_recordscalartype_to_scalar(config_record.data) + + status = _extract_status_from_recordset(res_str, recordset=recordset) + + return GetPropertiesRes(status=status, properties=properties) + + +def getpropertiesres_to_recordset(getpropertiesres: GetPropertiesRes) -> RecordSet: + """Construct a RecordSet from a GetPropertiesRes object.""" + recordset = RecordSet() + res_str = "getpropertiesres" + recordset.set_configs( + name=f"{res_str}.properties", + record=ConfigsRecord(getpropertiesres.properties), # type: ignore + ) + # status + recordset = _embed_status_into_recordset( + res_str, getpropertiesres.status, recordset + ) + + return recordset diff --git a/src/py/flwr/common/recordset_compat_test.py b/src/py/flwr/common/recordset_compat_test.py new file mode 100644 index 000000000000..ad91cd3a42fc --- /dev/null +++ b/src/py/flwr/common/recordset_compat_test.py @@ -0,0 +1,234 @@ +# Copyright 2024 Flower Labs GmbH. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""RecordSet from legacy messages tests.""" + +from copy import deepcopy +from typing import Dict + +import numpy as np + +from .parameter import ndarrays_to_parameters +from .recordset_compat import ( + evaluateins_to_recordset, + evaluateres_to_recordset, + fitins_to_recordset, + fitres_to_recordset, + getparametersins_to_recordset, + getparametersres_to_recordset, + getpropertiesins_to_recordset, + getpropertiesres_to_recordset, + recordset_to_evaluateins, + recordset_to_evaluateres, + recordset_to_fitins, + recordset_to_fitres, + recordset_to_getparametersins, + recordset_to_getparametersres, + recordset_to_getpropertiesins, + recordset_to_getpropertiesres, +) +from .typing import ( + Code, + EvaluateIns, + EvaluateRes, + FitIns, + FitRes, + GetParametersIns, + GetParametersRes, + GetPropertiesIns, + GetPropertiesRes, + NDArrays, + Scalar, + Status, +) + + +def get_ndarrays() -> NDArrays: + """Return list of NumPy arrays.""" + arr1 = np.array([[1.0, 2.0], [3.0, 4], [5.0, 6.0]]) + arr2 = np.eye(2, 7, 3) + + return [arr1, arr2] + + +################################################## +# Testing conversion: *Ins --> RecordSet --> *Ins +# Testing conversion: *Res <-- RecordSet <-- *Res +################################################## + + +def _get_valid_fitins() -> FitIns: + arrays = get_ndarrays() + return FitIns(parameters=ndarrays_to_parameters(arrays), config={"a": 1.0, "b": 0}) + + +def _get_valid_fitres() -> FitRes: + """Returnn Valid parameters but potentially invalid config.""" + arrays = get_ndarrays() + metrics: Dict[str, Scalar] = {"a": 1.0, "b": 0} + return FitRes( + parameters=ndarrays_to_parameters(arrays), + num_examples=1, + status=Status(code=Code(0), message=""), + metrics=metrics, + ) + + +def _get_valid_evaluateins() -> EvaluateIns: + fit_ins = _get_valid_fitins() + return EvaluateIns(parameters=fit_ins.parameters, config=fit_ins.config) + + +def _get_valid_evaluateres() -> EvaluateRes: + """Return potentially invalid config.""" + metrics: Dict[str, Scalar] = {"a": 1.0, "b": 0} + return EvaluateRes( + num_examples=1, + loss=0.1, + status=Status(code=Code(0), message=""), + metrics=metrics, + ) + + +def _get_valid_getparametersins() -> GetParametersIns: + config_dict: Dict[str, Scalar] = { + "a": 1.0, + "b": 3, + "c": True, + } # valid since both Ins/Res communicate over ConfigsRecord + + return GetParametersIns(config_dict) + + +def _get_valid_getparametersres() -> GetParametersRes: + arrays = get_ndarrays() + return GetParametersRes( + status=Status(code=Code(0), message=""), + parameters=ndarrays_to_parameters(arrays), + ) + + +def _get_valid_getpropertiesins() -> GetPropertiesIns: + getparamsins = _get_valid_getparametersins() + return GetPropertiesIns(config=getparamsins.config) + + +def _get_valid_getpropertiesres() -> GetPropertiesRes: + config_dict: Dict[str, Scalar] = { + "a": 1.0, + "b": 3, + "c": True, + } # valid since both Ins/Res communicate over ConfigsRecord + + return GetPropertiesRes( + status=Status(code=Code(0), message=""), properties=config_dict + ) + + +def test_fitins_to_recordset_and_back() -> None: + """Test conversion FitIns --> RecordSet --> FitIns.""" + fitins = _get_valid_fitins() + + fitins_copy = deepcopy(fitins) + + recordset = fitins_to_recordset(fitins, keep_input=False) + + fitins_ = recordset_to_fitins(recordset, keep_input=False) + + assert fitins_copy == fitins_ + + +def test_fitres_to_recordset_and_back() -> None: + """Test conversion FitRes --> RecordSet --> FitRes.""" + fitres = _get_valid_fitres() + + fitres_copy = deepcopy(fitres) + + recordset = fitres_to_recordset(fitres, keep_input=False) + fitres_ = recordset_to_fitres(recordset, keep_input=False) + + assert fitres_copy == fitres_ + + +def test_evaluateins_to_recordset_and_back() -> None: + """Test conversion EvaluateIns --> RecordSet --> EvaluateIns.""" + evaluateins = _get_valid_evaluateins() + + evaluateins_copy = deepcopy(evaluateins) + + recordset = evaluateins_to_recordset(evaluateins, keep_input=False) + + evaluateins_ = recordset_to_evaluateins(recordset, keep_input=False) + + assert evaluateins_copy == evaluateins_ + + +def test_evaluateres_to_recordset_and_back() -> None: + """Test conversion EvaluateRes --> RecordSet --> EvaluateRes.""" + evaluateres = _get_valid_evaluateres() + + evaluateres_copy = deepcopy(evaluateres) + + recordset = evaluateres_to_recordset(evaluateres) + evaluateres_ = recordset_to_evaluateres(recordset) + + assert evaluateres_copy == evaluateres_ + + +def test_get_properties_ins_to_recordset_and_back() -> None: + """Test conversion GetPropertiesIns --> RecordSet --> GetPropertiesIns.""" + getproperties_ins = _get_valid_getpropertiesins() + + getproperties_ins_copy = deepcopy(getproperties_ins) + + recordset = getpropertiesins_to_recordset(getproperties_ins) + getproperties_ins_ = recordset_to_getpropertiesins(recordset) + + assert getproperties_ins_copy == getproperties_ins_ + + +def test_get_properties_res_to_recordset_and_back() -> None: + """Test conversion GetPropertiesRes --> RecordSet --> GetPropertiesRes.""" + getproperties_res = _get_valid_getpropertiesres() + + getproperties_res_copy = deepcopy(getproperties_res) + + recordset = getpropertiesres_to_recordset(getproperties_res) + getproperties_res_ = recordset_to_getpropertiesres(recordset) + + assert getproperties_res_copy == getproperties_res_ + + +def test_get_parameters_ins_to_recordset_and_back() -> None: + """Test conversion GetParametersIns --> RecordSet --> GetParametersIns.""" + getparameters_ins = _get_valid_getparametersins() + + getparameters_ins_copy = deepcopy(getparameters_ins) + + recordset = getparametersins_to_recordset(getparameters_ins) + getparameters_ins_ = recordset_to_getparametersins(recordset) + + assert getparameters_ins_copy == getparameters_ins_ + + +def test_get_parameters_res_to_recordset_and_back() -> None: + """Test conversion GetParametersRes --> RecordSet --> GetParametersRes.""" + getparameteres_res = _get_valid_getparametersres() + + getparameters_res_copy = deepcopy(getparameteres_res) + + recordset = getparametersres_to_recordset(getparameteres_res) + getparameteres_res_ = recordset_to_getparametersres(recordset) + + assert getparameters_res_copy == getparameteres_res_ diff --git a/src/py/flwr/common/recordset_test.py b/src/py/flwr/common/recordset_test.py index 83e1e4595f1d..e1825eaeef14 100644 --- a/src/py/flwr/common/recordset_test.py +++ b/src/py/flwr/common/recordset_test.py @@ -14,7 +14,6 @@ # ============================================================================== """RecordSet tests.""" - from typing import Callable, Dict, List, OrderedDict, Type, Union import numpy as np @@ -24,7 +23,7 @@ from .metricsrecord import MetricsRecord from .parameter import ndarrays_to_parameters, parameters_to_ndarrays from .parametersrecord import Array, ParametersRecord -from .recordset_utils import ( +from .recordset_compat import ( parameters_to_parametersrecord, parametersrecord_to_parameters, ) diff --git a/src/py/flwr/common/recordset_utils.py b/src/py/flwr/common/recordset_utils.py deleted file mode 100644 index c1e724fa2758..000000000000 --- a/src/py/flwr/common/recordset_utils.py +++ /dev/null @@ -1,87 +0,0 @@ -# Copyright 2024 Flower Labs GmbH. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""RecordSet utilities.""" - - -from typing import OrderedDict - -from .parametersrecord import Array, ParametersRecord -from .typing import Parameters - - -def parametersrecord_to_parameters( - record: ParametersRecord, keep_input: bool = False -) -> Parameters: - """Convert ParameterRecord to legacy Parameters. - - Warning: Because `Arrays` in `ParametersRecord` encode more information of the - array-like or tensor-like data (e.g their datatype, shape) than `Parameters` it - might not be possible to reconstruct such data structures from `Parameters` objects - alone. Additional information or metadta must be provided from elsewhere. - - Parameters - ---------- - record : ParametersRecord - The record to be conveted into Parameters. - keep_input : bool (default: False) - A boolean indicating whether entries in the record should be deleted from the - input dictionary immediately after adding them to the record. - """ - parameters = Parameters(tensors=[], tensor_type="") - - for key in list(record.data.keys()): - parameters.tensors.append(record.data[key].data) - - if not keep_input: - del record.data[key] - - return parameters - - -def parameters_to_parametersrecord( - parameters: Parameters, keep_input: bool = False -) -> ParametersRecord: - """Convert legacy Parameters into a single ParametersRecord. - - Because there is no concept of names in the legacy Parameters, arbitrary keys will - be used when constructing the ParametersRecord. Similarly, the shape and data type - won't be recorded in the Array objects. - - Parameters - ---------- - parameters : Parameters - Parameters object to be represented as a ParametersRecord. - keep_input : bool (default: False) - A boolean indicating whether parameters should be deleted from the input - Parameters object (i.e. a list of serialized NumPy arrays) immediately after - adding them to the record. - """ - tensor_type = parameters.tensor_type - - p_record = ParametersRecord() - - num_arrays = len(parameters.tensors) - for idx in range(num_arrays): - if keep_input: - tensor = parameters.tensors[idx] - else: - tensor = parameters.tensors.pop(0) - p_record.set_parameters( - OrderedDict( - {str(idx): Array(data=tensor, dtype="", stype=tensor_type, shape=[])} - ) - ) - - return p_record diff --git a/src/py/flwr/common/serde.py b/src/py/flwr/common/serde.py index 2094c76a9856..2600d46edddc 100644 --- a/src/py/flwr/common/serde.py +++ b/src/py/flwr/common/serde.py @@ -28,6 +28,7 @@ from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord from flwr.proto.recordset_pb2 import MetricsRecordValue as ProtoMetricsRecordValue from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord +from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet from flwr.proto.recordset_pb2 import Sint64List, StringList from flwr.proto.task_pb2 import Value from flwr.proto.transport_pb2 import ( @@ -45,6 +46,7 @@ from .configsrecord import ConfigsRecord from .metricsrecord import MetricsRecord from .parametersrecord import Array, ParametersRecord +from .recordset import RecordSet # === ServerMessage message === @@ -719,3 +721,33 @@ def configs_record_from_proto(record_proto: ProtoConfigsRecord) -> ConfigsRecord ), keep_input=False, ) + + +# === RecordSet message === + + +def recordset_to_proto(recordset: RecordSet) -> ProtoRecordSet: + """Serialize RecordSet to ProtoBuf.""" + return ProtoRecordSet( + parameters={ + k: parameters_record_to_proto(v) for k, v in recordset.parameters.items() + }, + metrics={k: metrics_record_to_proto(v) for k, v in recordset.metrics.items()}, + configs={k: configs_record_to_proto(v) for k, v in recordset.configs.items()}, + ) + + +def recordset_from_proto(recordset_proto: ProtoRecordSet) -> RecordSet: + """Deserialize RecordSet from ProtoBuf.""" + return RecordSet( + parameters={ + k: parameters_record_from_proto(v) + for k, v in recordset_proto.parameters.items() + }, + metrics={ + k: metrics_record_from_proto(v) for k, v in recordset_proto.metrics.items() + }, + configs={ + k: configs_record_from_proto(v) for k, v in recordset_proto.configs.items() + }, + ) diff --git a/src/py/flwr/common/serde_test.py b/src/py/flwr/common/serde_test.py index c584597d89f6..53f40eee5e53 100644 --- a/src/py/flwr/common/serde_test.py +++ b/src/py/flwr/common/serde_test.py @@ -23,12 +23,14 @@ from flwr.proto.recordset_pb2 import ConfigsRecord as ProtoConfigsRecord from flwr.proto.recordset_pb2 import MetricsRecord as ProtoMetricsRecord from flwr.proto.recordset_pb2 import ParametersRecord as ProtoParametersRecord +from flwr.proto.recordset_pb2 import RecordSet as ProtoRecordSet # pylint: enable=E0611 from . import typing from .configsrecord import ConfigsRecord from .metricsrecord import MetricsRecord from .parametersrecord import Array, ParametersRecord +from .recordset import RecordSet from .serde import ( array_from_proto, array_to_proto, @@ -40,6 +42,8 @@ named_values_to_proto, parameters_record_from_proto, parameters_record_to_proto, + recordset_from_proto, + recordset_to_proto, scalar_from_proto, scalar_to_proto, status_from_proto, @@ -242,3 +246,64 @@ def test_configs_record_serialization_deserialization() -> None: # Assert assert isinstance(proto, ProtoConfigsRecord) assert original.data == deserialized.data + + +def test_recordset_serialization_deserialization() -> None: + """Test serialization and deserialization of RecordSet.""" + # Prepare + encoder_params_record = ParametersRecord( + array_dict=OrderedDict( + [ + ( + "k1", + Array(dtype="float", shape=[2, 2], stype="dense", data=b"1234"), + ), + ("k2", Array(dtype="int", shape=[3], stype="sparse", data=b"567")), + ] + ), + keep_input=False, + ) + decoder_params_record = ParametersRecord( + array_dict=OrderedDict( + [ + ( + "k1", + Array( + dtype="float", shape=[32, 32, 4], stype="dense", data=b"0987" + ), + ), + ] + ), + keep_input=False, + ) + + original = RecordSet( + parameters={ + "encoder_parameters": encoder_params_record, + "decoder_parameters": decoder_params_record, + }, + metrics={ + "acc_metrics": MetricsRecord( + metrics_dict={"accuracy": 0.95, "loss": 0.1}, keep_input=False + ) + }, + configs={ + "my_configs": ConfigsRecord( + configs_dict={ + "learning_rate": 0.01, + "batch_size": 32, + "public_key": b"21f8sioj@!#", + "log": "Hello, world!", + }, + keep_input=False, + ) + }, + ) + + # Execute + proto = recordset_to_proto(original) + deserialized = recordset_from_proto(proto) + + # Assert + assert isinstance(proto, ProtoRecordSet) + assert original == deserialized diff --git a/src/py/flwr/proto/recordset_pb2.py b/src/py/flwr/proto/recordset_pb2.py index 4134511f1f53..f7f74d72182b 100644 --- a/src/py/flwr/proto/recordset_pb2.py +++ b/src/py/flwr/proto/recordset_pb2.py @@ -14,7 +14,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lwr/proto/recordset.proto\x12\nflwr.proto\"\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\"\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\"\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\"\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\"\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\"B\n\x05\x41rray\x12\r\n\x05\x64type\x18\x01 \x01(\t\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05stype\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\"\x9f\x01\n\x12MetricsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x42\x07\n\x05value\"\xd9\x02\n\x12\x43onfigsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"M\n\x10ParametersRecord\x12\x11\n\tdata_keys\x18\x01 \x03(\t\x12&\n\x0b\x64\x61ta_values\x18\x02 \x03(\x0b\x32\x11.flwr.proto.Array\"\x8f\x01\n\rMetricsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.MetricsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.MetricsRecordValue:\x02\x38\x01\"\x8f\x01\n\rConfigsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.ConfigsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.ConfigsRecordValue:\x02\x38\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1a\x66lwr/proto/recordset.proto\x12\nflwr.proto\"\x1a\n\nDoubleList\x12\x0c\n\x04vals\x18\x01 \x03(\x01\"\x1a\n\nSint64List\x12\x0c\n\x04vals\x18\x01 \x03(\x12\"\x18\n\x08\x42oolList\x12\x0c\n\x04vals\x18\x01 \x03(\x08\"\x1a\n\nStringList\x12\x0c\n\x04vals\x18\x01 \x03(\t\"\x19\n\tBytesList\x12\x0c\n\x04vals\x18\x01 \x03(\x0c\"B\n\x05\x41rray\x12\r\n\x05\x64type\x18\x01 \x01(\t\x12\r\n\x05shape\x18\x02 \x03(\x05\x12\r\n\x05stype\x18\x03 \x01(\t\x12\x0c\n\x04\x64\x61ta\x18\x04 \x01(\x0c\"\x9f\x01\n\x12MetricsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x42\x07\n\x05value\"\xd9\x02\n\x12\x43onfigsRecordValue\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"M\n\x10ParametersRecord\x12\x11\n\tdata_keys\x18\x01 \x03(\t\x12&\n\x0b\x64\x61ta_values\x18\x02 \x03(\x0b\x32\x11.flwr.proto.Array\"\x8f\x01\n\rMetricsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.MetricsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.MetricsRecordValue:\x02\x38\x01\"\x8f\x01\n\rConfigsRecord\x12\x31\n\x04\x64\x61ta\x18\x01 \x03(\x0b\x32#.flwr.proto.ConfigsRecord.DataEntry\x1aK\n\tDataEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12-\n\x05value\x18\x02 \x01(\x0b\x32\x1e.flwr.proto.ConfigsRecordValue:\x02\x38\x01\"\x97\x03\n\tRecordSet\x12\x39\n\nparameters\x18\x01 \x03(\x0b\x32%.flwr.proto.RecordSet.ParametersEntry\x12\x33\n\x07metrics\x18\x02 \x03(\x0b\x32\".flwr.proto.RecordSet.MetricsEntry\x12\x33\n\x07\x63onfigs\x18\x03 \x03(\x0b\x32\".flwr.proto.RecordSet.ConfigsEntry\x1aO\n\x0fParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12+\n\x05value\x18\x02 \x01(\x0b\x32\x1c.flwr.proto.ParametersRecord:\x02\x38\x01\x1aI\n\x0cMetricsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.flwr.proto.MetricsRecord:\x02\x38\x01\x1aI\n\x0c\x43onfigsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12(\n\x05value\x18\x02 \x01(\x0b\x32\x19.flwr.proto.ConfigsRecord:\x02\x38\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -25,6 +25,12 @@ _globals['_METRICSRECORD_DATAENTRY']._serialized_options = b'8\001' _globals['_CONFIGSRECORD_DATAENTRY']._options = None _globals['_CONFIGSRECORD_DATAENTRY']._serialized_options = b'8\001' + _globals['_RECORDSET_PARAMETERSENTRY']._options = None + _globals['_RECORDSET_PARAMETERSENTRY']._serialized_options = b'8\001' + _globals['_RECORDSET_METRICSENTRY']._options = None + _globals['_RECORDSET_METRICSENTRY']._serialized_options = b'8\001' + _globals['_RECORDSET_CONFIGSENTRY']._options = None + _globals['_RECORDSET_CONFIGSENTRY']._serialized_options = b'8\001' _globals['_DOUBLELIST']._serialized_start=42 _globals['_DOUBLELIST']._serialized_end=68 _globals['_SINT64LIST']._serialized_start=70 @@ -51,4 +57,12 @@ _globals['_CONFIGSRECORD']._serialized_end=1126 _globals['_CONFIGSRECORD_DATAENTRY']._serialized_start=1051 _globals['_CONFIGSRECORD_DATAENTRY']._serialized_end=1126 + _globals['_RECORDSET']._serialized_start=1129 + _globals['_RECORDSET']._serialized_end=1536 + _globals['_RECORDSET_PARAMETERSENTRY']._serialized_start=1307 + _globals['_RECORDSET_PARAMETERSENTRY']._serialized_end=1386 + _globals['_RECORDSET_METRICSENTRY']._serialized_start=1388 + _globals['_RECORDSET_METRICSENTRY']._serialized_end=1461 + _globals['_RECORDSET_CONFIGSENTRY']._serialized_start=1463 + _globals['_RECORDSET_CONFIGSENTRY']._serialized_end=1536 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/recordset_pb2.pyi b/src/py/flwr/proto/recordset_pb2.pyi index 1e9556de9ce6..86244697129c 100644 --- a/src/py/flwr/proto/recordset_pb2.pyi +++ b/src/py/flwr/proto/recordset_pb2.pyi @@ -238,3 +238,68 @@ class ConfigsRecord(google.protobuf.message.Message): ) -> None: ... def ClearField(self, field_name: typing_extensions.Literal["data",b"data"]) -> None: ... global___ConfigsRecord = ConfigsRecord + +class RecordSet(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + class ParametersEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> global___ParametersRecord: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[global___ParametersRecord] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + class MetricsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> global___MetricsRecord: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[global___MetricsRecord] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + class ConfigsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: typing.Text + @property + def value(self) -> global___ConfigsRecord: ... + def __init__(self, + *, + key: typing.Text = ..., + value: typing.Optional[global___ConfigsRecord] = ..., + ) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["value",b"value"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["key",b"key","value",b"value"]) -> None: ... + + PARAMETERS_FIELD_NUMBER: builtins.int + METRICS_FIELD_NUMBER: builtins.int + CONFIGS_FIELD_NUMBER: builtins.int + @property + def parameters(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___ParametersRecord]: ... + @property + def metrics(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___MetricsRecord]: ... + @property + def configs(self) -> google.protobuf.internal.containers.MessageMap[typing.Text, global___ConfigsRecord]: ... + def __init__(self, + *, + parameters: typing.Optional[typing.Mapping[typing.Text, global___ParametersRecord]] = ..., + metrics: typing.Optional[typing.Mapping[typing.Text, global___MetricsRecord]] = ..., + configs: typing.Optional[typing.Mapping[typing.Text, global___ConfigsRecord]] = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["configs",b"configs","metrics",b"metrics","parameters",b"parameters"]) -> None: ... +global___RecordSet = RecordSet diff --git a/src/py/flwr/proto/task_pb2.py b/src/py/flwr/proto/task_pb2.py index 963b07db94f8..f9b2180b15dd 100644 --- a/src/py/flwr/proto/task_pb2.py +++ b/src/py/flwr/proto/task_pb2.py @@ -17,7 +17,7 @@ from flwr.proto import transport_pb2 as flwr_dot_proto_dot_transport__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xd1\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12)\n\x02sa\x18\x08 \x01(\x0b\x32\x1d.flwr.proto.SecureAggregation\x12<\n\x15legacy_server_message\x18\x65 \x01(\x0b\x32\x19.flwr.proto.ServerMessageB\x02\x18\x01\x12<\n\x15legacy_client_message\x18\x66 \x01(\x0b\x32\x19.flwr.proto.ClientMessageB\x02\x18\x01\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\xcc\x02\n\x05Value\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"\xa0\x01\n\x11SecureAggregation\x12\x44\n\x0cnamed_values\x18\x01 \x03(\x0b\x32..flwr.proto.SecureAggregation.NamedValuesEntry\x1a\x45\n\x10NamedValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.flwr.proto.Value:\x02\x38\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x15\x66lwr/proto/task.proto\x12\nflwr.proto\x1a\x15\x66lwr/proto/node.proto\x1a\x1a\x66lwr/proto/recordset.proto\x1a\x1a\x66lwr/proto/transport.proto\"\xff\x02\n\x04Task\x12\"\n\x08producer\x18\x01 \x01(\x0b\x32\x10.flwr.proto.Node\x12\"\n\x08\x63onsumer\x18\x02 \x01(\x0b\x32\x10.flwr.proto.Node\x12\x12\n\ncreated_at\x18\x03 \x01(\t\x12\x14\n\x0c\x64\x65livered_at\x18\x04 \x01(\t\x12\x0b\n\x03ttl\x18\x05 \x01(\t\x12\x10\n\x08\x61ncestry\x18\x06 \x03(\t\x12\x11\n\ttask_type\x18\x07 \x01(\t\x12(\n\trecordset\x18\x08 \x01(\x0b\x32\x15.flwr.proto.RecordSet\x12<\n\x15legacy_server_message\x18\x65 \x01(\x0b\x32\x19.flwr.proto.ServerMessageB\x02\x18\x01\x12<\n\x15legacy_client_message\x18\x66 \x01(\x0b\x32\x19.flwr.proto.ClientMessageB\x02\x18\x01\x12-\n\x02sa\x18g \x01(\x0b\x32\x1d.flwr.proto.SecureAggregationB\x02\x18\x01\"\\\n\x07TaskIns\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\\\n\x07TaskRes\x12\x0f\n\x07task_id\x18\x01 \x01(\t\x12\x10\n\x08group_id\x18\x02 \x01(\t\x12\x0e\n\x06run_id\x18\x03 \x01(\x12\x12\x1e\n\x04task\x18\x04 \x01(\x0b\x32\x10.flwr.proto.Task\"\xcc\x02\n\x05Value\x12\x10\n\x06\x64ouble\x18\x01 \x01(\x01H\x00\x12\x10\n\x06sint64\x18\x02 \x01(\x12H\x00\x12\x0e\n\x04\x62ool\x18\x03 \x01(\x08H\x00\x12\x10\n\x06string\x18\x04 \x01(\tH\x00\x12\x0f\n\x05\x62ytes\x18\x05 \x01(\x0cH\x00\x12-\n\x0b\x64ouble_list\x18\x15 \x01(\x0b\x32\x16.flwr.proto.DoubleListH\x00\x12-\n\x0bsint64_list\x18\x16 \x01(\x0b\x32\x16.flwr.proto.Sint64ListH\x00\x12)\n\tbool_list\x18\x17 \x01(\x0b\x32\x14.flwr.proto.BoolListH\x00\x12-\n\x0bstring_list\x18\x18 \x01(\x0b\x32\x16.flwr.proto.StringListH\x00\x12+\n\nbytes_list\x18\x19 \x01(\x0b\x32\x15.flwr.proto.BytesListH\x00\x42\x07\n\x05value\"\xa0\x01\n\x11SecureAggregation\x12\x44\n\x0cnamed_values\x18\x01 \x03(\x0b\x32..flwr.proto.SecureAggregation.NamedValuesEntry\x1a\x45\n\x10NamedValuesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12 \n\x05value\x18\x02 \x01(\x0b\x32\x11.flwr.proto.Value:\x02\x38\x01\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -28,18 +28,20 @@ _globals['_TASK'].fields_by_name['legacy_server_message']._serialized_options = b'\030\001' _globals['_TASK'].fields_by_name['legacy_client_message']._options = None _globals['_TASK'].fields_by_name['legacy_client_message']._serialized_options = b'\030\001' + _globals['_TASK'].fields_by_name['sa']._options = None + _globals['_TASK'].fields_by_name['sa']._serialized_options = b'\030\001' _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._options = None _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_options = b'8\001' _globals['_TASK']._serialized_start=117 - _globals['_TASK']._serialized_end=454 - _globals['_TASKINS']._serialized_start=456 - _globals['_TASKINS']._serialized_end=548 - _globals['_TASKRES']._serialized_start=550 - _globals['_TASKRES']._serialized_end=642 - _globals['_VALUE']._serialized_start=645 - _globals['_VALUE']._serialized_end=977 - _globals['_SECUREAGGREGATION']._serialized_start=980 - _globals['_SECUREAGGREGATION']._serialized_end=1140 - _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_start=1071 - _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_end=1140 + _globals['_TASK']._serialized_end=500 + _globals['_TASKINS']._serialized_start=502 + _globals['_TASKINS']._serialized_end=594 + _globals['_TASKRES']._serialized_start=596 + _globals['_TASKRES']._serialized_end=688 + _globals['_VALUE']._serialized_start=691 + _globals['_VALUE']._serialized_end=1023 + _globals['_SECUREAGGREGATION']._serialized_start=1026 + _globals['_SECUREAGGREGATION']._serialized_end=1186 + _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_start=1117 + _globals['_SECUREAGGREGATION_NAMEDVALUESENTRY']._serialized_end=1186 # @@protoc_insertion_point(module_scope) diff --git a/src/py/flwr/proto/task_pb2.pyi b/src/py/flwr/proto/task_pb2.pyi index ebe69d05c974..39119797c9e4 100644 --- a/src/py/flwr/proto/task_pb2.pyi +++ b/src/py/flwr/proto/task_pb2.pyi @@ -23,9 +23,10 @@ class Task(google.protobuf.message.Message): TTL_FIELD_NUMBER: builtins.int ANCESTRY_FIELD_NUMBER: builtins.int TASK_TYPE_FIELD_NUMBER: builtins.int - SA_FIELD_NUMBER: builtins.int + RECORDSET_FIELD_NUMBER: builtins.int LEGACY_SERVER_MESSAGE_FIELD_NUMBER: builtins.int LEGACY_CLIENT_MESSAGE_FIELD_NUMBER: builtins.int + SA_FIELD_NUMBER: builtins.int @property def producer(self) -> flwr.proto.node_pb2.Node: ... @property @@ -37,11 +38,13 @@ class Task(google.protobuf.message.Message): def ancestry(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[typing.Text]: ... task_type: typing.Text @property - def sa(self) -> global___SecureAggregation: ... + def recordset(self) -> flwr.proto.recordset_pb2.RecordSet: ... @property def legacy_server_message(self) -> flwr.proto.transport_pb2.ServerMessage: ... @property def legacy_client_message(self) -> flwr.proto.transport_pb2.ClientMessage: ... + @property + def sa(self) -> global___SecureAggregation: ... def __init__(self, *, producer: typing.Optional[flwr.proto.node_pb2.Node] = ..., @@ -51,12 +54,13 @@ class Task(google.protobuf.message.Message): ttl: typing.Text = ..., ancestry: typing.Optional[typing.Iterable[typing.Text]] = ..., task_type: typing.Text = ..., - sa: typing.Optional[global___SecureAggregation] = ..., + recordset: typing.Optional[flwr.proto.recordset_pb2.RecordSet] = ..., legacy_server_message: typing.Optional[flwr.proto.transport_pb2.ServerMessage] = ..., legacy_client_message: typing.Optional[flwr.proto.transport_pb2.ClientMessage] = ..., + sa: typing.Optional[global___SecureAggregation] = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","legacy_client_message",b"legacy_client_message","legacy_server_message",b"legacy_server_message","producer",b"producer","sa",b"sa"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","legacy_client_message",b"legacy_client_message","legacy_server_message",b"legacy_server_message","producer",b"producer","sa",b"sa","task_type",b"task_type","ttl",b"ttl"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["consumer",b"consumer","legacy_client_message",b"legacy_client_message","legacy_server_message",b"legacy_server_message","producer",b"producer","recordset",b"recordset","sa",b"sa"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["ancestry",b"ancestry","consumer",b"consumer","created_at",b"created_at","delivered_at",b"delivered_at","legacy_client_message",b"legacy_client_message","legacy_server_message",b"legacy_server_message","producer",b"producer","recordset",b"recordset","sa",b"sa","task_type",b"task_type","ttl",b"ttl"]) -> None: ... global___Task = Task class TaskIns(google.protobuf.message.Message):